In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, random_split

from transformers import BertForSequenceClassification, BertTokenizer

import pytorch_lightning as pl

import pandas as pd
import numpy as np

In [None]:
df = pd.read_csv('IMDB Dataset.csv')

In [None]:
df.head()

In [None]:
df['sentiment'] = df.sentiment.map(lambda x: 1 if x == 'positive' else 0)

## The Dataset

In [None]:
class ImdbDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def __getitem__(self, idx):
        text = self.df['review'].iloc[idx]
        label = self.df['sentiment'].iloc[idx]
        
        tokenized_text = self.tokenizer.encode_plus(
            text=text, 
            max_length=128, 
            padding='max_length', 
            truncation=True, 
            return_attention_mask=True,
            return_token_type_ids=True
        )
        
        input_ids = tokenized_text['input_ids']
        attention_mask = tokenized_text['attention_mask']
        token_type_ids = tokenized_text['token_type_ids']
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'label': torch.tensor(label, dtype=torch.float)
        }
    
    def __len__(self):
        return len(self.df)
        

## The DataModule

In [None]:
class ImdbDataModule(pl.LightningDataModule):
    def __init__(self, df):
        super().__init__()
        self.dataset = ImdbDataset(df)

    def setup(self, stage) -> None:
        if stage == "fit" or stage is None:
            lengths = [
                int(len(self.dataset) * 0.8), 
                int(len(self.dataset) * 0.2)
            ]
            self.train_data, self.val_data = random_split(self.dataset, lengths)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=8, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=8, num_workers=8)

## The LightningModule

In [None]:
class LitImdb(pl.LightningModule):
    def __init__(self, fine_tune=True):
        super(LitImdb, self).__init__()
        self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
        self.model.classifier = nn.Linear(
            in_features=768,
            out_features=1
        )
        if fine_tune:
            self.freeze()
        
    def freeze(self):
        for param in self.model.named_parameters():
            if 'classifer' not in param[0]:
                param[1].requires_grad = False

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters())

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_masks = batch['attention_mask']
        token_type_ids = batch['token_type_ids']
        targets = batch['label']

        preds = self.forward(
            input_ids, attention_masks, token_type_ids
        )

        loss = F.binary_cross_entropy_with_logits(input=preds, target=targets)

        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_masks = batch['attention_mask']
        token_type_ids = batch['token_type_ids']
        targets = batch['label']

        preds = self.forward(
            input_ids, attention_masks, token_type_ids
        )

        loss = F.binary_cross_entropy_with_logits(input=preds, target=targets)

        self.log("train_loss", loss)

        return loss

In [None]:
model = LitImdb()
dm = ImdbDataModule(df)

trainer = pl.Trainer(
    logger=True,
    checkpoint_callback=True,
    gpus=1,
    max_epochs=3,
)

trainer.fit(model, datamodule=dm)