In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import import_ipynb
import gibbs_sampler_poise
import kl_divergence_calculator
import data_preprocessing
from torchvision.utils import save_image
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
import torchvision.transforms as transforms
from torch.nn import functional as F  #for the activation function
from torchviz import make_dot
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import torchvision
import random
random.seed(30)

importing Jupyter notebook from gibbs_sampler_poise.ipynb
importing Jupyter notebook from kl_divergence_calculator.ipynb
importing Jupyter notebook from data_preprocessing.ipynb


In [2]:
# learning parameters
batch_size = 10
lr = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device=torch.device('cpu')
tx = transforms.ToTensor()

In [3]:
## Importing MNIST and SVHN datasets
joint_dataset_train=data_preprocessing.JointDataset(mnist_pt_path="/home/achint/Practice_code/VAE/MNIST/MNIST/processed/training.pt",
                             svhn_mat_path="/home/achint/Practice_code/VAE/SVHN/train_32x32.mat")
joint_dataset_test = data_preprocessing.JointDataset(mnist_pt_path="/home/achint/Practice_code/VAE/MNIST/MNIST/processed/test.pt",
                             svhn_mat_path="/home/achint/Practice_code/VAE/SVHN/test_32x32.mat")

joint_dataset_train_loader = DataLoader(
    joint_dataset_train,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)
joint_dataset_test_loader = DataLoader(
    joint_dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)


In [4]:
writer = SummaryWriter('/home/achint/Practice_code/logs')
examples= iter(joint_dataset_train_loader)
example1_data,example2_data,example1_labels,example2_labels=examples.next()
img1_grid= torchvision.utils.make_grid(example1_data)
img2_grid= torchvision.utils.make_grid(example2_data)
writer.add_image('mnist',img1_grid)
writer.add_image('svhn',img2_grid)


TypeError: Cannot handle this data type: (1, 1, 10), |u1

In [19]:
example2_labels

tensor([1])

