In [1]:
import numpy as np
from sklearn.model_selection import StratifiedKFold
from data_loader import load
from torch import nn
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler

In [2]:
#X_train, y_train: use for training and validating
#stratify_criterion: used to make sure each fold in cross validation has the same annotator agreement distribution
#X_test, y_test: only use for estimating final performance, prof really emphasized this (weird kink but ok)
X_train, y_train, stratify_criterion, X_test, y_test = load()

print(X_train.shape, y_train.shape, stratify_criterion.shape, X_test.shape, y_test.shape)

(96000, 548) (96000,) (96000,) (24000, 548) (24000,)


In [6]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

skf = StratifiedKFold(3)
scaler = StandardScaler()
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
for i, (train_index, validate_index) in enumerate(skf.split(X_train, stratify_criterion)):
    
    print(f"===============================\nFold {i+1}\n===============================")

    X_fold = X_train[train_index]
    scaler.fit(X_fold)
    X_fold = scaler.transform(X_fold)
    X_fold = torch.Tensor(X_fold).to(device)

    y_fold = y_train[train_index]
    y_fold = np.eye(7)[y_fold]
    y_fold = torch.Tensor(y_fold).to(device)

    X_validate = X_train[validate_index]
    X_validate = scaler.transform(X_validate)
    X_validate = torch.Tensor(X_validate).to(device)

    y_validate = y_train[validate_index]
    y_validate = np.eye(7)[y_validate]
    y_validate = torch.Tensor(y_validate).to(device)

    model = nn.Sequential(
        nn.Linear(548, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 7)
    ).to(device)

    train_dataloader = DataLoader(TensorDataset(X_fold, y_fold), batch_size=64)
    test_dataloader = DataLoader(TensorDataset(X_validate, y_validate))
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss().to(device)
    epochs = 15
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer)
        test(test_dataloader, model, loss_fn)
    

#estimate actual performance on best classifier with X_test and y_test


Fold 1
Epoch 1
-------------------------------
loss: 1.904940  [   64/64000]
loss: 0.253511  [ 6464/64000]
loss: 0.444626  [12864/64000]
loss: 0.180425  [19264/64000]
loss: 0.267905  [25664/64000]
loss: 0.021022  [32064/64000]
loss: 0.972238  [38464/64000]
loss: 0.143013  [44864/64000]
loss: 0.007163  [51264/64000]
loss: 0.385664  [57664/64000]
Test Error: 
 Accuracy: 67.0%, Avg loss: 1.216759 

Epoch 2
-------------------------------
loss: 4.567226  [   64/64000]
loss: 0.130201  [ 6464/64000]
loss: 0.358334  [12864/64000]
loss: 0.127221  [19264/64000]
loss: 0.179912  [25664/64000]
loss: 0.012853  [32064/64000]
loss: 0.027096  [38464/64000]
loss: 0.181500  [44864/64000]
loss: 0.004888  [51264/64000]
loss: 0.177110  [57664/64000]
Test Error: 
 Accuracy: 71.9%, Avg loss: 0.990567 

Epoch 3
-------------------------------
loss: 4.320066  [   64/64000]
loss: 0.085550  [ 6464/64000]
loss: 0.280293  [12864/64000]
loss: 0.130426  [19264/64000]
loss: 0.107230  [25664/64000]
loss: 0.006911  [32