In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import data_utils
import train_utils
import models

%reload_ext autoreload
%autoreload 2

In [6]:
dev = data_utils.COAD_dataset(data_utils.COAD_DEV)
dev_loader = torch.utils.data.DataLoader(dev, batch_size=1, shuffle=True, pin_memory=True)
train = data_utils.COAD_dataset(data_utils.COAD_TRAIN)
train_loader = torch.utils.data.DataLoader(train, batch_size=1, shuffle=True, pin_memory=True)
valid = data_utils.COAD_dataset(data_utils.COAD_VALID)
valid_loader = torch.utils.data.DataLoader(valid, batch_size=1, shuffle=False, pin_memory=True)

In [7]:
n_conv_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = 512
n_rnn_layers = 2
dropout=0.5
gen = models.Generator(n_conv_layers, kernel_size, n_conv_filters, hidden_size, n_rnn_layers, dropout=dropout)
gen.cuda()

n_conv_layers = 2
n_fc_layers = 2
kernel_size = [4,3]
n_conv_filters = [36,48]
hidden_size = [512,512]
dropout=0.5
enc = models.ConvNet(n_conv_layers, n_fc_layers, kernel_size, n_conv_filters, hidden_size, dropout=dropout)
enc.cuda()

lamb1 = 0
lamb2 = 0
xent = nn.CrossEntropyLoss()
learning_rate = 1e-4
temp = 10
params = list(enc.parameters()) + list(gen.parameters())
optimizer = torch.optim.Adam(params, lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, min_lr=1e-6)

In [8]:
def pool_fn(x):
    #v,a = torch.max(x,0)
    v = torch.mean(x,0)
    return v

In [14]:
for e in range(800,1500):
    train_utils.rationales_training_loop_GS(e, train_loader, gen, enc, pool_fn, lamb1, lamb2, xent, learning_rate, optimizer,temp)
    if e > 30:
        lamb1 += 0.001
        temp -= 0.25
    temp = np.max([temp,1])
    lamb1 = np.min([lamb1,1.0])
    if e % 5 == 0:
        print('Lambda: {0:0.5f}, LR: {1:0.7f}, Temperature: {2:0.2f}'.format(lamb1, optimizer.state_dict()['param_groups'][0]['lr'],temp))
    frac_tiles = train_utils.rationales_validation_loop_GS(e, valid_loader, gen, enc, pool_fn, xent, scheduler)
    if frac_tiles < 0.9:
        break

Epoch: 800, Train Loss: 13.519963, Train Omega: 19.3704, Fraction of Tiles: 0.9923
Lambda: 0.40100, LR: 0.0000010, Temperature: 1.00
Epoch: 800, Val Loss: 8.7677, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 801, Train Loss: 13.527605, Train Omega: 19.3962, Fraction of Tiles: 0.9921
Epoch: 801, Val Loss: 8.7641, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 802, Train Loss: 13.464378, Train Omega: 19.4592, Fraction of Tiles: 0.9925
Epoch: 802, Val Loss: 8.7693, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 803, Train Loss: 12.891794, Train Omega: 19.5746, Fraction of Tiles: 0.9925
Epoch: 803, Val Loss: 8.7464, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 804, Train Loss: 13.644829, Train Omega: 19.5732, Fraction of Tiles: 0.9924
Epoch: 804, Val Loss: 8.7415, Val Acc: 1.0000, Fraction of Tiles: 0.9918
Epoch: 805, Train Loss: 13.619876, Train Omega: 19.6381, Fraction of Tiles: 0.9924
Lambda: 0.40600, LR: 0.0000010, Temperature: 1.00
Epoch: 805, Val Loss: 8.7656, Val Acc

