In [None]:
from CoRe_Dataloader_ECSG import load_pth_file,load_raw_from_pth_file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,TensorDataset
import math
import torchinfo
import time
import numpy as np
import wandb
import datetime
from collections import OrderedDict
from sklearn.model_selection import train_test_split

In [None]:
raw_sgram_ds, raw_param_ds = load_raw_from_pth_file()
def get_new_test_train(p = 0.3):
    assert p < 1
    xtrain,xtest,ytrain,ytest = train_test_split(raw_sgram_ds.cpu().numpy(),raw_param_ds.cpu().numpy(),test_size = p)
    train_dataset = TensorDataset(torch.tensor(xtrain),torch.tensor(ytrain))
    if p == 0:
        test_dataset = TensorDataset(torch.tensor(xtrain),torch.tensor(ytrain))
    else:
        test_dataset = TensorDataset(torch.tensor(xtest), torch.tensor(ytest))
    return DataLoader(train_dataset,batch_size = 8, shuffle = True), DataLoader(test_dataset,batch_size = 64,shuffle = True)


In [None]:
import torchmetrics as metrics
acc = metrics.Accuracy(task="multiclass",num_classes=19).to("cuda:0")
auroc = metrics.AUROC(task = "multiclass",num_classes=19).to("cuda:0")


In [None]:
def new_accuracy(model:torch.nn.Module,dl:DataLoader):
    model.eval()
    raw_output = []
    parameters = []
    with torch.no_grad():
        for batch,(sg,params) in enumerate(dl):
            sg = sg.to("cuda:0").to(torch.float)
            sgsh = sg.shape
            sg = sg.view(sgsh[0], 1, sgsh[1], sgsh[2])

            params = params[:,0].to("cuda:0").to(torch.long)
            
            raw_output.append(model(sg).detach())
            parameters.append(params)
            
    model.train()
    output = torch.vstack(raw_output)
    parameters = torch.hstack(parameters)
    accuracy = acc(output,parameters)
    auc = auroc(output,parameters)
    return accuracy,auc


In [None]:
# import vit
# import vit_pytorch
from vit_pytorch import vit_for_small_dataset as vit

In [12]:
def init_model():
    return vit.ViT(image_size=400,
                   patch_size=20,
                   num_classes=19,
                   dim=1024,
                   depth=2,
                   heads=16,
                   mlp_dim=int(2048/2),
                   dropout = 0.1,
                   emb_dropout = 0.1,
                   channels=1).to("cuda:0")


In [13]:
model = init_model()


img = torch.randn(1,1, 400,400).to("cuda:0")

preds = model(img)  # (1, 1000)
print(preds)
torchinfo.summary(model, input_size=(1,1, 400, 40))

tensor([[ 0.3697, -0.1882, -0.3220,  0.2802,  0.6521, -0.9650, -0.1339, -0.7992,
          0.4796,  0.1362,  0.6529,  0.8297,  0.6034,  0.7285,  0.7743,  0.2377,
         -0.7944,  0.6078,  1.1449]], device='cuda:0',
       grad_fn=<AddmmBackward0>)




Layer (type:depth-idx)                             Output Shape              Param #
ViT                                                [1, 19]                   411,648
├─SPT: 1-1                                         [1, 40, 1024]             --
│    └─Sequential: 2-1                             [1, 40, 1024]             --
│    │    └─Rearrange: 3-1                         [1, 40, 2000]             --
│    │    └─LayerNorm: 3-2                         [1, 40, 2000]             4,000
│    │    └─Linear: 3-3                            [1, 40, 1024]             2,049,024
├─Dropout: 1-2                                     [1, 41, 1024]             --
├─Transformer: 1-3                                 [1, 41, 1024]             --
│    └─ModuleList: 2-2                             --                        --
│    │    └─ModuleList: 3-4                        --                        6,298,625
│    │    └─ModuleList: 3-5                        --                        6,298,625
├─Iden

In [14]:
startlr = 3e-5
optimizer = optim.Adam(params=model.parameters(), lr=startlr)
optimizer1 = optim.NAdam(params=model.parameters(), lr=startlr)
step_scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[250], gamma=0.9)
# at the end of 600 epochs, the learning rate is 0.000,002,62
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=2, gamma=0.99)
scheduler_pl = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='max', factor=0.7, patience=35, verbose=True)
lossfn = nn.CrossEntropyLoss()


