In [1]:
from CoRe_Dataloader_ECSG import load_pth_file,load_raw_from_pth_file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import math
import torchinfo
import time
import numpy as np
import wandb
import datetime
from collections import OrderedDict
# trainds = get_dataset()
# train_dl = DataLoader(trainds,batch_size=6,shuffle = True,)
# test_dl = DataLoader(trainds,batch_size=16*2,shuffle = True,)

In [2]:
train_dl, test_dl = load_pth_file(train_dl_batch_size=8)
raw_train_ds, raw_test_ds = load_raw_from_pth_file()

In [3]:
next(iter(test_dl))[1]

tensor([[12.0000,  1.3500,  1.3500],
        [16.0000,  1.3506,  1.3506],
        [16.0000,  1.3750,  1.3750],
        [14.0000,  1.3509,  1.3509],
        [ 3.0000,  1.4000,  1.4000],
        [10.0000,  1.3754,  1.3754],
        [16.0000,  1.6201,  1.0801],
        [ 8.0000,  1.5150,  1.5150],
        [ 8.0000,  1.5150,  1.5150],
        [14.0000,  1.3500,  1.3500],
        [14.0000,  1.3504,  1.3504],
        [14.0000,  1.3500,  1.3500],
        [16.0000,  1.3505,  1.3505],
        [10.0000,  1.3754,  1.3754],
        [16.0000,  1.8002,  0.9001],
        [16.0000,  1.8002,  0.9001],
        [16.0000,  1.6500,  1.0979],
        [18.0000,  1.3500,  1.3500],
        [ 2.0000,  1.3510,  1.3510],
        [10.0000,  1.6500,  1.1000],
        [14.0000,  1.5000,  1.5000],
        [ 2.0000,  1.1000,  1.4000],
        [16.0000,  1.3501,  1.3501],
        [14.0000,  1.3509,  1.3509],
        [16.0000,  1.3500,  1.3500],
        [10.0000,  1.5000,  1.5000],
        [16.0000,  1.3501,  1.3501],
 

In [4]:
import torchmetrics as metrics
acc = metrics.Accuracy(task="multiclass",num_classes=19).to("cuda:0")
auroc = metrics.AUROC(task = "multiclass",num_classes=19).to("cuda:0")
l1e = metrics.MeanAbsoluteError().to("cuda:0")
l2e = metrics.MeanSquaredError().to("cuda:0")

In [5]:
import vit
import vit_pytorch as nl_vit

In [6]:
model = vit.ViT(image_size=400,
                        patch_size=20,
                        num_classes=2,
                        dim=int(1024/2),
                        depth=2,
                        heads=8,
                        mlp_dim=int(2048/2),
                        channels=1).to("cuda:0")


img = torch.randn(1,1, 400,400).to("cuda:0")

preds = model(img)  # (1, 1000)
print(preds)

tensor([[-0.2071,  0.4427]], device='cuda:0', grad_fn=<AddmmBackward0>)


In [7]:
startlr = 3e-5
optimizer = optim.Adam(params=model.parameters(), lr=startlr)
optimizer1 = optim.NAdam(params=model.parameters(), lr=startlr)
step_scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[5, 15, 45, 135], gamma=0.9)
# at the end of 600 epochs, the learning rate is 0.000,002,62
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=5, gamma=0.986)
scheduler_pl = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='max', factor=0.7, patience=30, verbose=True)
lossfn = nn.MSELoss()


In [8]:
def new_accuracy(model:torch.nn.Module,dl:DataLoader):
    model.eval()
    raw_output = []
    parameters = []
    with torch.no_grad():
        for batch,(sg,params) in enumerate(dl):
            sg = sg.to("cuda:0").to(torch.float)
            # print(params[:,1:])
            sgsh = sg.shape
            sg = sg.view(sgsh[0], 1, sgsh[1], sgsh[2])

            params = params[:,1:].to("cuda:0").to(torch.float)
            
            raw_output.append(model(sg).detach())
            parameters.append(params)
            
    model.train()
    output = torch.concat(raw_output,dim=0)
    parameters = torch.concat(parameters,dim=0)
    accuracy = l1e(output,parameters)
    auc = l2e(output,parameters)
    return accuracy,auc


In [9]:
tot_err, mse = new_accuracy(model=model, dl=test_dl)

