In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
!pip install datasets

LOADING THE DATASET

In [None]:
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import torch

# Load IMDb dataset
dataset = load_dataset("imdb")

# Load pretrained tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the dataset
def tokenize_function(example):
    return tokenizer(
        example['text'],
        padding="max_length",
        truncation=True,
        max_length=256
    )

# Apply tokenization
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Set format for PyTorch
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# Create DataLoaders
train_loader = DataLoader(tokenized_datasets['train'], batch_size=32, shuffle=True)
test_loader = DataLoader(tokenized_datasets['test'], batch_size=32)

# Inspect one batch
for batch in train_loader:
    print(batch['input_ids'].shape)     # [32, 256]
    print(batch['attention_mask'].shape)
    print(batch['label'])
    break


POSITIONAL ENCODER

In [None]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, embed_dim)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)  # [1, max_len, embed_dim]

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)


ENCODER

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask=None):
        attn_output, _ = self.attn(x, x, x, key_padding_mask=src_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x


TRANSFORMER

In [None]:
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, num_classes, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.pos_encoding = PositionalEncoding(embed_dim)
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, ff_dim)
            for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids)
        x = self.pos_encoding(x)
        # Convert attention_mask (1: real, 0: pad) to key_padding_mask (True: pad, False: real)
        key_padding_mask = ~attention_mask.bool()
        for layer in self.layers:
            x = layer(x, src_mask=key_padding_mask)
        x = x.mean(dim=1)  # Average pooling
        return self.classifier(x)


TRAINING THE MODEL

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = TransformerClassifier(
    vocab_size=tokenizer.vocab_size,
    embed_dim=128,
    num_heads=4,
    ff_dim=256,
    num_layers=2,
    num_classes=2,
    pad_idx=tokenizer.pad_token_id
).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# For tracking
train_losses = []
val_losses = []

# Split validation set manually from training set (optional)
from torch.utils.data import random_split

train_size = int(0.9 * len(tokenized_datasets['train']))
val_size = len(tokenized_datasets['train']) - train_size
train_dataset, val_dataset = random_split(tokenized_datasets['train'], [train_size, val_size])
val_loader = DataLoader(val_dataset, batch_size=32)

for epoch in range(10):
    model.train()
    total_train_loss = 0
    for batch in DataLoader(train_dataset, batch_size=32, shuffle=True):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Validation loss
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            total_val_loss += loss.item()

    train_losses.append(total_train_loss)
    val_losses.append(total_val_loss)
    print(f"Epoch {epoch+1}, Train Loss: {total_train_loss:.2f}, Val Loss: {total_val_loss:.2f}")


VISUALIZING TRAINING AND VALIDATION LOSS

In [None]:
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs. Validation Loss")
plt.legend()
plt.show()


TEST ACCURACY

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids, attention_mask)
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")
