In [1]:
import torch
import os
import numpy as np
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import import_ipynb
import gibbs_sampler_poise
import kl_divergence_calculator
import data_preprocessing
from torchvision.utils import save_image
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
import torchvision.transforms as transforms
from torch.nn import functional as F  #for the activation function
from torchviz import make_dot
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import torchvision
import umap
import random
import shutil

importing Jupyter notebook from gibbs_sampler_poise.ipynb
importing Jupyter notebook from kl_divergence_calculator.ipynb
importing Jupyter notebook from data_preprocessing.ipynb


In [6]:
# learning parameters
latent_dim1 = 32
latent_dim2 = 16
batch_size = 10
dim_VIDEO   = 19200
dim_WFIELD   = 2* 135* 160
lr = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tx = transforms.ToTensor()
VIDEO_TEST_PATH     = "/hdd/achint_files/musall_behavior/video_test_data.npy"

WFIELD_TEST_PATH     = "/hdd/achint_files/wfield_data/DOWNSAMPLED_wfield_test_data.npy"
PATH = "/home/achint/Practice_code/Synthetic_dataset/POISE_VIDEO_WFIELD/untitled.txt"
SUMMARY_WRITER_PATH = "/home/achint/Practice_code/logs"
CROSS_RECONSTRUCTION_PATH = "/home/achint/Practice_code/Synthetic_dataset/POISE_VIDEO_WFIELD/reconstructions_experiments/cross_generation/"
JOINT_RECONSTRUCTION_PATH = "/home/achint/Practice_code/Synthetic_dataset/POISE_VIDEO_WFIELD/reconstructions_experiments/joint_generation/"

In [7]:
## Importing VIDEO and WFIELD datasets
joint_dataset_test  = data_preprocessing.EvilMouDataSet(video_dir=VIDEO_TEST_PATH,
                                                        wfield_dir = WFIELD_TEST_PATH)

joint_dataset_test_loader = DataLoader(
    joint_dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)

