In [7]:
pip install torch transformers datasets wandb tqdm scikit-learn psutil

Note: you may need to restart the kernel to use updated packages.


In [5]:
pip install --upgrade pandas scikit-learn transformers


Note: you may need to restart the kernel to use updated packages.


In [4]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
import wandb
import torch.nn.utils.prune as prune
from torch.quantization import quantize_dynamic
import gc
import os
from typing import Dict, Any
import psutil
from sklearn.metrics import accuracy_score
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class IMDBDataset(Dataset):
    def __init__(self, split='train', max_length=256):
        self.dataset = load_dataset('imdb')[split]
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(item['label'])  # Changed from 'label' to 'labels'
        }

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False
        
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

class MemoryEfficientTrainer:
    def __init__(self, model, train_loader, val_loader, epochs=3,
                 accumulation_steps=4, early_stopping_patience=3):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.epochs = epochs
        self.accumulation_steps = accumulation_steps
        self.early_stopping = EarlyStopping(patience=early_stopping_patience)
        
        # Initialize optimizer with gradient clipping
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=2e-5,
            weight_decay=0.01,
            eps=1e-8
        )
        
        # Cosine schedule with warmup
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=epochs
        )
        
    def log_memory(self):
        memory = psutil.Process().memory_info().rss / 1024 / 1024  # MB
        logger.info(f"Memory Usage: {memory:.2f} MB")
        wandb.log({'memory_usage_mb': memory})
        
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        self.optimizer.zero_grad()
        
        for i, batch in enumerate(tqdm(self.train_loader, desc="Training")):
            # Move batch to device
            batch = {k: v.to(self.device) for k, v in batch.items()}
            
            # Clear memory periodically
            if i % 50 == 0:
                gc.collect()
                torch.cuda.empty_cache()  # Added GPU memory cleanup
                self.log_memory()
            
            loss = self.process_batch(batch)
            total_loss += loss.item()
            
            # Gradient accumulation
            if (i + 1) % self.accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)  # Added gradient clipping
                self.optimizer.step()
                self.optimizer.zero_grad()
                
        return total_loss / len(self.train_loader)
    
    def process_batch(self, batch):
        outputs = self.model(**batch)
        loss = outputs.loss / self.accumulation_steps
        loss.backward()
        return loss
    
    def validate(self):
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validating"):
                batch = {k: v.to(self.device) for k, v in batch.items()}
                outputs = self.model(**batch)
                total_loss += outputs.loss.item()
                
                preds = torch.argmax(outputs.logits, dim=-1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(batch['labels'].cpu().numpy())
        
        accuracy = accuracy_score(all_labels, all_preds)
        avg_loss = total_loss / len(self.val_loader)
        
        return avg_loss, accuracy
    
    def train(self):
        best_accuracy = 0
        
        for epoch in range(self.epochs):
            logger.info(f"Epoch {epoch + 1}/{self.epochs}")
            
            train_loss = self.train_epoch()
            val_loss, accuracy = self.validate()
            
            # Update learning rate
            self.scheduler.step()
            
            metrics = {
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'accuracy': accuracy,
                'lr': self.scheduler.get_last_lr()[0]
            }
            
            wandb.log(metrics)
            logger.info(f"Metrics: {metrics}")
            
            # Save best model
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                self.save_checkpoint(f'best_model.pt', metrics)
            
            # Early stopping check
            self.early_stopping(val_loss)
            if self.early_stopping.should_stop:
                logger.info("Early stopping triggered")
                break
    
    def save_checkpoint(self, filename: str, metrics: Dict[str, Any]):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'metrics': metrics
        }, filename)

def main():
    # Initialize wandb
    wandb.init(project="model-compression-t2xlarge")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    
    # Set batch sizes based on available memory
    BATCH_SIZE = 16
    
    # Load datasets
    train_dataset = IMDBDataset('train', max_length=256)
    val_dataset = IMDBDataset('test', max_length=256)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=3,
        pin_memory=True  # Added pin_memory for faster data transfer
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        num_workers=3,
        pin_memory=True
    )
    
    try:
        # Train base model
        logger.info("Training base model...")
        base_model = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased',
            num_labels=2,
            problem_type="single_label_classification"
        )
        
        trainer = MemoryEfficientTrainer(
            base_model,
            train_loader,
            val_loader,
            epochs=3,
            accumulation_steps=4
        )
        trainer.train()
        
        # Knowledge Distillation
        logger.info("Training distilled model...")
        small_config = BertConfig(
            hidden_size=768,
            num_hidden_layers=4,
            num_attention_heads=12,
            intermediate_size=2048,
            num_labels=2,
            problem_type="single_label_classification"
        )
        
        student_model = BertForSequenceClassification(small_config)
        trainer = MemoryEfficientTrainer(
            student_model,
            train_loader,
            val_loader,
            epochs=3,
            accumulation_steps=4
        )
        trainer.train()
        
        # Model Pruning
        logger.info("Pruning model...")
        pruned_model = base_model
        for name, module in pruned_model.named_modules():
            if isinstance(module, nn.Linear):
                prune.l1_unstructured(module, name='weight', amount=0.3)
        
        trainer = MemoryEfficientTrainer(
            pruned_model,
            train_loader,
            val_loader,
            epochs=1,
            accumulation_steps=4
        )
        trainer.train()
        
        # Quantization
        logger.info("Quantizing model...")
        quantized_model = quantize_dynamic(
            base_model.cpu(),  # Move to CPU for quantization
            {nn.Linear},
            dtype=torch.qint8
        )
        
        # Save final models
        logger.info("Saving models...")
        torch.save(base_model.state_dict(), 'base_model.pth')
        torch.save(student_model.state_dict(), 'distilled_model.pth')
        torch.save(pruned_model.state_dict(), 'pruned_model.pth')
        torch.save(quantized_model.state_dict(), 'quantized_model.pth')
        
    except Exception as e:
        logger.error(f"An error occurred: {str(e)}")
        raise
    finally:
        # Cleanup
        gc.collect()
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()

INFO:__main__:Using device: cpu
INFO:__main__:Training base model...
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:__main__:Epoch 1/3
INFO:__main__:Memory Usage: 773.18 MB:00<?, ?it/s]
INFO:__main__:Memory Usage: 8336.94 MB:17<4:04:38,  9.70s/it]
INFO:__main__:Memory Usage: 7908.70 MB6:19<4:00:10,  9.85s/it]
INFO:__main__:Memory Usage: 7969.35 MB4:28<3:52:24,  9.87s/it]
INFO:__main__:Memory Usage: 8330.05 MB2:35<3:42:00,  9.77s/it]
INFO:__main__:Memory Usage: 8402.38 MB0:38<3:32:38,  9.72s/it]
INFO:__main__:Memory Usage: 8125.68 MB8:39<3:22:49,  9.64s/it]
INFO:__main__:Memory Usage: 8483.95 MB6:41<3:14:09,  9.60s/it]
INFO:__main__:Memory Usage: 8246.66 MB:04:38<3:08:54,  9.75s/it]
INFO:__main__:Memory Usage: 8459.96 MB:34:04<11:11:32, 36.

KeyboardInterrupt: 