In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import data_utils
import train_utils
import models

%reload_ext autoreload
%autoreload 2

In [6]:
dev = data_utils.COAD_dataset(data_utils.COAD_DEV)
dev_loader = torch.utils.data.DataLoader(dev, batch_size=1, shuffle=True, pin_memory=True)
train = data_utils.COAD_dataset(data_utils.COAD_TRAIN)
train_loader = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True, pin_memory=True)
valid = data_utils.COAD_dataset(data_utils.COAD_VALID)
valid_loader = torch.utils.data.DataLoader(valid, batch_size=1, shuffle=False, pin_memory=True)

In [7]:
n_conv_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = 512
n_rnn_layers = 2
dropout=0.5
gen = models.Generator(n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=dropout)
gen.cuda()

n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
enc = models.ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
enc.cuda()

lamb1 = 0
lamb2 = 0
xent = nn.CrossEntropyLoss()
learning_rate = 1e-4
temp = 10
params = list(enc.parameters()) + list(gen.parameters())
optimizer = torch.optim.Adam(params, lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, min_lr=1e-6)

In [8]:
def pool_fn(x):
    #v,a = torch.max(x,0)
    v = torch.mean(x,0)
    return v

In [None]:
for e in range(800,1500):
    train_utils.rationales_training_loop_GS(e, train_loader, gen, enc, pool_fn, lamb1, lamb2, xent, learning_rate, optimizer,temp)
    if e > 30:
        lamb1 += 0.001
        temp -= 0.25
    temp = np.max([temp,1])
    lamb1 = np.min([lamb1,1.0])
    if e % 5 == 0:
        print('Lambda: {0:0.5f}, LR: {1:0.7f}, Temperature: {2:0.2f}'.format(lamb1, optimizer.state_dict()['param_groups'][0]['lr'],temp))
    frac_tiles = train_utils.rationales_validation_loop_GS(e, valid_loader, gen, enc, pool_fn, xent, scheduler)
    if frac_tiles < 0.9:
        break

Epoch: 800, Train Loss: 13.519963, Train Omega: 19.3704, Fraction of Tiles: 0.9923
Lambda: 0.40100, LR: 0.0000010, Temperature: 1.00
Epoch: 800, Val Loss: 8.7677, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 801, Train Loss: 13.527605, Train Omega: 19.3962, Fraction of Tiles: 0.9921
Epoch: 801, Val Loss: 8.7641, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 802, Train Loss: 13.464378, Train Omega: 19.4592, Fraction of Tiles: 0.9925
Epoch: 802, Val Loss: 8.7693, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 803, Train Loss: 12.891794, Train Omega: 19.5746, Fraction of Tiles: 0.9925
Epoch: 803, Val Loss: 8.7464, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 804, Train Loss: 13.644829, Train Omega: 19.5732, Fraction of Tiles: 0.9924
Epoch: 804, Val Loss: 8.7415, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 805, Train Loss: 13.619876, Train Omega: 19.6381, Fraction of Tiles: 0.9924
Lambda: 0.40600, LR: 0.0000010, Temperature: 1.00
Epoch: 805, Val Loss: 8.7656, Val Acc

In [88]:
train_utils.rationales_validation_loop_GS(e, train_loader, gen, enc, pool_fn, xent, scheduler)

Epoch: 399, Val Loss: 16.3430, Val Acc: 0.8049, Fraction of Tiles: 0.0228


In [11]:
torch.save(gen.state_dict(),'generator.pt')
torch.save(enc.state_dict(),'encoder.pt')