# Todo
 - Masked transformers
 - any more data reductions?
 

In [None]:
from CoReDataLoader import  dataset, dataloader

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchmetrics as metrics

In [None]:
acc = metrics.Accuracy(task = "multiclass",num_classes = dataset.numeoss).to(dataset.device)
auroc = metrics.AUROC(task = "multiclass",num_classes=dataset.numeoss).to(dataset.device)
prec = metrics.Precision(task="multiclass",num_classes=dataset.numeoss).to(dataset.device)
f1score = metrics.F1Score(task = "multiclass",num_classes=dataset.numeoss).to(dataset.device)
avgprec = metrics.AveragePrecision(task = "multiclass",num_classes=dataset.numeoss).to(dataset.device)
precrecall =metrics.PrecisionRecallCurve(task = "multiclass",num_classes=dataset.numeoss).to(dataset.device)
recall = metrics.Recall(task = "multiclass",num_classes=dataset.numeoss).to(dataset.device)
roc = metrics.ROC(task = "multiclass",num_classes=dataset.numeoss).to(dataset.device)
conf = metrics.ConfusionMatrix(task = "multiclass",num_classes=dataset.numeoss).to(dataset.device)

In [None]:
class classifier(nn.Module):
    def __init__(self,input_length = dataset.output_length):
        super().__init__()
        self.inlayer = nn.Linear(input_length,4096*3)
        self.bnorm1 = nn.BatchNorm1d(4096*3)
        self.inter1 = nn.Linear(4096*3,4096*2)
        self.bnorm2 = nn.BatchNorm1d(4096*2)
        self.inter2 = nn.Linear(4096*2,dataset.numeoss)
        self.softmax = nn.Softmax()
        self.activation = nn.GELU()

    def forward(self,inp):
        itn = inp.to(torch.float32)
        itn = F.normalize(itn, 0)
        itn = self.inlayer(itn)
        itn = self.bnorm1(itn)
        itn = self.inter1(itn)
        itn = self.activation(itn)
        itn = self.bnorm2(itn)
        itn = self.inter2(itn)
        itn = self.activation(itn)
        return itn
        
net = classifier().to(dataset.device)

In [None]:
import wandb

In [None]:
def get_accuracy():
    with torch.no_grad():
        outputs = torch.zeros((len(dataset),19)).to("cuda:0")
        corrects = torch.zeros(len(dataset)).to("cuda:0")
        net.eval()
        for ctr, (ts,params) in enumerate(dataset):
            ts = ts.view(1,-1).to("cuda:0")
            params = params.view(1,-1)
            params = params[:,0]
            results = net(ts)
            outputs[ctr] = results
            corrects[ctr] = params
    net.train()
    return acc(outputs,corrects).item()

In [None]:
def new_accuracy():
    with torch.no_grad():
        net.eval()
        tss,params = next(iter(DataLoader(dataset, batch_size=len(dataset))))
        net.train()
    return acc(net(tss),params[:,0])*100


# Run from below here to test new runs

In [None]:
print("Original Accuracy: ", new_accuracy())

In [None]:
criterion = nn.CrossEntropyLoss()
lr = 5e-5
amsgrad = False
sgd = optim.SGD(net.parameters(), lr=lr,momentum = 0.8,nesterov = True)
adam = optim.Adam(net.parameters(),lr = lr,amsgrad=amsgrad)
optimizer = adam
epochs = 600


In [32]:
from datetime import datetime
ctime = datetime.now()
stime = f"{ctime.date().isoformat()}-{ctime.time().hour}-{ctime.time().minute}"
stime

'2023-01-15-16-30'

In [33]:
from collections import OrderedDict
best = OrderedDict()

In [34]:
wandb.init(project="simple ann classifier")
losses = []
accuracies = [0]
sacc = []
accuracy = 0
bestepoch = 0
for epoch in range(epochs):
    for batch, (ts,params) in enumerate(dataloader):
        params = params[:,0].to(torch.long).to(dataset.device)
        optimizer.zero_grad()
        outputs = net(ts.to(dataset.device))
        # print(outputs.shape,params[:,0].to(torch.long).shape)
        loss = criterion(outputs,params)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        # print(outputs.shape,params.shape)
        wandb.log({"loss":loss.item(),"batch_acc":acc(outputs,params)})

    accuracy = new_accuracy()
    accuracies.append(accuracy)

    if accuracy >= max(accuracies) or (accuracy >= max(accuracies) and loss.item() <= min(losses)):
        net.state_dict(destination = best)
        best = net.state_dict()
        torch.save(best, f"./saved_models/{stime}_Acc_{lr}_{amsgrad}.pt")
        bestepoch = epoch
        print(f"MODEL SAVED AT ACCURACY: {accuracy} and EPOCH {epoch}")

    
    wandb.log({"epoch":epoch,"acc":accuracy,"max_accuracy":max(accuracies)})

