In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
from tqdm import tqdm
import math
import gc
import os
from sklearn.model_selection import train_test_split

# PART 1: Data Preparation
class IMDBDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=256):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

def prepare_data(batch_size=2, max_length=256):
    # Load the IMDB dataset
    dataset = load_dataset("stanfordnlp/imdb")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    
    # Split the training data into training and validation manually using train_test_split
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        dataset['train']['text'],  # Use a smaller subset
        dataset['train']['label'],
        test_size=0.2,  # 80% training, 20% validation
        random_state=42
    )
    
    # Prepare training and validation datasets
    train_dataset = IMDBDataset(
        train_texts,
        train_labels,
        tokenizer,
        max_length=max_length
    )
    
    val_dataset = IMDBDataset(
        val_texts,
        val_labels,
        tokenizer,
        max_length=max_length
    )
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    return train_loader, val_loader, tokenizer

# PART 2: Model Definition
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=256):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=4, dim_feedforward=256, dropout=0.1):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=256)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            norm_first=True
        )
        
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        self.classifier = nn.Sequential(
            nn.Linear(d_model, 32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32, 2)
        )
        
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.02
        nn.init.uniform_(self.embedding.weight.data, -initrange, initrange)
        for layer in self.classifier:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight.data)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias.data)
    
    def forward(self, x, attention_mask=None):
        x = self.embedding(x)
        x = self.pos_encoder(x)
        
        if attention_mask is not None:
            attention_mask = attention_mask.bool()
            x = x.masked_fill(~attention_mask.unsqueeze(-1), 0)
        
        x = self.transformer_encoder(x)
        
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(-1).float()
            x = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
        else:
            x = x.mean(dim=1)
        
        x = self.classifier(x)
        return x

