In [2]:
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from transformers import (
    BertTokenizer, 
    BertForSequenceClassification,
    BertConfig,
    AdamW,
    get_linear_schedule_with_warmup
)
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import wandb
import numpy as np
from tqdm.auto import tqdm
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
import logging
import gc
import psutil
from sklearn.metrics import accuracy_score, f1_score
import torch.nn.utils.prune as prune
from dataclasses import asdict

# Configure logging
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

@dataclass
class TrainingConfig:
    batch_size: int = 32
    accumulation_steps: int = 2
    epochs: int = 3
    max_length: int = 128  # Reduced from 256
    learning_rate: float = 2e-5
    warmup_steps: int = 1000
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    early_stopping_patience: int = 2
    num_workers: int = 4

class IMDBDataset(Dataset):
    def __init__(self, split: str, config: TrainingConfig):
        self.dataset = load_dataset('imdb')[split]
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_length = config.max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.tokenizer(
            item['text'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(item['label'])
        }

class EfficientTrainer:
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: TrainingConfig
    ):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        
        # Initialize optimizer with weight decay
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in model.named_parameters() 
                          if not any(nd in n for nd in no_decay)],
                'weight_decay': config.weight_decay
            },
            {
                'params': [p for n, p in model.named_parameters() 
                          if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0
            }
        ]
        
        self.optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=config.learning_rate,
            eps=1e-8
        )
        
        # Calculate total steps for scheduler
        num_update_steps_per_epoch = len(train_loader) // config.accumulation_steps
        num_training_steps = num_update_steps_per_epoch * config.epochs
        
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=config.warmup_steps,
            num_training_steps=num_training_steps
        )
        
        self.scaler = GradScaler()
        self.best_val_loss = float('inf')
        self.patience_counter = 0
        
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        self.optimizer.zero_grad()
        
        with tqdm(total=len(self.train_loader), desc="Training") as pbar:
            for i, batch in enumerate(self.train_loader):
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Automatic mixed precision training
                with autocast():
                    outputs = self.model(**batch)
                    loss = outputs.loss / self.config.accumulation_steps
                
                # Scale loss and backward pass
                self.scaler.scale(loss).backward()
                
                if (i + 1) % self.config.accumulation_steps == 0:
                    # Unscale gradients for clipping
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.config.max_grad_norm
                    )
                    
                    # Optimizer step with scaling
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
                
                total_loss += loss.item() * self.config.accumulation_steps
                pbar.update(1)
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
                
                # Periodic memory cleanup
                if i % 50 == 0:
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
        
        return total_loss / len(self.train_loader)
    
    @torch.no_grad()
    def validate(self):
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        for batch in tqdm(self.val_loader, desc="Validating"):
            batch = {k: v.to(self.device) for k, v in batch.items()}
            
            outputs = self.model(**batch)
            total_loss += outputs.loss.item()
            
            preds = torch.argmax(outputs.logits, dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch['labels'].cpu().numpy())
        
        metrics = {
            'val_loss': total_loss / len(self.val_loader),
            'accuracy': accuracy_score(all_labels, all_preds),
            'f1': f1_score(all_labels, all_preds, average='weighted')
        }
        
        return metrics
    
    def train(self):
        for epoch in range(self.config.epochs):
            logger.info(f"Epoch {epoch + 1}/{self.config.epochs}")
            
            # Training phase
            train_loss = self.train_epoch()
            
            # Validation phase
            val_metrics = self.validate()
            
            # Log metrics
            metrics = {
                'epoch': epoch + 1,
                'train_loss': train_loss,
                **val_metrics,
                'learning_rate': self.scheduler.get_last_lr()[0]
            }
            wandb.log(metrics)
            logger.info(f"Metrics: {metrics}")
            
            # Early stopping check
            if val_metrics['val_loss'] < self.best_val_loss:
                self.best_val_loss = val_metrics['val_loss']
                self.save_checkpoint('best_model.pt', metrics)
                self.patience_counter = 0
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.config.early_stopping_patience:
                    logger.info("Early stopping triggered")
                    break
    
    def save_checkpoint(self, filename: str, metrics: Dict[str, Any]):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'metrics': metrics
        }, filename)

