In [20]:
"""
PYTORCH WRAPPER CLASSES - COMPLETE TUTORIAL
============================================
Learn to write clean, reusable wrappers with examples
"""

print("="*80)
print("PART 1: Understanding Inheritance & Wrappers")
print("="*80)

print("""
INHERITANCE BASICS:
-------------------

class Parent:
    def method(self):
        return "Parent method"

class Child(Parent):  # Child inherits from Parent
    def method(self):  # Override parent method
        return "Child method"

WRAPPER CONCEPT:
----------------
- Extend existing class functionality
- Keep original behavior
- Add custom features
- Maintain compatibility

PyTorch Common Wrappers:
------------------------
1. Dataset â†’ Custom Dataset (your BilingualDataset)
2. nn.Module â†’ Custom Models
3. Optimizer â†’ Custom Optimizers
4. Loss Functions â†’ Custom Losses
5. Transforms â†’ Custom Transforms
""")

PART 1: Understanding Inheritance & Wrappers

INHERITANCE BASICS:
-------------------

class Parent:
    def method(self):
        return "Parent method"

class Child(Parent):  # Child inherits from Parent
    def method(self):  # Override parent method
        return "Child method"

WRAPPER CONCEPT:
----------------
- Extend existing class functionality
- Keep original behavior
- Add custom features
- Maintain compatibility

PyTorch Common Wrappers:
------------------------
1. Dataset â†’ Custom Dataset (your BilingualDataset)
2. nn.Module â†’ Custom Models
3. Optimizer â†’ Custom Optimizers
4. Loss Functions â†’ Custom Losses
5. Transforms â†’ Custom Transforms



In [21]:
# ============================================================================
# EXAMPLE 1: Simple Wrapper - Understanding the Pattern
# ============================================================================

print("\n" + "="*80)
print("EXAMPLE 1: Simple Dataset Wrapper")
print("="*80)

from torch.utils.data import Dataset
import torch

class SimpleDataset(Dataset):
    """
    Basic wrapper around Dataset class
    
    Pattern:
        1. Inherit from base class
        2. Initialize with super().__init__()
        3. Implement required methods
        4. Add custom functionality
    """
    
    def __init__(self, data, labels):
        """
        Step 1: Initialize parent class
        Step 2: Store your custom data
        """
        super().__init__()  # Call parent __init__
        
        # Custom initialization
        self.data = data
        self.labels = labels
        print(f"âœ“ SimpleDataset initialized with {len(data)} samples")
    
    def __len__(self):
        """
        Required method: Return dataset size
        """
        return len(self.data)
    
    def __getitem__(self, idx):
        """
        Required method: Return one sample
        Add your custom processing here
        """
        # Your custom logic
        x = self.data[idx]
        y = self.labels[idx]
        
        # Convert to tensors (custom behavior)
        return {
            'data': torch.tensor(x, dtype=torch.float32),
            'label': torch.tensor(y, dtype=torch.long),
            'idx': idx  # Extra field (custom addition)
        }



EXAMPLE 1: Simple Dataset Wrapper


In [22]:
# Test it
data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
labels = [0, 1, 2]

dataset = SimpleDataset(data, labels)
print(f"Dataset length: {len(dataset)}")
print(f"Sample 0: {dataset[0]}")
print(f"Sample 1: {dataset[1]}")

âœ“ SimpleDataset initialized with 3 samples
Dataset length: 3
Sample 0: {'data': tensor([1., 2., 3.]), 'label': tensor(0), 'idx': 0}
Sample 1: {'data': tensor([4., 5., 6.]), 'label': tensor(1), 'idx': 1}


In [23]:
# ============================================================================
# EXAMPLE 2: Dataset with Transforms (Adding Features)
# ============================================================================

print("\n" + "="*80)
print("EXAMPLE 2: Dataset with Transform Wrapper")
print("="*80)

