In [None]:
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision.models as models
import numpy as np
import math
from scipy import linalg
import matplotlib.pyplot as plt

In [None]:
cudnn.benchmark = True

In [None]:
#set manual seed to a constant get a consistent output
manualSeed =999
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed:  999


<torch._C.Generator at 0x79e566ea5210>

In [None]:
#loading the dataset
# dataset = dset.CIFAR10(root="./data", download=True,
#                            transform=transforms.Compose([
#                                transforms.Resize(64),
#                                transforms.ToTensor(),
#                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#                            ]))
dataset = dset.FashionMNIST(root="./data", download=True,
                           transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5, )),
                           ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 16854826.89it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 265708.42it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 5000026.99it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 6420540.29it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [None]:
nc=1

#checking the availability of cuda devices
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# number of gpu's available
ngpu = 1
# input noise dimension
nz = 100
# number of generator filters
ngf = 64

#number of discriminator filters
ndf = 64


In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 and classname.find('Prune') != -1:
        m.conv.weight.data.normal_(0.0, 0.02)
    elif classname.find('Linear') != -1 and classname.find('Prune') != -1:
        m.linear.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


In [None]:
class PruneLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(PruneLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features)
        self.mask = np.ones([self.out_features, self.in_features])
        m = self.in_features
        n = self.out_features
        self.sparsity = 1.0
        # Initailization
        self.linear.weight.data.normal_(0, math.sqrt(2. / (m+n)))

    def forward(self, x):
        out = self.linear(x)
        return out
        pass

    def prune_by_percentage(self, q=5.0):
        """
        Pruning the weight paramters by threshold.
        :param q: pruning percentile. 'q' percent of the least
        significant weight parameters will be pruned.
        """
        """
        Prune the weight connections by percentage. Calculate the sparisty after
        pruning and store it into 'self.sparsity'.
        Store the pruning pattern in 'self.mask' for further fine-tuning process
        with pruned connections.
        --------------Your Code---------------------
        """

        current_weights = self.linear.weight.data.view(-1).cpu().numpy()
        threshold = np.percentile(np.abs(current_weights), q)

        self.mask = np.abs(current_weights) >= threshold

        mask_as_tensor = torch.from_numpy(self.mask).float().to(self.linear.weight.device)
        mask_as_tensor = mask_as_tensor.view_as(self.linear.weight.data)

        self.linear.weight.data *= mask_as_tensor

        self.sparsity = 1 - (np.sum(self.mask) / len(current_weights))


    def prune_by_std(self, s=0.25):
        """
        Pruning by a factor of the standard deviation value.
        :param std: (scalar) factor of the standard deviation value.
        Weight magnitude below np.std(weight)*std
        will be pruned.
        """

        """
        Prune the weight connections by standarad deviation.
        Calculate the sparisty after pruning and store it into 'self.sparsity'.
        Store the pruning pattern in 'self.mask' for further fine-tuning process
        with pruned connections.
        --------------Your Code---------------------
        """

        current_weights = self.linear.weight.data.view(-1).cpu().numpy()
        threshold = np.std(current_weights) * s

        self.mask = np.abs(current_weights) >= threshold

        mask_as_tensor = torch.from_numpy(self.mask).float().to(self.linear.weight.device)
        mask_as_tensor = mask_as_tensor.view_as(self.linear.weight.data)

        self.linear.weight.data *= mask_as_tensor

        self.sparsity = 1 - (np.sum(self.mask) / len(current_weights))



class PrunedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(PrunedConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)

        # Expand and Transpose to match the dimension
        self.mask = np.ones_like([out_channels, in_channels, kernel_size, kernel_size])

        # Initialization
        n = self.kernel_size * self.kernel_size * self.out_channels
        m = self.kernel_size * self.kernel_size * self.in_channels
        self.conv.weight.data.normal_(0, math.sqrt(2. / (n+m) ))
        self.sparsity = 1.0

    def forward(self, x):
        out = self.conv(x)
        return out

    def prune_by_percentage(self, q=5.0):
        """
        Pruning by a factor of the standard deviation value.
        :param s: (scalar) factor of the standard deviation value.
        Weight magnitude below np.std(weight)*std
        will be pruned.
        """

        """
        Prune the weight connections by percentage. Calculate the sparisty after
        pruning and store it into 'self.sparsity'.
        Store the pruning pattern in 'self.mask' for further fine-tuning process
        with pruned connections.
        --------------Your Code---------------------
        """

        current_weights = self.conv.weight.data.view(-1).cpu().numpy()
        threshold = np.percentile(np.abs(current_weights), q)

        self.mask = np.abs(current_weights) >= threshold

        mask_as_tensor = torch.from_numpy(self.mask).float().to(self.conv.weight.device)
        mask_as_tensor = mask_as_tensor.view_as(self.conv.weight.data)

        self.conv.weight.data *= mask_as_tensor

        self.sparsity = 1 - (np.sum(self.mask) / len(current_weights))

    def prune_by_std(self, s=0.25):
        """
        Pruning by a factor of the standard deviation value.
        :param s: (scalar) factor of the standard deviation value.
        Weight magnitude below np.std(weight)*std
        will be pruned.
        """

        """
        Prune the weight connections by standarad deviation.
        Calculate the sparisty after pruning and store it into 'self.sparsity'.
        Store the pruning pattern in 'self.mask' for further fine-tuning process
        with pruned connections.
        --------------Your Code---------------------
        """

        current_weights = self.conv.weight.data.view(-1).cpu().numpy()
        threshold = np.std(current_weights) * s

        self.mask = np.abs(current_weights) >= threshold

        mask_as_tensor = torch.from_numpy(self.mask).float().to(self.conv.weight.device)
        mask_as_tensor = mask_as_tensor.view_as(self.conv.weight.data)

        self.conv.weight.data *= mask_as_tensor

        self.sparsity = 1 - (np.sum(self.mask) / len(current_weights))