def create_small_config() -> BertConfig:
    """Creates a smaller BERT configuration for knowledge distillation"""
    return BertConfig(
        hidden_size=384,  # Half of BERT-base
        num_hidden_layers=4,  # Third of BERT-base
        num_attention_heads=6,
        intermediate_size=1536,
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        num_labels=2
    )

def main():
    # Initialize wandb
    wandb.init(project="efficient-bert-compression")
    
    # Configuration
    config = TrainingConfig()
    wandb.config.update(asdict(config))
    
    # Load and prepare datasets
    logger.info("Loading datasets...")
    train_dataset = IMDBDataset('train', config)
    val_dataset = IMDBDataset('test', config)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        pin_memory=True
    )
    
    try:
        # Train teacher model
        logger.info("Training teacher model...")
        teacher_model = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased',
            num_labels=2
        )
        
        trainer = EfficientTrainer(
            teacher_model,
            train_loader,
            val_loader,
            config
        )
        trainer.train()
        
        # Train smaller student model
        logger.info("Training student model...")
        student_config = create_small_config()
        student_model = BertForSequenceClassification(student_config)
        
        trainer = EfficientTrainer(
            student_model,
            train_loader,
            val_loader,
            config
        )
        trainer.train()
        
        # Apply pruning to student model
        logger.info("Applying pruning...")
        for name, module in student_model.named_modules():
            if isinstance(module, nn.Linear):
                prune.l1_unstructured(module, name='weight', amount=0.3)
        
        # Fine-tune pruned model
        trainer = EfficientTrainer(
            student_model,
            train_loader,
            val_loader,
            config
        )
        trainer.train()
        
        # Quantization
        logger.info("Quantizing model...")
        student_model.cpu()
        quantized_model = torch.quantization.quantize_dynamic(
            student_model,
            {nn.Linear},
            dtype=torch.qint8
        )
        
        # Save final models
        torch.save(student_model.state_dict(), 'compressed_model.pt')
        torch.save(quantized_model.state_dict(), 'quantized_model.pt')
        
    except Exception as e:
        logger.error(f"An error occurred: {str(e)}")
        raise
    finally:
        # Cleanup
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()

2025-01-19 15:59:05,951 - INFO - Loading datasets...
2025-01-19 15:59:08,677 - INFO - Training teacher model...
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  self.scaler = GradScaler()
2025-01-19 15:59:09,226 - INFO - Epoch 1/3


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():


Validating:   0%|          | 0/782 [00:00<?, ?it/s]

2025-01-19 18:32:08,512 - INFO - Metrics: {'epoch': 1, 'train_loss': 0.5162448072258163, 'val_loss': 0.3262274710656813, 'accuracy': 0.86184, 'f1': 0.8615977149527893, 'learning_rate': 7.820000000000001e-06}
2025-01-19 18:32:18,356 - INFO - Epoch 2/3


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():


Validating:   0%|          | 0/782 [00:00<?, ?it/s]

2025-01-19 21:06:18,722 - INFO - Metrics: {'epoch': 2, 'train_loss': 0.2970118322278685, 'val_loss': 0.3012382892767906, 'accuracy': 0.87356, 'f1': 0.8731432015356452, 'learning_rate': 1.5640000000000003e-05}
2025-01-19 21:06:50,077 - INFO - Epoch 3/3


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():


Validating:   0%|          | 0/782 [00:00<?, ?it/s]

2025-01-20 02:51:13,414 - INFO - Metrics: {'epoch': 3, 'train_loss': 0.220952901600972, 'val_loss': 0.27243215005006405, 'accuracy': 0.89084, 'f1': 0.8908355843430524, 'learning_rate': 0.0}
2025-01-20 02:51:44,034 - INFO - Training student model...
  self.scaler = GradScaler()
