# Group 74 - Project 4

# Notebook for reproducing main results
### This notebook trains the 2LSTM_Attention model for N epochs, base being 2 epochs.
This notebook requires having downloaded the synthetic data "golf_spectograms_tensor.pt" and it being placed in the same directory as this notebook. This synthetic data is used due to the original data being unavailable for public sharing. <br>
Since the data is synthetic, the train and test data is the same, as it is not intended to show the actual model performance.

In [31]:
# Import packages
import pandas as pd
import torch.nn as nn
import torch
import numpy as np

from typing import Iterable
from torch.utils.data import Dataset, DataLoader
from torch.nn import Module
from torchvision.transforms.functional import resize
from torchvision.transforms import transforms
from torch.nn.functional import mse_loss as torch_mse_loss
from numpy import log10
from pathlib import Path


### First, global constants are defined.

In [32]:
EPOCHS = 2 # the model converges in test perfermance after ~250-300 epochs
LEARNING_RATE = 10**-4
WEIGHT_DECAY = 10**-3
BATCH_SIZE = 10
NUM_WORKERS = 10
OPTIMIZER = torch.optim.Adam
DEVICE = "cpu"

NFFT = 512
TS_CROPTWIDTH = (-150, 200)
VR_CROPTWIDTH = (-60, 15)

### Below, some necessary functions for data pre-processing and prediction are defined.

In [33]:
def print_model_complexity(model: Module, return_params = False) -> None:
    """Check and print the number of parameters in the network

    Args:
        model (module): Pytorch model class
    """
    
    total_params = sum(p.numel() for p in model().parameters())
    
    print(f"Number of parameters in model {model.__name__}: {total_params} = {'{:.2e}'.format(total_params)}")

def mse_loss(output: torch.Tensor,
             target: torch.Tensor) -> torch.Tensor:
    return torch_mse_loss(output, target)

def weights_init_uniform_rule(m):
    classname = m.__class__.__name__
    # for every Linear layer in a model..
    if classname.find('Linear') != -1:
        # get the number of the inputs
        n = m.in_features
        y = 1.0/n**.5
        m.weight.data.uniform_(-y, y)
        m.bias.data.fill_(0)

def train_one_epoch(loss_fn, model, train_data_loader, optimizer):
    running_loss = 0.
    last_loss = 0.
    total_loss = 0.

    for i, (data, target) in enumerate(train_data_loader):
        spectrogram = data.to(DEVICE)
        target = target.to(DEVICE)

        optimizer.zero_grad()

        outputs = model(spectrogram)

        # Compute the loss and its gradients
        loss = loss_fn(outputs.squeeze(), target)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        total_loss += loss.item()
        if i % train_data_loader.batch_size == train_data_loader.batch_size - 1:
            last_loss = running_loss / train_data_loader.batch_size # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    return total_loss / (i+1)

### Below, necessary classes are defined.

In [34]:
class SpectrogramDataset(Dataset):
    """
    Class for the synthetic dataset
    """ 
    def __init__(self, data: torch.Tensor, data_size: int):
        self.data = data.permute(2, 0, 1).unsqueeze(0).repeat(data_size, 1, 1, 1)
        self.target = torch.rand(data_size)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

### Define the model.

In [35]:
class LSTM2_attention(nn.Module):
    loss_fn = mse_loss
    dataset = SpectrogramDataset
    
    def __init__(self, dropout_rate=0.2):
        super().__init__()
        
        # CNN Feature Extractor
        self.cnn_features = nn.Sequential(
            # First Conv Block
            nn.Conv2d(in_channels=6, out_channels=48, kernel_size=5, stride=2),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),
            nn.Dropout2d(dropout_rate),
            
            # Second Conv Block with Residual Connection
            nn.Conv2d(in_channels=48, out_channels=48, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=48, out_channels=48, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),
            nn.Dropout2d(dropout_rate),
            
            # Third Conv Block
            nn.Conv2d(in_channels=48, out_channels=96, kernel_size=3, stride=2),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1),
            nn.AdaptiveAvgPool2d((8, 25))
        )
        
        self.lstm_input_size = 96 * 25
        
        # LSTM
        self.lstm = nn.LSTM(
            input_size=self.lstm_input_size,
            hidden_size=192, 
            num_layers=3,
            batch_first=True,
            dropout=dropout_rate,
            bidirectional=True
        )
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(384, 96),  # 384 from bidirectional LSTM (192*2)
            nn.Tanh(),
            nn.Linear(96, 1)
        )
        
        # Regression Head
        self.regressor = nn.Sequential(
            nn.Linear(384, 128),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 32),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout_rate/2),
            nn.Linear(32, 1)
        )
        
    def attention_net(self, lstm_output):
        attention_weights = self.attention(lstm_output)
        attention_weights = torch.softmax(attention_weights, dim=1)
        context = torch.sum(attention_weights * lstm_output, dim=1)
        return context
        
    def forward(self, x):
        batch_size = x.size(0)
        
        # CNN Feature Extraction
        cnn_out = self.cnn_features(x)
        
        # Reshape for LSTM
        cnn_out = cnn_out.permute(0, 2, 1, 3)  # [batch, height, channels, width]
        cnn_out = cnn_out.reshape(batch_size, 8, self.lstm_input_size)  # [batch, seq_len, features]
        
        # LSTM processing
        lstm_out, _ = self.lstm(cnn_out)
        
        # Apply attention
        context = self.attention_net(lstm_out)
        
        # Final regression
        return self.regressor(context)

