In [1]:
import random
import torch 
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import h5py
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms
import torch.optim as optim
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy

In [2]:
# clearing cuda cache memory
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [3]:
os.listdir("../dataset")

['QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272.test.snappy.parquet',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run1_n47540',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run1_n47540.test.snappy.parquet',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run2_n55494',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run2_n55494.test.snappy.parquet',
 'quark-gluon_data-set_n139306.hdf5',
 'SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5',
 'SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5']

In [4]:
# import dataset
QuarkGluon_dataset = h5py.File("../dataset/quark-gluon_data-set_n139306.hdf5","r")
# truncating the dataset into a small one
random.seed(42)
sample_size = 75000
random_trunc = random.sample(range(0, 139306), sample_size)

In [5]:
QuarkGluon_imgs=np.array(QuarkGluon_dataset["X_jets"])[random_trunc]

In [6]:
img_arrs = torch.Tensor(QuarkGluon_imgs)

In [7]:
class QuarkGluonDataset(Dataset):
    def __init__(self,split_inx, transform=None,target_transform= None):
        self.img_arrs_split = img_arrs[split_inx]
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return self.img_arrs_split.shape[0]
    def __getitem__(self,idx):
        image=self.img_arrs_split[idx,:,:,:]
        # changing the dim of image to channels, height, width by transposing the
        # original image tensor.
        image = image.permute(2,1,0)
        if self.transform:
            image = self.transform(image)
        return image, -1

In [8]:
class ResidualUnitEnc(nn.Module):
    def __init__(self,in_channels, out_channels, **kwargs):
        super().__init__(**kwargs)
        strides = 1
        if in_channels == out_channels:
            strides = 1
            pad = "same"
        else:
            strides = 2
            pad = 1
        self.relu = nn.ReLU(inplace=True)
        self.main_layers = nn.ModuleList([
            nn.Conv2d(in_channels,out_channels,3,strides,padding=pad,bias=False),
            nn.BatchNorm2d(out_channels),
            self.relu,
            nn.Conv2d(out_channels,out_channels,3,stride=1,padding="same",bias=False),
            nn.BatchNorm2d(out_channels)
        ])
        
        self.skip_layers =[]
        if strides > 1 :
            self.skip_layers = nn.ModuleList([
                nn.Conv2d(in_channels,out_channels,1,strides,padding=0,bias=False),
                nn.BatchNorm2d(out_channels)
            ])
    def forward(self,x):
        Z = x 
        for layer in self.main_layers:
            Z = layer(Z)
        skip_z = x
        for layer in self.skip_layers:
            skip_z= layer(skip_z)
        return self.relu(Z + skip_z)  
    
class ResidualUnitDec(nn.Module):
    def __init__(self,in_channels, out_channels, **kwargs):
        super().__init__(**kwargs)
        strides = 1
        if in_channels == out_channels:
            strides = 1
            pad = 1
            out_pad =0 
        else:
            strides = 2
            pad = 1
            out_pad  =1 
            
        self.relu = nn.ReLU(inplace=True)    
        self.main_layers = nn.ModuleList([
            nn.ConvTranspose2d(in_channels,out_channels,3,strides,padding=pad,output_padding=out_pad,bias=False),
            nn.BatchNorm2d(out_channels),
            self.relu,
            nn.ConvTranspose2d(out_channels,out_channels,3,stride=1,padding=pad,bias=False),
            nn.BatchNorm2d(out_channels)
        ])
        
        self.skip_layers =[]
        if strides > 1 :
            self.skip_layers = nn.ModuleList([
                nn.ConvTranspose2d(in_channels,out_channels,1,strides,padding=0,output_padding=out_pad,bias=False),
                nn.BatchNorm2d(out_channels)
            ])
            
    def forward(self,x):
        Z = x 
        for layer in self.main_layers:
            Z = layer(Z)
        skip_z = x
        for layer in self.skip_layers:
            skip_z= layer(skip_z)
        return self.relu(Z + skip_z)  