In [11]:
latent_dim1 = 1
latent_dim2 = 1
dim_MNIST   = 784
class VAE(nn.Module):
    def __init__(self):
        super(VAE,self).__init__()
        self.gibbs                   = gibbs_sampler_poise.gibbs_sampler()  
        self.kl_div                  = kl_divergence_calculator.kl_divergence()
        ## Encoder set1(MNIST)
        self.set1_enc1 = nn.Linear(in_features = dim_MNIST,out_features = 512)
        self.set1_enc2 = nn.Linear(in_features = 512,out_features = 128)
        self.set1_enc3 = nn.Linear(in_features = 128,out_features = 2*latent_dim1)
        
        ## Decoder set1(MNIST)
        self.set1_dec1 = nn.Linear(in_features = latent_dim1,out_features = 128)
        self.set1_dec2 = nn.Linear(in_features = 128,out_features = 512)
        self.set1_dec3 = nn.Linear(in_features = 512,out_features = dim_MNIST)
        
        ## Encoder set2(SVHN)
        # input size: 3 x 32 x 32
        self.set2_enc1 = nn.Conv2d(in_channels=3, out_channels=2*latent_dim2, kernel_size=4, stride=2, padding=1)
        # size: 32 x 16 x 16
        self.set2_enc2 = nn.Conv2d(in_channels=2*latent_dim2, out_channels=2*latent_dim2, kernel_size=4, stride=2, padding=1)
        # size: 32 x 8 x 8
        self.set2_enc3 = nn.Conv2d(in_channels=2*latent_dim2, out_channels=latent_dim2, kernel_size=4, stride=2, padding=1)
        # size: 16 x 4 x 4    

        ## Decoder set2(SVHN)
        # input size: 16x1x1
        self.set2_dec0 = nn.ConvTranspose2d(in_channels=latent_dim2,out_channels=latent_dim2, kernel_size=4, stride=1, padding=0)
        # input size: 16x4x4
        self.set2_dec1 = nn.ConvTranspose2d(in_channels=latent_dim2,out_channels=2*latent_dim2, kernel_size=3, stride=1, padding=1)
        # size: 32 x 4 x 4
        self.set2_dec2 = nn.ConvTranspose2d(in_channels=2*latent_dim2,out_channels=2*latent_dim2, kernel_size=5, stride=1, padding=0)
        # size: 32 x 8 x 8
        self.set2_dec3 = nn.ConvTranspose2d(in_channels=2*latent_dim2,out_channels=2*latent_dim2, kernel_size=4, stride=2, padding=1)
        # size: 32 x 16 x 16
        self.set2_dec4 = nn.ConvTranspose2d(in_channels=2*latent_dim2,out_channels=3, kernel_size=4, stride=2, padding=1)
        # size: 3 x 32 x 32
        
        self.SVHNc1 = nn.Conv2d(latent_dim2, latent_dim2, 4, 1, 0)
        # size: 16 x 1 x 1
        self.SVHNc2 = nn.Conv2d(latent_dim2, latent_dim2, 4, 1, 0)
        # size: 16 x 1 x 1
        self.register_parameter(name='g11', param = nn.Parameter(torch.randn(latent_dim1,latent_dim2)))
        self.register_parameter(name='g22', param = nn.Parameter(torch.randn(latent_dim1,latent_dim2)))
        self.flag_initialize= 1
        self.g12= torch.zeros(latent_dim1,latent_dim2).to(device)
    def forward(self,x1,x2):
        data1    = x1 #MNIST
        data2    = x2 #SVHN
        # Modality 1 (MNIST)
        x1       = F.relu(self.set1_enc1(x1))
        x1       = F.relu(self.set1_enc2(x1))  
        x1       = self.set1_enc3(x1).view(-1,2,latent_dim1)  # ->[128,2,32]
        mu1      = x1[:,0,:] # ->[128,32]
        log_var1 = x1[:,1,:] # ->[128,32]
        var1     = -torch.exp(log_var1)           #lambdap_2<0
        # Modality 2 (SVHN)
        x2 = x2.view(-1,3, 32,32) 
        x2 = F.relu(self.set2_enc1(x2))
        x2 = F.relu(self.set2_enc2(x2))
        x2 = F.relu(self.set2_enc3(x2))
        # get 'mu' and 'log_var' for SVHN
        mu2 = (self.SVHNc1(x2).squeeze(3)).squeeze(2)
        log_var2 = (self.SVHNc2(x2).squeeze(3)).squeeze(2)
        var2 = -torch.exp(log_var2)       
        g22      = -torch.exp(self.g22)     
        g11_copy = self.g11.detach()
        g22_copy = g22.detach()
        mu1_copy = mu1.detach()
        mu2_copy = mu2.detach()
        var1_copy=var1.detach()
        var2_copy=var2.detach()   
        
        if self.flag_initialize==1:
            self.flag_initialize=0
            z1_prior,z2_prior = self.gibbs.initialize_prior_sample(g11_copy,g22_copy)
            z1_posterior,z2_posterior = self.gibbs.initialize_posterior_sample(g11_copy,g22_copy,mu1_copy,var1_copy,mu2_copy,var2_copy)
            self.z1_prior =z1_prior
            self.z2_prior =z2_prior
            self.z1_posterior=z1_posterior
            self.z2_posterior=z2_posterior

        z1_prior     = self.z1_prior.detach()
        z2_prior     = self.z2_prior.detach()
        z1_posterior = self.z1_posterior.detach()
        z2_posterior = self.z2_posterior.detach()
        self.z1_gibbs_prior,self.z2_gibbs_prior         = self.gibbs.prior_sample(z1_prior,z2_prior,self.g11,g22)
        self.z1_gibbs_posterior,self.z2_gibbs_posterior = self.gibbs.posterior_sample(z1_posterior,z2_posterior,self.g11,g22,mu1,var1,mu2,var2)
        G1 = torch.cat((self.g11,self.g12),0)
        G2 = torch.cat((self.g12,g22),0)
        G  = torch.cat((G1,G2),1)

        self.z2_gibbs_posterior = self.z2_gibbs_posterior.unsqueeze(2)
        self.z2_gibbs_posterior = self.z2_gibbs_posterior.unsqueeze(3)
        
        # decoding for MNIST
        x1 = F.relu(self.set1_dec1(self.z1_gibbs_posterior))
        x1 = self.set1_dec2(x1)
        reconstruction1=torch.sigmoid(self.set1_dec3(x1))
        # decoding for SVHN
        x2 = F.relu(self.set2_dec0(self.z2_gibbs_posterior))
        x2 = F.relu(self.set2_dec1(x2))
        x2 = F.relu(self.set2_dec2(x2))
        x2 = F.relu(self.set2_dec3(x2))
        x2 = (self.set2_dec4(x2)).view(-1,3072)
        reconstruction2 = torch.sigmoid(x2)
        # calculating loss
        lambda2 = torch.cat((mu2,var2),1)                         # Output of encoder for set2 
        part_fun0,part_fun1,part_fun2 = self.kl_div.calc(G,self.z1_posterior,self.z2_posterior,self.z1_prior,self.z2_prior,mu1,var1,mu2,var2)
        print(torch.sum(part_fun0))
        print(torch.sum(part_fun1))
        print(torch.sum(part_fun2))
        bce_loss = nn.BCELoss(reduction='sum')
        MSE1 = bce_loss(reconstruction1, data1)
        MSE2 = bce_loss(reconstruction2, data2)
        KLD  = part_fun0+part_fun1+part_fun2

        print(torch.sum(MSE2))
        print(torch.sum(MSE1))

        loss = MSE1+MSE2+KLD
        return reconstruction1,reconstruction2,mu1,var1,mu2,var2,loss

In [12]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(),lr=lr)
for name, para in model.named_parameters():
    print(name)