In [8]:
class VAE(nn.Module):
    def __init__(self,latent_dim1, latent_dim2, batch_size,use_mse_loss=True):
        super(VAE,self).__init__()
        self.latent_dim1 = latent_dim1
        self.latent_dim2 = latent_dim2
        self.batch_size = batch_size
        self.use_mse_loss = use_mse_loss
        self.gibbs                   = gibbs_sampler_poise.gibbs_sampler(self.latent_dim1, self.latent_dim2, self.batch_size)  
        self.kl_div                  = kl_divergence_calculator.kl_divergence(self.latent_dim1, self.latent_dim2, self.batch_size)
        ## Encoder set1(VIDEO)
        self.set1_enc1 = nn.Linear(in_features = dim_VIDEO,out_features = 512)
        self.set1_enc2 = nn.Linear(in_features = 512,out_features = 128)
        self.set1_enc3 = nn.Linear(in_features = 128,out_features = 2*latent_dim1) 
        ## Decoder set1(VIDEO)
        self.set1_dec1 = nn.Linear(in_features = latent_dim1,out_features = 128)
        self.set1_dec2 = nn.Linear(in_features = 128,out_features = 512)
        self.set1_dec3 = nn.Linear(in_features = 512,out_features = dim_VIDEO)
        ## Encoder set2(WFIELD)
        # input size: 1x2 x 135 x 160
        self.set2_enc1 = nn.Conv2d(in_channels=2, out_channels=latent_dim2, kernel_size=4, stride=2, padding=1)
        # size: 1x16 x 67 x 80
        self.set2_enc2 = nn.Conv2d(in_channels=latent_dim2, out_channels=latent_dim2, kernel_size=4, stride=2, padding=1)
        # size: 1x16 x 33 x 40
        self.set2_enc3 = nn.Conv2d(in_channels=latent_dim2, out_channels=latent_dim2, kernel_size=10, stride=(6,8), padding=1)
        # size: 16 x 5 x 5   
        ## Decoder set2(WFIELD)
        # input size: 16x1x1
        self.set2_dec0 = nn.ConvTranspose2d(in_channels=latent_dim2,out_channels=latent_dim2, kernel_size=5, stride=2, padding=0)
        # input size: 16x5x5
        self.set2_dec1 = nn.ConvTranspose2d(in_channels=latent_dim2,out_channels=latent_dim2, kernel_size=10, stride=(6,8), padding=1,output_padding=(1,0))
        # size: 16 x 33 x 40
        self.set2_dec2 = nn.ConvTranspose2d(in_channels=latent_dim2,out_channels=2, kernel_size=6, stride=4, padding=1,output_padding=(3,0))
        # size: 16 x 135 x 160
        
        self.VIDEOc1 = nn.Conv2d(latent_dim2, latent_dim2, 4, 2, 0)
        # size: 16 x 1 x 1
        self.VIDEOc2 = nn.Conv2d(latent_dim2, latent_dim2, 4, 2, 0)
        # size: 16 x 1 x 1
        self.register_parameter(name='g11', param = nn.Parameter(torch.randn(latent_dim1,latent_dim2)))
        self.register_parameter(name='g22', param = nn.Parameter(torch.randn(latent_dim1,latent_dim2)))
        self.flag_initialize= 1
        self.g12= torch.zeros(latent_dim1,latent_dim2).to(device)
    def forward(self,x1,x2):
        data1    = x1 #VIDEO
        data2    = x2 #WFIELD
        # Modality 1 (VIDEO)
        x1       = F.relu(self.set1_enc1(x1))
        x1       = F.relu(self.set1_enc2(x1))  
        x1       = self.set1_enc3(x1).view(-1,2,latent_dim1)  # ->[128,2,32]
        mu1      = x1[:,0,:] # ->[128,32]
        log_var1 = x1[:,1,:] # ->[128,32]
        var1     = -torch.exp(log_var1)           #lambdap_1<0
        # Modality 2 (WFIELD)
        x2 = x2.view(-1,2, 135,160) 
        x2 = F.relu(self.set2_enc1(x2))
        x2 = F.relu(self.set2_enc2(x2))
        x2 = F.relu(self.set2_enc3(x2))
        mu2 = (self.VIDEOc1(x2).squeeze(3)).squeeze(2)
        log_var2 = (self.VIDEOc2(x2).squeeze(3)).squeeze(2)
        var2     = -torch.exp(log_var2)           #lambdap_2<0     
        g22      = -torch.exp(self.g22) 