### Run the training.

In [36]:
data_dir = "golf_spectrograms_tensor.pt"
data_tensor = torch.load(data_dir, map_location=DEVICE, weights_only=True)
MODEL = LSTM2_attention

print(f"Using {DEVICE} device")

dataset_train = SpectrogramDataset(data_tensor, data_size=400)
dataset_test = SpectrogramDataset(data_tensor, data_size=100)

train_data_loader = DataLoader(dataset_train, 
                                batch_size=BATCH_SIZE,
                                shuffle=True,
                                num_workers=0)
test_data_loader = DataLoader(dataset_test,
                                batch_size=3,
                                shuffle=False,
                                num_workers=0)

model = MODEL().to(DEVICE)
model.apply(weights_init_uniform_rule)

optimizer = OPTIMIZER(model.parameters(), lr=LEARNING_RATE)
model_name = f"model_{MODEL.__name__}"

 ## TRAINING LOOP
epoch_number = 0
best_vloss = 1_000_000.

for epoch in range(EPOCHS):
        print('EPOCH {}:'.format(epoch_number + 1))

        # Make sure gradient tracking is on
        model.train(True)

        # Do a pass over the training data and get the average training MSE loss
        avg_loss = train_one_epoch(MODEL.loss_fn, model, train_data_loader, optimizer)
        
        # Calculate the root mean squared error: This gives
        # us the opportunity to evaluate the loss as an error
        # in natural units of the ball velocity (m/s)
        rmse = avg_loss**(1/2)

        # Take the log as well for easier tracking of the
        # development of the loss.
        log_rmse = log10(rmse)

        # Reset test loss
        running_test_loss = 0.

        # Set the model to evaluation mode
        model.eval()

        # Disable gradient computation and evaluate the test data
        with torch.no_grad():
            for i, (data, target) in enumerate(test_data_loader):
                # Get data and targets
                spectrogram = data.to(DEVICE)
                target = target.to(DEVICE)
                
                # Get model outputs
                test_outputs = model(spectrogram)

                # Calculate the loss
                test_loss = MODEL.loss_fn(test_outputs.squeeze(), target)

                # Add loss to runnings loss
                running_test_loss += test_loss

        # Calculate average test loss
        avg_test_loss = running_test_loss / (i + 1)

        # Calculate the RSE for the training predictions
        test_rmse = avg_test_loss**(1/2)

        # Take the log as well for visualisation
        log_test_rmse = torch.log10(test_rmse)

        print('LOSS train {} ; LOSS test {}'.format(avg_loss, avg_test_loss))

        if avg_test_loss < best_vloss:
            best_vloss = avg_test_loss
            torch.save(model.state_dict(), "model.pth")

        epoch_number += 1

Using cpu device
EPOCH 1:
  batch 10 loss: 0.2696193292737007
  batch 20 loss: 0.2471896454691887
  batch 30 loss: 0.10568164624273776
  batch 40 loss: 0.10395607501268386


  return torch_mse_loss(output, target)


LOSS train 0.18161167399957776 ; LOSS test 0.07967961579561234
EPOCH 2:
  batch 10 loss: 0.08612769506871701
  batch 20 loss: 0.08793904446065426
  batch 30 loss: 0.09118521362543106
  batch 40 loss: 0.08546863384544849
LOSS train 0.0876801467500627 ; LOSS test 0.08101930469274521