Epoch: 849, Val Loss: 8.7244, Val Acc: 0.9796, Fraction of Tiles: 0.9920
Epoch: 850, Train Loss: 12.970561, Train Omega: 21.7599, Fraction of Tiles: 0.9920
Lambda: 0.45100, LR: 0.0000010, Temperature: 1.00
Epoch: 850, Val Loss: 8.7033, Val Acc: 1.0000, Fraction of Tiles: 0.9919
Epoch: 851, Train Loss: 12.921928, Train Omega: 21.8763, Fraction of Tiles: 0.9921
Epoch: 851, Val Loss: 8.7109, Val Acc: 0.9796, Fraction of Tiles: 0.9919
Epoch: 852, Train Loss: 13.586502, Train Omega: 21.8891, Fraction of Tiles: 0.9922
Epoch: 852, Val Loss: 8.6923, Val Acc: 0.9796, Fraction of Tiles: 0.9919
Epoch: 853, Train Loss: 13.723776, Train Omega: 21.9342, Fraction of Tiles: 0.9922
Epoch: 853, Val Loss: 8.7261, Val Acc: 0.9796, Fraction of Tiles: 0.9919
Epoch: 854, Train Loss: 13.730603, Train Omega: 22.0097, Fraction of Tiles: 0.9925
Epoch: 854, Val Loss: 8.7227, Val Acc: 0.9796, Fraction of Tiles: 0.9920
Epoch: 855, Train Loss: 13.187430, Train Omega: 22.1124, Fraction of Tiles: 0.9924
Lambda: 0.4560

Epoch: 899, Train Loss: 13.539145, Train Omega: 24.2120, Fraction of Tiles: 0.9921
Epoch: 899, Val Loss: 8.7218, Val Acc: 0.9592, Fraction of Tiles: 0.9916
Epoch: 900, Train Loss: 13.077727, Train Omega: 24.2614, Fraction of Tiles: 0.9925
Lambda: 0.50100, LR: 0.0000010, Temperature: 1.00
Epoch: 900, Val Loss: 8.6974, Val Acc: 0.9796, Fraction of Tiles: 0.9916
Epoch: 901, Train Loss: 13.262290, Train Omega: 24.2364, Fraction of Tiles: 0.9922
Epoch: 901, Val Loss: 8.6748, Val Acc: 1.0000, Fraction of Tiles: 0.9915
Epoch: 902, Train Loss: 12.500979, Train Omega: 24.3815, Fraction of Tiles: 0.9923
Epoch: 902, Val Loss: 8.6918, Val Acc: 0.9796, Fraction of Tiles: 0.9914
Epoch: 903, Train Loss: 13.189882, Train Omega: 24.3101, Fraction of Tiles: 0.9923
Epoch: 903, Val Loss: 8.6327, Val Acc: 1.0000, Fraction of Tiles: 0.9914
Epoch: 904, Train Loss: 12.862827, Train Omega: 24.4640, Fraction of Tiles: 0.9925
Epoch: 904, Val Loss: 8.6456, Val Acc: 1.0000, Fraction of Tiles: 0.9914
Epoch: 905, Tr

Epoch: 948, Val Loss: 8.6345, Val Acc: 0.9796, Fraction of Tiles: 0.9913
Epoch: 949, Train Loss: 12.790853, Train Omega: 26.5210, Fraction of Tiles: 0.9920
Epoch: 949, Val Loss: 8.6543, Val Acc: 0.9796, Fraction of Tiles: 0.9910
Epoch: 950, Train Loss: 13.464854, Train Omega: 26.5755, Fraction of Tiles: 0.9919
Lambda: 0.55100, LR: 0.0000010, Temperature: 1.00
Epoch: 950, Val Loss: 8.6555, Val Acc: 0.9796, Fraction of Tiles: 0.9910
Epoch: 951, Train Loss: 13.027010, Train Omega: 26.6988, Fraction of Tiles: 0.9919
Epoch: 951, Val Loss: 8.6135, Val Acc: 1.0000, Fraction of Tiles: 0.9908
Epoch: 952, Train Loss: 12.894250, Train Omega: 26.6979, Fraction of Tiles: 0.9923
Epoch: 952, Val Loss: 8.5935, Val Acc: 1.0000, Fraction of Tiles: 0.9908
Epoch: 953, Train Loss: 13.793112, Train Omega: 26.7980, Fraction of Tiles: 0.9925
Epoch: 953, Val Loss: 8.6271, Val Acc: 1.0000, Fraction of Tiles: 0.9908
Epoch: 954, Train Loss: 12.906662, Train Omega: 26.7895, Fraction of Tiles: 0.9920
Epoch: 954, Va