In [15]:
def train_eval_model(config,train_dl,test_dl,adam = True,nadam = False):
    tot_acc,auc = new_accuracy(model=model,dl = test_dl)
    max_acc = -1
    max_auc = -1
    for epoch in range(1,config.epochs+1):
        btime = time.time()
        ldl = len(train_dl)
        for batch,(sg,params) in enumerate(train_dl):
            stime = time.time()
            sg = sg.to("cuda:0").to(torch.float)
            sgsh = sg.shape
            sg = sg.view(sgsh[0],1,sgsh[1],sgsh[2])
            params = params[:,0].to("cuda:0").to(torch.long)
            optimizer.zero_grad()
            outputs = model(sg)
            loss = lossfn(outputs,params)
            loss.backward()
            optimizer.step() if adam else None
            optimizer1.step() if nadam else None
            wandb.log({"loss":loss.item(),"batch_accuracy":acc(outputs,params),"lr":scheduler.get_last_lr()[0],"epoch":epoch})
            print(f"{epoch:5}/{config.epochs:5} // {batch:5}/{ldl:5} | Loss: {loss.item():2.4},batch_accuracy:{acc(outputs,params):3.4}, last_total_accuracy: {tot_acc}, Maximum Accuracy {max_acc} last AUROC {auc} Max AUC {max_auc} lr:{scheduler.get_last_lr()[0]:1.5},Time per Batch: {time.time()-stime:1.2} seconds     ",end = "\r",flush=True)
            torch.cuda.empty_cache()
        tot_acc, auc = new_accuracy(model=model, dl=test_dl)
        scheduler.step()
        step_scheduler.step()
        scheduler_pl.step(tot_acc)
        if(tot_acc > max_acc):
            max_acc = tot_acc
            config.best_model = model.state_dict()
            try:
                torch.save(config.best_model, f"./saved_models/cnns/best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}__acc_{max_acc}__auc_{auc}.pt")
            except:  pass    
            print("\nSAVING MODEL")
        max_auc = max(max_auc,auc)
        print(f"\nEpoch {epoch+1}/{config.epochs} finished. Total accuracy: {tot_acc:3.5} AUROC: {auc} Time per Epoch: {time.time()-btime:1.5}")

        wandb.log({"epoch":epoch,"accuracy":tot_acc,"max_accuracy":max_acc,"lr":scheduler.get_last_lr()[0],"auroc":auc})
    return max_acc,max_auc

In [16]:
results = []
trials = 5

In [17]:
for i in range(trials):
    wandb.init(project="simple_vision_transformer_forsmalldatasets_validation_test")
    config = wandb.config
    config.run_name = wandb.run._run_id
    config = wandb.config
    config.epochs = 500
    config.inx = 400
    config.iny = 400
    config.lr = startlr     
    config.trial = i+1
    config.total_trials = trials
    config.best_model = OrderedDict()
    config.start_time = datetime.datetime.now().isoformat()
    config.savename = f"best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}__acc_max_acc__auc_auc.pt"
    train_dl, test_dl = get_new_test_train(.05)
    results.append(train_eval_model(wandb.config,train_dl,test_dl,nadam=True))
    model = init_model()

0,1
accuracy,▁▁▃▄▅▆▆▇▇▇▇▇▇▇█▇▇██▇█▇███▇▇▇█████
auroc,▁▁▅▇▇▇█████████████████████▇█████
batch_accuracy,▁▃▃▂▅▅▅█▅▇▆▇▆▆▇▇▇▅▇▇▆██▇████████▆▇▇███▇█
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▆█▅▄▆▃▂▄▂▃▂▃▃▂▃▂▄▂▁▂▁▂▂▁▁▁▁▁▁▁▁▄▁▂▁▁▁▂▁
lr,█████▇▇▇▇▇▆▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁
max_accuracy,▁▁▃▄▅▆▆▇▇▇▇▇▇▇███████████████████

0,1
accuracy,0.83966
auroc,0.75977
batch_accuracy,1.0
epoch,34.0
loss,0.00637
lr,3e-05
max_accuracy,0.85654


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

In [None]:
a = []
for i in results:
    a.append([i[0].cpu().item(),i[1].cpu().item()])

In [None]:
results = np.array(a)
average_accuracy = np.average(results[:,0])
average_auroc = np.average(results[:,1])
print(average_accuracy,average_auroc)

In [None]:
torch.save(config.best_model,
           f"./saved_models/ViT/best_model_state_dict_at_for{config.run_name}_stime_{config.start_time.replace(':', '-')}_BEST_MODEL.pt")
