In [None]:
from CoRe_Dataloader_ECSG import load_pth_file,load_raw_from_pth_file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import math
import torchinfo
import time
import numpy as np
import wandb
import datetime
from collections import OrderedDict
# trainds = get_dataset()
# train_dl = DataLoader(trainds,batch_size=6,shuffle = True,)
# test_dl = DataLoader(trainds,batch_size=16*2,shuffle = True,)

In [None]:
train_dl, test_dl = load_pth_file(train_dl_batch_size=8)
raw_train_ds, raw_test_ds = load_raw_from_pth_file()

In [None]:
import torchmetrics as metrics
acc = metrics.Accuracy(task="multiclass",num_classes=19).to("cuda:0")
auroc = metrics.AUROC(task = "multiclass",num_classes=19).to("cuda:0")


In [None]:
def new_accuracy(model:torch.nn.Module,dl:DataLoader):
    model.eval()
    raw_output = []
    parameters = []
    with torch.no_grad():
        for batch,(sg,params) in enumerate(dl):
            sg = sg.to("cuda:0").to(torch.float)
            sgsh = sg.shape
            sg = sg.view(sgsh[0], 1, sgsh[1], sgsh[2])

            params = params[:,0].to("cuda:0").to(torch.long)
            
            raw_output.append(model(sg).detach())
            parameters.append(params)
            
    model.train()
    output = torch.vstack(raw_output)
    parameters = torch.hstack(parameters)
    accuracy = acc(output,parameters)
    auc = auroc(output,parameters)
    return accuracy,auc


In [None]:
import vit
import vit_pytorch as nl_vit

In [None]:
model = vit.ViT(image_size=400,
                        patch_size=20,
                        num_classes=19,
                        dim=int(1024/2),
                        depth=2,
                        heads=8,
                        mlp_dim=int(2048/2),
                        channels=1).to("cuda:0")


img = torch.randn(1,1, 400,400).to("cuda:0")

preds = model(img)  # (1, 1000)
print(preds)

In [None]:
startlr = 3e-5
optimizer = optim.Adam(params=model.parameters(), lr=startlr)
optimizer1 = optim.NAdam(params=model.parameters(), lr=startlr)
step_scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[5, 15, 45, 135], gamma=0.9)
# at the end of 600 epochs, the learning rate is 0.000,002,62
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=5, gamma=0.986)
scheduler_pl = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='max', factor=0.7, patience=30, verbose=True)
lossfn = nn.CrossEntropyLoss()


In [None]:
def train_eval_model(config,adam = True,nadam = False):
    tot_acc,auc = new_accuracy(model=model,dl = test_dl)
    max_acc = -1
    max_auc = -1
    for epoch in range(1,config.epochs+2):
        btime = time.time()
        ldl = len(train_dl)
        for batch,(sg,params) in enumerate(train_dl):
            stime = time.time()
            sg = sg.to("cuda:0").to(torch.float)
            sgsh = sg.shape
            sg = sg.view(sgsh[0],1,sgsh[1],sgsh[2])
            params = params[:,0].to("cuda:0").to(torch.long)
            optimizer.zero_grad()
            outputs = model(sg)
            loss = lossfn(outputs,params)
            loss.backward()
            optimizer.step() if adam else None
            optimizer1.step() if nadam else None
            wandb.log({"loss":loss.item(),"batch_accuracy":acc(outputs,params),"lr":scheduler.get_last_lr()[0],"epoch":epoch})
            print(f"{epoch:5}/{config.epochs:5} // {batch:5}/{ldl:5} | Loss: {loss.item():2.4},batch_accuracy:{acc(outputs,params):3.4}, last_total_accuracy: {tot_acc}, Maximum Accuracy {max_acc} last AUROC {auc} Max AUC {max_auc} lr:{scheduler.get_last_lr()[0]:1.5},Time per Batch: {time.time()-stime:1.2} seconds     ",end = "\r",flush=True)
            torch.cuda.empty_cache()
        tot_acc, auc = new_accuracy(model=model, dl=test_dl)
        scheduler.step()
        step_scheduler.step()
        scheduler_pl.step(tot_acc)
        if(tot_acc > max_acc):
            max_acc = tot_acc
            config.best_model = model.state_dict()
            try:
                torch.save(config.best_model, f"./saved_models/cnns/best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}__acc_{max_acc}__auc_{auc}.pt")
            except:  pass    
            print("\nSAVING MODEL")
        max_auc = max(max_auc,auc)
        print(f"\nEpoch {epoch+1}/{config.epochs} finished. Total accuracy: {tot_acc:3.5} AUROC: {auc} Time per Epoch: {time.time()-btime:1.5}")

        wandb.log({"epoch":epoch,"accuracy":tot_acc,"max_accuracy":max_acc,"lr":scheduler.get_last_lr()[0],"auroc":auc})

In [None]:
wandb.init(project = "simple_vision_transformer")
config = wandb.config
config.run_name = wandb.run._run_id
config = wandb.config
config.epochs = 1000
config.inx = 400
config.iny = 400
config.lr = startlr     
config.best_model = OrderedDict()
config.start_time = datetime.datetime.now().isoformat()
config.savename = f"best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}__acc_max_acc__auc_auc.pt"

train_eval_model(wandb.config,nadam=True)

In [None]:
torch.save(config.best_model,
           f"./saved_models/ViT/best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}_BEST_MODEL.pt")
