In [3]:
import torch
from torch import nn
import pytorch_lightning as pl
from torchmetrics import Accuracy
from transformers import DistilBertTokenizer, DistilBertModel
from torch.utils.data import DataLoader, Dataset, random_split
import pandas as pd

In [15]:
class ReviewDataset(Dataset):
    def __init__(self, reviews, ratings):
        self.reviews = reviews
        self.ratings = ratings
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    def __len__(self):
        return len(self.reviews)

    def __getitem__(self, idx):
        review = self.reviews[idx]
        rating = self.ratings[idx]
        encoding = self.tokenizer(review, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        return input_ids, attention_mask, torch.tensor(rating)

class ReviewDataModule(pl.LightningDataModule):
    def __init__(self, data_file, batch_size=8, val_split=0.2, num_workers=7):
        super().__init__()
        self.data_file = data_file
        self.batch_size = batch_size
        self.val_split = val_split
        self.num_workers = num_workers

    def prepare_data(self):
        # Load data from CSV file
        data = pd.read_csv(self.data_file)
        self.reviews = data['review'].tolist()
        self.ratings = data['rating'].tolist()

    def setup(self, stage=None):
        dataset = ReviewDataset(self.reviews, self.ratings)
        val_size = int(len(dataset) * self.val_split)
        train_size = len(dataset) - val_size
        self.train_dataset, self.val_dataset = random_split(dataset, [train_size, val_size])

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

In [9]:
class ReviewRatingModel(pl.LightningModule):
    def __init__(self, num_classes: int, learning_rate: float = 1e-4):
        super().__init__()
        self.save_hyperparameters()
        
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)
        self.learning_rate = learning_rate

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(cls_output)
        return logits

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        logits = self(input_ids, attention_mask)
        loss = self.criterion(logits, labels)
        self.log('train_loss', loss)
        self.log('train_acc', self.accuracy(logits, labels), prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        logits = self(input_ids, attention_mask)
        loss = self.criterion(logits, labels)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.accuracy(logits, labels), prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [17]:
dm = ReviewDataModule("data/train_data.csv", batch_size=4, num_workers=16)
model = ReviewRatingModel(num_classes=5)
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params
--------------------------------------------------
0 | bert       | DistilBertModel    | 66.4 M
1 | classifier | Linear             | 3.8 K 
2 | criterion  | CrossEntropyLoss   | 0     
3 | accuracy   | MulticlassAccuracy | 0     
--------------------------------------------------
66.4 M    Trainable params
0         Non-trainable params
66.4 M    Total params
265.467   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]