Epoch: 998, Train Loss: 13.447588, Train Omega: 28.9318, Fraction of Tiles: 0.9920
Epoch: 998, Val Loss: 8.6650, Val Acc: 0.9796, Fraction of Tiles: 0.9907
Epoch: 999, Train Loss: 13.625788, Train Omega: 28.9619, Fraction of Tiles: 0.9918
Epoch: 999, Val Loss: 8.6879, Val Acc: 0.9796, Fraction of Tiles: 0.9907
Epoch: 1000, Train Loss: 12.701841, Train Omega: 28.9781, Fraction of Tiles: 0.9921
Lambda: 0.60100, LR: 0.0000010, Temperature: 1.00
Epoch: 1000, Val Loss: 8.6605, Val Acc: 0.9796, Fraction of Tiles: 0.9907
Epoch: 1001, Train Loss: 13.246277, Train Omega: 28.9992, Fraction of Tiles: 0.9917
Epoch: 1001, Val Loss: 8.6682, Val Acc: 0.9796, Fraction of Tiles: 0.9907
Epoch: 1002, Train Loss: 13.271617, Train Omega: 29.0649, Fraction of Tiles: 0.9917
Epoch: 1002, Val Loss: 8.6875, Val Acc: 0.9796, Fraction of Tiles: 0.9907
Epoch: 1003, Train Loss: 13.149003, Train Omega: 29.1830, Fraction of Tiles: 0.9922
Epoch: 1003, Val Loss: 8.6446, Val Acc: 0.9796, Fraction of Tiles: 0.9907
Epoch:

Epoch: 1047, Train Loss: 12.859121, Train Omega: 31.3052, Fraction of Tiles: 0.9916
Epoch: 1047, Val Loss: 8.6358, Val Acc: 0.9592, Fraction of Tiles: 0.9904
Epoch: 1048, Train Loss: 12.942370, Train Omega: 31.2724, Fraction of Tiles: 0.9915
Epoch: 1048, Val Loss: 8.6500, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1049, Train Loss: 13.615425, Train Omega: 31.3648, Fraction of Tiles: 0.9917
Epoch: 1049, Val Loss: 8.6469, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1050, Train Loss: 13.465986, Train Omega: 31.3544, Fraction of Tiles: 0.9914
Lambda: 0.65100, LR: 0.0000010, Temperature: 1.00
Epoch: 1050, Val Loss: 8.6215, Val Acc: 0.9796, Fraction of Tiles: 0.9903
Epoch: 1051, Train Loss: 13.020616, Train Omega: 31.4817, Fraction of Tiles: 0.9918
Epoch: 1051, Val Loss: 8.6017, Val Acc: 0.9796, Fraction of Tiles: 0.9903
Epoch: 1052, Train Loss: 13.380962, Train Omega: 31.5709, Fraction of Tiles: 0.9917
Epoch: 1052, Val Loss: 8.6100, Val Acc: 0.9796, Fraction of Tiles: 0.9903
Ep

Epoch: 1096, Train Loss: 12.943488, Train Omega: 33.5467, Fraction of Tiles: 0.9910
Epoch: 1096, Val Loss: 8.6468, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1097, Train Loss: 13.495075, Train Omega: 33.5629, Fraction of Tiles: 0.9913
Epoch: 1097, Val Loss: 8.6653, Val Acc: 0.9796, Fraction of Tiles: 0.9903
Epoch: 1098, Train Loss: 12.859590, Train Omega: 33.6897, Fraction of Tiles: 0.9917
Epoch: 1098, Val Loss: 8.6997, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1099, Train Loss: 12.958753, Train Omega: 33.7902, Fraction of Tiles: 0.9917
Epoch: 1099, Val Loss: 8.6547, Val Acc: 0.9796, Fraction of Tiles: 0.9903
Epoch: 1100, Train Loss: 13.185972, Train Omega: 33.7207, Fraction of Tiles: 0.9913
Lambda: 0.70100, LR: 0.0000010, Temperature: 1.00
Epoch: 1100, Val Loss: 8.6534, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1101, Train Loss: 12.994440, Train Omega: 33.8365, Fraction of Tiles: 0.9914
Epoch: 1101, Val Loss: 8.6774, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Ep

