In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, random_split

from transformers import BertForSequenceClassification, BertTokenizer

import pytorch_lightning as pl
import torchmetrics

import pandas as pd

In [2]:
df = pd.read_csv("IMDB Dataset.csv")

In [3]:
df.head()

Unnamed: 0,review,sentiment
0,One of the other reviewers has mentioned that ...,positive
1,A wonderful little production. <br /><br />The...,positive
2,I thought this was a wonderful way to spend ti...,positive
3,Basically there's a family where a little boy ...,negative
4,"Petter Mattei's ""Love in the Time of Money"" is...",positive


In [4]:
df["sentiment"] = df.sentiment.map(lambda x: 1 if x == "positive" else 0)

## The Dataset

In [5]:
class ImdbDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    def __getitem__(self, idx):
        text = self.df["review"].iloc[idx]
        label = self.df["sentiment"].iloc[idx]

        tokenized_text = self.tokenizer.encode_plus(
            text=text,
            max_length=128,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=True,
        )

        input_ids = tokenized_text["input_ids"]
        attention_mask = tokenized_text["attention_mask"]
        token_type_ids = tokenized_text["token_type_ids"]

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
            "label": torch.tensor(label, dtype=torch.float),
        }

    def __len__(self):
        return len(self.df)

## The DataModule

In [8]:
class ImdbDataModule(pl.LightningDataModule):
    def __init__(self, df):
        super().__init__()
        self.dataset = ImdbDataset(df)

    def setup(self, stage) -> None:
        if stage == "fit" or stage is None:
            lengths = [int(len(self.dataset) * 0.8), int(len(self.dataset) * 0.2)]
            self.train_data, self.val_data = random_split(self.dataset, lengths)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=8)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=8)

## The LightningModule

In [9]:
class LitImdb(pl.LightningModule):
    def __init__(self, fine_tune=True):
        super(LitImdb, self).__init__()
        self.model = BertForSequenceClassification.from_pretrained(
            "bert-base-uncased", num_labels=1
        )
        if fine_tune:
            self.freeze()
        self.loss = nn.BCEWithLogitsLoss()
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()

    def freeze(self):
        for param in self.model.named_parameters():
            if "classifier" not in param[0]:
                param[1].requires_grad = False

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters())

    def forward(self, input_ids, attention_masks, token_type_ids):
        return self.model(input_ids, attention_masks, token_type_ids)

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_masks = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]
        targets = batch["label"]

        preds = self.forward(input_ids, attention_masks, token_type_ids)
        preds = preds["logits"].view(-1)
        loss = self.loss(input=preds, target=targets)
        acc = self.train_acc(preds, targets.int())

        self.log("train_loss", loss)
        self.log("train_acc", acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_masks = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]
        targets = batch["label"]

        preds = self.forward(input_ids, attention_masks, token_type_ids)
        preds = preds["logits"].view(-1)
        loss = self.loss(input=preds, target=targets)
        acc = self.val_acc(preds, targets.int())

        self.log("val_loss", loss)
        self.log("val_acc", acc, prog_bar=True)

        return loss

In [None]:
model = LitImdb()
dm = ImdbDataModule(df)

trainer = pl.Trainer(
    logger=True,
    checkpoint_callback=True,
    max_epochs=3,
)

trainer.fit(model, datamodule=dm)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: -1it [00:00, ?it/s]