# Todo
 - Masked transformers
 - any more data reductions?
 

In [None]:
from CoReDataLoader import  dataset,dataloader, maxlen
numclasses = len(dataset.eoss)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchmetrics as metrics

In [None]:
acc = metrics.Accuracy(task = "multiclass",num_classes = numclasses)
auroc = metrics.AUROC(task = "multiclass",num_classes=numclasses)
prec = metrics.Precision(task="multiclass",num_classes=numclasses)
f1score = metrics.F1Score(task = "multiclass",num_classes=numclasses)
avgprec = metrics.AveragePrecision(task = "multiclass",num_classes=numclasses)
precrecall = metrics.PrecisionRecallCurve(task = "multiclass",num_classes = numclasses)
recall = metrics.Recall(task = "multiclass",num_classes=numclasses)
roc = metrics.ROC(task = "multiclass",num_classes=numclasses)
conf = metrics.ConfusionMatrix(task = "multiclass",num_classes=numclasses)

In [None]:
class classifier(nn.Module):
    def __init__(self,input_length = dataset.output_length):
        super().__init__()
        self.inlayer = nn.Linear(input_length,4096)
        self.bnorm = nn.BatchNorm1d(4096)
        self.silu = nn.SiLU()
        self.inter1 = nn.Linear(4096,2056)
        self.inter2 = nn.Linear(2056,1024)
        self.inter3 = nn.Linear(1024,1024)
        self.output = nn.Linear(1024,numclasses)
        self.softmax = nn.Softmax()
        self.activation = nn.GELU()

    def forward(self,inp):
        itn = inp.to(torch.float32)
        itn = F.normalize(itn, 0)
        itn = self.inlayer(itn)
        itn = self.bnorm1(itn)
        itn = self.inter1(itn)
        itn = self.activation(itn)
        itn = self.bnorm2(itn)
        itn = self.inter2(itn)
        itn = self.activation(itn)
        return itn
        
net = classifier().to(dataset.device)

In [None]:
import wandb

In [None]:
def get_accuracy():
    with torch.no_grad():
        outputs = torch.zeros((len(dataset),19)).to("cuda:0")
        corrects = torch.zeros(len(dataset)).to("cuda:0")
        net.eval()
        for ctr, (ts,params) in enumerate(dataset):
            ts = ts.view(1,-1).to("cuda:0")
            params = params.view(1,-1)
            params = params[:,0]
            results = net(ts)
            outputs[ctr] = results
            corrects[ctr] = params
    net.train()
    return acc(outputs,corrects).item()

In [None]:
def new_accuracy():
    with torch.no_grad():
        net.eval()
        tss,params = next(iter(DataLoader(dataset, batch_size=len(dataset))))
        net.train()
    return acc(net(tss),params[:,0])*100


# Run from below here to test new runs

In [None]:
print("Original Accuracy: ", new_accuracy())

In [None]:
criterion = nn.CrossEntropyLoss()
lr = 5e-5
amsgrad = False
sgd = optim.SGD(net.parameters(), lr=lr,momentum = 0.8,nesterov = True)
adam = optim.Adam(net.parameters(),lr = lr,amsgrad=amsgrad)
optimizer = adam
epochs = 600


In [None]:
from collections import OrderedDict
best = OrderedDict()

In [None]:
wandb.init(project="simple ann classifier")
losses = []
accuracies = [0]
sacc = []
accuracy = 0
bestepoch = 0
for epoch in range(epochs):
    for batch, (ts,params) in enumerate(dataloader):
        params = params[:,0].to(torch.long)
        optimizer.zero_grad()
        outputs = net(ts.to(dataset.device))
        # print(outputs.shape,params[:,0].to(torch.long).shape)
        loss = criterion(outputs,params)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        # print(outputs.shape,params.shape)
        wandb.log({"loss":loss.item(),"batch_acc":acc(outputs,params)})

        print(f"{epoch+1}/{epochs} {batch+1}/{len(dataloader)} loss = {loss.item()} accuracy = {acc(outputs,params)*100}% {auroc(outputs,params) = } {prec(outputs,params) = } {f1score(outputs,params) = } {avgprec(outputs,params) = } {recall(outputs,params) = } \n",end = "\r",flush = True)
        correct = 0
    wandb.log()
    print(f"Epoch finished: {epoch+1}\n")

In [None]:
torch.save(best, f"./saved_models/{stime}_Best_Model_Epoch_{epoch}_Acc_{max(accuracies)}_{lr}_{amsgrad}.pt")

In [None]:
net.load_state_dict(best)
new_accuracy()

In [None]:
print(best)