In [1]:
import numpy as np
import matplotlib.pyplot as plt

%reload_ext autoreload
%autoreload 2

from nmrnet2 import *
from nmrdataset import *
import nmrMod as nmr

import torch as th
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader
from safetensors.torch import load_model
from torchinfo import summary

#from collections.abc import Callable
from tqdm.notebook import trange, tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NMRNet().to(device)
summary(model, input_size=(1, nmr.nPts))

Layer (type:depth-idx)                   Output Shape              Param #
NMRNet                                   [1, 1, 1024]              --
├─ParallelCNNBlock: 1-1                  [1, 136, 1024]            --
│    └─ModuleList: 2-1                   --                        --
│    │    └─Sequential: 3-1              [1, 16, 1024]             51
│    │    └─Sequential: 3-2              [1, 32, 1024]             165
│    │    └─Sequential: 3-3              [1, 64, 1024]             585
│    │    └─Sequential: 3-4              [1, 16, 1024]             561
│    │    └─Sequential: 3-5              [1, 8, 1024]              585
├─Conv1d: 1-2                            [1, 64, 1024]             8,768
├─Conv1d: 1-3                            [1, 32, 1024]             2,080
├─LSTM: 1-4                              [1, 1024, 32]             6,400
├─Conv1d: 1-5                            [1, 32, 1024]             1,056
├─Conv1d: 1-6                            [1, 16, 1024]             52

In [9]:
ML = 10000
ML_test = 500
batch_size = 32
train_set = NMRDataset(maxLen = ML, mode = "wide")
test_set = NMRDataset(maxLen = ML_test, startSeed = ML)

train_loader: DataLoader = DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=False,
    num_workers=4
)
test_loader: DataLoader = DataLoader(
    dataset=test_set,  batch_size=batch_size, shuffle=False,
    num_workers=4
)

In [10]:
criterion = nn.BCEWithLogitsLoss()

# 2. Define Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 3. Define Number of Epochs
NUM_EPOCHS = 100

# 4. Arrays to store loss history
train_losses = []
test_losses = []


for epoch in (bar := trange(NUM_EPOCHS, desc="Training   | Training epoch", 
                            bar_format="{desc}:{percentage:3.0f}%|{bar:50}{r_bar}")):
    # --- Training Phase ---
    model.train()
    running_train_loss = 0.0
    
    for i, nos in enumerate(train_loader):
        inputs, targets = nos
        bar.set_description_str(f"Training   - Batch no {i:04}/{(ML//batch_size + 1):04} | Training epoch")
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # --- Forward Pass ---
        # The model now returns raw logits
        outputs = model(inputs)
        
        # --- Calculate Loss ---
        # criterion compares the raw logits (outputs) with the targets
        loss = criterion(outputs, targets)
        
        # --- Backward Pass and Optimization ---
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_train_loss += loss.item() * inputs.size(0)

    # Calculate average training loss for the epoch
    epoch_train_loss = running_train_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)

    # --- Evaluation (Test) Phase ---
    model.eval()
    running_test_loss = 0.0
    
    # Use torch.no_grad() to disable gradient calculations
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            
            # Calculate loss
            loss = criterion(outputs, targets)
            
            running_test_loss += loss.item() * inputs.size(0)

    # Calculate average test loss for the epoch
    epoch_test_loss = running_test_loss / len(test_loader.dataset)
    test_losses.append(epoch_test_loss)
    

Training   | Training epoch:  0%|                                                  | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 