In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math

import sys
sys.path.append('src/')
from train_utils import train
from get_model import Network
from get_data import get_data
from get_data_wrapper import TripleDataset

import matplotlib.pyplot as plt
%matplotlib inline

# torch.cuda.is_available()
# torch.backends.cudnn.benchmark = True

# Load normalized data

In [None]:
X_test, Y_test, X_val, Y_val, X_train, Y_train = get_data()

# Create iterators

In [None]:
train_data = TripleDataset(
    torch.FloatTensor(X_train), 
    torch.LongTensor(Y_train), 
    torch.ones((len(Y_train), 1))
)

val_data = TripleDataset(
    torch.FloatTensor(X_val), 
    torch.LongTensor(Y_val), 
    torch.ones((len(Y_val), 1))
)

In [None]:
batch_size = 128

In [None]:
train_iterator = DataLoader(
    train_data, batch_size=batch_size, shuffle=True
)

val_iterator = DataLoader(
    val_data, batch_size=128, shuffle=False
)

n_train_samples = len(train_data)
n_train_samples

# Model

In [None]:
class weighted_loss(nn.Module):

    def __init__(self):
        super(weighted_loss, self).__init__()
        
    def forward(self, logits, targets, sample_weights):
        
        x = F.log_softmax(logits)
        x = torch.gather(x, 1, targets.view(-1, 1))        
        
        # return usual logloss and weighted logloss
        return -x.mean(0), -(x*sample_weights).mean(0)

In [None]:
model = Network(
    input_dim=54, num_classes=7, 
    architecture=[100, 100], 
    dropout=[0.1, 0.1]
)
# model.cuda();

In [None]:
weights = [
    p for n, p in model.named_parameters()
    if len(p.size()) == 2
]
biases = [
    model.classifier[1].bias
]
bn_weights = [
    p for n, p in model.named_parameters()
    if 'bn.weight' in n
]
bn_biases = [
    p for n, p in model.named_parameters()
    if 'bn.bias' in n
]

In [None]:
criterion = weighted_loss()

params = [
    {'params': weights, 'weight_decay': 1e-4}, 
    {'params': biases + bn_weights + bn_biases}
]
optimizer = optim.Adam(params, lr=1e-3)

# Train

In [None]:
n_epochs = 50
validation_step = 100
reweight_epoch = 6
n_batches = math.ceil(n_train_samples/batch_size)

# total number of batches in the train set
n_batches

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

lr_scheduler = ReduceLROnPlateau(
    optimizer, mode='max', factor=0.1, patience=10, 
    verbose=True, threshold=0.01, threshold_mode='abs'
)

In [None]:
all_losses = train(
    model, criterion, optimizer, 
    train_iterator, n_epochs, steps_per_epoch=n_batches, 
    val_iterator=val_iterator, n_validation_batches=28,
    reweight_epoch=100,
    patience=10, threshold=0.01, lr_scheduler=lr_scheduler
)

# Loss/epoch

In [None]:
plt.plot([x[0] for x in all_losses], label='train');
plt.plot([x[1] for x in all_losses], label='test');
plt.legend();
plt.xlabel('epoch');
plt.ylabel('loss');

In [None]:
plt.plot([x[4] for x in all_losses], label='train');
plt.plot([x[5] for x in all_losses], label='test');
plt.legend();
plt.xlabel('epoch');
plt.ylabel('accuracy');