In [13]:
from CoReDataLoader import  dataset,dataloader, maxlen
numclasses = len(dataset.eoss)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics as metrics

In [15]:
acc = metrics.Accuracy(task = "multiclass",num_classes = numclasses)
auroc = metrics.AUROC(task = "multiclass",num_classes=numclasses)
prec = metrics.Precision(task="multiclass",num_classes=numclasses)
f1score = metrics.F1Score(task = "multiclass",num_classes=numclasses)
avgprec = metrics.AveragePrecision(task = "multiclass",num_classes=numclasses)
precrecall = metrics.PrecisionRecallCurve(task = "multiclass",num_classes = numclasses)
recall = metrics.Recall(task = "multiclass",num_classes=numclasses)
roc = metrics.ROC(task = "multiclass",num_classes=numclasses)
conf = metrics.ConfusionMatrix(task = "multiclass",num_classes=numclasses)

In [None]:
class classifier(nn.Module):
    def __init__(self,input_length = maxlen):
        super().__init__()
        self.inlayer = nn.Linear(input_length,4096)
        self.bnorm = nn.BatchNorm1d(4096)
        self.silu = nn.SiLU()
        self.inter1 = nn.Linear(4096,2056)
        self.inter2 = nn.Linear(2056,1024)
        self.inter3 = nn.Linear(1024,1024)
        self.output = nn.Linear(1024,numclasses)
        self.softmax = nn.Softmax()

        
    def forward(self,inp):
        inp = inp.to(torch.float32)
        itn = self.inlayer(inp)
        itn = self.bnorm(itn)
        itn = self.silu(itn)
        itn = self.silu(self.inter1(itn))
        itn = self.silu(self.inter2(itn))
        itn = self.silu(self.inter3(itn))
        itn = self.silu(self.output(itn))
        return itn
net = classifier()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr = 1,betas = (0.9,0.999))
epochs = 100
net.to("cpu")

classifier(
  (inlayer): Linear(in_features=40817, out_features=4096, bias=True)
  (bnorm): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (silu): SiLU()
  (inter1): Linear(in_features=4096, out_features=2056, bias=True)
  (inter2): Linear(in_features=2056, out_features=1024, bias=True)
  (inter3): Linear(in_features=1024, out_features=1024, bias=True)
  (output): Linear(in_features=1024, out_features=19, bias=True)
  (softmax): Softmax(dim=None)
)

In [16]:
losses = []
correct = 0
for epoch in range(epochs):
    for batch, (ts,params) in enumerate(dataloader):
        params = params[:,0].to(torch.long)
        optimizer.zero_grad()
        outputs = net(ts)
        # print(outputs.shape,params[:,0].to(torch.long).shape)
        loss = criterion(outputs,params.to(torch.long))
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        print(f"{epoch+1}/{epochs} {batch+1}/{len(dataloader)} loss = {loss.item()} accuracy = {acc(outputs,params)*100}% {auroc(outputs,params) = } {prec(outputs,params) = } {f1score(outputs,params) = } {avgprec(outputs,params) = } {recall(outputs,params) = } \n",end = "\r",flush = True)
        correct = 0
    print(f"Epoch finished: {epoch+1}\n")

1/100 1/148 loss = 2.944438934326172 accuracy = 0.0% auroc(outputs,params) = tensor(0.1053) prec(outputs,params) = tensor(0.) f1score(outputs,params) = tensor(0.) avgprec(outputs,params) = tensor(0.2500) recall(outputs,params) = tensor(0.) 
1/100 2/148 loss = 2.944438934326172 accuracy = 0.0% auroc(outputs,params) = tensor(0.2368) prec(outputs,params) = tensor(0.) f1score(outputs,params) = tensor(0.) avgprec(outputs,params) = tensor(0.1111) recall(outputs,params) = tensor(0.) 
1/100 3/148 loss = 2.944438934326172 accuracy = 0.0% auroc(outputs,params) = tensor(0.1316) prec(outputs,params) = tensor(0.) f1score(outputs,params) = tensor(0.) avgprec(outputs,params) = tensor(0.2000) recall(outputs,params) = tensor(0.) 
1/100 4/148 loss = 2.944438934326172 accuracy = 0.0% auroc(outputs,params) = tensor(0.1842) prec(outputs,params) = tensor(0.) f1score(outputs,params) = tensor(0.) avgprec(outputs,params) = tensor(0.1429) recall(outputs,params) = tensor(0.) 
1/100 5/148 loss = 2.944438934326172

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses)
plt.show()