In [1]:
# !pip install torchmetrics

In [2]:
import torch 
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import h5py
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms
import torch.optim as optim
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy

In [3]:
# clearing cuda cache memory
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [4]:
os.listdir("../dataset")

['QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272.test.snappy.parquet',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run1_n47540',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run1_n47540.test.snappy.parquet',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run2_n55494',
 'QCDToGGQQ_IMGjet_RH1all_jet0_run2_n55494.test.snappy.parquet',
 'quark-gluon_data-set_n139306.hdf5',
 'SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5',
 'SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5']

In [5]:
# import dataset
electron_dataset = h5py.File("../dataset/SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5","r")
electron_imgs=np.array(electron_dataset["X"])
electron_labels=np.array(electron_dataset["y"],dtype=np.int64)

photon_dataset = h5py.File("../dataset/SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5","r")
photon_imgs=np.array(photon_dataset["X"])
photon_labels=np.array(photon_dataset["y"],dtype=np.int64)

In [6]:
img_arrs = torch.Tensor(np.vstack((photon_imgs,electron_imgs)))
labels = torch.Tensor(np.hstack((photon_labels,electron_labels))).to(torch.int64)

In [91]:
class SingleElectronPhotonDataset(Dataset):
    def __init__(self,split_inx, transform=None,target_transform= None):
        self.img_arrs_split = img_arrs[split_inx]
        self.labels_split = labels[split_inx]
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return self.labels_split.shape[0]
    def __getitem__(self,idx):
        image=self.img_arrs_split[idx,:,:,:]
        # changing the dim of image to channels, height, width by transposing the
        # original image tensor.
        image = image.permute(2,1,0)
        label = self.labels_split[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image,label

In [92]:
class ResidualUnitEnc(nn.Module):
    def __init__(self,in_channels, out_channels, **kwargs):
        super().__init__(**kwargs)
        strides = 1
        if in_channels == out_channels:
            strides = 1
            pad = "same"
        else:
            strides = 2
            pad = 1
        self.relu = nn.ReLU(inplace=True)
        self.main_layers = nn.ModuleList([
            nn.Conv2d(in_channels,out_channels,3,strides,padding=pad,bias=False),
            nn.BatchNorm2d(out_channels),
            self.relu,
            nn.Conv2d(out_channels,out_channels,3,stride=1,padding="same",bias=False),
            nn.BatchNorm2d(out_channels)
        ])
        
        self.skip_layers =[]
        if strides > 1 :
            self.skip_layers = nn.ModuleList([
                nn.Conv2d(in_channels,out_channels,1,strides,padding=0,bias=False),
                nn.BatchNorm2d(out_channels)
            ])
    def forward(self,x):
        Z = x 
        for layer in self.main_layers:
            Z = layer(Z)
        skip_z = x
        for layer in self.skip_layers:
            skip_z= layer(skip_z)
        return self.relu(Z + skip_z)  
    
class ResidualUnitDec(nn.Module):
    def __init__(self,in_channels, out_channels, **kwargs):
        super().__init__(**kwargs)
        strides = 1
        if in_channels == out_channels:
            strides = 1
            pad = 1
            out_pad =0 
        else:
            strides = 2
            pad = 1
            out_pad  =1 
            
        self.relu = nn.ReLU(inplace=True)    
        self.main_layers = nn.ModuleList([
            nn.ConvTranspose2d(in_channels,out_channels,3,strides,padding=pad,output_padding=out_pad,bias=False),
            nn.BatchNorm2d(out_channels),
            self.relu,
            nn.ConvTranspose2d(out_channels,out_channels,3,stride=1,padding=pad,bias=False),
            nn.BatchNorm2d(out_channels)
        ])
        
        self.skip_layers =[]
        if strides > 1 :
            self.skip_layers = nn.ModuleList([
                nn.ConvTranspose2d(in_channels,out_channels,1,strides,padding=0,output_padding=out_pad,bias=False),
                nn.BatchNorm2d(out_channels)
            ])
            
    def forward(self,x):
        Z = x 
        for layer in self.main_layers:
            Z = layer(Z)
        skip_z = x
        for layer in self.skip_layers:
            skip_z= layer(skip_z)
        print("after res unit", Z.shape)
        return self.relu(Z + skip_z)  

In [93]:
class ResNet18Enc(nn.Module):
    def __init__(self,latent_dim):
        super(ResNet18Enc, self).__init__()

        self.latent_dim = latent_dim
        self.relu = nn.ReLU(inplace=True)
        self.bn1= nn.BatchNorm2d(64)
        
        self.conv1 = nn.Conv2d(2,64,3,stride=2,padding=1,bias=False)
        self.max_pool = nn.MaxPool2d(3,2,1,return_indices=True) 
        prev_filters = 64
        self.res_unit_list = nn.ModuleList([ResidualUnitEnc(prev_filters,prev_filters)])
        for filters in [64]*1+[128]*2 + [256]*2 +[512]*2:
            self.res_unit_list.append(ResidualUnitEnc(prev_filters,filters))
            prev_filters = filters 
            
        self.fc1 = nn.Linear(512, latent_dim)
        self.muFC = nn.Linear(latent_dim, latent_dim)
        self.sigmaFC = nn.Linear(latent_dim, latent_dim) 
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x,max_indices = self.max_pool(x)
        for res_unit in self.res_unit_list:
            x = res_unit(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        mu =  self.muFC(x)
        sigma = torch.exp(self.sigmaFC(x))

        return x,mu,sigma,max_indices      
    
    def __str__(self):
        return "ResNet18Enc"
    
class ResNet18Dec(nn.Module):
    def __init__(self,latent_dim):
        super(ResNet18Dec, self).__init__()
        
        self.latent_dim = latent_dim
        self.relu = nn.ReLU(inplace=True)
        # removed batch normalization
        # self.bn1= nn.BatchNorm2d(3)
        
        self.fc = nn.Linear(latent_dim, 512)
        self.dec_res_units = nn.ModuleList([])
        prev_filters = 512
        for filters in [512]*2+ [256]*2+[128]*2+[64]*2:
            self.dec_res_units.append(ResidualUnitDec(prev_filters,filters))
            prev_filters = filters 
        self.max_unpool = nn.MaxUnpool2d(3,2)
        self.convTranspose1 = nn.ConvTranspose2d(64,2,3,stride=2,padding=1,output_padding =1,bias=False)

    def forward(self,z,mp_indices): 
        x = self.fc(z)
        x = x.view(z.size(0), 512, 1, 1)
        x = F.interpolate(x, scale_factor=4)
        for dec_res in self.dec_res_units:
            x = dec_res(x)
        print("before unpooling", x.shape)
        x = self.max_unpool(x,mp_indices)
        x = self.convTranspose1(x)
        # x = self.bn1(x)
        return torch.sigmoid(x)
    
    def __str__(self):
        return "ResNet18Dec"

In [94]:
class VAE_ResNet18(nn.Module):
    def __init__(self,latent_dim=100):
        super(VAE_ResNet18, self).__init__()
        
        self.encoder = ResNet18Enc(latent_dim)
        self.decoder = ResNet18Dec(latent_dim)
        
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0
        
    def forward(self,x):
        x,mu,sigma,mp_indices = self.encoder(x)
        z = self.reparameterize(x,mu,sigma)
        
#         z_with_mpIndices = dict()
#         z_with_mpIndices.z = z 
#         z_with_mpIndices.mp_indices = mp_indices
        x = self.decoder(z,mp_indices)
        
        return x 

    def calculateKL(self,z,mu,sigma):
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
    
    def reparameterize(self,x,mu,sigma):
        z = mu + sigma*self.N.sample(mu.shape)
        return z 
        
    def __str__(self):
        return "VAE_ResNet18"

In [95]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("cpu"))

In [96]:
model = VAE_ResNet18(latent_dim=100).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
reconLoss = torch.nn.L1Loss() # can use MSELoss 

epochs = 25

In [97]:
preprocess = transforms.Compose([
    transforms.Resize(128), # multiply of 16
    transforms.Normalize(mean=[0.5, 0.5], std=[0.5, 0.5])
#     transforms.Normalize(mean=[0.5, 0.5,0.5], std=[0.5, 0.5,0.5]),
])

train_inx, valid_inx, test_inx = random_split(range(labels.shape[0]),[0.7,0.2,0.1],generator=torch.Generator()
                                            .manual_seed(42))

train_data = SingleElectronPhotonDataset(split_inx=train_inx,transform = preprocess)
valid_data = SingleElectronPhotonDataset(split_inx=valid_inx,transform = preprocess)
test_data = SingleElectronPhotonDataset(split_inx=test_inx,transform = preprocess)
# dataset = SingleElectronPhotonDataset()

train_dataloader = DataLoader(train_data,batch_size = 64, shuffle = True)
valid_dataloader = DataLoader(valid_data,batch_size = 64, shuffle = True)
test_dataloader = DataLoader(test_data,batch_size = 64, shuffle = True)

RuntimeError: [enforce fail at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 2855731200 bytes.

In [None]:
def train(model, device, loader, optimizer):
    model.train()

    loss_accum = 0
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        x, _ = batch
        x = x.to(device)
        x_hat = model(x)
        optimizer.zero_grad()
        loss = reconLoss(x_hat,x) + model.kl
        loss.backward()
        optimizer.step()
        loss_accum += loss.item()

    print('Average training loss: {}'.format(loss_accum / (step + 1))) 

In [None]:
def evaluate(model, device, loader):
    model.eval()
    model.eval()
    elbo_loss_accum = 0.0
    recon_loss_accum = 0.0
    with torch.no_grad():
        for step, batch in enumerate(loader):
            x = x.to(device)
            # Encode data
            encoded_data = model.encoder(x)
            # Decode data
            x_hat = model(x)
            recon_loss = reconLoss(x_hat,x)  
            elbo_loss = recon_loss + model.kl
            elbo_loss_accum += elbo_loss
            recon_loss_accum += recon_loss
        return recon_loss_accum/(step+1), elbo_loss_accum/(step+1)

In [None]:
checkpoints_path = "../models"
checkpoints = os.listdir(checkpoints_path)
checkpoint_path = list(filter(lambda i : str(model) in i, checkpoints))

In [None]:
starting_epoch = 1
if len(checkpoint_path)>0:
    checkpoint = torch.load(f"{checkpoints_path}/{checkpoint_path[0]}")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    starting_epoch = checkpoint['epoch']+1

for epoch in range(starting_epoch, epochs + 1):
    print("=====Epoch {}".format(epoch))
    print('Training...')
    train(model, device, train_dataloader, optimizer)
    
    print("Evaluating...")
    train_recon_loss, train_elbo_loss = evaluate(model,device,train_dataloader)
    val_recon_loss, val_elbo_loss = evaluate(model,device,valid_dataloader)

    print('Recon. losses: ',{'Train': train_recon_loss, 'Validation': val_recon_loss},
          '\nELBO losses: ',{'Train': train_elbo_loss, 'Validation': val_elbo_loss})
    
    # save checkpoint of current epoch
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f"{checkpoints_path}/{str(model)}-{epoch}.pt")

    if epoch>1:
        os.remove(f"{checkpoints_path}/{str(model)}-{epoch-1}.pt")

print('\nFinished training!')
print('\nTest recon. loss: {}, Test ELBO. loss'.format(evaluate(model,device,test_dataloader)[0],
                                                      evaluate(model,device,test_dataloader)[1]))