In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [None]:
import numpy as np
from tqdm import tqdm as tqdm
#from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import os
import optparse
import os.path as osp
import math
import torch_geometric
import torch
import sys
from torch_geometric.data import Data
import glob


In [None]:
from torch.utils.data import Dataset

class ClassificationDataset(Dataset):
    def __init__(self, datapath, numLC, signal):
        self.datapath_ = datapath
        self.numLC_ = numLC
        self.signal_ = signal
        self.process = self.load_datafiles(self.datapath_)
        self.data = torch.stack(self.process[0],dim=0)
        self.labels = torch.stack(self.process[1],dim=0)
    


    def load_datafiles(self, filepath) :
        "filepath of the form /grid_mnt/data__data.polcms/cms/sghosh/NEWPID_DATA/ntup_pho_frac0p8/ "
        filelist = [filen for filen in glob.glob(filepath+'data*.pt')]
        output_tensor_list = []
        label = []
        for i in tqdm(filelist):
            for filei in torch.load(i) :
                nlay = len(np.unique(filei.clus2d_feat[:,5]))
                if nlay < 15 :
                    continue
                optensor = self.transform_data(filei.clus2d_feat, self.numLC_)
                output_tensor_list.append(optensor)
                if self.signal_:
                    label.append(torch.ones(1,dtype=torch.double))
                else:
                    label.append(torch.zeros(1,dtype=torch.double))
                    
        return output_tensor_list, label
                
    def transform_data(self, tensor, nLC) :            
        target = torch.zeros(nLC, tensor.shape[1])
        source = tensor[:nLC,:]
        target[:source.shape[0], :] = source
        return target
        
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx): 
        return self.data[idx], self.labels[idx]

In [None]:
pho_path = '/grid_mnt/data__data.polcms/cms/sghosh/NEWPID_DATA/ntup_pho_frac0p8/'
pho_dataset = ClassificationDataset(pho_path, 150, True)
pi_path = '/grid_mnt/data__data.polcms/cms/sghosh/NEWPID_DATA/ntup_pi_frac0p8/'
pi_dataset = ClassificationDataset(pi_path, 150, False)
print("lenpho:",len(pho_dataset)," lenpi:",len(pi_dataset))

In [None]:
nsamp = min(len(pho_dataset),len(pi_dataset))
idxs = np.arange(nsamp)
tridx = idxs[:int(0.8*nsamp)]
tsidx = idxs[int(0.8*nsamp):nsamp]
photr = torch.utils.data.Subset(pho_dataset, tridx)
pitr = torch.utils.data.Subset(pi_dataset, tridx)

phots = torch.utils.data.Subset(pho_dataset, tsidx)
pits = torch.utils.data.Subset(pi_dataset, tsidx)

ntrainbatch = 200
trainloader = torch.utils.data.DataLoader(torch.utils.data.ConcatDataset([photr, pitr]), batch_size=ntrainbatch ,shuffle=True)#, num_workers=2)
testloader = torch.utils.data.DataLoader(torch.utils.data.ConcatDataset([phots, pits]), batch_size=ntrainbatch,shuffle=True)#, num_workers=2)

epoch_size = len(photr) + len(pitr)


In [None]:
## Define model
## refer https://github.com/ludovicobuizza/HAR-Transformer/blob/main/src/transformer/transformer.py for batchnorm
import torch.nn as nn
import torch.nn.functional as F

class MHA_model(nn.Module):
    

    def __init__(
        self,
        input_dim = 6,
        dim_emb = 64,
        n_heads = 16,
        num_enclayers = 6,
        dim_feedforward = 64,
        n_seq = 150,
        output_dim = 2,
    ):
        
        super(MHA_model, self).__init__()

        self.emb_inp = nn.Sequential(
            nn.Linear(input_dim, dim_emb),
            nn.Dropout(0.2),
            nn.LeakyReLU(negative_slope=0.4),
            nn.Linear(dim_emb, dim_emb),
            nn.Dropout(0.2),
#            nn.LeakyReLU(negative_slope=0.4),
#            nn.Linear(hidden_dim*2, hidden_dim),
            nn.LeakyReLU(negative_slope=0.4),
        )
        