class TransformDataset(Dataset):
    """
    Wrapper that adds transform functionality
    
    New Feature: Apply transformations to data
    """
    
    def __init__(self, data, labels, transform=None):
        """
        Add transform parameter - new feature!
        """
        super().__init__()
        self.data = data
        self.labels = labels
        self.transform = transform  # NEW: Optional transform
        print(f"âœ“ TransformDataset initialized")
        print(f"  - Samples: {len(data)}")
        print(f"  - Transform: {transform is not None}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx], dtype=torch.float32)
        y = torch.tensor(self.labels[idx], dtype=torch.long)
        
        # NEW: Apply transform if provided
        if self.transform:
            x = self.transform(x)
        
        return {'data': x, 'label': y}

# Custom transform function
def normalize_transform(x):
    """Normalize to 0-1 range"""
    return (x - x.min()) / (x.max() - x.min() + 1e-8)


EXAMPLE 2: Dataset with Transform Wrapper


In [24]:
# Test with transform
dataset_no_transform = TransformDataset(data, labels)
dataset_with_transform = TransformDataset(data, labels, transform=normalize_transform)

print("\nWithout transform:")
print(f"  Sample 0: {dataset_no_transform[0]['data']}")

print("\nWith transform:")
print(f"  Sample 0: {dataset_with_transform[0]['data']}")

âœ“ TransformDataset initialized
  - Samples: 3
  - Transform: False
âœ“ TransformDataset initialized
  - Samples: 3
  - Transform: True

Without transform:
  Sample 0: tensor([1., 2., 3.])

With transform:
  Sample 0: tensor([0.0000, 0.5000, 1.0000])


In [25]:
# ============================================================================
# EXAMPLE 3: Caching Dataset Wrapper (Performance Optimization)
# ============================================================================

print("\n" + "="*80)
print("EXAMPLE 3: Caching Dataset Wrapper")
print("="*80)

class CachedDataset(Dataset):
    """
    Wrapper that caches loaded samples
    
    Feature: Loads data once, stores in memory
    Use case: When data loading is expensive
    """
    
    def __init__(self, data, labels, enable_cache=True):
        super().__init__()
        self.data = data
        self.labels = labels
        self.enable_cache = enable_cache
        
        # NEW: Cache storage
        self.cache = {} if enable_cache else None
        print(f"âœ“ CachedDataset initialized")
        print(f"  - Caching: {enable_cache}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Check cache first
        if self.enable_cache and idx in self.cache:
            print(f"  âš¡ Cache hit for idx {idx}")
            return self.cache[idx]
        
        # Load data (expensive operation simulated)
        print(f"  ðŸ’¾ Loading idx {idx} from disk")
        x = torch.tensor(self.data[idx], dtype=torch.float32)
        y = torch.tensor(self.labels[idx], dtype=torch.long)
        
        sample = {'data': x, 'label': y}
        
        # Store in cache
        if self.enable_cache:
            self.cache[idx] = sample
        
        return sample


EXAMPLE 3: Caching Dataset Wrapper


In [26]:
# Test caching
dataset = CachedDataset(data, labels, enable_cache=True)

print("\nFirst access:")
_ = dataset[0]

print("\nSecond access (should be cached):")
_ = dataset[0]


âœ“ CachedDataset initialized
  - Caching: True

First access:
  ðŸ’¾ Loading idx 0 from disk

Second access (should be cached):
  âš¡ Cache hit for idx 0


In [27]:
# ============================================================================
# EXAMPLE 4: BilingualDataset - Our Real Example
# ============================================================================

print("\n" + "="*80)
print("EXAMPLE 4: BilingualDataset ( Code Explained)")
print("="*80)

print("""
class BilingualDataset(Dataset):
    '''
    Wrapper for translation data
    
    What it wraps: torch.utils.data.Dataset
    What it adds:
        - Tokenization
        - Padding to fixed length
        - Mask creation
        - Teacher forcing setup
    '''
    
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        # Step 1: Call parent init
        super().__init__()
        
        # Step 2: Store configuration
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len
        
        # Step 3: Pre-compute reusable values
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")])
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")])
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")])
    
    def __len__(self):
        # Return size of wrapped dataset
        return len(self.ds)
    
    def __getitem__(self, idx):
        # Step 1: Get raw data
        pair = self.ds[idx]
        src_text = pair['translation'][self.src_lang]
        tgt_text = pair['translation'][self.tgt_lang]
        
        # Step 2: Apply custom processing
        # - Tokenization
        # - Padding
        # - Mask creation
        
        # Step 3: Return processed sample
        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "label": label,
            # ... etc
        }

KEY WRAPPER FEATURES:
---------------------
âœ“ Wraps raw dataset (ds)
âœ“ Adds tokenization logic
âœ“ Adds padding logic
âœ“ Adds mask creation
âœ“ Returns training-ready tensors
""")



EXAMPLE 4: BilingualDataset ( Code Explained)

class BilingualDataset(Dataset):
    '''
    Wrapper for translation data
    
    What it wraps: torch.utils.data.Dataset
    What it adds:
        - Tokenization
        - Padding to fixed length
        - Mask creation
        - Teacher forcing setup
    '''
    
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        # Step 1: Call parent init
        super().__init__()
        
        # Step 2: Store configuration
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len
        
        # Step 3: Pre-compute reusable values
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")])
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")])
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")])
  

In [28]:
# ============================================================================
# EXAMPLE 5: Model Wrapper (nn.Module)
# ============================================================================

print("\n" + "="*80)
print("EXAMPLE 5: Custom Model Wrapper")
print("="*80)

import torch.nn as nn

class SimpleMLPWrapper(nn.Module):
    """
    Wrapper around nn.Module
    
    Pattern: Same as Dataset
        1. Inherit from nn.Module
        2. Call super().__init__()
        3. Define layers in __init__
        4. Implement forward()
    """
    
    def __init__(self, input_size, hidden_size, output_size, dropout=0.1):
        """
        Initialize custom model
        """
        super().__init__()  # REQUIRED: Call parent init
        
        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
        print(f"âœ“ SimpleMLPWrapper initialized")
        print(f"  - Input: {input_size}")
        print(f"  - Hidden: {hidden_size}")
        print(f"  - Output: {output_size}")
    
    def forward(self, x):
        """
        Define forward pass (REQUIRED)
        """
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


EXAMPLE 5: Custom Model Wrapper


In [29]:
# Test model wrapper
model = SimpleMLPWrapper(input_size=10, hidden_size=20, output_size=2)
x = torch.randn(5, 10)  # Batch of 5 samples , each having 10 features 
output = model(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")

âœ“ SimpleMLPWrapper initialized
  - Input: 10
  - Hidden: 20
  - Output: 2

Input shape: torch.Size([5, 10])
Output shape: torch.Size([5, 2])


In [30]:
# ============================================================================
# EXAMPLE 6: Advanced Model Wrapper with Custom Features
# ============================================================================

print("\n" + "="*80)
print("EXAMPLE 6: Advanced Model Wrapper")
print("="*80)

class ResidualBlock(nn.Module):
    """
    Custom block that can be reused
    """
    
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        # Residual connection
        return x + self.relu(self.fc(x))

class AdvancedModelWrapper(nn.Module):
    """
    Advanced wrapper with multiple features
    """
    
    def __init__(self, input_dim, output_dim, num_blocks=3):
        super().__init__()
        
        self.input_proj = nn.Linear(input_dim, 128)
        
        # Create multiple residual blocks
        self.blocks = nn.ModuleList([
            ResidualBlock(128) for _ in range(num_blocks)
        ])
        
        self.output_proj = nn.Linear(128, output_dim)
        
        print(f"âœ“ AdvancedModelWrapper initialized")
        print(f"  - Residual blocks: {num_blocks}")
    
    def forward(self, x):
        x = self.input_proj(x)
        
        # Pass through residual blocks
        for block in self.blocks:
            x = block(x)
        
        x = self.output_proj(x)
        return x


EXAMPLE 6: Advanced Model Wrapper


In [31]:
model = AdvancedModelWrapper(input_dim=10, output_dim=2, num_blocks=3)
x = torch.randn(5, 10)
output = model(x)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")

âœ“ AdvancedModelWrapper initialized
  - Residual blocks: 3

Input shape: torch.Size([5, 10])
Output shape: torch.Size([5, 2])


In [32]:
# ============================================================================
# EXAMPLE 7: Loss Function Wrapper
# ============================================================================

print("\n" + "="*80)
print("EXAMPLE 7: Custom Loss Function Wrapper")
print("="*80)

class WeightedLoss(nn.Module):
    """
    Wrapper around standard loss with custom weighting
    """
    
    def __init__(self, base_loss, weight_factor=1.0):
        super().__init__()
        self.base_loss = base_loss
        self.weight_factor = weight_factor
        print(f"âœ“ WeightedLoss initialized")
        print(f"  - Base loss: {base_loss.__class__.__name__}")
        print(f"  - Weight factor: {weight_factor}")
    
    def forward(self, pred, target):
        # Calculate base loss
        loss = self.base_loss(pred, target)
        
        # Apply custom weighting
        weighted_loss = loss * self.weight_factor
        
        return weighted_loss



EXAMPLE 7: Custom Loss Function Wrapper


In [33]:

# Test custom loss
base_criterion = nn.MSELoss()
custom_criterion = WeightedLoss(base_criterion, weight_factor=2.0)

pred = torch.randn(5, 10)
target = torch.randn(5, 10)

base_loss = base_criterion(pred, target)
custom_loss = custom_criterion(pred, target)

print(f"\nBase loss: {base_loss.item():.4f}")
print(f"Custom loss: {custom_loss.item():.4f}")
print(f"Ratio: {custom_loss.item() / base_loss.item():.1f}x")

âœ“ WeightedLoss initialized
  - Base loss: MSELoss
  - Weight factor: 2.0

Base loss: 2.1309
Custom loss: 4.2618
Ratio: 2.0x


In [34]:
# ============================================================================
# EXAMPLE 8: Optimizer Wrapper (Advanced)
# ============================================================================

print("\n" + "="*80)
print("EXAMPLE 8: Optimizer Wrapper (Advanced)")
print("="*80)

class WarmupOptimizer:
    """
    Wrapper around optimizer with learning rate warmup
    
    NOT inheriting from nn.Module (optimizers don't inherit from it)
    """
    
    def __init__(self, optimizer, warmup_steps):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.step_count = 0
        
        # Store initial learning rate
        self.base_lr = optimizer.param_groups[0]['lr']
        
        print(f"âœ“ WarmupOptimizer initialized")
        print(f"  - Warmup steps: {warmup_steps}")
        print(f"  - Base LR: {self.base_lr}")
    
    def step(self):
        """Custom step with warmup"""
        self.step_count += 1
        
        # Calculate warmup learning rate
        if self.step_count < self.warmup_steps:
            lr_scale = self.step_count / self.warmup_steps
            lr = self.base_lr * lr_scale
            
            # Update optimizer learning rate
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
            
            print(f"  Step {self.step_count}: LR = {lr:.6f} (warmup)")
        else:
            print(f"  Step {self.step_count}: LR = {self.base_lr:.6f} (normal)")
        
        # Call wrapped optimizer step
        self.optimizer.step()
    
    def zero_grad(self):
        """Delegate to wrapped optimizer"""
        self.optimizer.zero_grad()


EXAMPLE 8: Optimizer Wrapper (Advanced)


In [35]:
# Test warmup optimizer
dummy_model = nn.Linear(10, 2)
base_optimizer = torch.optim.Adam(dummy_model.parameters(), lr=0.001)
wrapped_optimizer = WarmupOptimizer(base_optimizer, warmup_steps=5)

print("\nSimulating training steps:")
for i in range(8):
    wrapped_optimizer.zero_grad()
    # Simulate backward pass
    wrapped_optimizer.step()

âœ“ WarmupOptimizer initialized
  - Warmup steps: 5
  - Base LR: 0.001

Simulating training steps:
  Step 1: LR = 0.000200 (warmup)
  Step 2: LR = 0.000400 (warmup)
  Step 3: LR = 0.000600 (warmup)
  Step 4: LR = 0.000800 (warmup)
  Step 5: LR = 0.001000 (normal)
  Step 6: LR = 0.001000 (normal)
  Step 7: LR = 0.001000 (normal)
  Step 8: LR = 0.001000 (normal)


In [36]:
# ============================================================================
# PATTERN SUMMARY
# ============================================================================

print("\n" + "="*80)
print("WRAPPER CLASS PATTERN SUMMARY")
print("="*80)

print("""
UNIVERSAL WRAPPER PATTERN:
--------------------------

class CustomWrapper(BaseClass):
    '''
    Step 0: Choose base class to wrap
    - Dataset
    - nn.Module
    - Loss function
    - etc.
    '''
    
    def __init__(self, *args, **kwargs):
        '''
        Step 1: Always call super().__init__()
        '''
        super().__init__()
        
        '''
        Step 2: Store configuration
        '''
        self.config = kwargs
        
        '''
        Step 3: Initialize custom components
        '''
        self.custom_component = ...
    
    def required_method(self, *args):
        '''
        Step 4: Implement required methods
        - Dataset: __len__, __getitem__
        - nn.Module: forward
        '''
        # Your custom logic
        pass
    
    def optional_custom_method(self):
        '''
        Step 5: Add optional custom methods
        '''
        pass

CHECKLIST:
----------
âœ“ Inherit from base class
âœ“ Call super().__init__()
âœ“ Implement required methods
âœ“ Add custom features
âœ“ Document your additions
âœ“ Test your wrapper

COMMON BASES TO WRAP:
---------------------
1. torch.utils.data.Dataset
   - Required: __init__, __len__, __getitem__

2. torch.nn.Module
   - Required: __init__, forward

3. torch.optim.Optimizer
   - Wrap existing optimizer
   - Add custom step logic

4. Functions
   - Create wrapper functions
   - Add logging, timing, validation
""")


WRAPPER CLASS PATTERN SUMMARY

UNIVERSAL WRAPPER PATTERN:
--------------------------

class CustomWrapper(BaseClass):
    '''
    Step 0: Choose base class to wrap
    - Dataset
    - nn.Module
    - Loss function
    - etc.
    '''
    
    def __init__(self, *args, **kwargs):
        '''
        Step 1: Always call super().__init__()
        '''
        super().__init__()
        
        '''
        Step 2: Store configuration
        '''
        self.config = kwargs
        
        '''
        Step 3: Initialize custom components
        '''
        self.custom_component = ...
    
    def required_method(self, *args):
        '''
        Step 4: Implement required methods
        - Dataset: __len__, __getitem__
        - nn.Module: forward
        '''
        # Your custom logic
        pass
    
    def optional_custom_method(self):
        '''
        Step 5: Add optional custom methods
        '''
        pass

CHECKLIST:
----------
âœ“ Inherit from base class
âœ“ Call super(

In [37]:

# ============================================================================
# PRACTICAL EXAMPLE: Complete Custom Dataset
# ============================================================================

print("\n" + "="*80)
print("PRACTICAL: Complete Custom Dataset Template")
print("="*80)

class MyCustomDataset(Dataset):
    """
    Template for any custom dataset
    Copy this and modify for your needs!
    """
    
    def __init__(self, 
                 data_path,
                 transform=None,
                 max_length=512,
                 cache=True):
        """
        Initialize dataset
        
        Args:
            data_path: Path to data
            transform: Optional transform
            max_length: Maximum sequence length
            cache: Whether to cache samples
        """
        super().__init__()
        
        # Store config
        self.data_path = data_path
        self.transform = transform
        self.max_length = max_length
        self.cache_enabled = cache
        
        # Load metadata (not full data!)
        self.metadata = self._load_metadata()
        
        # Initialize cache
        self.cache = {} if cache else None
        
        print(f"âœ“ MyCustomDataset initialized")
        print(f"  - Samples: {len(self.metadata)}")
        print(f"  - Max length: {max_length}")
        print(f"  - Caching: {cache}")
    
    def _load_metadata(self):
        """Helper: Load file paths or indices"""
        # In real code: load file list, database entries, etc.
        return list(range(100))  # Dummy metadata
    
    def __len__(self):
        """Return dataset size"""
        return len(self.metadata)
    
    def __getitem__(self, idx):
        """Load and return sample"""
        # Check cache
        if self.cache_enabled and idx in self.cache:
            return self.cache[idx]
        
        # Load data
        data = self._load_sample(idx)
        
        # Apply transform
        if self.transform:
            data = self.transform(data)
        
        # Cache if enabled
        if self.cache_enabled:
            self.cache[idx] = data
        
        return data
    
    def _load_sample(self, idx):
        """Helper: Load one sample"""
        # In real code: load from disk, database, etc.
        return {
            'data': torch.randn(10),
            'label': idx % 2,
            'idx': idx
        }

# Test template
dataset = MyCustomDataset(
    data_path="./data",
    transform=None,
    max_length=512,
    cache=True
)

print(f"\nDataset length: {len(dataset)}")
print(f"Sample 0: {dataset[0]}")

print("\n" + "="*80)
print("âœ… WRAPPER CLASSES TUTORIAL COMPLETE!")
print("="*80)

print("""
KEY TAKEAWAYS:
--------------
1. Wrappers extend existing classes
2. Always call super().__init__()
3. Implement required methods
4. Add custom features incrementally
5. Keep code modular and reusable

PRACTICE EXERCISE:
------------------
Write a wrapper for:
- Dataset that applies random augmentation
- Model that adds dropout layers
- Loss that combines multiple losses


""")



PRACTICAL: Complete Custom Dataset Template
âœ“ MyCustomDataset initialized
  - Samples: 100
  - Max length: 512
  - Caching: True

Dataset length: 100
Sample 0: {'data': tensor([-1.0317,  1.2901,  1.1302, -0.6877, -0.0383, -0.4107, -1.0900,  0.2785,
        -0.6659,  1.9071]), 'label': 0, 'idx': 0}

âœ… WRAPPER CLASSES TUTORIAL COMPLETE!

KEY TAKEAWAYS:
--------------
1. Wrappers extend existing classes
2. Always call super().__init__()
3. Implement required methods
4. Add custom features incrementally
5. Keep code modular and reusable

PRACTICE EXERCISE:
------------------
Write a wrapper for:
- Dataset that applies random augmentation
- Model that adds dropout layers
- Loss that combines multiple losses