Epoch: 1145, Val Loss: 8.7029, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1146, Train Loss: 12.804286, Train Omega: 35.9561, Fraction of Tiles: 0.9910
Epoch: 1146, Val Loss: 8.6263, Val Acc: 0.9796, Fraction of Tiles: 0.9903
Epoch: 1147, Train Loss: 12.515285, Train Omega: 36.0756, Fraction of Tiles: 0.9913
Epoch: 1147, Val Loss: 8.6631, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1148, Train Loss: 12.242423, Train Omega: 36.1092, Fraction of Tiles: 0.9915
Epoch: 1148, Val Loss: 8.6451, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1149, Train Loss: 12.818867, Train Omega: 36.1397, Fraction of Tiles: 0.9911
Epoch: 1149, Val Loss: 8.6066, Val Acc: 0.9796, Fraction of Tiles: 0.9903
Epoch: 1150, Train Loss: 13.033856, Train Omega: 36.1411, Fraction of Tiles: 0.9912
Lambda: 0.75100, LR: 0.0000010, Temperature: 1.00
Epoch: 1150, Val Loss: 8.6028, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1151, Train Loss: 13.346181, Train Omega: 36.1776, Fraction of Tiles: 0.9908
Ep

Epoch: 1195, Train Loss: 12.553575, Train Omega: 38.3312, Fraction of Tiles: 0.9909
Lambda: 0.79600, LR: 0.0000010, Temperature: 1.00
Epoch: 1195, Val Loss: 8.6591, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1196, Train Loss: 12.931449, Train Omega: 38.4479, Fraction of Tiles: 0.9910
Epoch: 1196, Val Loss: 8.6971, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1197, Train Loss: 13.026729, Train Omega: 38.3606, Fraction of Tiles: 0.9908
Epoch: 1197, Val Loss: 8.6678, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1198, Train Loss: 12.283889, Train Omega: 38.3150, Fraction of Tiles: 0.9905
Epoch: 1198, Val Loss: 8.6909, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1199, Train Loss: 12.604963, Train Omega: 38.4711, Fraction of Tiles: 0.9906
Epoch: 1199, Val Loss: 8.6651, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1200, Train Loss: 13.179156, Train Omega: 38.5144, Fraction of Tiles: 0.9909
Lambda: 0.80100, LR: 0.0000010, Temperature: 1.00
Epoch: 1200, Val Loss: 8.7

Epoch: 1244, Train Loss: 12.931780, Train Omega: 40.4986, Fraction of Tiles: 0.9902
Epoch: 1244, Val Loss: 8.6957, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1245, Train Loss: 12.567958, Train Omega: 40.3809, Fraction of Tiles: 0.9897
Lambda: 0.84600, LR: 0.0000010, Temperature: 1.00
Epoch: 1245, Val Loss: 8.7072, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1246, Train Loss: 13.103680, Train Omega: 40.6134, Fraction of Tiles: 0.9901
Epoch: 1246, Val Loss: 8.7481, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1247, Train Loss: 13.045317, Train Omega: 40.5680, Fraction of Tiles: 0.9900
Epoch: 1247, Val Loss: 8.7528, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1248, Train Loss: 12.713680, Train Omega: 40.6120, Fraction of Tiles: 0.9897
Epoch: 1248, Val Loss: 8.7433, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Epoch: 1249, Train Loss: 12.835609, Train Omega: 40.7443, Fraction of Tiles: 0.9893
Epoch: 1249, Val Loss: 8.7731, Val Acc: 0.9592, Fraction of Tiles: 0.9903
Ep

