In [3]:
import numpy as np
import sys
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
from torch.utils.data import Dataset, DataLoader, TensorDataset

In [4]:
vocabulary_size = 2000
BATCH_SIZE = 64

In [14]:
X = np.load('../data/X_train.npy')
y = np.load('../data/y_train.npy')

In [5]:
class RedditDataSet(Dataset):

    def __init__(self, lines, labels):
        self.lines = [torch.tensor(line) for line in lines]
        self.labels = [torch.tensor(label) for label in labels]

    def __getitem__(self, i):
        line = self.lines[i]
        label = self.labels[i]

        return line, label

    def __len__(self):
        return len(self.lines)

def collate(seq_list):
    inputs, labels = zip(*seq_list)
    lens = [len(seq) for seq in inputs]
    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)
    inputs = [inputs[i] for i in seq_order]
    labels = [labels[i] for i in seq_order]

    return inputs, labels

In [70]:
class DataSetCreater:

    def __init__(self, x_train):
        self.x = x_train

        self.x_data = None
        self.y_data = None

    def build(self):
        self.x_data = []
        self.y_data = []
        print("Building...")
        for data in self.x:
            tmp_x, tmp_y = self.build_helper(data)
            self.x_data.append(tmp_x)
            self.y_data.append(tmp_y)

        print("Finishing Building")
        return RedditDataSet(self.x_data, self.y_data)

    def build_helper(self, data):
        data.append(1)
        result = np.zeros((len(data), vocabulary_size))
        for i, value in enumerate(data):
            result[i][value] = 1

        return torch.FloatTensor(result[:-1]), torch.IntTensor(data[1:])

In [71]:
d = DataSetCreater(x_train=X[:int(len(X) * 0.9)])
train_dataset = d.build()

Building...
Finishing Building


In [72]:
d = DataSetCreater(x_train=X[int(len(X) * 0.9):])
val_dataset = d.build()

Building...
Finishing Building


In [73]:
train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=False)
val_data_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate, shuffle=False)

In [74]:
class RNN(nn.Module):
    
    def __init__(self, nlayers, input_size, hidden_size, output_size, rnn_type='LSTM'):
        super(RNN, self).__init__()
        self.nlayers = nlayers
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        if rnn_type == "LSTM":
            self.cell = nn.LSTM
        else:
            self.cell = nn.GRU
            
        self.rnn = nn.LSTM(input_size=self.input_size, hidden_size=self.output_size)
        
    def forward(self, seq_list):
        batch_size = len(seq_list)
        lens = [len(s) for s in seq_list]
        max_len = lens[0]
        
        input_padded = rnn.pad_sequence(seq_list, batch_first=True)
        input_pack = rnn.pack_padded_sequence(input_padded, batch_first=True, lengths=lens)
        
        output_pack, _ = self.rnn(input_pack)
        
        output, _ = rnn.pad_packed_sequence(output_pack)
        
        return output, lens

In [180]:
class Trainer:
    
    def __init__(self, input_size, output_size, train_data_loader, val_data_loader, nlayers, cell_type='LSTM'):
        self.model = RNN(input_size=input_size, hidden_size=output_size, output_size=output_size, nlayers=nlayers, rnn_type=cell_type)
        
        self.train_data_loader = train_data_loader
        self.val_data_loader = val_data_loader
        self.optimizer = torch.optim.SGD(params=self.model.parameters(), lr=0.00002)
        self.criterion = torch.nn.CrossEntropyLoss()
        
    def train(self, epoches):
        print("Start Training...")
        for epoch in range(epoches):
            self.model.train()
            loss_rec = 0
            for batch_size, (data, target) in enumerate(self.train_data_loader):
                self.optimizer.zero_grad()
                output, lens = self.model(data)

                preds = [output[:lens[i], i] for i in range(len(data))]
                preds = torch.cat(preds, dim=0)

                target = torch.cat(target, dim=0).type(torch.LongTensor)
                loss = self.criterion(preds, target)
                loss_rec += loss.item()
                loss.backward()
    
            print("At epoch {} the loss is {}", str(epoch), str(loss_rec))
            
#             if epoch % 2 == 0:
# #                 self.model.eval()
# #                 loss_rec = 0
# #                 accuracy = 1
# #                 total = 0
# #                 for batch_size, (data, target) in enumerate(self.val_data_loader):
# #                     output, lens = self.model(data)
# #                     preds = [torch.argmax(output[:lens[i], i], dim=1).type(torch.int32) for i in range(len(data))]
# #                     total += len(data)
# #                     for i in range(len(data)):
# #                         accuracy += torch.sum(preds[i] == target[i])

# #                 accuracy /= total
# #                 print("At epoch {}, the accuracy is {}".format(str(epoch), str(accuracy.item())))

In [181]:
trainer = Trainer(input_size=2000, output_size=2000, train_data_loader=train_data_loader, val_data_loader=val_data_loader,nlayers=2)

In [None]:
trainer.train(10)

Start Training...
