In [None]:
!pip install pytorch-lightning
!pip install torchtext

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

from torchtext.data import Field 
from torchtext.datasets import IMDB
from torchtext.data import BucketIterator
from torchtext.vocab import FastText
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

import pandas as pd
import numpy as np

In [3]:
text_field = Field(sequential=True, include_lengths=True, fix_length=200)
label_field = Field(sequential=False)

train, test = IMDB.splits(text_field, label_field)


In [4]:
text_field.build_vocab(train, vectors=FastText('simple'))
label_field.build_vocab(train)

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32

train_iter, test_iter = BucketIterator.splits(
    (train, test), 
    batch_size=batch_size, 
    device=device
)

In [8]:
class MyModel(LightningModule):
    def __init__(self, embedding, lstm_input_size=300, lstm_hidden_size=100, output_size=3):
        super().__init__()
        self.embedding = embedding
        self.lstm = nn.LSTM(lstm_input_size, lstm_hidden_size)
        self.lin = nn.Linear(lstm_hidden_size, output_size)
        self.loss_function = nn.CrossEntropyLoss()
    
    def forward(self, X: torch.Tensor):
        # X is vector of shape (batch, input, )
        # need to be permuted because by default X is batch first
        x = self.embedding[X].to(self.device).permute(1, 0, 2)
        x, _ = self.lstm(x)
        x = F.elu(x.permute(1, 0, 2))
        x = self.lin(x)
        x = x.sum(dim=1)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch.text[0].T, batch.label
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        return dict(
            loss=loss,
            log=dict(
                train_loss=loss
            )
        )
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.01)
    
    def train_dataloader(self):
        return train_iter
    
    def validation_step(self, batch, batch_idx):
        x, y = batch.text[0].T, batch.label
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        return dict(
            validation_loss=loss,
            log=dict(
                val_loss=loss
            )
        )
    
    def val_dataloader(self):
        return test_iter

In [9]:
model = MyModel(text_field.vocab.vectors)

In [10]:
logger = TensorBoardLogger('tb_logs', name='my_model')
trainer = Trainer(
    gpus=1, 
    logger=logger,
    max_epochs=3
)
trainer.fit(model)

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
-----------------------------------------------
0 | lstm          | LSTM             | 160 K 
1 | lin           | Linear           | 303   
2 | loss_function | CrossEntropyLoss | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [11]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir tb_logs/

Reusing TensorBoard on port 6006 (pid 14308), started 3 days, 11:25:52 ago. (Use '!kill 14308' to kill it.)