torch.save(best, f"./saved_models/{stime}_Best_Model_Epoch_{bestepoch}_Acc_{max(accuracies)}_{lr}_{amsgrad}.pt")

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
batch_acc,▅▇▃▅▇▆▅▅▅▃█▄▆▄▅▃▃▇▁▅▁▄▃▆▅▄▆▄▄▃▃▇▃▅▄▃▅▅▄▅
loss,▅▂▇▃▂▂▃▄▄▅▁▄▂▅▄▆▆▂█▅█▄▅▂▃▆▃▅▄▄▆▂▅▄▄▄▄▄▄▃

0,1
batch_acc,0.57143
loss,1.21762


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016916666666656966, max=1.0…

MODEL SAVED AT ACCURACY: 63.37028884887695 and EPOCH 0
MODEL SAVED AT ACCURACY: 63.547672271728516 and EPOCH 115
MODEL SAVED AT ACCURACY: 65.94235229492188 and EPOCH 116
MODEL SAVED AT ACCURACY: 68.42572021484375 and EPOCH 118
MODEL SAVED AT ACCURACY: 72.23947143554688 and EPOCH 121
MODEL SAVED AT ACCURACY: 72.2838134765625 and EPOCH 141
MODEL SAVED AT ACCURACY: 73.88026428222656 and EPOCH 148
MODEL SAVED AT ACCURACY: 75.65409851074219 and EPOCH 149
MODEL SAVED AT ACCURACY: 76.23059844970703 and EPOCH 152
MODEL SAVED AT ACCURACY: 77.07316589355469 and EPOCH 153
MODEL SAVED AT ACCURACY: 77.60531616210938 and EPOCH 154
MODEL SAVED AT ACCURACY: 79.51219940185547 and EPOCH 158
MODEL SAVED AT ACCURACY: 79.5565414428711 and EPOCH 171
MODEL SAVED AT ACCURACY: 80.57649993896484 and EPOCH 173
MODEL SAVED AT ACCURACY: 81.50775909423828 and EPOCH 195
MODEL SAVED AT ACCURACY: 81.68514251708984 and EPOCH 196
MODEL SAVED AT ACCURACY: 82.1729507446289 and EPOCH 231
MODEL SAVED AT ACCURACY: 89.8004379

In [35]:
torch.save(best, f"./saved_models/{stime}_Best_Model_Epoch_{epoch}_Acc_{max(accuracies)}_{lr}_{amsgrad}.pt")

In [36]:
net.load_state_dict(best)
new_accuracy()

tensor(58.8470, device='cuda:0')

In [37]:
print(best)

OrderedDict([('inlayer.weight', tensor([[-0.0271, -0.0271, -0.0179,  ..., -0.0148,  0.0136, -0.0085],
        [ 0.0214,  0.0172, -0.0002,  ...,  0.0060, -0.0062,  0.0109],
        [-0.0569, -0.0381, -0.0543,  ...,  0.0042, -0.0101, -0.0135],
        ...,
        [-0.0094, -0.0139, -0.0167,  ...,  0.0125,  0.0010, -0.0057],
        [-0.0234, -0.0040, -0.0083,  ...,  0.0060,  0.0123, -0.0084],
        [-0.0131,  0.0063, -0.0042,  ...,  0.0107, -0.0050,  0.0125]],
       device='cuda:0')), ('inlayer.bias', tensor([-0.0351,  0.0239, -0.0089,  ...,  0.0101,  0.0078, -0.0089],
       device='cuda:0')), ('bnorm1.weight', tensor([1.0736, 1.0731, 1.0695,  ..., 1.0613, 1.0492, 1.0012], device='cuda:0')), ('bnorm1.bias', tensor([-0.0025,  0.0017, -0.0004,  ..., -0.0056, -0.0011,  0.0115],
       device='cuda:0')), ('bnorm1.running_mean', tensor([-0.0351,  0.0239, -0.0089,  ...,  0.0101,  0.0078, -0.0089],
       device='cuda:0')), ('bnorm1.running_var', tensor([9.1924e-08, 1.9659e-08, 4.2868e-08,