# PART 3: Training Procedure
def train_model(batch_size=64, num_epochs=20, learning_rate=0.0001):
    # Prepare data
    train_loader, val_loader, tokenizer = prepare_data(batch_size=batch_size)
    
    # Safe CUDA initialization
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize model and move to device
    model = TransformerClassifier(vocab_size=tokenizer.vocab_size)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    
    best_val_accuracy = 0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        model.train()
        
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch in tqdm(train_loader, desc="Training"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            
            train_loss += loss.item()
            predictions = torch.argmax(outputs, dim=1)
            train_correct += (predictions == labels).sum().item()
            train_total += labels.size(0)
            
            # Clear memory periodically
            if device.type == 'cuda':
                del outputs, loss
                gc.collect()
        
        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = train_correct / train_total
        print(f"Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")
        
        # Validation phase (optional)
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                predictions = torch.argmax(outputs, dim=1)
                val_correct += (predictions == labels).sum().item()
                val_total += labels.size(0)
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = val_correct / val_total
        print(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
    save_path = "transformer_classifier.pth"
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")    

if __name__ == "__main__":
    train_model()
    





Using device: cuda

Epoch 1/20


Training: 100%|██████████| 313/313 [01:17<00:00,  4.02it/s]


Training Loss: 0.7053, Training Accuracy: 0.5078


Validation: 100%|██████████| 79/79 [00:06<00:00, 11.99it/s]


Validation Loss: 0.6911, Validation Accuracy: 0.4972

Epoch 2/20


Training: 100%|██████████| 313/313 [01:19<00:00,  3.92it/s]


Training Loss: 0.6942, Training Accuracy: 0.5071


Validation: 100%|██████████| 79/79 [00:06<00:00, 11.61it/s]


Validation Loss: 0.6906, Validation Accuracy: 0.5274

Epoch 3/20


Training: 100%|██████████| 313/313 [01:18<00:00,  3.98it/s]


Training Loss: 0.6930, Training Accuracy: 0.5152


Validation: 100%|██████████| 79/79 [00:06<00:00, 11.59it/s]


Validation Loss: 0.6897, Validation Accuracy: 0.5342

Epoch 4/20


Training: 100%|██████████| 313/313 [01:25<00:00,  3.64it/s]


Training Loss: 0.6920, Training Accuracy: 0.5185


Validation: 100%|██████████| 79/79 [00:06<00:00, 11.43it/s]


Validation Loss: 0.6905, Validation Accuracy: 0.5186

Epoch 5/20


Training: 100%|██████████| 313/313 [01:19<00:00,  3.92it/s]


Training Loss: 0.6918, Training Accuracy: 0.5305


Validation: 100%|██████████| 79/79 [00:06<00:00, 12.71it/s]


Validation Loss: 0.6763, Validation Accuracy: 0.5274

Epoch 6/20


Training: 100%|██████████| 313/313 [01:18<00:00,  3.99it/s]


Training Loss: 0.5897, Training Accuracy: 0.6842


Validation: 100%|██████████| 79/79 [00:06<00:00, 12.33it/s]


Validation Loss: 0.4640, Validation Accuracy: 0.7848

Epoch 7/20


Training: 100%|██████████| 313/313 [01:16<00:00,  4.09it/s]


Training Loss: 0.5045, Training Accuracy: 0.7584


Validation: 100%|██████████| 79/79 [00:06<00:00, 11.92it/s]


Validation Loss: 0.4271, Validation Accuracy: 0.8036

Epoch 8/20


Training: 100%|██████████| 313/313 [01:25<00:00,  3.66it/s]


Training Loss: 0.4706, Training Accuracy: 0.7800


Validation: 100%|██████████| 79/79 [00:06<00:00, 12.55it/s]


Validation Loss: 0.4649, Validation Accuracy: 0.7674

Epoch 9/20


Training: 100%|██████████| 313/313 [01:16<00:00,  4.10it/s]


Training Loss: 0.4420, Training Accuracy: 0.7959


Validation: 100%|██████████| 79/79 [00:06<00:00, 12.98it/s]


Validation Loss: 0.3651, Validation Accuracy: 0.8356

Epoch 10/20


Training: 100%|██████████| 313/313 [01:15<00:00,  4.15it/s]


Training Loss: 0.4061, Training Accuracy: 0.8183


Validation: 100%|██████████| 79/79 [00:06<00:00, 12.54it/s]


Validation Loss: 0.3587, Validation Accuracy: 0.8462

Epoch 11/20


Training: 100%|██████████| 313/313 [01:19<00:00,  3.94it/s]


Training Loss: 0.3869, Training Accuracy: 0.8296


Validation: 100%|██████████| 79/79 [00:05<00:00, 13.36it/s]


Validation Loss: 0.3312, Validation Accuracy: 0.8576

Epoch 12/20


Training: 100%|██████████| 313/313 [01:23<00:00,  3.73it/s]


Training Loss: 0.3609, Training Accuracy: 0.8436


Validation: 100%|██████████| 79/79 [00:06<00:00, 12.88it/s]


Validation Loss: 0.3369, Validation Accuracy: 0.8614

Epoch 13/20


Training: 100%|██████████| 313/313 [01:18<00:00,  4.01it/s]


Training Loss: 0.3567, Training Accuracy: 0.8469


Validation: 100%|██████████| 79/79 [00:06<00:00, 13.02it/s]


Validation Loss: 0.4113, Validation Accuracy: 0.8292

Epoch 14/20


Training: 100%|██████████| 313/313 [01:14<00:00,  4.19it/s]


Training Loss: 0.3594, Training Accuracy: 0.8453


Validation: 100%|██████████| 79/79 [00:06<00:00, 12.92it/s]


Validation Loss: 0.4946, Validation Accuracy: 0.8058

Epoch 15/20


Training: 100%|██████████| 313/313 [01:23<00:00,  3.76it/s]


Training Loss: 0.3347, Training Accuracy: 0.8580


Validation: 100%|██████████| 79/79 [00:06<00:00, 11.47it/s]


Validation Loss: 0.3673, Validation Accuracy: 0.8560

Epoch 16/20


Training: 100%|██████████| 313/313 [01:22<00:00,  3.79it/s]


Training Loss: 0.3097, Training Accuracy: 0.8705


Validation: 100%|██████████| 79/79 [00:06<00:00, 11.32it/s]


Validation Loss: 0.3372, Validation Accuracy: 0.8686

Epoch 17/20


Training: 100%|██████████| 313/313 [01:14<00:00,  4.22it/s]


Training Loss: 0.3116, Training Accuracy: 0.8738


Validation: 100%|██████████| 79/79 [00:06<00:00, 13.07it/s]


Validation Loss: 0.4740, Validation Accuracy: 0.8018

Epoch 18/20


Training: 100%|██████████| 313/313 [01:14<00:00,  4.21it/s]


Training Loss: 0.2948, Training Accuracy: 0.8779


Validation: 100%|██████████| 79/79 [00:06<00:00, 13.08it/s]


Validation Loss: 0.3441, Validation Accuracy: 0.8552

Epoch 19/20


Training: 100%|██████████| 313/313 [01:14<00:00,  4.22it/s]


Training Loss: 0.2817, Training Accuracy: 0.8849


Validation: 100%|██████████| 79/79 [00:06<00:00, 12.91it/s]


Validation Loss: 0.3144, Validation Accuracy: 0.8722

Epoch 20/20


Training: 100%|██████████| 313/313 [01:13<00:00,  4.23it/s]


Training Loss: 0.2663, Training Accuracy: 0.8918


Validation: 100%|██████████| 79/79 [00:06<00:00, 13.17it/s]

Validation Loss: 0.4746, Validation Accuracy: 0.8210
Model saved to transformer_classifier.pth





In [13]:
import torch
from transformers import AutoTokenizer

# Load the tokenizer and model configuration
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model and load the saved state dictionary
model = TransformerClassifier(vocab_size=tokenizer.vocab_size)
model.load_state_dict(torch.load("transformer_classifier.pth", map_location=device))
model.to(device)
model.eval()

def predict(text):
    # Tokenize the input text
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=256,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Run inference
    with torch.no_grad():
        output = model(input_ids, attention_mask)
        prediction = torch.argmax(output, dim=1).item()

    # Map prediction to label (0 or 1 for binary classification in IMDB dataset)
    label_map = {0: "Negative", 1: "Positive"}
    return label_map[prediction]

# Example usage
text = "The actors did amazing acting"
print("Prediction:", predict(text))


Prediction: Positive
