In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(123)

def show_tensor_images(image_tensor,num_images = 25 , size = (3,64,64)):
    image_tensor = (image_tensor + 1)/2
    image_unfit = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images],nrow=5)
    plt.imshow(image_grid.permute(1,2,0),squeeze())
    plt.show()
    
def make_grad_hook():
    grads = []
    def grad_hook(m):
        if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
            grads.append(m.weight.grad)
    return grads, grad_hook

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'matplotlib'

In [3]:
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [4]:
class Encoder(nn.Module):
    def __init__(self,z_dim=256,im_chan = 3, hidden_dim = 224):
        super(Encoder,self).__init__()
        self.z_dim = z_dim
        self.encoder = nn.Sequential(
            self.make_encoder_block(3,64,kernel_size=3),
            self.make_encoder_block(64,64,kernel_size=3),
            nn.MaxPool2d(kernel_size=2,stride=2),
            self.make_encoder_block(64,128,kernel_size=3),
            self.make_encoder_block(128,128,kernel_size=3),
            nn.MaxPool2d(kernel_size=2,stride=2),
            self.make_encoder_block(128,256,kernel_size=3),
            self.make_encoder_block(256,256,kernel_size=3),
            self.make_encoder_block(256,256,kernel_size=3),
            nn.MaxPool2d(kernel_size=2,stride=2),
            self.make_encoder_block(256,512,kernel_size=3),
            self.make_encoder_block(512,512,kernel_size=3),
            self.make_encoder_block(512,512,kernel_size=3),
            nn.MaxPool2d(kernel_size=2,stride=2),
            self.make_encoder_block(512,512,kernel_size=3),
            self.make_encoder_block(512,512,kernel_size=3),
            self.make_encoder_block(512,512,kernel_size=3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Flatten(),
            nn.Linear(7*7*512,4096),
            nn.Linear(4096,4096),
            nn.Linear(4096,z_dim)
        )
        
    def make_encoder_block(self,input_channels,output_channels,kernel_size=3):
        return nn.Sequential(
            nn.Conv2d(input_channels,output_channels,kernel_size = kernel_size, padding = 1),
            nn.LeakyReLU(0.2,inplace = True),
        )
    
    def forward(self,image):
        return self.encoder(image)

In [8]:
from torchsummary import summary
device = torch.device('cuda')
encoder = Encoder()
encoder = encoder.to(device)
summary(encoder,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
         LeakyReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
         LeakyReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
         LeakyReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
         LeakyReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
        LeakyReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
        LeakyReLU-14          [-1, 256,

In [71]:
class Decoder(nn.Module):
    def __init__(self,z_dim=256,im_chan = 3, hidden_dim=224):
        super(Decoder,self).__init__()
        self.z_dim = z_dim
        self.decoder = nn.Sequential(
            nn.Linear(z_dim,4096),
            nn.LeakyReLU(0.2),
            nn.Linear(4096,25088),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1,(512,7,7)),
            self.make_decoder_block(512,512,output_padding = 1),
            self.make_decoder_block(512,512),
            self.make_decoder_block(512,512),
            self.make_decoder_block(512,512,output_padding=1),
            self.make_decoder_block(512,512),
            self.make_decoder_block(512,512),
            self.make_decoder_block(512,256,output_padding = 1),
            self.make_decoder_block(256,256),
            self.make_decoder_block(256,128,output_padding = 1),
            self.make_decoder_block(128,64,output_padding = 1),
            nn.ConvTranspose2d(64,im_chan,kernel_size = 3,stride = 1,padding = 1),
            nn.Tanh()
        )
        
    def make_decoder_block(self,input_dim,output_dim,kernel_size=3,output_padding = None):
        if output_padding:
            return nn.Sequential(
                nn.ConvTranspose2d(input_dim,output_dim,kernel_size = kernel_size,stride=2,padding = 1,output_padding = output_padding),
                nn.BatchNorm2d(output_dim),
                nn.LeakyReLU(0.2)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_dim,output_dim,kernel_size = kernel_size,stride=1,padding = 1),
                nn.BatchNorm2d(output_dim),
                nn.LeakyReLU(0.2)
            )
        
    def forward(self,z):
        x = self.decoder(z)
        return x

In [97]:
encoder = encoder.to('cpu')
decoder = Decoder()
decoder = decoder.to('cpu')
summary(decoder,(256,),device = 'cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 4096]       1,052,672
         LeakyReLU-2                 [-1, 4096]               0
            Linear-3                [-1, 25088]     102,785,536
         LeakyReLU-4                [-1, 25088]               0
         Unflatten-5            [-1, 512, 7, 7]               0
   ConvTranspose2d-6          [-1, 512, 14, 14]       2,359,808
       BatchNorm2d-7          [-1, 512, 14, 14]           1,024
         LeakyReLU-8          [-1, 512, 14, 14]               0
   ConvTranspose2d-9          [-1, 512, 14, 14]       2,359,808
      BatchNorm2d-10          [-1, 512, 14, 14]           1,024
        LeakyReLU-11          [-1, 512, 14, 14]               0
  ConvTranspose2d-12          [-1, 512, 14, 14]       2,359,808
      BatchNorm2d-13          [-1, 512, 14, 14]           1,024
        LeakyReLU-14          [-1, 512,

In [94]:
torch.cuda.empty_cache()