Epoch: 1293, Train Loss: 12.822269, Train Omega: 42.3637, Fraction of Tiles: 0.9877
Epoch: 1293, Val Loss: 8.9651, Val Acc: 0.9592, Fraction of Tiles: 0.9886
Epoch: 1294, Train Loss: 12.613388, Train Omega: 42.4381, Fraction of Tiles: 0.9878
Epoch: 1294, Val Loss: 8.9412, Val Acc: 0.9592, Fraction of Tiles: 0.9883
Epoch: 1295, Train Loss: 12.292401, Train Omega: 42.4217, Fraction of Tiles: 0.9878
Lambda: 0.89600, LR: 0.0000010, Temperature: 1.00
Epoch: 1295, Val Loss: 8.8821, Val Acc: 0.9796, Fraction of Tiles: 0.9881
Epoch: 1296, Train Loss: 12.814820, Train Omega: 42.5498, Fraction of Tiles: 0.9877
Epoch: 1296, Val Loss: 8.8962, Val Acc: 0.9796, Fraction of Tiles: 0.9881
Epoch: 1297, Train Loss: 12.198462, Train Omega: 42.5353, Fraction of Tiles: 0.9877
Epoch: 1297, Val Loss: 8.8988, Val Acc: 0.9796, Fraction of Tiles: 0.9880
Epoch: 1298, Train Loss: 12.226636, Train Omega: 42.6511, Fraction of Tiles: 0.9877
Epoch: 1298, Val Loss: 8.9247, Val Acc: 0.9796, Fraction of Tiles: 0.9880
Ep

Epoch: 1342, Train Loss: 12.839834, Train Omega: 44.3855, Fraction of Tiles: 0.9868
Epoch: 1342, Val Loss: 9.2085, Val Acc: 0.9796, Fraction of Tiles: 0.9867
Epoch: 1343, Train Loss: 12.601609, Train Omega: 44.4657, Fraction of Tiles: 0.9863
Epoch: 1343, Val Loss: 9.1573, Val Acc: 0.9796, Fraction of Tiles: 0.9867
Epoch: 1344, Train Loss: 12.607564, Train Omega: 44.4492, Fraction of Tiles: 0.9860
Epoch: 1344, Val Loss: 9.2399, Val Acc: 0.9796, Fraction of Tiles: 0.9867
Epoch: 1345, Train Loss: 12.244568, Train Omega: 44.5469, Fraction of Tiles: 0.9867
Lambda: 0.94600, LR: 0.0000010, Temperature: 1.00
Epoch: 1345, Val Loss: 9.1852, Val Acc: 0.9796, Fraction of Tiles: 0.9867
Epoch: 1346, Train Loss: 12.277535, Train Omega: 44.4334, Fraction of Tiles: 0.9861
Epoch: 1346, Val Loss: 9.1798, Val Acc: 0.9796, Fraction of Tiles: 0.9867
Epoch: 1347, Train Loss: 12.510987, Train Omega: 44.6471, Fraction of Tiles: 0.9867
Epoch: 1347, Val Loss: 9.2101, Val Acc: 0.9796, Fraction of Tiles: 0.9867
Ep

Epoch: 1391, Train Loss: 12.548576, Train Omega: 46.4852, Fraction of Tiles: 0.9848
Epoch: 1391, Val Loss: 9.5325, Val Acc: 0.9592, Fraction of Tiles: 0.9853
Epoch: 1392, Train Loss: 12.574088, Train Omega: 46.1790, Fraction of Tiles: 0.9839
Epoch: 1392, Val Loss: 9.5285, Val Acc: 0.9592, Fraction of Tiles: 0.9850
Epoch: 1393, Train Loss: 12.255608, Train Omega: 46.2087, Fraction of Tiles: 0.9840
Epoch: 1393, Val Loss: 9.5920, Val Acc: 0.9592, Fraction of Tiles: 0.9848
Epoch: 1394, Train Loss: 12.420140, Train Omega: 46.2224, Fraction of Tiles: 0.9844
Epoch: 1394, Val Loss: 9.5909, Val Acc: 0.9592, Fraction of Tiles: 0.9846
Epoch: 1395, Train Loss: 13.034850, Train Omega: 46.3795, Fraction of Tiles: 0.9840
Lambda: 0.99600, LR: 0.0000010, Temperature: 1.00
Epoch: 1395, Val Loss: 9.5419, Val Acc: 0.9592, Fraction of Tiles: 0.9845
Epoch: 1396, Train Loss: 12.824154, Train Omega: 46.4426, Fraction of Tiles: 0.9840
Epoch: 1396, Val Loss: 9.6415, Val Acc: 0.9592, Fraction of Tiles: 0.9846
Ep