g11
g22
set1_enc1.weight
set1_enc1.bias
set1_enc2.weight
set1_enc2.bias
set1_enc3.weight
set1_enc3.bias
set1_dec1.weight
set1_dec1.bias
set1_dec2.weight
set1_dec2.bias
set1_dec3.weight
set1_dec3.bias
set2_enc1.weight
set2_enc1.bias
set2_enc2.weight
set2_enc2.bias
set2_enc3.weight
set2_enc3.bias
set2_dec0.weight
set2_dec0.bias
set2_dec1.weight
set2_dec1.bias
set2_dec2.weight
set2_dec2.bias
set2_dec3.weight
set2_dec3.bias
set2_dec4.weight
set2_dec4.bias
SVHNc1.weight
SVHNc1.bias
SVHNc2.weight
SVHNc2.bias


In [13]:
def train(model,joint_dataloader):
    model.train()
    running_loss = 0.0
    for i,joint_data in enumerate(joint_dataloader):
        data1    = joint_data[0]
        data1    = data1.float()
        data2   = joint_data[1]
        data2   = data2.float()
        data1    = data1.to(device)
        data2   = data2.to(device)
        data1    = data1.view(data1.size(0), -1)
        data2   = data2.view(data2.size(0), -1)
        optimizer.zero_grad()
        reconstruction1,reconstruction2,mu1,var1,mu2,var2,loss       = model(data1,data2)  
        running_loss += loss.item()          #.item converts tensor with one element to number
        loss.backward()                      #.backward
        optimizer.step()                     #.step one learning step
    train_loss = running_loss/(len(joint_dataloader.dataset))
    return train_loss
    
def test(model,joint_dataloader):

    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i,joint_data in enumerate(joint_dataloader):
            data1   = joint_data[0]
            data1   = data1.float()

            data2  =joint_data[1]
            data2 = data2.float()

            label1  =joint_data[2]
            label2  =joint_data[3]
            
            data1 = data1.to(device)
            data2 = data2.to(device)
            data1 = data1.view(data1.size(0), -1)
            data2 = data2.view(data2.size(0), -1)
            reconstruction1,reconstruction2,mu1,var1,mu2,var2,loss = model(data1,data2)  
            label_mnist.append(label1)
            label_svhn.append(label2)
            running_loss += loss.item()
            #save the last batch input and output of every epoch
            if i == int(len(joint_dataloader.dataset)/joint_dataloader.batch_size) - 1:
                num_rows = 8
                both = torch.cat((data.view(batch_size, 1, 28, 28)[:8], 
                                  reconstruction.view(batch_size, 1, 28, 28)[:8]))
                bothp = torch.cat((datap.view(batch_size, 3, 32, 32)[:8], 
                                  reconstructionp.view(batch_size, 3, 32, 32)[:8]))
                save_image(both.cpu(), f"/home/achint/Practice_code/Synthetic_dataset/POISE_VAE_SVHN_MNIST/reconstructions/1_outputMNIST_{epoch}.png", nrow=num_rows)
                save_image(bothp.cpu(), f"/home/achint/Practice_code/Synthetic_dataset/POISE_VAE_SVHN_MNIST/reconstructions/2_outputSVHN_{epoch}.png", nrow=num_rows)
#     latent_repMNIST = torch.vstack(latent_repMNIST).cpu().numpy()
#     latent_repSVHN  = torch.vstack(latent_repSVHN).cpu().numpy()
#     label_mnist = torch.hstack(label_mnist).cpu().numpy()
#     label_svhn = torch.hstack(label_svhn).cpu().numpy()
    test_loss = running_loss/len(joint_dataloader.dataset)
    return test_loss

In [14]:
train_loss = []
test_loss = []
epochs = 1
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(model,joint_dataset_train_loader)
    test_epoch_loss = test(model,joint_dataset_test_loader)
    train_loss.append(train_epoch_loss)
    test_loss.append(test_epoch_loss)     
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Test Loss: {test_epoch_loss:.4f}")

Epoch 1 of 1
4
5
6
7
8
9
10
11
T1_prior tensor([[[-0.0531],
         [ 0.0028]]])
T2_prior tensor([[[-0.3200,  0.1024]]])
T1_post tensor([[[-5.8471e-05],
         [ 3.4189e-09]]])
T2_post tensor([[[3.8482e-03, 1.4809e-05]]])
tensor(-0.0173, grad_fn=<SumBackward0>)
tensor(0.0173, grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
tensor(-27829.4473, grad_fn=<SumBackward0>)
tensor(1133.8669, grad_fn=<SumBackward0>)
4
5
6
7
8
9
10
11
T1_prior tensor([[[-0.0531],
         [ 0.0028]]])
T2_prior tensor([[[-0.3200,  0.1024]]])
T1_post tensor([[[-5.8471e-05],
         [ 3.4189e-09]]])
T2_post tensor([[[3.8482e-03, 1.4809e-05]]])
tensor(-0.9562, grad_fn=<SumBackward0>)
tensor(0.9562, grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
tensor(-14789.9736, grad_fn=<SumBackward0>)
tensor(894.1980, grad_fn=<SumBackward0>)
4
5
6
7
8
9
10
11
T1_prior tensor([[[-0.0531],
         [ 0.0028]]])
T2_prior tensor([[[-0.3200,  0.1024]]])
T1_post tensor([[[-5.8471e-05],
         [ 3.4189e-09]]

RuntimeError: all elements of input should be between 0 and 1