2025-01-20 02:51:44,436 - INFO - Epoch 1/3


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():


Validating:   0%|          | 0/782 [00:00<?, ?it/s]

2025-01-20 03:58:40,538 - INFO - Metrics: {'epoch': 1, 'train_loss': 0.6945401575711682, 'val_loss': 0.6919617516457882, 'accuracy': 0.5006, 'f1': 0.3348068834040917, 'learning_rate': 7.820000000000001e-06}
2025-01-20 03:58:41,987 - INFO - Epoch 2/3


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>Exception ignored in: 
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>Exception ignored in:   File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__

<function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>
Traceback (most recent call last):
      File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()Traceback (most recent call last):
  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
self._shutdown_workers()
    
Exception ignored in:   File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/home/ubuntu/jupyter_env/lib/python3

Validating:   0%|          | 0/782 [00:00<?, ?it/s]

2025-01-20 05:06:42,617 - INFO - Metrics: {'epoch': 2, 'train_loss': 0.6752846569508848, 'val_loss': 0.5729781811118431, 'accuracy': 0.70224, 'f1': 0.7013948758333788, 'learning_rate': 1.5640000000000003e-05}
2025-01-20 05:06:43,849 - INFO - Epoch 3/3


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>
Traceback (most recent call last):
  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>
Traceback (most recent call last):
  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/jupyte

Validating:   0%|          | 0/782 [00:00<?, ?it/s]

2025-01-20 06:14:53,802 - INFO - Metrics: {'epoch': 3, 'train_loss': 0.45204234228033546, 'val_loss': 0.435336763282185, 'accuracy': 0.79724, 'f1': 0.7971066159415434, 'learning_rate': 0.0}
2025-01-20 06:14:55,086 - INFO - Applying pruning...
  self.scaler = GradScaler()
2025-01-20 06:14:55,527 - INFO - Epoch 1/3


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>
Traceback (most recent call last):
  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>
Traceback (most recent call last):
    if w.is_alive():
  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
         Exception ignored in: self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>  
^^
^^  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
^    ^if w.is_alive():^
Traceback (most recent call last):
^  F

Validating:   0%|          | 0/782 [00:00<?, ?it/s]

2025-01-20 07:24:35,256 - INFO - Metrics: {'epoch': 1, 'train_loss': 0.37206131815338683, 'val_loss': 0.43944935735953433, 'accuracy': 0.80208, 'f1': 0.8015950360483053, 'learning_rate': 7.820000000000001e-06}
2025-01-20 07:24:56,754 - INFO - Epoch 2/3


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():


Validating:   0%|          | 0/782 [00:00<?, ?it/s]

2025-01-20 08:33:20,380 - INFO - Metrics: {'epoch': 2, 'train_loss': 0.3575733512392282, 'val_loss': 0.4440631276887396, 'accuracy': 0.799, 'f1': 0.7981743551832109, 'learning_rate': 1.5640000000000003e-05}
2025-01-20 08:33:20,381 - INFO - Epoch 3/3


Training:   0%|          | 0/782 [00:00<?, ?it/s]

  with autocast():
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__

Traceback (most recent call last):
  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>        self._shutdown_workers()self._shutdown_workers()Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7316935a8180>

  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
Traceback (most recent call last):

  File "/home/ubuntu/jupyter_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", li

Validating:   0%|          | 0/782 [00:00<?, ?it/s]

2025-01-20 09:41:46,217 - INFO - Metrics: {'epoch': 3, 'train_loss': 0.33169591484014943, 'val_loss': 0.43250956128129875, 'accuracy': 0.80532, 'f1': 0.8047632707523465, 'learning_rate': 0.0}
2025-01-20 09:42:07,687 - INFO - Quantizing model...


In [3]:
print("HO")

HO