In [9]:
class ResNet18Enc(nn.Module):
    def __init__(self,latent_dim):
        super(ResNet18Enc, self).__init__()

        self.latent_dim = latent_dim
        self.relu = nn.ReLU(inplace=True)
        self.bn1= nn.BatchNorm2d(64)
        
        self.conv1 = nn.Conv2d(3,64,3,stride=2,padding=1,bias=False)
        self.max_pool = nn.MaxPool2d(3,2,1) 
        prev_filters = 64
        self.res_unit_list = nn.ModuleList([ResidualUnitEnc(prev_filters,prev_filters)])
        for filters in [64]*1+[128]*2 + [256]*2 +[512]*2:
            self.res_unit_list.append(ResidualUnitEnc(prev_filters,filters))
            prev_filters = filters 
            
        self.fc1 = nn.Linear(512, latent_dim)
        self.muFC = nn.Linear(latent_dim, latent_dim)
        self.logVarFC = nn.Linear(latent_dim, latent_dim) 
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        for res_unit in self.res_unit_list:
            x = res_unit(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        mu =  self.muFC(x)
        logVar = self.logVarFC(x) # what is learnt is log(sigma**2)
        
        sigma = torch.exp(0.5*logVar) 

        return x,mu,sigma      
    
    def __str__(self):
        return "ResNet18Enc"
    
class ResNet18Dec(nn.Module):
    def __init__(self,latent_dim):
        super(ResNet18Dec, self).__init__()
        
        self.latent_dim = latent_dim
        self.relu = nn.ReLU(inplace=True)
        # removed batch normalization
        # self.bn1= nn.BatchNorm2d(3)
        
        self.fc = nn.Linear(latent_dim, 512)
        self.dec_res_units = nn.ModuleList([])
        prev_filters = 512
        for filters in [512]*2+ [256]*2+[128]*2+[64]*2:
            self.dec_res_units.append(ResidualUnitDec(prev_filters,filters))
            prev_filters = filters 
            
        # self.max_unpool = nn.MaxUnpool2d(2,2)
        # self.max_unpool = nn.MaxUnpool2d(3,2)
        self.convTranspose1 = nn.ConvTranspose2d(64,64,3,stride=2,padding=1,output_padding =1,
                                                 groups=64,bias=False)
        # ideally following conv should recover max output of prev conv
        # hence we add groups
        self.convTranspose2 = nn.ConvTranspose2d(64,3,3,stride=2,padding=1,output_padding =1,bias=False)

    def forward(self,z): 
        x = self.fc(z)
        x = x.view(z.size(0), 512, 1, 1)
        x = F.interpolate(x, scale_factor=4)
        for dec_res in self.dec_res_units:
            x = dec_res(x)
#         x = self.max_unpool(x,mp_indices)
        x = self.convTranspose1(x)
        x = self.convTranspose2(x)
        # x = self.bn1(x)
        return torch.sigmoid(x)
    
    def __str__(self):
        return "ResNet18Dec"

In [10]:
class VAE_ResNet18(nn.Module):
    def __init__(self,latent_dim=100):
        super(VAE_ResNet18, self).__init__()
        
        self.encoder = ResNet18Enc(latent_dim)
        self.decoder = ResNet18Dec(latent_dim)
        
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0
        
    def forward(self,x):
        x,mu,sigma = self.encoder(x)
        z = self.reparameterize(x,mu,sigma)
        # calculating kl divergence 
        # storing kl value as a variable.
        self.calculateKL(z,mu,sigma)
        x = self.decoder(z)
        return x 

    def calculateKL(self,z,mu,sigma):
        self.kl =  0.5*((sigma**2) +(mu**2) -torch.log(sigma**2) -1).sum()
    
    def reparameterize(self,x,mu,sigma):
        z = mu + sigma*self.N.sample(mu.shape)
        return z 
        
    def __str__(self):
        return "VAE_ResNet18"

In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("cpu"))

In [12]:
model = VAE_ResNet18(latent_dim=100).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
reconLoss = torch.nn.BCELoss() 

epochs = 5

In [13]:
preprocess = transforms.Compose([
    transforms.Resize(128),
    transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
])

train_inx, test_inx = random_split(range(QuarkGluon_imgs.shape[0]),[0.8,0.2],generator=torch.Generator()
                                            .manual_seed(42))

train_data = QuarkGluonDataset(split_inx=train_inx,transform = preprocess)
test_data = QuarkGluonDataset(split_inx=test_inx,transform = preprocess)
# dataset = SingleElectronPhotonDataset()

train_dataloader = DataLoader(train_data,batch_size = 64, shuffle = True)
test_dataloader = DataLoader(test_data,batch_size = 64, shuffle = True)

In [14]:
def train(model, device, loader, optimizer):
    model.train()

    loss_accum = 0.0
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        x, _ = batch
        x = x.to(device)
        x_hat = model(x)
        optimizer.zero_grad()
        # ELBO loss => -ELBO SHOULD BE MINIMIZED
        # +ELBO =>  -model.kl -reconLoss(x_hat,x)
        # kl is always positive 
        # hence -ELBO is as follows
        loss = model.kl+reconLoss(x_hat,x)
        loss.backward()
        optimizer.step()
        loss_accum += loss.item()

    print('Average training loss: {}'.format(loss_accum / (step + 1))) 

