In [163]:
from rnn_data import load_imdb
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset


In [164]:
(x_train, y_train), (x_val, y_val), (i2w, w2i), numcls = load_imdb(final=False)

In [165]:
class IMDBDataset(Dataset):
    def __init__(self, x, y, padding_value=0) -> None:
        super().__init__()

        x = [torch.tensor(xi) for xi in x]
        self.x = pad_sequence(x, batch_first=True, padding_value=padding_value)

        self.y = torch.tensor(y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, index):
        return self.x[index], self.y[index]

In [166]:
# Handle & convert to Torch datasets
train_data = IMDBDataset(x_train, y_train)
val_data = IMDBDataset(x_val, y_val)

# Dataloaders
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
val_loader = DataLoader(val_data, batch_size=128, shuffle=True)

In [167]:
class GlobalMaxPool(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, x):
        return torch.max(x, dim=self.dim)[0]

In [169]:
network = nn.Sequential(
    nn.Embedding(len(i2w),300, 0),
    nn.Linear(300 ,300),
    nn.ReLU(),
    GlobalMaxPool(1),
    nn.Linear(300, 2),
    nn.Softmax()
)

In [173]:
optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(5000):
    losses = []
    for batch_x, batch_y in DataLoader(train_data, batch_size=128, shuffle=True):
        pred_y = network(batch_x)
        
        loss = loss_fn(pred_y, batch_y)
        t_loss = loss.item()
        losses.append(t_loss)

        network.zero_grad()
        loss.backward()

        optimizer.step()
    
    val_losses = []
    for batch_x, batch_y in DataLoader(val_data, batch_size=128, shuffle=True):
        pred_y = network(batch_x)
        
        loss = loss_fn(pred_y, batch_y)
        val_losses.append(loss.item())

        print(loss.item())

    print('Epoch:', epoch, 'Loss:', np.mean(losses), 'Val Loss:', np.mean(val_losses))


  input = module(input)


KeyboardInterrupt: 