# 📦 Week 09-10 · Notebook 06 · Mini-batch Training & Curriculum Scheduling

Design batching strategies and curricula to stabilize training on skewed manufacturing corpora.

## 🎯 Learning Objectives
- **Implement Custom Data Loaders:** Build a PyTorch `Dataset` and `DataLoader` that can handle our manufacturing data with its different severity levels.
- **Design a Curriculum Strategy:** Implement a curriculum learning approach where the model first learns from common, "easy" examples (routine maintenance) before being shown rare, "hard" examples (critical failures).
- **Balance Training Parameters:** Analyze the trade-offs between batch size, training throughput, and GPU memory constraints, which is critical in a production environment with shared resources.
- **Document Operational Procedures:** Draft a Standard Operating Procedure (SOP) for handing off long-running training jobs between shifts, ensuring continuity and accountability.

## 🧩 Scenario
Only 3% of maintenance logs capture critical failures, yet they matter most. You need a training curriculum that ramps from routine notes → cautionary signals → critical incidents.

In [None]:
import math
import random
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

torch.manual_seed(909)

## 🗂️ Curriculum Tags
Each sample is assigned a severity level. Critical incidents are oversampled later in the curriculum.

In [None]:
def create_curriculum_dataset(num_samples=2000):
    """
    Creates a dataset with three levels of maintenance severity.
    'critical' samples have a distinct feature pattern to make them learnable.
    """
    severities = ['routine', 'warning', 'critical']
    # Skewed distribution: most logs are routine.
    severity_weights = [0.85, 0.12, 0.03]
    records = []
    
    for i in range(num_samples):
        level = random.choices(severities, weights=severity_weights)[0]
        # Base embedding for all samples
        embedding = torch.randn(128)
        
        # Add a signal for more severe events
        if level == 'critical':
            embedding[:32] += 2.5  # Add a strong signal to the first 32 features
        elif level == 'warning':
            embedding[32:64] += 1.5 # Add a moderate signal to the next 32 features
            
        records.append({
            'id': f'LOG-{i:04d}',
            'severity': level,
            'embedding': embedding
        })
        
    return records

curriculum_records = create_curriculum_dataset()
df_severity = pd.Series([r['severity'] for r in curriculum_records])

print("Dataset Severity Distribution:")
print(df_severity.value_counts(normalize=True).round(3))

In [None]:
class MaintenanceLogDataset(Dataset):
    """
    A PyTorch Dataset to handle the maintenance log records.
    It returns the embedding and a numerical label for the severity.
    """
    def __init__(self, records):
        self.records = records
        # Map severity strings to integer labels
        self.severity_map = {'routine': 0, 'warning': 1, 'critical': 2}

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]
        embedding = record['embedding']
        label = self.severity_map[record['severity']]
        return embedding, torch.tensor(label, dtype=torch.long)

# We will use this dataset later with a custom sampler
full_dataset = MaintenanceLogDataset(curriculum_records)
print(f"Created a dataset with {len(full_dataset)} records.")
# Example: get the first item
embedding, label = full_dataset[0]
print(f"First item's embedding shape: {embedding.shape}, Label: {label}")

## 🗜️ Adaptive Batch Sampler
Start with routine samples, then gradually mix in higher-severity items.

In [None]:
from torch.utils.data import Sampler
import numpy as np

class CurriculumSampler(Sampler):
    """
    A custom sampler that implements curriculum learning.
    - Epoch 0-2: Only 'routine' samples.
    - Epoch 3-5: 'routine' and 'warning' samples.
    - Epoch 6+: All samples, with 'critical' ones repeated to ensure they are seen often.
    """
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.current_epoch = 0
        
        # Get indices for each severity level
        self.indices_by_severity = {
            'routine': [i for i, r in enumerate(dataset.records) if r['severity'] == 'routine'],
            'warning': [i for i, r in enumerate(dataset.records) if r['severity'] == 'warning'],
            'critical': [i for i, r in enumerate(dataset.records) if r['severity'] == 'critical']
        }

    def __iter__(self):
        # Determine which indices to use based on the current epoch
        if self.current_epoch < 3: # Phase 1: Easy samples
            indices = self.indices_by_severity['routine']
        elif self.current_epoch < 6: # Phase 2: Medium samples
            indices = self.indices_by_severity['routine'] + self.indices_by_severity['warning']
        else: # Phase 3: Hard samples
            # Oversample critical and warning cases to ensure the model learns from them
            indices = (self.indices_by_severity['routine'] + 
                       self.indices_by_severity['warning'] * 5 + 
                       self.indices_by_severity['critical'] * 15)
        
        random.shuffle(indices)
        return iter(indices)

    def __len__(self):
        # The length of the sampler changes with the curriculum phase
        if self.current_epoch < 3:
            return len(self.indices_by_severity['routine'])
        elif self.current_epoch < 6:
            return len(self.indices_by_severity['routine']) + len(self.indices_by_severity['warning'])
        else:
            return (len(self.indices_by_severity['routine']) + 
                    len(self.indices_by_severity['warning']) * 5 + 
                    len(self.indices_by_severity['critical']) * 15)

    def set_epoch(self, epoch):
        """Called by the training loop to advance the curriculum."""
        self.current_epoch = epoch

# --- Dummy Model and Training Loop to Demonstrate the Sampler ---
model = nn.Linear(128, 3) # 128 features, 3 severity classes
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

num_epochs = 10
batch_size = 64
sampler = CurriculumSampler(full_dataset, batch_size)
# Note: When using a custom sampler, `shuffle` in DataLoader must be False.
data_loader = DataLoader(full_dataset, batch_size=batch_size, sampler=sampler)

print("--- Starting Training with Curriculum Sampler ---")
for epoch in range(num_epochs):
    sampler.set_epoch(epoch) # IMPORTANT: Update the sampler's epoch
    
    phase = "Easy (Routine)"
    if 2 < epoch < 6:
        phase = "Medium (Routine + Warning)"
    elif epoch >= 6:
        phase = "Hard (All, Oversampled)"
        
    print(f"\\nEpoch {epoch+1}/{num_epochs} | Phase: {phase} | Batches: {len(data_loader)}")
    
    for i, (embeddings, labels) in enumerate(data_loader):
        optimizer.zero_grad()
        outputs = model(embeddings)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        if (i + 1) % 20 == 0:
            # Get the distribution of the current batch for inspection
            label_counts = pd.Series(labels.numpy()).value_counts().to_dict()
            print(f'  Batch {i+1}, Loss: {loss.item():.4f}, Batch Dist: {label_counts}')

print("\\n--- Training Complete ---")