In [15]:
def evaluate(model, device, loader):
    model.eval()

    elbo_loss_accum = 0.0
    recon_loss_accum = 0.0
    with torch.no_grad():
        for step, batch in enumerate(loader):
            x, _ = batch
            x = x.to(device)
            # Encode data
            encoded_data = model.encoder(x)
            # Decode data
            x_hat = model(x)
            recon_loss = reconLoss(x_hat,x)  
            elbo_loss = recon_loss + model.kl
            elbo_loss_accum += elbo_loss.item()
            recon_loss_accum += recon_loss.item()
    return recon_loss_accum/(step+1), elbo_loss_accum/(step+1)

In [16]:
def load_latest_model(model,optimizer,checkpoints_path):
    checkpoints = os.listdir(checkpoints_path)
    checkpoint_path = list(filter(lambda i : str(model) in i, checkpoints))
    if len(checkpoint_path)>0:
        checkpoint = torch.load(f"{checkpoints_path}/{checkpoint_path[0]}")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        starting_epoch = checkpoint['epoch']+1
    return model,optimizer,starting_epoch

In [17]:
starting_epoch = 1

checkpoints_path = "../models"
model,optimizer,starting_epoch = load_latest_model(model,optimizer,checkpoints_path)


for epoch in range(starting_epoch, epochs + 1):
    print("=====Epoch {}".format(epoch))
    print('Training...')
    train(model, device, train_dataloader, optimizer)
    
    print("Evaluating...")
    train_recon_loss, train_elbo_loss = evaluate(model,device,train_dataloader)
    
    print('Recon. losses: ',{'Train': train_recon_loss},
          '\nELBO(neg): ',{'Train': train_elbo_loss})
    
    # save checkpoint of current epoch
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f"{checkpoints_path}/{str(model)}-{epoch}.pt")

    if epoch>1:
        os.remove(f"{checkpoints_path}/{str(model)}-{epoch-1}.pt")

print('\nFinished training!')
test_recon_loss, test_elbo_loss= evaluate(model,device,test_dataloader)
print('\nTest recon. loss: {}, Test ELBO(neg) {}'.format(test_recon_loss,test_elbo_loss))

=====Epoch 1
Training...


Iteration: 100%|██████████| 938/938 [04:23<00:00,  3.56it/s]


Average training loss: 1.1246144923823091
Evaluating...
Recon. losses:  {'Train': 0.0006282743218460126} 
ELBO(neg):  {'Train': 0.026484536273387083}
=====Epoch 2
Training...


Iteration: 100%|██████████| 938/938 [05:46<00:00,  2.71it/s]


Average training loss: 0.008417858551482239
Evaluating...
Recon. losses:  {'Train': 0.0005163971761572681} 
ELBO(neg):  {'Train': 0.12569585414854054}
=====Epoch 3
Training...


Iteration: 100%|██████████| 938/938 [05:44<00:00,  2.72it/s]


Average training loss: 0.0038019351641099446
Evaluating...
Recon. losses:  {'Train': 0.0004975309617557268} 
ELBO(neg):  {'Train': 0.07768016608421013}
=====Epoch 4
Training...


Iteration: 100%|██████████| 938/938 [05:49<00:00,  2.69it/s]


Average training loss: 0.0018241143105789891
Evaluating...
Recon. losses:  {'Train': 0.0004916695451839512} 
ELBO(neg):  {'Train': 0.034036114611831716}
=====Epoch 5
Training...


Iteration: 100%|██████████| 938/938 [05:44<00:00,  2.73it/s]


Average training loss: 0.0010141205617975294
Evaluating...
Recon. losses:  {'Train': 0.0004904386723962134} 
ELBO(neg):  {'Train': 0.0008961843993283435}

Finished training!

Test recon. loss: 0.0004902180527000034, Test ELBO(neg): 0.0008199529209610154


In [None]:
unseen_data_inx = list(set(list(range(0,139306))) - set(random_trunc))

del QuarkGluon_imgs # delete this set to save memory

QuarkGluon_test_imgs=np.array(QuarkGluon_dataset["X_jets"])[unseen_data_inx]

print("\nunseen data size: ",QuarkGluon_test_imgs.shape)

img_arrs = QuarkGluon_test_imgs
test_unseen_data = QuarkGluonDataset(split_inx=range(0,len(unseen_data_inx)),transform = preprocess)
unseen_test_dataloader = DataLoader(test_unseen_data,batch_size = 64, shuffle = True)

print("\nEvaluating on unseen data...")
unseenTest_recon_loss, unseenTestt_elbo_loss = evaluate(model,device,unseen_test_dataloader)

print('\nRecon. losses: ',{'Train': unseenTest_recon_loss},
          '\nELBO(neg): ',{'Train': unseenTestt_elbo_loss})