In [None]:
# !pip install torchmetrics

In [None]:
import torch 
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import h5py
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms
import torch.optim as optim
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy

In [None]:
# clearing cuda cache memory
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
os.listdir("../dataset")

In [None]:
# import dataset
electron_dataset = h5py.File("../dataset/SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5","r")
electron_imgs=np.array(electron_dataset["X"])
electron_labels=np.array(electron_dataset["y"],dtype=np.int64)

photon_dataset = h5py.File("../dataset/SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5","r")
photon_imgs=np.array(photon_dataset["X"])
photon_labels=np.array(photon_dataset["y"],dtype=np.int64)

In [None]:
# GoogLeNet
# Xception
# SENet

In [None]:
img_arrs = torch.Tensor(np.vstack((photon_imgs,electron_imgs)))
labels = torch.Tensor(np.hstack((photon_labels,electron_labels))).to(torch.int64)

In [None]:
class SingleElectronPhotonDataset(Dataset):
    def __init__(self,split_inx, transform=None,target_transform= None):
        self.img_arrs_split = img_arrs[split_inx]
        self.labels_split = labels[split_inx]
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return self.labels_split.shape[0]
    def __getitem__(self,idx):
        image=self.img_arrs_split[idx,:,:,:]
        # changing the dim of image to channels, height, width by transposing the
        # original image tensor.
        image = image.permute(2,1,0)
        label = self.labels_split[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image,label

In [None]:
class ResidualUnit(nn.Module):
    def __init__(self,in_channels, out_channels, **kwargs):
        super().__init__(**kwargs)
        strides = 1
        if in_channels == out_channels:
            strides = 1
            pad = "same"
        else:
            strides = 2
            pad = 1
        self.relu = nn.ReLU(inplace=True)
        self.main_layers = nn.ModuleList([
            nn.Conv2d(in_channels,out_channels,3,strides,padding=pad,bias=False),
            nn.BatchNorm2d(out_channels),
            self.relu,
            nn.Conv2d(out_channels,out_channels,3,stride=1,padding="same",bias=False),
            nn.BatchNorm2d(out_channels)
        ])
        
        self.skip_layers =[]
        if strides > 1 :
            self.skip_layers = nn.ModuleList([
                nn.Conv2d(in_channels,out_channels,1,strides,padding=0,bias=False),
                nn.BatchNorm2d(out_channels)
            ])
    def forward(self,x):
        Z = x 
        for layer in self.main_layers:
            Z = layer(Z)
        skip_z = x
        for layer in self.skip_layers:
            skip_z= layer(skip_z)
        return self.relu(Z + skip_z)  
    
    
class ResNet34(nn.Module):
    def __init__(self,num_classes):
        super(ResNet34, self).__init__()

        self.num_classes = num_classes
        self.relu = nn.ReLU(inplace=True)
        
        self.conv1 = nn.Conv2d(2,64,7,stride=2,padding=3,bias=False)
        self.bn1= nn.BatchNorm2d(64)
        self.max_pool = nn.AdaptiveMaxPool2d((112,112)) # same padding
        prev_filters = 64
        self.res_unit_list = nn.ModuleList([ResidualUnit(prev_filters,prev_filters)])
        for filters in [64]*2+[128]*4 + [256]*6 +[512]*3:
            self.res_unit_list.append(ResidualUnit(prev_filters,filters))
            prev_filters = filters 
#         self.fc = nn.LazyLinear(num_classes)
        self.fc = nn.Linear(512, num_classes)
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        for res_unit in self.res_unit_list:
            x = res_unit(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.softmax(x,dim=1)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("cpu"))
multicls_criterion = torch.nn.CrossEntropyLoss()

In [None]:
model = ResNet34(num_classes=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-2)

epochs = 2

In [None]:
preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.Normalize(mean=[0.5, 0.5], std=[0.5, 0.5]),
])

train_inx, valid_inx, test_inx = random_split(range(labels.shape[0]),[0.7,0.2,0.1],generator=torch.Generator()
                                            .manual_seed(42))

train_data = SingleElectronPhotonDataset(split_inx=train_inx,transform = preprocess)
valid_data = SingleElectronPhotonDataset(split_inx=valid_inx,transform = preprocess)
test_data = SingleElectronPhotonDataset(split_inx=test_inx,transform = preprocess)
# dataset = SingleElectronPhotonDataset()

train_dataloader = DataLoader(train_data,batch_size = 8, shuffle = True)
valid_dataloader = DataLoader(valid_data,batch_size = 8, shuffle = True)
test_dataloader = DataLoader(test_data,batch_size = 8, shuffle = True)

In [None]:
def train(model, device, loader, optimizer):
    model.train()

    loss_accum = 0
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        output = model(inputs)
        loss= 0 
        optimizer.zero_grad()
        loss += multicls_criterion(output, labels)
        loss.backward()
        optimizer.step()

        loss_accum += loss.item()

    print('Average training loss: {}'.format(loss_accum / (step + 1)))

In [None]:
def evaluate(model, device, loader,evaluator= "roauc"):
    model.eval()
    
    preds_list = []
    target_list = []
    for step, batch in enumerate(loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            output = model(inputs)
            preds_list.extend(output.tolist())
        target_list += batch[1].tolist()
    if evaluator == "roauc":   
        metric = MulticlassAUROC(num_classes=2, average="macro", thresholds=None)
    if evaluator == "acc":
        metric = MulticlassAccuracy(num_classes=2, average="macro")
    # print("AUC-ROC metric score : ",metric(torch.Tensor(preds_list),torch.Tensor(target_list)).item())
    return metric(torch.Tensor(preds_list),torch.Tensor(target_list).to(torch.int64)).item()

In [None]:
train_curves = []
valid_curves = []

for epoch in range(1, epochs + 1):
    print("=====Epoch {}".format(epoch))
    print('Training...')
    train(model, device, train_dataloader, optimizer)
    
    print("Evaluating...")
    train_perf_roauc = evaluate(model,device,train_dataloader)
    valid_perf_roauc = evaluate(model,device,valid_dataloader)
    train_perf_acc = evaluate(model,device,train_dataloader, evaluator = "acc")
    valid_perf_acc = evaluate(model,device,valid_dataloader,evaluator = "acc")
    train_curves.append([train_perf_acc,train_perf_roauc])
    valid_curves.append([valid_perf_acc,valid_perf_roauc])
    
    print('ROAUC scores: ',{'Train': train_perf_roauc, 'Validation': valid_perf_roauc}, '\nAccuracy scores: ',
         {'Train': train_perf_acc, 'Validation': valid_perf_acc})

print('\nFinished training!')
print('\nROAUC Test score: {}'.format(evaluate(model,device,test_dataloader)))