In [None]:
import numpy as np
import torch
from torch import nn
import seaborn as sns
import matplotlib.pyplot as plt
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
from datetime import datetime

# Train Model

### Dataloader

In [None]:
WIN_SIZE = 500

class SmokingDataset(Dataset):
    def __init__(self, dir):
        self.dir = dir
    
    def __len__(self):
        return len(os.listdir(self.dir))
    
    def __getitem__(self, key):
        if isinstance(key, slice):
            stop, start, step = key.indices(len(self))

            length = len(range(stop, start, step))
            X = torch.zeros([length, 3*WIN_SIZE])
            y = torch.zeros([length, 1])

            for j,i in enumerate(range(stop, start, step)):
                xi, yi = self[i]
                X[j] = xi
                y[j] = yi

            return (X, y)

        elif isinstance(key, int):
            X, y = torch.load(os.path.join(self.dir, f'{key}.pt'))
            return (X.flatten(), y)

In [None]:
dir = '../data/working-dataset/'
train_dataset = SmokingDataset(f'{dir}/4_all/train/')
test_dataset = SmokingDataset(f'{dir}/4_all/test/')

train_length = len(train_dataset)
test_length = len(test_dataset)

In [None]:
# Get cpu or gpu device for training
device = "cuda:1" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# Define Model
n_hl = 10

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(WIN_SIZE*3, n_hl),
            nn.ReLU(),
            nn.Linear(n_hl, 1)
        )
    
    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits  

model = MLP().to(device)
    
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Define training parameters and preapre dataloaders from saved datasets

epochs = 1
batch_size = 64

losses = []
test_losses = []

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=10000)    # batches for memory

n_train_batches = len(train_dataloader)
n_test_batches = len(test_dataloader)

date = datetime.now().strftime("%m.%d.%y-%H-%M-%S")
os.system(f'mkdir -p results/{date}/training')
os.system(f'mkdir -p results/{date}/model/');

In [None]:
# Train
for epoch in range(epochs):

    print(f'Epoch {epoch} - Training')
    model.train()
    losses.append(0)

    for X_train, y_train in tqdm(train_dataloader):

        X_train = X_train.to(device)
        y_train = y_train.to(device)

        # Forward Pass
        logits = model(X_train)
        loss = criterion(logits, y_train)

        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses[-1] += loss.item()

    losses[-1] /= n_train_batches
    print(f'\tloss: {losses[-1]}')


    ## Test
    print(f'Epoch {epoch} - Testing')
    model.eval()

    preds = []
    n_correct = 0
    test_losses.append(0)

    for X_test, y_test in tqdm(test_dataloader):
        X_test = X_test.to(device)
        y_test = y_test.to(device)

        logits = model(X_test)
        pred = torch.round(nn.Sigmoid()(logits))

        n_correct += sum(y_test == pred)
        preds += pred.flatten().tolist()
        loss = criterion(logits, y_test)
        test_losses[-1] += loss.item()

    test_losses[-1] /= n_test_batches
    accuracy = (n_correct / test_length).item()
    print(f'\tTest Accuracy: {100*accuracy:.4}%')
    print(f'\tTest Loss: {test_losses[-1]}')

    torch.save(model.state_dict(), f'results/{date}/model/model-epoch-{epoch}.pt')

torch.save(losses, f'results/{date}/training/train_losses.pt')
torch.save(test_losses, f'results/{date}/training/test_losses.pt')

In [None]:
# Plot loss curve at specified resolution (save losses and plot)
os.system(f'mkdir -p results/{date}figures')

epoch_toplot = 10
losses_toplot = torch.load(f'results/{date}/training/train_losses.pt')[:epoch_toplot]
test_losses_toplot = torch.load(f'results/{date}/training/test_losses.pt')[:epoch_toplot]

fig, ax = plt.subplots(1)
ax.plot(torch.tensor(losses_toplot), label='Train Loss')
ax.plot(torch.tensor(test_losses_toplot), label='Test Loss')
ax.set_ylabel("Loss")
ax.set_xlabel('Epochs')
ax.legend()
ax.set_title(f'Train and Test Loss over {epoch} Epochs')
fig.set_size_inches(16, 9)
plt.savefig(f'results/{date}/training/train_test_loss-{epoch}-epochs.jpg', dpi=400)