In [None]:
import pandas as pd
import numpy as np
import torch
from torch import nn
import seaborn as sns
import plotly.express as px
import json
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
from torch.utils.data import TensorDataset,DataLoader
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

# Train Model

In [None]:
# Get cpu or gpu device for training
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# Define Model
n_hl = 10

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(300, n_hl),
            nn.ReLU(),
            nn.Linear(n_hl, 1)
        )
    
    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits  

model = MLP().to(device)
    
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
# Define training parameters and preapre dataloaders from saved datasets

epochs = 25
batch_size = 64

losses = []
test_losses = []

train_dataset = torch.load('pipeline/datasets/train_dataset.pt')
test_dataset = torch.load('pipeline/datasets/test_dataset.pt')

train_length = len(train_dataset)
test_length = len(test_dataset)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=10000)    # batches for memory

n_train_batches = len(train_dataloader)
n_test_batches = len(test_dataloader)

In [None]:
# Train
model.train()
for epoch in range(epochs):
    for X_train, y_train in train_dataloader:

        X_train = X_train.to(device)
        y_train = y_train.to(device)

        # Forward Pass
        logits = model(X_train)
        loss = criterion(logits, y_train)

        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
    print(f'Epoch {epoch} - loss: {sum(losses[-n_train_batches:])/n_train_batches}')

In [None]:
# Plot loss curve at specified resolution
res = 250
plt.plot(torch.tensor(losses)[:len(losses)-len(losses)%res].view(-1,res).mean(1))
plt.savefig('loss.jpg')

In [None]:
# Test on test dataset
model.eval()

preds = []
n_correct = 0
loss = 0

for X_test, y_test in test_dataloader:
    X_test = X_test.to(device)
    y_test = y_test.to(device)

    logits = model(X_test)
    pred = torch.round(nn.Sigmoid()(logits))

    n_correct += sum(y_test == pred)
    preds += pred.flatten().tolist()
    loss += criterion(logits, y_test).item()

y_pred = np.array(preds).reshape(-1,1)
accuracy = (n_correct / test_length).item()
loss = loss/test_length
print(f'Accuracy: {100*accuracy:.4}%')
print(f'Loss: {loss:.5}')

In [None]:
# Heatmap

y_true = test_dataset[:][1]

fig,axes = plt.subplots(1,3,sharey=True,figsize=(10,5))
sns.heatmap(confusion_matrix(y_true=y_true,y_pred=y_pred,normalize='true'),annot=True,ax=axes[0],cbar=False,fmt='.2f')
sns.heatmap(confusion_matrix(y_true=y_true,y_pred=y_pred,normalize='pred'),annot=True,ax=axes[1],cbar=False,fmt='.2f')
sns.heatmap(confusion_matrix(y_true=y_true,y_pred=y_pred),annot=True,ax=axes[2],cbar=False,fmt='.2f')
axes[0].set_title('Recall')
axes[1].set_title('Precision')
axes[2].set_title('Count')
plt.savefig(f'cm.jpg',dpi=200,bbox_inches='tight')

In [None]:
# Save model

os.system('mkdir model')
torch.save(model.state_dict(), 'model/model.pt')