In [1]:
import pandas as pd
import numpy as np
import torch
from torch import nn
import seaborn as sns
import plotly.express as px
import json
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
from torch.utils.data import TensorDataset,DataLoader
from tqdm import tqdm

# Train Model

In [None]:
# Get cpu or gpu device for training
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# Define Model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(300, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )
    
    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits  

model = MLP().to(device)
    
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
# Define training parameters and preapre dataloaders from saved datasets

epochs = 10
batch_size = 64

losses = []
test_losses = []

train_dataset = torch.load('pipeline/datasets/train_dataset.pt')
test_dataset = torch.load('pipeline/datasets/test_dataset.pt')

train_length = len(train_dataset)
test_length = len(test_dataset)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset)

In [None]:
# Train

for epoch in range(epochs):
    for X_train_batch, y_train_batch in train_dataloader:
        
        X_train_batch = X_train_batch.to(device)
        y_train_batch = y_train_batch.to(device)
        
        # Forward Pass
        logits = model(X_train_batch)

        loss = criterion(logits, y_train_batch)

        # Backward Pass
        optimizer.zero_grad
        loss.backward()
        optimizer.step()

        losses.append(loss.cpu().item())
    print(f'Epoch {epoch} - loss: {losses[-train_length:]/train_length}')

In [None]:
# Plot loss curve at specified resolution

res = 100
plt.plot(torch.tensor(losses)[:len(losses)-len(losses)%res].view(-1,res).mean(1))

In [None]:
# Save model

os.system('mkdir model')
torch.save(model.state_dict(), 'model/model.pt')