In [None]:
from CoRe_Dataloader_ECSG import load_pth_file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,TensorDataset
import math
import torchinfo
import time
import numpy as np
import wandb
import datetime
from collections import OrderedDict

In [None]:
import torchmetrics as metrics
import pandas as pd
acc = metrics.Accuracy(task="multiclass", num_classes=19)
combined = metrics.MetricCollection([
    acc,
    metrics.AUROC(task="multiclass", num_classes=19),
    metrics.Precision(task="multiclass", num_classes=19),
    metrics.Recall(task="multiclass", num_classes=19),
    metrics.F1Score(task="multiclass", num_classes=19),
    metrics.FBetaScore(task="multiclass", num_classes=19)
]).to("cuda:0")


def get_df_from_rdict(rdict):
    return pd.DataFrame(pd.Series(rdict).map(lambda x: x.item())).T


In [None]:
def calc_metrics(model:torch.nn.Module,dl:DataLoader):
    model.eval()
    raw_output = []
    parameters = []
    with torch.no_grad():
        for batch,(sg,params) in enumerate(dl):
            sg = sg.to("cuda:0").to(torch.float)
            sgsh = sg.shape
            sg = sg.view(sgsh[0], 1, sgsh[1], sgsh[2])

            params = params[:,0].to("cuda:0").to(torch.long)
            
            raw_output.append(model(sg).detach())
            parameters.append(params)
            print(batch)
    model.train()
    output = torch.vstack(raw_output)
    parameters = torch.hstack(parameters)
    return combined(output,parameters)


In [None]:
# import vit
# import vit_pytorch
from vit_pytorch import vit_for_small_dataset as vit_sd
from vit_pytorch import vit as simple_vit
from vit_pytorch.deepvit import DeepViT

In [None]:
def init_model():
    # return simple_vit.ViT(image_size=400,
    #                patch_size=20,
    #                num_classes=19,
    #                dim=int(1024/2),
    #                depth=2,
    #                heads=8,
    #                mlp_dim=int(2048/2),
    #                channels=1).to("cuda:0")
    # return vit_sd.ViT(image_size=400,
    #                patch_size=20,
    #                num_classes=19,
    #                dim=1024,
    #                depth=4,
    #                heads=16,
    #                mlp_dim=int(2048/2),
    #                dropout = 0.1,
    #                emb_dropout = 0,
    #                channels=1).to("cuda:0")
    return DeepViT(image_size=400,
                    patch_size=20,
                    num_classes=19,
                    dim=1024,
                    depth=6,
                    heads=12,
                    mlp_dim=int(2048/2),
                    dropout=0.1,
                    emb_dropout=0,
                    channels=1).to("cuda:0")


In [None]:
model = init_model()
img = torch.randn(1,1, 400,400).to("cuda:0")
preds = model(img)  # (1, 1000)
print(preds)
torchinfo.summary(model, input_size=(1,1, 400, 40))

In [None]:
startlr = 3e-5
optimizer = optim.Adam(params=model.parameters(), lr=startlr)
optimizer1 = optim.NAdam(params=model.parameters(), lr=startlr)
step_scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[1,2,3,4], gamma=0.9)
# at the end of 600 epochs, the learning rate is 0.000,002,62
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=1, gamma=0.9)
scheduler_pl = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='max', factor=0.7, patience=35, verbose=True)
lossfn = nn.CrossEntropyLoss()


In [None]:
train_dl, test_dl, valid_dl  = load_pth_file()
calc_metrics(model,test_dl)

In [None]:
def train_eval_model(config, train_dl, test_dl, adam=True, nadam=False):
    ldl = len(train_dl)
    pre_acc,max_acc,max_auc = 0,0,0
    results = pd.DataFrame()
    for epoch in range(1,config.epochs+1):
        
        print("preeval finished")
        etime = time.time()
        for batch,(sg,params) in enumerate(train_dl):
            stime = time.time()
            sgsh = sg.shape
            sg = sg.to("cuda:0").to(torch.float).view(sgsh[0],1,sgsh[1],sgsh[2])
            params = params[:,0].to("cuda:0").to(torch.long)
            optimizer.zero_grad()
            outputs = model(sg)
            loss = lossfn(outputs,params)
            loss.backward()
            optimizer.step() if adam else None
            optimizer1.step() if nadam else None
            #
            torch.cuda.empty_cache()
            #
            wandb.log({"loss":loss.item(),"batch_accuracy":acc(outputs,params),"lr":scheduler.get_last_lr()[0],"epoch":epoch})
            print(f"{epoch:5}/{config.epochs:5} // {batch:5}/{ldl:5} | Loss: {loss.item():2.4},batch_accuracy:{acc(outputs,params):3.4}, lr:{scheduler.get_last_lr()[0]:1.5},Time per Batch: {time.time()-stime:1.2} seconds     ",end = "\r",flush=True)
        #
        scheduler.step()
        step_scheduler.step()
        scheduler_pl.step(max_acc)
        #
        epoch_results = calc_metrics(model, test_dl)
        results = pd.concat([results, get_df_from_rdict(epoch_results)])
        max_acc = max(results["MulticlassAccuracy"])
        max_auc = max(results["MulticlassAUROC"])
        #
        if pre_acc < max_acc:
            try:
                torch.save(config.best_model, f"./saved_models/ViT/best_model_state_dict_ViT_for{config.run_name}_stime_{config.start_time.replace(':', '-')}__acc_{max_acc}__auc_{max_auc}.pt")
                print("\nSAVING MODEL")
            except:
                wandb.alert(level="warning",title="OUT OF MEMORY")

        wandb.log({"epoch":epoch,"lr":scheduler.get_last_lr()[0]} | epoch_results | {"MaximumMulticlassAccuracy": max_acc, "MaximumMulticlassAUROC":max_auc} | {"EpochTime": time.time()-etime})

    epoch_results = calc_metrics(model, test_dl)
    results = pd.concat([results, get_df_from_rdict(epoch_results)])
    return max_acc,max_auc

In [None]:
results = []
trials = 5
for i in range(trials):
    wandb.init(project="simple_vision_transformer_with_noise_classifier")
    config = wandb.config
    config.run_name = wandb.run._run_id
    config = wandb.config
    config.epochs = 500
    config.inx = 400
    config.iny = 400
    config.lr = startlr     
    config.trial = i+1
    config.total_trials = trials
    config.best_model = OrderedDict()
    config.start_time = datetime.datetime.now().isoformat()
    config.savename = f"best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}__acc_max_acc__auc_auc.pt"
    train_dl, test_dl, valid_dl  = load_pth_file()
    results.append(train_eval_model(wandb.config,train_dl,test_dl,nadam=True))
    print(f"\n\nValidation Metrics: { calc_metrics(model = model, dl = valid_dl)}\n\n")
    model = init_model()

In [None]:
a = []
for i in results:
    a.append([i[0].cpu().item(),i[1].cpu().item()])

In [None]:
results = np.array(a)
average_accuracy = np.average(results[:,0])
average_auroc = np.average(results[:,1])
print(average_accuracy,average_auroc)

In [None]:
torch.save(config.best_model,
           f"./saved_models/ViT/best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}_BEST_MODEL.pt")
