In [1]:
# !pip install torchmetrics

In [2]:
import torch 
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import h5py
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms
import torch.optim as optim
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# clearing cuda cache memory
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [4]:
os.listdir("../dataset")

['SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5',
 'SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5']

In [5]:
# import dataset
electron_dataset = h5py.File("../dataset/SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5","r")
electron_imgs=np.array(electron_dataset["X"])
electron_labels=np.array(electron_dataset["y"],dtype=np.int64)

photon_dataset = h5py.File("../dataset/SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5","r")
photon_imgs=np.array(photon_dataset["X"])
photon_labels=np.array(photon_dataset["y"],dtype=np.int64)

In [6]:
# GoogLeNet
# Xception
# SENet

In [7]:
img_arrs = torch.Tensor(np.vstack((photon_imgs,electron_imgs)))
labels = torch.Tensor(np.hstack((photon_labels,electron_labels))).to(torch.int64)

In [8]:
class SingleElectronPhotonDataset(Dataset):
    def __init__(self,split_inx, transform=None,target_transform= None):
        self.img_arrs_split = img_arrs[split_inx]
        self.labels_split = labels[split_inx]
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return self.labels_split.shape[0]
    def __getitem__(self,idx):
        image=self.img_arrs_split[idx,:,:,:]
        # changing the dim of image to channels, height, width by transposing the
        # original image tensor.
        image = image.permute(2,1,0)
        label = self.labels_split[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image,label

In [9]:
class ResidualUnit(nn.Module):
    def __init__(self,in_channels, out_channels, **kwargs):
        """ Constructor
        Args:
            isSkipLearnable: learnalbe weights are associated with skip connection.
        """
        super().__init__(**kwargs)
        strides = 1
        if in_channels == out_channels:
            strides = 1
            pad = "same"
        else:
            strides = 2
            pad = 1
        self.relu = nn.ReLU(inplace=True)
        self.main_layers = nn.ModuleList([
            nn.Conv2d(in_channels,out_channels,3,strides,padding=pad,bias=False),
            nn.BatchNorm2d(out_channels),
            self.relu,
            nn.Conv2d(out_channels,out_channels,3,stride=1,padding="same",bias=False),
            nn.BatchNorm2d(out_channels)
        ])
        
        self.skip_layers =[]
        if strides > 1 :
            self.skip_layers = nn.ModuleList([
                nn.Conv2d(in_channels,out_channels,1,strides,padding=0,bias=False),
                nn.BatchNorm2d(out_channels)
            ])
    def forward(self,x):
        Z = x 
        for layer in self.main_layers:
            Z = layer(Z)
        skip_z = x
        for layer in self.skip_layers:
            skip_z= layer(skip_z)
        return self.relu(Z + skip_z)  

class ResNet18(nn.Module):
    def __init__(self,num_classes):
        super(ResNet18, self).__init__()

        self.num_classes = num_classes
        self.relu = nn.ReLU(inplace=True)
        
        self.conv1 = nn.Conv2d(2,64,3,stride=2,padding=3,bias=False)
        self.bn1= nn.BatchNorm2d(64)
        self.max_pool = nn.MaxPool2d(1,1) 
        prev_filters = 64
        self.res_unit_list = nn.ModuleList([ResidualUnit(prev_filters,prev_filters)])
        for filters in [64]*1+[128]*2 + [256]*2 +[512]*2:
            self.res_unit_list.append(ResidualUnit(prev_filters,filters))
            prev_filters = filters 
#         self.fc = nn.LazyLinear(num_classes)
        self.fc = nn.Linear(512, num_classes)
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        for res_unit in self.res_unit_list:
            x = res_unit(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.softmax(x,dim=1)
    
    def __str__(self):
        return "ResNet18"
    
    
class ResNet34(nn.Module):
    def __init__(self,num_classes):
        super(ResNet34, self).__init__()

        self.num_classes = num_classes
        self.relu = nn.ReLU(inplace=True)
        
        self.conv1 = nn.Conv2d(2,64,7,stride=2,padding=3,bias=False)
        self.bn1= nn.BatchNorm2d(64)
        self.max_pool = nn.MaxPool2d(1,1) 
        prev_filters = 64
        self.res_unit_list = nn.ModuleList([ResidualUnit(prev_filters,prev_filters)])
        for filters in [64]*2+[128]*4 + [256]*6 +[512]*3:
            self.res_unit_list.append(ResidualUnit(prev_filters,filters))
            prev_filters = filters 
#         self.fc = nn.LazyLinear(num_classes)
        self.fc = nn.Linear(512, num_classes)
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        for res_unit in self.res_unit_list:
            x = res_unit(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.softmax(x,dim=1)
    
    def __str__(self):
        return "ResNet34"

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("cpu"))

In [11]:
model = ResNet18(num_classes=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
multicls_criterion = torch.nn.CrossEntropyLoss()

epochs = 20

In [12]:
preprocess = transforms.Compose([
#     transforms.Resize(224),
    transforms.Resize(32), # multiply of 16
    transforms.Normalize(mean=[0.5, 0.5], std=[0.5, 0.5]),
])

train_inx, valid_inx, test_inx = random_split(range(labels.shape[0]),[0.7,0.2,0.1],generator=torch.Generator()
                                            .manual_seed(42))

train_data = SingleElectronPhotonDataset(split_inx=train_inx,transform = preprocess)
valid_data = SingleElectronPhotonDataset(split_inx=valid_inx,transform = preprocess)
test_data = SingleElectronPhotonDataset(split_inx=test_inx,transform = preprocess)
# dataset = SingleElectronPhotonDataset()

train_dataloader = DataLoader(train_data,batch_size = 64, shuffle = True)
valid_dataloader = DataLoader(valid_data,batch_size = 64, shuffle = True)
test_dataloader = DataLoader(test_data,batch_size = 64, shuffle = True)

In [13]:
def train(model, device, loader, optimizer):
    model.train()

    loss_accum = 0
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        output = model(inputs)
        loss= 0 
        optimizer.zero_grad()
        loss += multicls_criterion(output, labels)
        loss.backward()
        optimizer.step()

        loss_accum += loss.item()

    print('Average training loss: {}'.format(loss_accum / (step + 1))) 

In [14]:
def evaluate(model, device, loader,evaluator= "roauc"):
    model.eval()
    
    preds_list = []
    target_list = []
    for step, batch in enumerate(loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            output = model(inputs)
            preds_list.extend(output.tolist())
        target_list += batch[1].tolist()
    if evaluator == "roauc":   
        metric = MulticlassAUROC(num_classes=2, average="macro", thresholds=None)
    if evaluator == "acc":
        metric = MulticlassAccuracy(num_classes=2, average="macro")
    # print("AUC-ROC metric score : ",metric(torch.Tensor(preds_list),torch.Tensor(target_list)).item())
    return metric(torch.Tensor(preds_list),torch.Tensor(target_list).to(torch.int64)).item()

In [15]:
checkpoints_path = "../models"
checkpoints = os.listdir(checkpoints_path)
checkpoint_path = list(filter(lambda i : str(model) in i, checkpoints))

In [16]:
train_curves = []
valid_curves = []

starting_epoch = 1
if len(checkpoint_path)>0:
    checkpoint = torch.load(f"{checkpoints_path}/{checkpoint_path[0]}")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    starting_epoch = checkpoint['epoch']+1

for epoch in range(starting_epoch, epochs + 1):
    print("=====Epoch {}".format(epoch))
    print('Training...')
    train(model, device, train_dataloader, optimizer)
    
    print("Evaluating...")
    train_perf_roauc = evaluate(model,device,train_dataloader)
    valid_perf_roauc = evaluate(model,device,valid_dataloader)
    test_perf_roauc = evaluate(model,device,test_dataloader)
    train_perf_acc = evaluate(model,device,train_dataloader, evaluator = "acc")
    valid_perf_acc = evaluate(model,device,valid_dataloader,evaluator = "acc")
    test_perf_acc = evaluate(model,device,test_dataloader,evaluator = "acc")
    
    train_curves.append([train_perf_acc,train_perf_roauc])
    valid_curves.append([valid_perf_acc,valid_perf_roauc])
    
    print('ROAUC scores: ',{'Train': train_perf_roauc, 'Validation': valid_perf_roauc, "Test": test_perf_roauc}, '\nAccuracy scores: ',
         {'Train': train_perf_acc, 'Validation': valid_perf_acc, "Test": test_perf_acc})
    
    # save checkpoint of current epoch
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f"{checkpoints_path}/{str(model)}-{epoch}.pt")
    
    # delete checkpoint of previous epoch
    if epoch>1:
        os.remove(f"{checkpoints_path}/{str(model)}-{epoch-1}.pt")

print('\nFinished training!')
print('\nROAUC Test score: {}'.format(evaluate(model,device,test_dataloader)))

=====Epoch 3
Training...


Iteration: 100%|██████████| 5447/5447 [05:22<00:00, 16.87it/s]


Average training loss: 0.5891124870890059
Evaluating...
ROAUC scores:  {'Train': 0.6952402591705322, 'Validation': 0.6961061954498291, 'Test': 0.6968065500259399} 
Accuracy scores:  {'Train': 0.515848696231842, 'Validation': 0.5164598822593689, 'Test': 0.5161611437797546}
=====Epoch 4
Training...


Iteration: 100%|██████████| 5447/5447 [05:36<00:00, 16.20it/s]


Average training loss: 0.5833955890151455
Evaluating...
ROAUC scores:  {'Train': 0.5817879438400269, 'Validation': 0.5790382623672485, 'Test': 0.5763914585113525} 
Accuracy scores:  {'Train': 0.5199189782142639, 'Validation': 0.5207341313362122, 'Test': 0.5201530456542969}
=====Epoch 5
Training...


Iteration: 100%|██████████| 5447/5447 [05:27<00:00, 16.65it/s]


Average training loss: 0.5788430177675205
Evaluating...
ROAUC scores:  {'Train': 0.45132356882095337, 'Validation': 0.45053091645240784, 'Test': 0.44606348872184753} 
Accuracy scores:  {'Train': 0.5154888033866882, 'Validation': 0.5158680081367493, 'Test': 0.5160356760025024}
=====Epoch 6
Training...


Iteration: 100%|██████████| 5447/5447 [05:26<00:00, 16.68it/s]


Average training loss: 0.5756558832147044
Evaluating...
ROAUC scores:  {'Train': 0.7816985845565796, 'Validation': 0.7787708044052124, 'Test': 0.7801327705383301} 
Accuracy scores:  {'Train': 0.72030109167099, 'Validation': 0.716518759727478, 'Test': 0.7181563377380371}
=====Epoch 7
Training...


Iteration: 100%|██████████| 5447/5447 [05:29<00:00, 16.56it/s]


Average training loss: 0.5732525385244628
Evaluating...
ROAUC scores:  {'Train': 0.7827376127243042, 'Validation': 0.7809115648269653, 'Test': 0.7817345857620239} 
Accuracy scores:  {'Train': 0.7186840176582336, 'Validation': 0.717162013053894, 'Test': 0.718307614326477}
=====Epoch 8
Training...


Iteration: 100%|██████████| 5447/5447 [05:30<00:00, 16.50it/s]


Average training loss: 0.5713702193559671
Evaluating...
ROAUC scores:  {'Train': 0.769572913646698, 'Validation': 0.7648710012435913, 'Test': 0.7638487815856934} 
Accuracy scores:  {'Train': 0.6571156978607178, 'Validation': 0.6563001871109009, 'Test': 0.654026210308075}
=====Epoch 9
Training...


Iteration: 100%|██████████| 5447/5447 [05:28<00:00, 16.60it/s]


Average training loss: 0.5691806004258368
Evaluating...
ROAUC scores:  {'Train': 0.787178099155426, 'Validation': 0.7846541404724121, 'Test': 0.7845445275306702} 
Accuracy scores:  {'Train': 0.7262298464775085, 'Validation': 0.7234315872192383, 'Test': 0.7229331731796265}
=====Epoch 10
Training...


Iteration: 100%|██████████| 5447/5447 [05:30<00:00, 16.49it/s]


Average training loss: 0.5682199489901826
Evaluating...
ROAUC scores:  {'Train': 0.7734286189079285, 'Validation': 0.7712681293487549, 'Test': 0.7717280983924866} 
Accuracy scores:  {'Train': 0.7140264511108398, 'Validation': 0.7115604877471924, 'Test': 0.7126092314720154}
=====Epoch 11
Training...


Iteration: 100%|██████████| 5447/5447 [05:30<00:00, 16.48it/s]


Average training loss: 0.5667472504269471
Evaluating...
ROAUC scores:  {'Train': 0.7779282331466675, 'Validation': 0.7721612453460693, 'Test': 0.7705116868019104} 
Accuracy scores:  {'Train': 0.6911587715148926, 'Validation': 0.6878855228424072, 'Test': 0.6867059469223022}
=====Epoch 12
Training...


Iteration: 100%|██████████| 5447/5447 [05:31<00:00, 16.43it/s]


Average training loss: 0.5656084841285164
Evaluating...
ROAUC scores:  {'Train': 0.7940727472305298, 'Validation': 0.7908050417900085, 'Test': 0.7909234762191772} 
Accuracy scores:  {'Train': 0.7306987643241882, 'Validation': 0.7266595959663391, 'Test': 0.725690484046936}
=====Epoch 13
Training...


Iteration: 100%|██████████| 5447/5447 [05:29<00:00, 16.55it/s]


Average training loss: 0.5645784845708028
Evaluating...
ROAUC scores:  {'Train': 0.7992649078369141, 'Validation': 0.7935996055603027, 'Test': 0.7932800054550171} 
Accuracy scores:  {'Train': 0.7347122430801392, 'Validation': 0.7285760641098022, 'Test': 0.7269856929779053}
=====Epoch 14
Training...


Iteration: 100%|██████████| 5447/5447 [05:30<00:00, 16.50it/s]


Average training loss: 0.5636255293236493
Evaluating...
ROAUC scores:  {'Train': 0.6935126185417175, 'Validation': 0.6834774017333984, 'Test': 0.6842156648635864} 
Accuracy scores:  {'Train': 0.5668786764144897, 'Validation': 0.5633842945098877, 'Test': 0.562129020690918}
=====Epoch 15
Training...


Iteration: 100%|██████████| 5447/5447 [05:27<00:00, 16.65it/s]


Average training loss: 0.5627577040903236
Evaluating...
ROAUC scores:  {'Train': 0.800524115562439, 'Validation': 0.7935436964035034, 'Test': 0.793082594871521} 
Accuracy scores:  {'Train': 0.734512984752655, 'Validation': 0.7273541688919067, 'Test': 0.7265720367431641}
=====Epoch 16
Training...


Iteration: 100%|██████████| 5447/5447 [05:27<00:00, 16.62it/s]


Average training loss: 0.5615552340338202
Evaluating...
ROAUC scores:  {'Train': 0.7384127378463745, 'Validation': 0.72806316614151, 'Test': 0.7294089794158936} 
Accuracy scores:  {'Train': 0.6226247549057007, 'Validation': 0.6169329881668091, 'Test': 0.6170749068260193}
=====Epoch 17
Training...


Iteration: 100%|██████████| 5447/5447 [05:29<00:00, 16.55it/s]


Average training loss: 0.5611241156789651
Evaluating...
ROAUC scores:  {'Train': 0.8055920600891113, 'Validation': 0.7976040840148926, 'Test': 0.7973026633262634} 
Accuracy scores:  {'Train': 0.7411123514175415, 'Validation': 0.7320608496665955, 'Test': 0.7336180210113525}
=====Epoch 18
Training...


Iteration: 100%|██████████| 5447/5447 [05:26<00:00, 16.66it/s]


Average training loss: 0.5594348060837443
Evaluating...
ROAUC scores:  {'Train': 0.7960606813430786, 'Validation': 0.7886292934417725, 'Test': 0.7897617816925049} 
Accuracy scores:  {'Train': 0.7364866733551025, 'Validation': 0.7292419672012329, 'Test': 0.7307008504867554}
=====Epoch 19
Training...


Iteration: 100%|██████████| 5447/5447 [05:26<00:00, 16.66it/s]


Average training loss: 0.5590682475965615
Evaluating...
ROAUC scores:  {'Train': 0.7948096990585327, 'Validation': 0.7880959510803223, 'Test': 0.7884796857833862} 
Accuracy scores:  {'Train': 0.7308984994888306, 'Validation': 0.7238062620162964, 'Test': 0.7233740091323853}
=====Epoch 20
Training...


Iteration: 100%|██████████| 5447/5447 [05:27<00:00, 16.61it/s]


Average training loss: 0.5578431189169638
Evaluating...
ROAUC scores:  {'Train': 0.8067057132720947, 'Validation': 0.7964015007019043, 'Test': 0.7967956066131592} 
Accuracy scores:  {'Train': 0.7422732710838318, 'Validation': 0.7326356172561646, 'Test': 0.7332702875137329}

Finished training!

ROAUC Test score: 0.7967956066131592