#        self.emb_inp =  nn.Linear(input_dim, dim_emb)      
        encoder_layer = nn.TransformerEncoderLayer(
                dim_emb,
                n_heads,
                dim_feedforward,
                dropout = 0.2,
                batch_first= True,
                activation=nn.LeakyReLU(negative_slope=0.4),
            )
        

        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_enclayers)
        
        self.output_layer = nn.Sequential(nn.Linear(n_seq*dim_emb, dim_emb//2),
                                    nn.LeakyReLU(negative_slope=0.4),
                                    nn.Dropout(0.2),
                                    #nn.Softplus(),
#                                    nn.Linear(hidden_dim//2, hidden_dim//2),#added
 #                                   nn.ELU(),
                                    #nn.Softplus(),
                                    nn.Linear(dim_emb//2, output_dim)
                                   )
    
    
    
    def forward(self, X) :
        
        
        #project input to dim_emb dimensional space
        inp = self.emb_inp(X) #* math.sqrt(self.dim_embed)
        
        output = self.transformer_encoder(inp)  # (batch, seq, feat)
        output = output.reshape(output.shape[0], -1)  # (batch, seq * feat)
        output = self.output_layer(output)  # (batch, num_classes)

        return F.log_softmax(output, dim=-1)

In [None]:
### define training and testing functions
from lr_modulation import *
device = torch.device('cuda:1')#('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
#model = NetCNN().to(float) ### change to NetDNN().to(float) for DNN
#model = model.to(device)


model = MHA_model(n_seq=150,
        input_dim = 6,
        dim_emb = 16,
        n_heads = 16,
        num_enclayers = 6,
        dim_feedforward = 64,
        output_dim = 2)


model = model.to(device)


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)
scheduler = CyclicLRWithRestarts(optimizer, ntrainbatch, epoch_size, restart_period=10, t_mult=1.2, policy="cosine")
#scheduler = CyclicLRWithRestarts(optimizer, 200, 200, restart_period=30, t_mult=1.2, policy="cosine")

#lossfunction = nn.BCELoss()




def train(epoch):
    model.train()
    loss = []
    true_preds, num_preds = 0., 0.
    scheduler.step()
    for data in tqdm(trainloader):  ### change to loaderDNN for DNN
            datax, label = data[0].to(device), data[1].to(device)
            label = label.to(torch.long).squeeze()
            optimizer.zero_grad()
            result = model(datax)
            #print(result)
            
            #print(label)
            #print(result.dtype)
            #lossc = lossfunction(result, label)
            lossc = F.nll_loss(result, label)
            
            loss.append(lossc.item()) 
            lossc.backward()
            optimizer.step()
            scheduler.batch_step()
            pred_labels = torch.argmax(result.cpu(), dim=-1) 
            true_preds += (pred_labels == label.cpu()).sum()
            num_preds += label.shape[0]
            
    acc = true_preds / num_preds
    print( 'batches for train:',len(loss)) 
    print('train loss:',np.mean(np.array(loss)))
    print('train acc:',np.mean(np.array(acc)))
    return np.mean(np.array(loss)), np.array(acc)



def test(epoch):
    model.eval()
    loss = []
    true_preds, num_preds = 0., 0.
    for data in tqdm(testloader): ### change to loaderDNN for DNN
            datax, label = data[0].to(device), data[1].to(device)
            label = label.to(torch.long).squeeze()
            result = model(datax)
            lossc = F.nll_loss(result, label)
            loss.append(lossc.item())
            
            pred_labels = torch.argmax(result.cpu(), dim=-1) 
            true_preds += (pred_labels == label.cpu()).sum()
            num_preds += label.shape[0]
            
            
    acc = true_preds / num_preds
    print( 'batches for train:',len(loss)) 
    print('test loss:',np.mean(np.array(loss)))
    print('test acc:',np.mean(np.array(acc)))
    return np.mean(np.array(loss)), np.array(acc)

In [None]:
from tqdm import tqdm as tqdm
nepoch=200
best_loss = 99999999
losst = []
lossv = []
acct = []
accv = []
epochs = []
for epoch in range(nepoch):
    print ('epoch:',epoch)
    trainl = train(epoch)
    losst.append(trainl[0])
    acct.append(trainl[1])
    testl = test(epoch)
    loss_epoch = testl[0]
    lossv.append(loss_epoch)
    accv.append(testl[1])
    epochs.append(epoch)

    if loss_epoch < best_loss:
        best_loss = loss_epoch
        print('new best test loss:',best_loss)
        
    if ((epoch+1)%10 == 0):
        plt.plot(np.array(epochs),np.array(losst),c='b',label='training')
        plt.plot(np.array(epochs),np.array(lossv),c='r',label='testing')
        plt.legend()
        plt.title("loss")
        #if ((epoch+1) == nepoch):
        #    plt.savefig('%s/TrainvsValLoss_ep%d.png'%(plot_dir, epoch), bbox_inches='tight')
        plt.show()
        plt.plot(np.array(epochs),np.array(acct),c='b',label='training')
        plt.plot(np.array(epochs),np.array(accv),c='r',label='testing')
        plt.legend()
        plt.title("accuracy")
        #if ((epoch+1) == nepoch):
        #    plt.savefig('%s/TrainvsValLoss_ep%d.png'%(plot_dir, epoch), bbox_inches='tight')
        plt.show()

In [None]:
print(epoch_size)