In [None]:
import numpy as np
from torch import nn
import torch
from utils.vit import TransformerEncoder
from utils.data_loaders import *
from utils.utils import *
from easydict import EasyDict as edict

In [None]:
train_loader,val_loader,test_loader=get_loaders_IHM(sampler=False)

In [None]:
device=get_device()
print(device)

cf = edict() 

cf.vit_args = {
        "num_classes": 1,
        "dim": 76,
        "depth": 1,
        "use_class_token": True,   # ViT are just transformers that use class token as video representations
        "heads":16,  # head number of transformer
        "ff_dim":16,  # MLP dimension of transformer's feedforward layer
        "mlp_head_hidden_dim":[128],   # the hidden layer dimensions of the MLP head
        "dim_head":256,  # head dimension of transformer's attention module
        "pool":"cls",  # "cls" or "mean"
        "dropout":0.5,  # dropout rate of transformer
        "mlp_head_dropout":0.5,   # dropout rate of the MLP head
        "pe_method":'origin',  # the Positional Embedding method
        "pe_max_len":48,  # the maximum sequence length for Positional Embedding
        "activation":"gelu",   # the activation method, can be "gelu" "prelu" or "relu"
    }

model=TransformerEncoder(**cf.vit_args)
model.to(device)




In [None]:
opt = torch.optim.Adam(params=model.parameters(),lr=0.0001)
loss_fn = nn.BCELoss().to(device)
best=0

In [None]:
from tqdm import tqdm
TL=[]
VL=[]
VA=[]

for epoch in range(0,50):
    train_loss=0
    for i,data in enumerate(train_loader):


        model.train()
        opt.zero_grad()
      

        inputs,label=data
        inputs=inputs.to(torch.float32).to(device)
        label=label.to(torch.float32).to(device)
 
        pred=model(inputs)
        loss=loss_fn(pred[:,0],label)

        loss.backward()
 
        opt.step()     
        train_loss=train_loss+loss.detach().cpu()
         
      
    val_loss,auc=prediction_binary(model,val_loader,loss_fn,device)

    if auc>best:
       best=auc
       torch.save(model,'./VIT')

    print('Epoch : {:.1f} Train Loss {:.4f} Val Loss {:.4f} Val AUROC {:.4f}'.format(epoch,train_loss/len(train_loader),val_loss,auc))  
    TL.append(train_loss/len(train_loader))
    VL.append(val_loss)
    VA.append(auc)

In [None]:

model=torch.load('./VIT')
loss,auc=prediction_binary(model,test_loader,loss_fn,device) 
print(auc)

In [None]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])