# Initializing gibbs sample        
        if self.flag_initialize==1:
            z1_prior,z2_prior = self.gibbs.gibbs_sample(self.flag_initialize,
                                                        torch.zeros_like(mu1),
                                                        torch.zeros_like(mu2),
                                                        self.g11,
                                                        g22,
                                                        torch.zeros_like(mu1),
                                                        torch.zeros_like(var1),
                                                        torch.zeros_like(mu2),
                                                        torch.zeros_like(var2),
                                                        n_iterations=5000)
            z1_posterior,z2_posterior = self.gibbs.gibbs_sample(self.flag_initialize,
                                                                torch.zeros_like(mu1),
                                                                torch.zeros_like(mu2),
                                                                self.g11,
                                                                g22,
                                                                mu1, 
                                                                var1,
                                                                mu2,
                                                                var2,
                                                                n_iterations=5000)
            self.z1_prior        = z1_prior
            self.z2_prior        = z2_prior
            self.z1_posterior    = z1_posterior
            self.z2_posterior    = z2_posterior
            self.flag_initialize = 0
        z1_prior     = self.z1_prior.detach()
        z2_prior     = self.z2_prior.detach()
        z1_posterior = self.z1_posterior.detach()
        z2_posterior = self.z2_posterior.detach()
        self.z1_gibbs_prior,self.z2_gibbs_prior         = self.gibbs.gibbs_sample(self.flag_initialize,
                                                                                  z1_prior,
                                                                                  z2_prior,
                                                                                  self.g11,
                                                                                  g22,
                                                                                  torch.zeros_like(mu1),
                                                                                  torch.zeros_like(var1),
                                                                                  torch.zeros_like(mu2),
                                                                                  torch.zeros_like(var2),
                                                                                  n_iterations=5)
        self.z1_gibbs_posterior,self.z2_gibbs_posterior = self.gibbs.gibbs_sample(self.flag_initialize,
                                                                                  z1_posterior,
                                                                                  z2_posterior,
                                                                                  self.g11,
                                                                                  g22,
                                                                                  mu1,
                                                                                  var1,
                                                                                  mu2,
                                                                                  var2,
                                                                                  n_iterations=5)
        self.z1_posterior = self.z1_gibbs_posterior.detach()
        self.z2_posterior = self.z2_gibbs_posterior.detach()
        self.z1_prior = self.z1_gibbs_prior.detach()
        self.z2_prior = self.z2_gibbs_prior.detach()
        G1 = torch.cat((self.g11,self.g12),0)
        G2 = torch.cat((self.g12,g22),0)
        G  = torch.cat((G1,G2),1)
        self.z2_gibbs_posterior = self.z2_gibbs_posterior.unsqueeze(2)
        self.z2_gibbs_posterior = self.z2_gibbs_posterior.unsqueeze(3)
        # decoding for VIDEO
        x1 = F.relu(self.set1_dec1(self.z1_gibbs_posterior))
        x1 = self.set1_dec2(x1)
        # decoding for WFIELD
        x2 = F.relu(self.set2_dec0(self.z2_gibbs_posterior))
        x2 = F.relu(self.set2_dec1(x2))
        if self.use_mse_loss:
            reconstruction1 = self.set1_dec3(x1)
            reconstruction2 = self.set2_dec2(x2).view(-1,dim_WFIELD)
        else:
            reconstruction1 = torch.sigmoid(self.set1_dec3(x1))
            reconstruction2 = torch.sigmoid(self.set2_dec2(x2))

        self.z2_gibbs_posterior = self.z2_gibbs_posterior.squeeze()
#         self.z2_gibbs_posterior = self.z2_gibbs_posterior.unsqueeze(0)
        # calculating loss
        part_fun0,part_fun1,part_fun2 = self.kl_div.calc(G,self.z1_gibbs_posterior,self.z2_gibbs_posterior,self.z1_gibbs_prior,self.z2_gibbs_prior,mu1,var1,mu2,var2)
        if self.use_mse_loss:
            mse_loss = nn.MSELoss(reduction='sum')
            MSE1 = mse_loss(reconstruction1, data1)
            MSE2 = mse_loss(reconstruction2, data2)

        else:
            bce_loss = nn.BCELoss(reduction='sum')
            MSE1 = bce_loss(reconstruction1, data1)
            MSE2 = bce_loss(reconstruction2, data2)

        KLD  = part_fun0+part_fun1+part_fun2
        loss = MSE1+MSE2+KLD
        return self.z1_posterior,self.z2_posterior,reconstruction1,reconstruction2,mu1,var1,mu2,var2,loss, MSE1, MSE2, KLD


In [9]:
state = torch.load(PATH)
model = VAE(latent_dim1, latent_dim2, batch_size,use_mse_loss=True).to(device)
optimizer = optim.Adam(model.parameters(),lr=lr)
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
for name, para in model.named_parameters():
    print(name)

g11
g22
set1_enc1.weight
set1_enc1.bias
set1_enc2.weight
set1_enc2.bias
set1_enc3.weight
set1_enc3.bias
set1_dec1.weight
set1_dec1.bias
set1_dec2.weight
set1_dec2.bias
set1_dec3.weight
set1_dec3.bias
set2_enc1.weight
set2_enc1.bias
set2_enc2.weight
set2_enc2.bias
set2_enc3.weight
set2_enc3.bias
set2_dec0.weight
set2_dec0.bias
set2_dec1.weight
set2_dec1.bias
set2_dec2.weight
set2_dec2.bias
VIDEOc1.weight
VIDEOc1.bias
VIDEOc2.weight
VIDEOc2.bias