class PrunedConvTrans(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(PrunedConvTrans, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)

        # Expand and Transpose to match the dimension
        self.mask = np.ones_like([out_channels, in_channels, kernel_size, kernel_size])

        # Initialization
        n = self.kernel_size * self.kernel_size * self.out_channels
        m = self.kernel_size * self.kernel_size * self.in_channels
        self.conv.weight.data.normal_(0, math.sqrt(2. / (n+m) ))
        self.sparsity = 1.0

    def forward(self, x):
        out = self.conv(x)
        return out

    def prune_by_percentage(self, q=5.0):
        """
        Pruning by a factor of the standard deviation value.
        :param s: (scalar) factor of the standard deviation value.
        Weight magnitude below np.std(weight)*std
        will be pruned.
        """

        """
        Prune the weight connections by percentage. Calculate the sparisty after
        pruning and store it into 'self.sparsity'.
        Store the pruning pattern in 'self.mask' for further fine-tuning process
        with pruned connections.
        --------------Your Code---------------------
        """

        current_weights = self.conv.weight.data.view(-1).cpu().numpy()
        threshold = np.percentile(np.abs(current_weights), q)

        self.mask = np.abs(current_weights) >= threshold

        mask_as_tensor = torch.from_numpy(self.mask).float().to(self.conv.weight.device)
        mask_as_tensor = mask_as_tensor.view_as(self.conv.weight.data)

        self.conv.weight.data *= mask_as_tensor

        self.sparsity = 1 - (np.sum(self.mask) / len(current_weights))

    def prune_by_std(self, s=0.25):
        """
        Pruning by a factor of the standard deviation value.
        :param s: (scalar) factor of the standard deviation value.
        Weight magnitude below np.std(weight)*std
        will be pruned.
        """

        """
        Prune the weight connections by standarad deviation.
        Calculate the sparisty after pruning and store it into 'self.sparsity'.
        Store the pruning pattern in 'self.mask' for further fine-tuning process
        with pruned connections.
        --------------Your Code---------------------
        """

        current_weights = self.conv.weight.data.view(-1).cpu().numpy()
        threshold = np.std(current_weights) * s

        self.mask = np.abs(current_weights) >= threshold

        mask_as_tensor = torch.from_numpy(self.mask).float().to(self.conv.weight.device)
        mask_as_tensor = mask_as_tensor.view_as(self.conv.weight.data)

        self.conv.weight.data *= mask_as_tensor

        self.sparsity = 1 - (np.sum(self.mask) / len(current_weights))


In [None]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            PrunedConvTrans(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            PrunedConvTrans(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            PrunedConvTrans(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            PrunedConvTrans(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            PrunedConvTrans(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
            return output

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            PrunedConv(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            PrunedConv(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            PrunedConv(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            PrunedConv(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            PrunedConv(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)


In [None]:
netG = Generator(ngpu).to(device)
netG.apply(weights_init)
#load weights to test the model
#netG.load_state_dict(torch.load('weights/netG_epoch_24.pth'))
print(netG)


Generator(
  (main): Sequential(
    (0): PrunedConvTrans(
      (conv): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    )
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): PrunedConvTrans(
      (conv): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): PrunedConvTrans(
      (conv): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): PrunedConvTrans(
      (conv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


In [None]:
netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)
#load weights to test the model
#netD.load_state_dict(torch.load('weights/netD_epoch_24.pth'))
print(netD)

Discriminator(
  (main): Sequential(
    (0): PrunedConv(
      (conv): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): PrunedConv(
      (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): PrunedConv(
      (conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): PrunedConv(
      (conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)


In [None]:
criterion = nn.BCELoss()

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0003, betas=(0.5, 0.999))

fixed_noise = torch.randn(128, nz, 1, 1, device=device)
real_label = 1
fake_label = 0


In [None]:
niter = 15
g_loss = []
d_loss = []

for epoch in range(niter):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)

        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, niter, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        #save the output
        if i % 100 == 0:
            print('saving the output')
            vutils.save_image(real_cpu,'output/real_samples.png',normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),'output/fake_samples_epoch_%03d.png' % (epoch),normalize=True)

    # Check pointing for every epoch
    torch.save(netG.state_dict(), 'weights/netG_epoch_%d.pth' % (epoch))
    torch.save(netD.state_dict(), 'weights/netD_epoch_%d.pth' % (epoch))

[0/15][0/469] Loss_D: 3.9052 Loss_G: 0.7918 D(x): 0.9681 D(G(z)): 0.9663 / 0.5112
saving the output
[0/15][1/469] Loss_D: 3.6212 Loss_G: 1.5575 D(x): 0.9673 D(G(z)): 0.9545 / 0.2752
[0/15][2/469] Loss_D: 2.8868 Loss_G: 2.6249 D(x): 0.9415 D(G(z)): 0.9113 / 0.1060
[0/15][3/469] Loss_D: 1.9875 Loss_G: 3.8496 D(x): 0.8975 D(G(z)): 0.7787 / 0.0400
[0/15][4/469] Loss_D: 1.5935 Loss_G: 4.4465 D(x): 0.8262 D(G(z)): 0.6658 / 0.0208
[0/15][5/469] Loss_D: 1.7009 Loss_G: 4.6070 D(x): 0.7487 D(G(z)): 0.6774 / 0.0149
[0/15][6/469] Loss_D: 1.7460 Loss_G: 4.9374 D(x): 0.7873 D(G(z)): 0.6930 / 0.0122
[0/15][7/469] Loss_D: 1.6108 Loss_G: 5.5265 D(x): 0.7890 D(G(z)): 0.6674 / 0.0066
[0/15][8/469] Loss_D: 1.8009 Loss_G: 5.8860 D(x): 0.7778 D(G(z)): 0.7090 / 0.0042
[0/15][9/469] Loss_D: 1.4667 Loss_G: 6.6301 D(x): 0.7948 D(G(z)): 0.6427 / 0.0021
[0/15][10/469] Loss_D: 1.2307 Loss_G: 6.9334 D(x): 0.8093 D(G(z)): 0.5594 / 0.0016
[0/15][11/469] Loss_D: 1.0704 Loss_G: 6.8640 D(x): 0.7551 D(G(z)): 0.4627 / 0.0

KeyboardInterrupt: ignored

In [None]:
# !zip -r /content/outputs_fashmnist.zip /content/output

In [None]:
def summary(net):
    assert isinstance(net, nn.Module)
    print("Layer id\tType\t\tParameter\tNon-zero parameter\tSparsity(\%)")
    layer_id = 0
    num_total_params = 0
    num_total_nonzero_params = 0
    for n, m in net.named_modules():
        if isinstance(m, PruneLinear):
            weight = m.linear.weight.data.cpu().numpy()
            weight = weight.flatten()
            num_parameters = weight.shape[0]
            num_nonzero_parameters = (weight != 0).sum()
            sparisty = 1 - num_nonzero_parameters / num_parameters
            layer_id += 1
            print("%d\t\tLinear\t\t%d\t\t%d\t\t\t%f" %(layer_id, num_parameters, num_nonzero_parameters, sparisty))
            num_total_params += num_parameters
            num_total_nonzero_params += num_nonzero_parameters
        elif isinstance(m, PrunedConv) or isinstance(m, PrunedConvTrans):
            weight = m.conv.weight.data.cpu().numpy()
            weight = weight.flatten()
            num_parameters = weight.shape[0]
            print(num_parameters)
            num_nonzero_parameters = (weight != 0).sum()
            sparisty = 1 - num_nonzero_parameters / num_parameters
            layer_id += 1
            print("%d\t\tConvolutional\t%d\t\t%d\t\t\t%f" % (layer_id, num_parameters, num_nonzero_parameters, sparisty))
            num_total_params += num_parameters
            num_total_nonzero_params += num_nonzero_parameters
        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
            layer_id += 1
            print("%d\t\tBatchNorm\tN/A\t\tN/A\t\t\tN/A" % (layer_id))
        elif isinstance(m, nn.ReLU):
            layer_id += 1
            print("%d\t\tReLU\t\tN/A\t\tN/A\t\t\tN/A" % (layer_id))

    print("Total nonzero parameters: %d" %num_total_nonzero_params)
    print("Total parameters: %d" %num_total_params)
    total_sparisty = 1. - num_total_nonzero_params / num_total_params
    print("Total sparsity: %f" %total_sparisty)

    return total_sparisty
#####


In [None]:
def prune(net, method='std', q=5.0, s=0.25):
    # Before training started, please generate the mask
    assert isinstance(net, nn.Module)
    for n, m in net.named_modules():
        if isinstance(m, PrunedConv) or isinstance(m, PruneLinear) or isinstance(m, PrunedConvTrans):
            if method == 'percentage':
                m.prune_by_percentage(q)
            elif method == 'std':
                m.prune_by_std(s)


In [None]:
netD = Discriminator(ngpu).to(device)
netD.load_state_dict(torch.load("netD_epoch_14.pth"))

netG = Generator(ngpu).to(device)
netG.load_state_dict(torch.load("netG_epoch_14.pth"))

<All keys matched successfully>

In [None]:
test_noise = torch.randn(256, nz, 1, 1, device=device)
fake_images = netG(test_noise)

In [None]:
vutils.save_image(fake_images.detach(),'fake_samples_before_pruning.png', normalize=True)

In [None]:
summary(netG)

Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
819200
1		Convolutional	819200		819200			0.000000
2		BatchNorm	N/A		N/A			N/A
3		ReLU		N/A		N/A			N/A
2097152
4		Convolutional	2097152		2097152			0.000000
5		BatchNorm	N/A		N/A			N/A
6		ReLU		N/A		N/A			N/A
524288
7		Convolutional	524288		524288			0.000000
8		BatchNorm	N/A		N/A			N/A
9		ReLU		N/A		N/A			N/A
131072
10		Convolutional	131072		131072			0.000000
11		BatchNorm	N/A		N/A			N/A
12		ReLU		N/A		N/A			N/A
1024
13		Convolutional	1024		1024			0.000000
Total nonzero parameters: 3572736
Total parameters: 3572736
Total sparsity: 0.000000


0.0

In [None]:
summary(netD)

Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
1024
1		Convolutional	1024		1024			0.000000
131072
2		Convolutional	131072		131072			0.000000
3		BatchNorm	N/A		N/A			N/A
524288
4		Convolutional	524288		524288			0.000000
5		BatchNorm	N/A		N/A			N/A
2097152
6		Convolutional	2097152		2097152			0.000000
7		BatchNorm	N/A		N/A			N/A
8192
8		Convolutional	8192		8192			0.000000
Total nonzero parameters: 2761728
Total parameters: 2761728
Total sparsity: 0.000000


0.0

In [None]:
prune(netD, method='percentage', q=50)
prune(netG, method='percentage', q=50)

In [None]:
summary(netG)

Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
819200
1		Convolutional	819200		409600			0.500000
2		BatchNorm	N/A		N/A			N/A
3		ReLU		N/A		N/A			N/A
2097152
4		Convolutional	2097152		1048576			0.500000
5		BatchNorm	N/A		N/A			N/A
6		ReLU		N/A		N/A			N/A
524288
7		Convolutional	524288		262144			0.500000
8		BatchNorm	N/A		N/A			N/A
9		ReLU		N/A		N/A			N/A
131072
10		Convolutional	131072		65536			0.500000
11		BatchNorm	N/A		N/A			N/A
12		ReLU		N/A		N/A			N/A
1024
13		Convolutional	1024		512			0.500000
Total nonzero parameters: 1786368
Total parameters: 3572736
Total sparsity: 0.500000


0.5

In [None]:
summary(netD)

Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
1024
1		Convolutional	1024		512			0.500000
131072
2		Convolutional	131072		65536			0.500000
3		BatchNorm	N/A		N/A			N/A
524288
4		Convolutional	524288		262144			0.500000
5		BatchNorm	N/A		N/A			N/A
2097152
6		Convolutional	2097152		1048576			0.500000
7		BatchNorm	N/A		N/A			N/A
8192
8		Convolutional	8192		4096			0.500000
Total nonzero parameters: 1380864
Total parameters: 2761728
Total sparsity: 0.500000


0.5

In [None]:
# test_noise = torch.randn(256, nz, 1, 1, device=device)
pruned_fake_images = netG(test_noise)

In [None]:
vutils.save_image(pruned_fake_images.detach(),'fake_samples_after_pruning.png', normalize=True)

In [None]:
!pip install sewar

Collecting sewar
  Downloading sewar-0.4.6.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: sewar
  Building wheel for sewar (setup.py) ... [?25l[?25hdone
  Created wheel for sewar: filename=sewar-0.4.6-py3-none-any.whl size=11420 sha256=9a833b64fad34c3c7440e0c3e8a8193efbcdb5a87136e8feb78b8dd85bca95bc
  Stored in directory: /root/.cache/pip/wheels/3f/af/02/9c6556ba287b62a945d737def09b8b8c35c9e1d82b9dfae84c
Successfully built sewar
Installing collected packages: sewar
Successfully installed sewar-0.4.6


In [None]:
from sewar.full_ref import ssim, msssim, uqi

In [None]:
uqi_list = []
ssim_list = []
msssim_list = []

for idx in range(256):
    fake_img = fake_images[idx].detach().cpu().numpy().squeeze()
    fake_img = (fake_img + 1) * 255 / 2
    fake_img = fake_img.astype(int)

    fake_img_pruned = pruned_fake_images[idx].detach().cpu().numpy().squeeze()
    fake_img_pruned = (fake_img_pruned + 1) * 255 / 2
    fake_img_pruned = fake_img_pruned.astype(int)

    uqi_list.append(uqi(fake_img_pruned, fake_img))
    ssim_list.append(msssim(fake_img_pruned, fake_img))
    msssim_list.append(ssim(fake_img_pruned, fake_img))

In [None]:
np.mean(uqi_list), np.mean(ssim_list), np.mean(msssim_list)

(0.9220898153613754, (1+0j), 1.0)

In [None]:
import pandas as pd

In [None]:
stats = pd.DataFrame(columns = ["UQI", "SSIM", "MSSSIM"], index = ["Mean", "Median", "Mode", "Max", "Min", "Variance", "Std"])

In [None]:
(_, idx, counts) = np.unique(uqi_list, return_index=True, return_counts=True)
index = idx[np.argmax(counts)]
uqi_mode = uqi_list[index]
(_, idx, counts) = np.unique(ssim_list, return_index=True, return_counts=True)
index = idx[np.argmax(counts)]
ssim_mode = ssim_list[index]
(_, idx, counts) = np.unique(msssim_list, return_index=True, return_counts=True)
index = idx[np.argmax(counts)]
msssim_mode = msssim_list[index]

stats["UQI"] = [np.mean(uqi_list), np.median(uqi_list), uqi_mode, np.max(uqi_list), np.min(uqi_list), np.var(uqi_list), np.std(uqi_list)]
stats["SSIM"] = [np.mean(ssim_list), np.median(ssim_list), ssim_mode, np.max(ssim_list), np.min(ssim_list), np.var(ssim_list), np.std(ssim_list)]
stats["MSSSIM"] = [np.mean(msssim_list), np.median(msssim_list), msssim_mode, np.max(msssim_list), np.min(msssim_list), np.var(msssim_list), np.std(msssim_list)]

In [None]:
stats

Unnamed: 0,UQI,SSIM,MSSSIM
Mean,0.950989,1.0+0.0j,1.0
Median,0.954953,1.0+0.0j,1.0
Mode,0.859114,1.0+0.0j,"(1.0, 1.0)"
Max,0.993308,1.0+0.0j,1.0
Min,0.859114,1.0+0.0j,1.0
Variance,0.000772,0.0+0.0j,0.0
Std,0.02778,0.0+0.0j,0.0


In [None]:
q_values = [20, 40, 60, 80]
test_noise = torch.randn(256, nz, 1, 1, device=device)
fake_images = netG(test_noise)
stat_lists = []

for q_val in q_values:
    netG = Generator(ngpu).to(device)
    netG.load_state_dict(torch.load("netG_epoch_14.pth"))
    prune(netG, method='percentage', q=q_val)
    summary(netG)
    pruned_fake_images = netG(test_noise)

    uqi_list = []
    ssim_list = []
    msssim_list = []

    for idx in range(256):
        fake_img = fake_images[idx].detach().cpu().numpy().squeeze()
        fake_img = (fake_img + 1) * 255 / 2
        fake_img = fake_img.astype(int)

        fake_img_pruned = pruned_fake_images[idx].detach().cpu().numpy().squeeze()
        fake_img_pruned = (fake_img_pruned + 1) * 255 / 2
        fake_img_pruned = fake_img_pruned.astype(int)

        uqi_list.append(uqi(fake_img_pruned, fake_img))
        ssim_list.append(msssim(fake_img_pruned, fake_img))
        msssim_list.append(ssim(fake_img_pruned, fake_img))

    stats = pd.DataFrame(columns = ["UQI", "SSIM", "MSSSIM"], index = ["Mean", "Median", "Mode", "Max", "Min", "Variance", "Std"])
    (_, idx, counts) = np.unique(uqi_list, return_index=True, return_counts=True)
    index = idx[np.argmax(counts)]
    uqi_mode = uqi_list[index]
    (_, idx, counts) = np.unique(ssim_list, return_index=True, return_counts=True)
    index = idx[np.argmax(counts)]
    ssim_mode = ssim_list[index]
    (_, idx, counts) = np.unique(msssim_list, return_index=True, return_counts=True)
    index = idx[np.argmax(counts)]
    msssim_mode = msssim_list[index]

    stats["UQI"] = [np.mean(uqi_list), np.median(uqi_list), uqi_mode, np.max(uqi_list), np.min(uqi_list), np.var(uqi_list), np.std(uqi_list)]
    stats["SSIM"] = [np.mean(ssim_list), np.median(ssim_list), ssim_mode, np.max(ssim_list), np.min(ssim_list), np.var(ssim_list), np.std(ssim_list)]
    stats["MSSSIM"] = [np.mean(msssim_list), np.median(msssim_list), msssim_mode, np.max(msssim_list), np.min(msssim_list), np.var(msssim_list), np.std(msssim_list)]


    stat_lists.append(stats)


Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
819200
1		Convolutional	819200		655360			0.200000
2		BatchNorm	N/A		N/A			N/A
3		ReLU		N/A		N/A			N/A
2097152
4		Convolutional	2097152		1677721			0.200000
5		BatchNorm	N/A		N/A			N/A
6		ReLU		N/A		N/A			N/A
524288
7		Convolutional	524288		419430			0.200001
8		BatchNorm	N/A		N/A			N/A
9		ReLU		N/A		N/A			N/A
131072
10		Convolutional	131072		104857			0.200005
11		BatchNorm	N/A		N/A			N/A
12		ReLU		N/A		N/A			N/A
1024
13		Convolutional	1024		819			0.200195
Total nonzero parameters: 2858187
Total parameters: 3572736
Total sparsity: 0.200001
Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
819200
1		Convolutional	819200		491520			0.400000
2		BatchNorm	N/A		N/A			N/A
3		ReLU		N/A		N/A			N/A
2097152
4		Convolutional	2097152		1258291			0.400000
5		BatchNorm	N/A		N/A			N/A
6		ReLU		N/A		N/A			N/A
524288
7		Convolutional	524288		314573			0.400000
8		BatchNorm	N/A		N/A			N/A
9		ReLU		N/A		N/A			N/A
131072
10		Convolutional	131072		7

In [None]:
stat_lists[3]

Unnamed: 0,UQI,SSIM,MSSSIM
Mean,0.704175,1.0+0.0j,1.0
Median,0.70774,1.0+0.0j,1.0
Mode,0.398456,1.0+0.0j,"(1.0, 1.0)"
Max,0.92548,1.0+0.0j,1.0
Min,0.398456,1.0+0.0j,1.0
Variance,0.010979,0.0+0.0j,0.0
Std,0.104783,0.0+0.0j,0.0


#### Extra experimentation for further analysis in the future

In [None]:
class InceptionV3(nn.Module):
    """Pretrained InceptionV3 network returning feature maps"""

    # Index of default block of inception to return,
    # corresponds to output of final average pooling
    DEFAULT_BLOCK_INDEX = 3

    # Maps feature dimensionality to their output blocks indices
    BLOCK_INDEX_BY_DIM = {
        64: 0,   # First max pooling features
        192: 1,  # Second max pooling featurs
        768: 2,  # Pre-aux classifier features
        2048: 3  # Final average pooling features
    }

    def __init__(self,
                 output_blocks=[DEFAULT_BLOCK_INDEX],
                 resize_input=True,
                 normalize_input=True,
                 requires_grad=False):

        super(InceptionV3, self).__init__()

        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)

        assert self.last_needed_block <= 3, \
            'Last possible output block index is 3'

        self.blocks = nn.ModuleList()


        inception = models.inception_v3(pretrained=True)

        # Block 0: input to maxpool1
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ]
        self.blocks.append(nn.Sequential(*block0))

        # Block 1: maxpool1 to maxpool2
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ]
            self.blocks.append(nn.Sequential(*block1))

        # Block 2: maxpool2 to aux classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))

        # Block 3: aux classifier to final avgpool
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ]
            self.blocks.append(nn.Sequential(*block3))

        for param in self.parameters():
            param.requires_grad = requires_grad

    def forward(self, inp):
        """Get Inception feature maps
        Parameters
        ----------
        inp : torch.autograd.Variable
            Input tensor of shape Bx3xHxW. Values are expected to be in
            range (0, 1)
        Returns
        -------
        List of torch.autograd.Variable, corresponding to the selected output
        block, sorted ascending by index
        """
        outp = []
        x = inp

        if self.resize_input:
            x = torch.nn.functional.interpolate(x,
                                                size=(299, 299),
                                                mode='bilinear',
                                                align_corners=False)

        if self.normalize_input:
            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)

        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                outp.append(x)

            if idx == self.last_needed_block:
                break

        return outp

block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
model = InceptionV3([block_idx])
model=model.cuda()

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth
100%|██████████| 104M/104M [00:00<00:00, 146MB/s] 


In [None]:
def calculate_activation_statistics(images,model,batch_size=128, dims=10,
                    cuda=False):
    model.eval()
    act=np.empty((len(images), dims))

    if cuda:
        batch=images.cuda()
    else:
        batch=images
    pred = model(batch)[0]

        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
    if pred.size(2) != 1 or pred.size(3) != 1:
        pred = torch.nn.functional.adaptive_avg_pool2d(pred, output_size=(1, 1))

    act= pred.cpu().data.numpy().reshape(pred.size(0), -1)

    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma

In [None]:
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2


    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))


    if np.iscomplexobj(covmean):
        # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            # m = np.max(np.abs(covmean.imag))
            # raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)

In [None]:
def calculate_fretchet(images_real,images_fake,model):
     mu_1,std_1=calculate_activation_statistics(images_real,model,cuda=True)
     mu_2,std_2=calculate_activation_statistics(images_fake,model,cuda=True)

     """get fretched distance"""
     fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
     return fid_value

In [None]:
test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                              shuffle=True, num_workers=2)

In [None]:
fid_scores = []
for img_batch in test_dataloader:
    test_noisenoise = torch.randn(128, nz, 1, 1, device=device)
    fake_images = netG(test_noisenoise).detach().cpu().numpy()
    rgb_fake_images = np.repeat(fake_images[..., np.newaxis], 3, 1).squeeze()
    reshaped_img_batch = img_batch[0].detach().cpu().numpy()
    rgb_reshaped_img_batch = np.repeat(reshaped_img_batch[..., np.newaxis], 3, 1).squeeze()
    fid_scores.append(calculate_fretchet(torch.Tensor(rgb_reshaped_img_batch), torch.Tensor(rgb_fake_images), model))

# real = next(test_dataloader)

KeyboardInterrupt: ignored

Collecting sewar
  Downloading sewar-0.4.6.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: sewar
  Building wheel for sewar (setup.py) ... [?25l[?25hdone
  Created wheel for sewar: filename=sewar-0.4.6-py3-none-any.whl size=11420 sha256=0c0b892659315592b2ee60cc60db87be6218282583a62c8e9bea142cfbfa835d
  Stored in directory: /root/.cache/pip/wheels/3f/af/02/9c6556ba287b62a945d737def09b8b8c35c9e1d82b9dfae84c
Successfully built sewar
Installing collected packages: sewar
Successfully installed sewar-0.4.6


In [None]:
from sewar.full_ref import mse, rmse, psnr, uqi, ssim, ergas, scc, rase, sam, msssim, vifp

In [None]:
import torchvision.models as models

## Custom resnet for calculating FID

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
from torchvision import transforms
import time
from tqdm.autonotebook import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect

import matplotlib.pyplot as plt
import numpy as np

In [None]:
class MnistResNet(nn.Module):
  def __init__(self, in_channels=1):
    super(MnistResNet, self).__init__()

    # Load a pretrained resnet model from torchvision.models in Pytorch
    self.model = models.resnet50(pretrained=True)

    # Change the input layer to take Grayscale image, instead of RGB images.
    # Hence in_channels is set as 1 or 3 respectively
    # original definition of the first layer on the ResNet class
    # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
    self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # Change the output layer to output 10 classes instead of 1000 classes
    num_ftrs = self.model.fc.in_features
    self.model.fc = nn.Linear(num_ftrs, 10)

  def forward(self, x):
    return self.model(x)


In [None]:
my_resnet = MnistResNet()

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 158MB/s]


In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
def get_data_loaders(train_batch_size, val_batch_size):
    fashion_mnist = torchvision.datasets.FashionMNIST(download=True, train=True, root=".").train_data.float()

    data_transform = transforms.Compose([ transforms.Resize((224, 224)),
                                         transforms.ToTensor(),
                                         transforms.Normalize((fashion_mnist.mean()/255,), (fashion_mnist.std()/255,))])

    train_loader = DataLoader(torchvision.datasets.FashionMNIST(download=True, root=".", transform=data_transform, train=True),
                              batch_size=train_batch_size, shuffle=True)

    val_loader = DataLoader(torchvision.datasets.FashionMNIST(download=True, root=".", transform=data_transform, train=False),
                            batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader

In [None]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)

def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [None]:
# model:
model = MnistResNet().to(device)

# params you need to specify:
epochs = 1
batch_size = 128

# Dataloaders
train_loader, val_loader = get_data_loaders(batch_size, batch_size)

# loss function and optimiyer
loss_function = nn.CrossEntropyLoss() # your loss function, cross entropy works well for multi-class problems

# optimizer, I've used Adadelta, as it wokrs well without any magic numbers
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) # Using Karpathy's learning rate constant

start_ts = time.time()

losses = []
batches = len(train_loader)
val_batches = len(val_loader)

# loop for every epoch (training + evaluation)
for epoch in range(epochs):
    total_loss = 0

    # progress bar (works in Jupyter notebook too!)
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)

    # ----------------- TRAINING  --------------------
    # set model to training
    model.train()

    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)

        # training step for single batch
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)
        loss.backward()
        optimizer.step()

        # getting training quality data
        current_loss = loss.item()
        total_loss += current_loss

        # updating progress bar
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))

    # releasing unceseccary memory in GPU
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # ----------------- VALIDATION  -----------------
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []

    # set model to evaluating (testing)
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)

            outputs = model(X) # this get's the prediction from the network

            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1] # get class from network's prediction

            # calculate P/R/F1/A metrics for batch
            for acc, metric in zip((precision, recall, f1, accuracy),
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )

    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches) # for plotting learning curve
print(f"Training time: {time.time()-start_ts}s")




Loss:   0%|          | 0/469 [00:00<?, ?it/s]

ValueError: ignored

In [None]:
# torch.save(model.state_dict(), "MnistResNet.pt")
custom_model = MnistResNet()
custom_model.load_state_dict(torch.load("MnistResNet.pt"))
custom_model.to(device)



MnistResNet(
  (model): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
        

In [None]:
test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                              shuffle=True, num_workers=2)

In [None]:
fid_scores = []
p = transforms.Compose([transforms.Resize((224,224))])

for img_batch in test_dataloader:

    test_noisenoise = torch.randn(128, nz, 1, 1, device=device)
    fake_images = netG(test_noisenoise)
    fake_images = p(fake_images)

    real_images = p(img_batch[0])



    # rgb_reshaped_img_batch = np.repeat(reshaped_img_batch[..., np.newaxis], 3, 1).squeeze()

    fid_scores.append(calculate_fretchet(real_images, fake_images, custom_model))
    break


In [None]:
p = transforms.Compose([transforms.Resize((224,224))])

In [None]:
test_noisenoise = torch.randn(128, nz, 1, 1, device=device)
fake_images = netG(test_noisenoise)
fake_images = p(fake_images)