Epoch: 1440, Train Loss: 13.140446, Train Omega: 46.1155, Fraction of Tiles: 0.9821
Lambda: 1.00000, LR: 0.0000010, Temperature: 1.00
Epoch: 1440, Val Loss: 10.1101, Val Acc: 0.9592, Fraction of Tiles: 0.9828
Epoch: 1441, Train Loss: 12.801949, Train Omega: 45.8217, Fraction of Tiles: 0.9812
Epoch: 1441, Val Loss: 10.1206, Val Acc: 0.9388, Fraction of Tiles: 0.9826
Epoch: 1442, Train Loss: 12.915710, Train Omega: 45.8653, Fraction of Tiles: 0.9817
Epoch: 1442, Val Loss: 10.1891, Val Acc: 0.9388, Fraction of Tiles: 0.9826
Epoch: 1443, Train Loss: 12.845046, Train Omega: 45.7421, Fraction of Tiles: 0.9809
Epoch: 1443, Val Loss: 10.1996, Val Acc: 0.9388, Fraction of Tiles: 0.9825
Epoch: 1444, Train Loss: 13.044149, Train Omega: 45.8608, Fraction of Tiles: 0.9815
Epoch: 1444, Val Loss: 10.2575, Val Acc: 0.9388, Fraction of Tiles: 0.9826
Epoch: 1445, Train Loss: 12.869024, Train Omega: 45.7423, Fraction of Tiles: 0.9817
Lambda: 1.00000, LR: 0.0000010, Temperature: 1.00
Epoch: 1445, Val Loss

Epoch: 1488, Val Loss: 11.2200, Val Acc: 0.9184, Fraction of Tiles: 0.9739
Epoch: 1489, Train Loss: 13.597029, Train Omega: 43.2579, Fraction of Tiles: 0.9720
Epoch: 1489, Val Loss: 11.3494, Val Acc: 0.8980, Fraction of Tiles: 0.9739
Epoch: 1490, Train Loss: 13.178107, Train Omega: 43.2588, Fraction of Tiles: 0.9716
Lambda: 1.00000, LR: 0.0000010, Temperature: 1.00
Epoch: 1490, Val Loss: 11.3596, Val Acc: 0.8980, Fraction of Tiles: 0.9734
Epoch: 1491, Train Loss: 13.291368, Train Omega: 43.0748, Fraction of Tiles: 0.9713
Epoch: 1491, Val Loss: 11.3858, Val Acc: 0.8980, Fraction of Tiles: 0.9727
Epoch: 1492, Train Loss: 13.467972, Train Omega: 43.0901, Fraction of Tiles: 0.9712
Epoch: 1492, Val Loss: 11.4804, Val Acc: 0.8980, Fraction of Tiles: 0.9729
Epoch: 1493, Train Loss: 13.921758, Train Omega: 43.0752, Fraction of Tiles: 0.9710
Epoch: 1493, Val Loss: 11.4774, Val Acc: 0.8980, Fraction of Tiles: 0.9733
Epoch: 1494, Train Loss: 13.273874, Train Omega: 42.8866, Fraction of Tiles: 0.9

In [88]:
train_utils.rationales_validation_loop_GS(e, train_loader, gen, enc, pool_fn, xent, scheduler)

Epoch: 399, Val Loss: 16.3430, Val Acc: 0.8049, Fraction of Tiles: 0.0228


In [15]:
torch.save(gen.state_dict(),'generator.pt')
torch.save(enc.state_dict(),'encoder.pt')