In [None]:
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

import NMRAux as nmr
import Layers as ly

from torchinfo import summary

import torch as th
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset,DataLoader
#from torchvision.models import googlenet

#from collections.abc import Callable
from tqdm.notebook import trange

In [None]:
yy, res = nmr.generateRandomSpectrum(0)
fig, ax = plt.subplots(figsize=(8, 2))
ax.plot(yy["true"])
ax.scatter(res,[yy["true"][i] for i in res], c = "orange", zorder = 10)

In [None]:
class NMRDataset(Dataset):
    def __init__(self, maxLen = 250000, startSeed = 0):
        self.maxLen = maxLen
        self.startSeed = startSeed

    def __len__(self):
        return self.maxLen

    def __getitem__(self, idx):
        yy, res = nmr.generateRandomSpectrum(idx + self.startSeed)
        isPk = np.full_like(yy["true"], False)
        for i in res:
            isPk[i] = True

        #isPk[res[]] = True
        return th.from_numpy(np.float32(yy["true"]).reshape([1,-1])), isPk.reshape([1,-1])

In [None]:
ML = 10000
ML_test = 500
batch_size = 256
train_set = NMRDataset(maxLen = ML)
test_set = NMRDataset(maxLen = ML_test, startSeed = ML)

train_loader: DataLoader = DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=False,
    num_workers=2
)
test_loader: DataLoader = DataLoader(
    dataset=test_set,  batch_size=batch_size, shuffle=False,
    num_workers=2
)

In [None]:
fig, ax = plt.subplots(figsize=(8, 2))
ax.plot(train_set[0][0].T)

In [None]:
def NMRSeq() -> th.nn.Sequential:
    return th.nn.Sequential(

        ly.Inception_variant(1),

        ly.TransposeLayer(-1,-2),


        th.nn.Linear(
            in_features=136, out_features=64, bias=True
        ),

        th.nn.ReLU(),

        th.nn.Linear(
            in_features=64, out_features=32, bias=True
        ),

        th.nn.ReLU(),

        nn.LSTM(32, 16, bidirectional=True),
        ly.extract_tensor(),

        th.nn.ReLU(),

        th.nn.Linear(
            in_features=32, out_features=32, bias=True
        ),
        th.nn.ReLU(),

        th.nn.Linear(
            in_features=32, out_features=16, bias=True
        ),
        th.nn.ReLU(),

        th.nn.Linear(
            in_features=16, out_features=1, bias=True
        ),
        th.nn.ReLU(),

        ly.TransposeLayer(-1,-2),
    )

In [None]:
# Let's detect and select the most appropriate device
# (adapt it to your specific hardware needs: mps, tpu, ...)
device: th.device = th.device(
    "cuda" if th.cuda.is_available() else "cpu"
)

In [None]:
model: th.nn.Module = NMRSeq().to(device)
summary(model, input_size=(batch_size,1,nmr.nPts))

In [None]:
optimizer: th.optim.Optimizer = th.optim.Adam(
    params=model.parameters(), lr=0.001, weight_decay=0
)

#lossCriterion = nn.MSELoss()
lossCriterion = nn.CrossEntropyLoss()

In [None]:
EPOCHS = 100
BATCH_SIZE = 32

eval_losses = []
eval_acc = []
test_acc = []

# Loop over epochs
for epoch in trange(EPOCHS, desc="Training epoch"):

    print (f"Epoch #{epoch}")

    model.train()  # Remember to set the model in training mode before actual training
    # Loop over data
    for batch_idx, batched_datapoint in enumerate(train_loader):
        print(f"batch_idx: {batch_idx}")
        x, y = batched_datapoint
        x, y = x.to(device), y.to(device)

        # Forward pass + loss computation
        yhat = model(x)
        loss = lossCriterion(yhat, y)

        # Zero-out past gradients
        optimizer.zero_grad()

        # Backward pass
        loss.backward()

        # Update model parameters
        optimizer.step()

    model.eval()  # Remember to set the model in evaluation mode before evaluating it
    print (f"Eval")
    
    num_elem: int = 0
    trackingmetric: float = 0
    trackingcorrect: int = 0

    # Since we are just evaluating the model, we don't need to compute gradients
    with th.no_grad():
        # ... by looping over training data again
        for _, batched_datapoint_e in enumerate(train_loader):
            x_e, y_e = batched_datapoint_e
            x_e, y_e = x_e.to(device), y_e.to(device)
            modeltarget_e = model(x_e)
            ypred_e = th.argmax(modeltarget_e, dim=1, keepdim=True)
            trackingmetric += lossCriterion(modeltarget_e, y_e).item()
            trackingcorrect += ypred_e.eq(y_e.view_as(ypred_e)).sum().item()
            num_elem += x_e.shape[0]
        eval_losses.append(trackingmetric / num_elem)
        eval_acc.append(trackingcorrect / num_elem)