# Todo
 - Masked transformers
 - any more data reductions?
 

In [None]:
from CoReDataLoader import  dataset, dataloader, maxlen
numclasses = len(dataset.eoss)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchmetrics as metrics

In [None]:
acc = metrics.Accuracy(task = "multiclass",num_classes = numclasses).to(dataset.device)
auroc = metrics.AUROC(task = "multiclass",num_classes=numclasses).to(dataset.device)
prec = metrics.Precision(task="multiclass",num_classes=numclasses).to(dataset.device)
f1score = metrics.F1Score(task = "multiclass",num_classes=numclasses).to(dataset.device)
avgprec = metrics.AveragePrecision(task = "multiclass",num_classes=numclasses).to(dataset.device)
precrecall =metrics.PrecisionRecallCurve(task = "multiclass",num_classes=numclasses).to(dataset.device)
recall = metrics.Recall(task = "multiclass",num_classes=numclasses).to(dataset.device)
roc = metrics.ROC(task = "multiclass",num_classes=numclasses).to(dataset.device)
conf = metrics.ConfusionMatrix(task = "multiclass",num_classes=numclasses).to(dataset.device)

In [None]:
class classifier(nn.Module):
    def __init__(self,input_length = maxlen):
        super().__init__()
        self.inlayer = nn.Linear(input_length,4096*3)
        self.bnorm1 = nn.BatchNorm1d(4096*3)
        self.inter1 = nn.Linear(4096*3,4096*2)
        self.bnorm2 = nn.BatchNorm1d(4096*2)
        self.inter2 = nn.Linear(4096*2,numclasses)
        self.softmax = nn.Softmax()
        self.activation = nn.GELU()

    def forward(self,inp):
        itn = inp.to(torch.float32)
        itn = F.normalize(itn, 0)
        itn = self.inlayer(itn)
        itn = self.bnorm1(itn)
        itn = self.inter1(itn)
        itn = self.activation(itn)
        itn = self.bnorm2(itn)
        itn = self.inter2(itn)
        itn = self.activation(itn)
        return itn
        
net = classifier().to(dataset.device)

In [None]:
import wandb

In [None]:
def get_accuracy():
    with torch.no_grad():
        outputs = torch.zeros((len(dataset),19)).to("cuda:0")
        corrects = torch.zeros(len(dataset)).to("cuda:0")
        net.eval()
        for ctr, (ts,params) in enumerate(dataset):
            ts = ts.view(1,-1).to("cuda:0")
            params = params.view(1,-1)
            params = params[:,0]
            results = net(ts)
            outputs[ctr] = results
            corrects[ctr] = params
    net.train()
    return acc(outputs,corrects).item()

In [None]:
def new_accuracy():
    with torch.no_grad():
        tss,params = next(iter(DataLoader(dataset, batch_size=len(dataset))))
    return acc(net(tss),params[:,0])*100


# Run from below here to test new runs

In [None]:
print("Original Accuracy: ", new_accuracy())

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr = 2e-3,amsgrad = True)
epochs = 500

In [None]:
from datetime import datetime
stime = datetime.now().date().isoformat()

In [None]:
from collections import OrderedDict
best = OrderedDict()

In [None]:
wandb.init(project="simple ann classifier")
losses = []
accuracies = [0]
sacc = []
accuracy = 0
for epoch in range(epochs):
    for batch, (ts,params) in enumerate(dataloader):
        params = params[:,0].to(torch.long).to(dataset.device)
        optimizer.zero_grad()
        outputs = net(ts.to(dataset.device))
        # print(outputs.shape,params[:,0].to(torch.long).shape)
        loss = criterion(outputs,params)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        # print(outputs.shape,params.shape)
        wandb.log({"loss":loss.item(),"batch_acc":acc(outputs,params)})
    accuracy = get_accuracy()
    accuracies.append(accuracy)
    if accuracy >= max(accuracies):
        print(f"MODEL SAVED AT ACCURACY: {accuracy} and EPOCH {epoch}")
        net.state_dict(destination = best)
        wandb.alert(text = f"MODEL SAVED AT ACCURACY: {accuracy} and EPOCH {epoch}", title = "Model Saved")
    wandb.log({"epoch":epoch,"acc":accuracy,"max_accuracy":max(accuracies)})
    

In [None]:
torch.save(best, f"./saved_models/{stime}_Best_Model_Epoch_{epoch}_Acc_{max(accuracies)}.pt")

In [None]:
print(best)