In [32]:
def test(model,joint_dataloader,epoch,cross_generation,joint_generation):

    model.eval()
    latent_repVIDEO  = []
    latent_repWFIELD = []
    running_mse1 = 0.0
    running_mse2 = 0.0
    running_kld  = 0.0
    running_loss = 0.0
    with torch.no_grad():
        for i,joint_data in enumerate(joint_dataloader):
            data1   = joint_data[0]
            data1   = data1.float()
            data2   = joint_data[1]
            data2   = data2.float()
            data1   = data1.to(device)
            data2   = data2.to(device)
            data1   = data1.view(data1.size(0), -1)
            data2   = data2.view(data2.size(0), -1)
            if cross_generation:
                data2   = torch.zeros_like(data2)
            elif joint_generation:
                data1   = torch.zeros_like(data1)
                data2   = torch.zeros_like(data2)
            z1_posterior,z2_posterior,reconstruction1,reconstruction2,mu1,var1,mu2,var2,loss, MSE1, MSE2, KLD = model(data1,data2)  
            latent_repVIDEO.extend(z1_posterior.cpu().numpy())
            latent_repWFIELD.extend(z2_posterior.cpu().numpy())
            running_loss += loss.item()
            running_mse1 += MSE1.item()
            running_mse2 += MSE2.item()
            running_kld += KLD.item()
            #save the last batch input and output of every epoch
            if i == int(len(joint_dataloader.dataset)/joint_dataloader.batch_size) - 1:
                num_rows = 8
                trans = transforms.Compose([transforms.Resize((135,160))])
                both  = torch.cat((data1.view(batch_size, 1, 160, 120)[:8], 
                                  reconstruction1.view(batch_size, 1, 160, 120)[:8]))

                bothp = torch.cat((data2.view(batch_size, 2, 135, 160)[:8], 
                                  reconstruction2.view(batch_size, 2, 135, 160)[:8]))

                both = torch.cat((both,both),axis=1)

                both = trans(both)

                both_single = torch.cat((both,bothp),0)
                both_single = torch.cat((both,bothp),0)

                if cross_generation: 
                    save_image(both_single.cpu(), os.path.join(CROSS_RECONSTRUCTION_PATH, f"cross_generation_{epoch}.png"), nrow=num_rows)
                elif joint_generation:
                    save_image(both_single.cpu(), os.path.join(JOINT_RECONSTRUCTION_PATH, f"joint_generation_{epoch}.png"), nrow=num_rows)
    test_loss = running_loss/(len(joint_dataloader.dataset))
    mse1_loss = running_mse1 / (len(joint_dataloader.dataset))
    mse2_loss = running_mse2 / (len(joint_dataloader.dataset))
    kld_loss = running_kld / (len(joint_dataloader.dataset))
    writer.add_scalar("validation/loss", test_loss, epoch)
    writer.add_scalar("validation/MSE1", mse1_loss, epoch)
    writer.add_scalar("validation/MSE2", mse2_loss, epoch)
    writer.add_scalar("validation/KLD", kld_loss, epoch)
    return test_loss,latent_repVIDEO,latent_repWFIELD

In [33]:
test_loss = []
epochs = 1
writer=SummaryWriter(SUMMARY_WRITER_PATH)
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    test_epoch_loss,latent_repVIDEO,latent_repWFIELD = test(model,
                                                           joint_dataset_test_loader,
                                                           epoch,
                                                           cross_generation=True,
                                                           joint_generation=False)
    test_loss.append(test_epoch_loss)     
    print(f"Test Loss: {test_epoch_loss:.4f}")

Epoch 1 of 1
Test Loss: 5404.1009


In [37]:
z1 = np.asarray(latent_repVIDEO)
z2 = np.asarray(latent_repWFIELD)

In [38]:
np.shape(z2)

(7820, 16)