In [12]:
import torch
from torch import nn
import numpy as np
from sklearn.model_selection import train_test_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt 
from torch.distributions import Independent, Normal

In [13]:
batch_size = 32
device="cuda" if torch.cuda.is_available() else "cpu"
train_data = datasets.MNIST(root="../data", train=True, 
    download=True,
    transform=ToTensor(),
    target_transform=None
)

test_data = datasets.MNIST(root="../data", train=False,
    download=True,
    transform=ToTensor(),
    target_transform=None)

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)

In [14]:
import torch.nn.functional as F

In [79]:
class Encoder(nn.Module):
    def __init__(self, dim_x, dim_y, dim_z):

        super().__init__()
         # Encoder layers
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=1,padding='same')
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=2,padding=0)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1,padding='same')
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=2,padding=0)
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=80, kernel_size=7, stride=1,padding='valid')
        self.lin1 = nn.Linear(in_features=90, out_features=20)
        self.lin2 = nn.Linear(in_features=90, out_features=20)

        # reparameterization
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def forward(self, inputs):
        x = inputs[0].to(device)#.unsqueeze(dim=0)
        y = inputs[1].to(device)
        y = F.one_hot(y, 10)
        print(f"img shape: {x.shape}, labels shape: {y.shape}")
        x = F.leaky_relu(self.conv1(x))
        print(x.shape)
        # 32, 28, 28
        x = F.pad(x, (0,3,0,3))
        print(x.shape)
        # 32, 31, 31
        x = F.leaky_relu(self.conv2(x))
        print(x.shape)
        # 32, 14, 14
        x = F.leaky_relu(self.conv3(x))
        print(x.shape)
        # 64, 14, 14
        x = F.pad(x, (0,3,0,3))
        print(x.shape)
        # 64, 17, 17
        x = F.leaky_relu(self.conv4(x))
        print(x.shape)
        # 64, 7, 7
        x = F.leaky_relu(self.conv5(x))
        print(x.shape)
        # 80, 1, 1
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        print(f"After flatten shape: {x.shape}")
        # 80
        concat = torch.cat([x, y], dim=1)
        print(f"After concatenation shape: {concat.shape}")
        # 90
        # loc=torch.zeros(mu_logvar.shape)
        # scale=torch.ones(mu_logvar.shape)
        # diagn = Independent(Normal(loc, scale), 1)
        mu = self.lin1(concat)
        print(f"mu shape: {mu.shape}")
        # 20
        logvar = self.lin2(concat)
        # print(f"logvar shape: {logvar.shape}")
        z = self.reparameterize(mu, logvar)
        print(f"Returning shape {z.shape}")
        return  mu, logvar, z

In [83]:
class Decoder(nn.Module):
    def __init__(self, dim_y, dim_z):
        super().__init__()
        self.dim_z = dim_z
        self.dim_y = dim_y
        self.deconv1 = nn.ConvTranspose2d(in_channels=30, out_channels=64, kernel_size=7, stride=1, padding=0) # valid means no pad
        self.deconv2 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=2)
        self.deconv3 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=5, stride=2, padding=2, output_padding=1) # pad operation added in forward
        self.deconv4 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.deconv5 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2, output_padding=1) # pad operation added in forward
        self.deconv6 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.conv = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=5, stride=1,padding='same')
        

    def forward(self, inputs):
        x = inputs[0].to(device)#.unsqueeze(dim=0)
        y = inputs[1].to(device)
        y = F.one_hot(y, 10)
        print(f"latent space shape: {x.shape}, labels shape: {y.shape}")
        x = torch.cat([x, y], dim=1)
        x = torch.reshape(x, (batch_size, self.dim_z+self.dim_y, 1, 1))
        print(f"After concatenation shape: {x.shape}")
        x = F.leaky_relu(self.deconv1(x))
        print(f"ConvTrans1 output shape: {x.shape}")
        x = F.leaky_relu(self.deconv2(x))
        print(f"ConvTrans2 output shape: {x.shape}")
        x = F.pad(x, (0,0,0,0))
        x = F.leaky_relu(self.deconv3(x))
        print(f"ConvTrans3 output shape: {x.shape}")
        x = F.leaky_relu(self.deconv4(x))
        print(f"ConvTrans4 output shape: {x.shape}")
        # x = F.pad(x, (0,3,0,3))
        x = F.leaky_relu(self.deconv5(x))
        print(f"ConvTrans5 output shape: {x.shape}")
        x = F.leaky_relu(self.deconv6(x))
        print(f"ConvTrans6 output shape: {x.shape}")
        x = self.conv(x)
        print(f"Conv output shape: {x.shape}")
        # x = torch.flatten(x, start_dim=1, end_dim=-1)
        # print(f"After flatten shape: {x.shape}")
        return 0

In [84]:
enc = Encoder(dim_x=28, dim_y=10, dim_z=20).to(device)     
dec = Decoder(dim_y=10, dim_z=20).to(device)

i=50
# img=train_data[:10][0]
# lbl=train_data[:10][1]
# plt.imshow(img.squeeze())
# print(lbl)

for batch, (img, label) in enumerate(train_dataloader):
    _,_,out = enc((img, label))
    print(f"Encoder returned shape: {out.shape}\n------------\n---------------\n------------")
    out2 = dec((out, label))
    break



# inputs=(img, lbl)
# out = model(inputs)
#
# lblon

img shape: torch.Size([32, 1, 28, 28]), labels shape: torch.Size([32, 10])
torch.Size([32, 32, 28, 28])
torch.Size([32, 32, 31, 31])
torch.Size([32, 32, 14, 14])
torch.Size([32, 64, 14, 14])
torch.Size([32, 64, 17, 17])
torch.Size([32, 64, 7, 7])
torch.Size([32, 80, 1, 1])
After flatten shape: torch.Size([32, 80])
After concatenation shape: torch.Size([32, 90])
mu shape: torch.Size([32, 20])
Returning shape torch.Size([32, 20])
Encoder returned shape: torch.Size([32, 20])
------------
---------------
------------
latent space shape: torch.Size([32, 20]), labels shape: torch.Size([32, 10])
After concatenation shape: torch.Size([32, 30, 1, 1])
ConvTrans1 output shape: torch.Size([32, 64, 7, 7])
ConvTrans2 output shape: torch.Size([32, 64, 7, 7])
ConvTrans3 output shape: torch.Size([32, 64, 14, 14])
ConvTrans4 output shape: torch.Size([32, 32, 14, 14])
ConvTrans5 output shape: torch.Size([32, 32, 28, 28])
ConvTrans6 output shape: torch.Size([32, 32, 28, 28])
Conv output shape: torch.Size(