In [1]:
from dataloader import create_split_loaders
from model import GenericRNN
import torch

In [2]:
import torch.nn as nn
import copy
config = {'chunk_size':100, 'type_number':93, 'hidden':100, 
          'learning_rate':0.001, 'early_stop':False, 'increase_limit':3, 'epoch_num':1, 'N':50, 'M':100}

def train(config):
    # the size of every chunk
    chunk_size = config['chunk_size']
    # number of types, it is a constant
    type_number = config['type_number']
    # number of features in hidden layer
    hidden = config['hidden']
    # learning rate
    learning_rate = config['learning_rate']
    # whether we use early stop
    early_stop = config['early_stop']
    # after validation loss increase how many times do we stop training
    increase_limit = config['increase_limit']
    # number of epoch
    epoch_num = config['epoch_num']
    # receive train, validation, test data
    train, valid, test, c_to, one_to = create_split_loaders(chunk_size)
    # construct network
    net = GenericRNN(type_number, hidden, type_number)
    # use cross entropy loss
    criterion = nn.BCELoss()
    # keep tracking of the traininig loss
    training_record = []
    # keep tracking of the validation loss
    validation_record = []
    # Using Adam
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    last_valid = float('inf')
    best_net = -1
    # after how many batches do we record the training loss
    N = config['N']
    # after how many batches do we examine whether validation loss increase
    M = config['M']
    # store best loss
    best_loss = float('inf')
    inresement = 0
    for epoch in range(epoch_num):
        count = 0
        for minibatch in train:
            count += 1
            predict_all = torch.zeros(chunk_size, type_number)
            target_all = torch.zeros(chunk_size, type_number)
            optimizer.zero_grad()
            for ii in range(chunk_size):
                train_batch = torch.zeros(1, 1, type_number)
                train_batch[0] = minibatch[0][ii]
                target = minibatch[1][ii]
                if ii == 0:
                    predict = net.predict(train_batch)
                else:
                    # teacher forcing
                    teacher = torch.ones(1, 1, type_number)
                    teacher[0] = minibatch[1][ii - 1]
                    predict = net.predict(train_batch, teacher)
                predict_all[ii] = predict
                target_all[ii] = target
            # validation and early stop
            if early_stop:
                if count % M == 0:
                    loss_val = 0
                    count = 0
                    for val in validation:
                        count += 1
                        predict_valid = torch.zeros(chunk_size, type_number)
                        target_valid = torch.zeros(chunk_size, type_number)
                        for ii in range(chunk_size):
                            valid_batch = torch.zeros(1, 1, type_number)
                            valid_batch[0] = val[0][ii]
                            target = val[1][ii]
                            predict = net.predict(valid_batch)
                            predict_valid[ii] = predict
                            target_valid[ii] = target
                        loss_val += criterion(predict_valid, target_valid)
                    loss_val /= count
                    validation_reocrd.append(loss_val.item())
                    if loss_val > last_valid:
                        increasement += 1
                    else:
                        increasement = 0
                    if loss_val < best_loss:
                        best_loss = loss_val
                        best_net = copy.deepcopy(net)
                    last_valid = loss_val
                    if increasement >= increase_limit:
                        break
            # calculate loss
            loss = criterion(predict_all, target_all)
            loss.backward()
            optimizer.step()
            if count % N == 0:
                training_record.append(loss.item())
#             print(loss.item())
        if early_stop:
            if increasement >= increase_limit:
                break
    if early_stop:
        net = best_net
    return net

In [3]:
train(config)

0.059479132294654846
0.05934030935168266
0.059313517063856125
0.05922171473503113
0.05927309766411781
0.0592011958360672
0.05923322215676308
0.059060920029878616
0.059357188642024994
0.05927383154630661
0.05931169539690018
0.05945882946252823
0.059438858181238174
0.059393420815467834
0.059231050312519073
0.05929296836256981
0.05922679603099823
0.059009380638599396
0.05899554491043091
0.0589991994202137
0.059305012226104736
0.059058114886283875
0.05919048190116882
0.05907919630408287
0.058892153203487396
0.059055425226688385
0.05888177827000618
0.05883494392037392
0.058776021003723145
0.058827437460422516
0.058665599673986435
0.05865735933184624
0.058535136282444
0.058328110724687576
0.05822577327489853
0.058402158319950104
0.058595843613147736
0.05845176801085472
0.05838976055383682
0.058516036719083786
0.05869276449084282
0.05852101370692253
0.058448418974876404
0.058007463812828064
0.05794146656990051
0.057804711163043976
0.05756378546357155
0.057471390813589096
0.05780388414859772
0

KeyboardInterrupt: 