# Todo
 - Masked transformers
 - any more data reductions?
 

In [None]:
from CoReDataLoader import  dataset, dataloader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchmetrics as metrics
import wandb

In [None]:
class classifier(nn.Module):
    def __init__(self,input_length = dataset.output_length):
        super().__init__()
        self.inlayer = nn.Linear(input_length,4096*3)
        self.bnorm1 = nn.BatchNorm1d(4096*3)
        self.inter1 = nn.Linear(4096*3,4096*2)
        self.bnorm2 = nn.BatchNorm1d(4096*2)
        self.inter2 = nn.Linear(4096*2,4096)
        self.bnorm3 = nn.BatchNorm1d(4096)
        self.inter3 = nn.Linear(4096,2)
        self.softmax = nn.Softmax()
        self.activation = nn.GELU()

    def forward(self,inp):
        itn = inp.to(torch.float32)
        itn = F.normalize(itn, 0)
        itn = self.inlayer(itn)
        itn = self.bnorm1(itn)
        itn = self.inter1(itn)
        itn = self.activation(itn)
        itn = self.bnorm2(itn)
        itn = self.inter2(itn)
        itn = self.activation(itn)
        itn = self.bnorm3(itn)
        itn = self.inter3(itn)
        itn = self.activation(itn)

        return itn
        
net = classifier().to(dataset.device)

In [None]:
wmape = metrics.WeightedMeanAbsolutePercentageError()
gwmape = metrics.WeightedMeanAbsolutePercentageError().to(dataset.device)
mse = metrics.MeanSquaredError()
mae = metrics.MeanAbsoluteError()
msle = metrics.MeanSquaredLogError()
lce = metrics.LogCoshError(num_outputs=2)
smape = metrics.SymmetricMeanAbsolutePercentageError()
mape = metrics.MeanAbsolutePercentageError()
collection = metrics.MetricCollection([wmape,mse,mae,mape,smape,lce,msle])


In [None]:
def ensemble_error(metric : metrics.Metric = collection):
    with torch.no_grad():
        tss,params = next(iter(DataLoader(dataset, batch_size=len(dataset))))
        cparams = params.to("cpu")
        del params
        return metric(net(tss).to("cpu"),cparams[:,1:])

In [None]:
ensemble_error()

# Run from below here to test new runs

In [None]:
"""
 ▄▄▄▄▄▄   ▄▄   ▄▄ ▄▄    ▄    ▄▄▄▄▄▄▄ ▄▄▄▄▄▄   ▄▄▄▄▄▄▄ ▄▄   ▄▄
█   ▄  █ █  █ █  █  █  █ █  █       █   ▄  █ █       █  █▄█  █
█  █ █ █ █  █ █  █   █▄█ █  █    ▄▄▄█  █ █ █ █   ▄   █       █
█   █▄▄█▄█  █▄█  █       █  █   █▄▄▄█   █▄▄█▄█  █ █  █       █
█    ▄▄  █       █  ▄    █  █    ▄▄▄█    ▄▄  █  █▄█  █       █
█   █  █ █       █ █ █   █  █   █   █   █  █ █       █ ██▄██ █
█▄▄▄█  █▄█▄▄▄▄▄▄▄█▄█  █▄▄█  █▄▄▄█   █▄▄▄█  █▄█▄▄▄▄▄▄▄█▄█   █▄█

"""
print("Original error: ", ensemble_error())

In [None]:
criterion = nn.MSELoss()
criterion2 = nn.L1Loss()
lr = 3e-5
amsgrad = False
# sgd = optim.SGD(net.parameters(), lr=lr,momentum = 0.8,nesterov = True)
adam = optim.Adam(net.parameters(),lr = lr,amsgrad=amsgrad)
optimizer = adam
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience = 50)
epochs = 1200


In [None]:
from datetime import datetime
ctime = datetime.now()
stime = f"{ctime.date().isoformat()}-{ctime.time().hour}-{ctime.time().minute}"
stime

In [None]:
from collections import OrderedDict
best = OrderedDict()

In [None]:
import pandas as pd

In [67]:
error_df = pd.DataFrame()

In [None]:
wandb.init(project="M1M2regressor")
wandb.watch(net)
wandb.log({"lr":lr,"amsgrad":amsgrad})
losses = []
errors = []
dee = ensemble_error()
bestepoch = 0
for epoch in range(epochs):
    for batch, (ts,params) in enumerate(dataloader):
        params = params[:,1:].to(dataset.device)    
        optimizer.zero_grad()
        outputs = net(ts.to(dataset.device))
        
        loss = criterion(outputs,params)
        loss += criterion2(outputs,params)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        # scheduler.step(error)
        wandb.log({"loss":loss.item(),"batch_err":gwmape(outputs,params)})

    dee = ensemble_error()
    error_df = pd.concat([error_df,pd.DataFrame(dee)],axis = 1)
    print(error_df)
    error = dee["MeanAbsoluteError"]
    errors.append(error)
    if error <= min(errors) or (error <= min(errors) and loss.item() <= min(losses)): # ignore
        best = net.state_dict()
        torch.save(best, f"./saved_models/INTRAIN_{stime}_wmape_{lr}_{amsgrad}.pt")
        bestepoch = epoch
        print(f"MODEL SAVED AT error: {error}, MSE error {error}, and EPOCH {epoch}")
    wandb.log(dee | {"epoch":epoch})


torch.save(best, f"./saved_models/{stime}_Best_Model_Epoch_{bestepoch}_err_{max(errors)}_{lr}_{amsgrad}.pt")

In [None]:
torch.save(best, f"./saved_models/{stime}_Best_Model_Epoch_{bestepoch}_wmape_{min(errors)}_{lr}_{amsgrad}.pt")

In [None]:
net.load_state_dict(best)

In [None]:
print(best)

In [None]:
lends = len(dataset)
lends = 5
tss,params = next(iter(DataLoader(dataset, batch_size=lends,shuffle=True)))
outs = net(tss)

# for i,v in zip():
errors = [[] for i in dataset.eoss]
print(dataset.eoss)
for a,b,c in zip(outs.reshape(-1),params[:,1:].reshape(-1),params[:,0].reshape(-1)):
    print(f"{a.item()},{b.item()} | {a.item()-b.item()} | {c.item()}")
    errors[int(c.item())].append(a.item()-b.item())
errors

In [None]:
errors

In [None]:
from scipy.stats import describe
for i,v in enumerate(errors):
    print(i,describe(v))

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(errors)

In [None]:
plt.hist(errors)

In [None]:
import torcheval as te 
tss,params = next(iter(DataLoader(dataset, batch_size=len(dataset))))
r2 = te.metrics.functional.r2_score(net(tss),params[:,1:])
r2