In [10]:
def train_eval_model(config,adam = True,nadam = False):
    tot_err,mse = new_accuracy(model=model,dl = test_dl)
    max_err = float("inf")
    max_mse = float("inf")
    for epoch in range(1,config.epochs+2):
        btime = time.time()
        ldl = len(train_dl)
        for batch,(sg,params) in enumerate(train_dl):
            stime = time.time()
            sg = sg.to("cuda:0").to(torch.float)
            sgsh = sg.shape
            sg = sg.view(sgsh[0],1,sgsh[1],sgsh[2])
            params = params[:,1:].to("cuda:0").to(torch.float)
            optimizer.zero_grad()
            outputs = model(sg)
            loss = lossfn(outputs,params)
            loss.backward()
            optimizer.step() if adam else None
            optimizer1.step() if nadam else None
            wandb.log({"loss":loss.item(),"lr":scheduler.get_last_lr()[0],"epoch":epoch})
            print(f"{epoch:5}/{config.epochs:5} // {batch:5}/{ldl:5} | Loss: {loss.item():2.4}, last_total_error: {tot_err}, Maximum Error {max_err} last MSE {mse} Max MSE {max_mse} lr:{scheduler.get_last_lr()[0]:1.5},Time per Batch: {time.time()-stime:1.2} seconds     ",end = "\r",flush=True)
            torch.cuda.empty_cache()
        tot_err, mse = new_accuracy(model=model, dl=test_dl)
        scheduler.step()
        step_scheduler.step()
        scheduler_pl.step(tot_err)
        if(tot_err < max_err):
            max_err = tot_err
            config.best_model = model.state_dict()
            try:
                torch.save(config.best_model, f"./saved_models/cnns/best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}__acc_{max_acc}__auc_{auc}.pt")
            except:  pass    
            print("\nSAVING MODEL")
        max_mse = min(max_mse,mse)
        print(f"\nEpoch {epoch+1}/{config.epochs} finished. Total error: {tot_err:3.5} MSE: {mse} Time per Epoch: {time.time()-btime:1.5}")

        wandb.log({"epoch":epoch,"error":tot_err,"max_accuracy":max_err,"lr":scheduler.get_last_lr()[0],"MSE":mse,"MaxMSE":max_mse})

In [11]:
wandb.init(project = "simple_vision_transformer_regressor")
config = wandb.config
config.run_name = wandb.run._run_id
config = wandb.config
config.epochs = 1000
config.inx = 400
config.iny = 400
config.lr = startlr     
config.best_model = OrderedDict()
config.start_time = datetime.datetime.now().isoformat()
config.savename = f"best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}__acc_max_acc__auc_auc.pt"

train_eval_model(wandb.config,nadam=True)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33maashraychegu[0m ([33malabs[0m). Use [1m`wandb login --relogin`[0m to force relogin


    1/ 1000 //   295/  296 | Loss: 0.03619, last_total_error: 1.207710862159729, Maximum Accuracy inf last MSE 1.5725191831588745 Max MSE inf lr:3e-05,Time per Batch: 0.027 seconds       
SAVING MODEL

Epoch 2/1000 finished. Total error: 0.15745 MSE: 0.034464865922927856 Time per Epoch: 20.676
    2/ 1000 //   295/  296 | Loss: 0.02543, last_total_error: 0.1574515998363495, Maximum Accuracy inf last MSE 0.034464865922927856 Max MSE 0.034464865922927856 lr:3e-05,Time per Batch: 0.025 seconds      
SAVING MODEL

Epoch 3/1000 finished. Total error: 0.11566 MSE: 0.021198676899075508 Time per Epoch: 16.315
    3/ 1000 //   295/  296 | Loss: 0.01799, last_total_error: 0.11566206067800522, Maximum Accuracy inf last MSE 0.021198676899075508 Max MSE 0.021198676899075508 lr:3e-05,Time per Batch: 0.024 seconds      
SAVING MODEL

Epoch 4/1000 finished. Total error: 0.098527 MSE: 0.016187112778425217 Time per Epoch: 16.423
    4/ 1000 //   295/  296 | Loss: 0.01163, last_total_error: 0.09852689504

KeyboardInterrupt: 

In [None]:
torch.save(config.best_model,
           f"./saved_models/ViT/best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}_BEST_MODEL.pt")
