# Training Operation for ResNet on CIFAR-10 (PyTorch)

This notebook contains the PyTorch training logic that will be embedded into the KFP pipeline.

## Workflow
1. Edit and test your training code in this notebook
2. Build the pipeline using `./build.sh` (reads directly from this notebook)
3. Submit the generated `outputs/pipeline.yaml` to Vertex AI

## Key Function
The main `trainOp()` function is what gets embedded into the KFP component.

In [None]:
from typing import List

In [None]:
def trainOp(
    model_relative_path: str = 'model', 
    model_name: str = 'resnet_pytorch', 
    prefix: str = '1711140944', 
    epochs: int = 50, 
    networks: int = 3, 
    batch_size: int = 128,
    classes_list: List[str] = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
    model_version: str = '@auto-timestamp',
    deploy_to_watsonxai: bool = False,
    kfp_output_path: str = None,  # KFP output path (for KFP 1.8 compatibility)
):
    """Main training operation using PyTorch.
    
    This function serves as the main entry point for training.
    It handles:
    - Data loading from CIFAR-10
    - Model training with ResNet architecture using PyTorch
    - Model export in Triton-compatible format
    - Supports both local runs and KFP artifact output
    
    Args:
        model_relative_path: Relative path for model output (default: 'model')
        model_name: Name of the model (default: 'resnet_pytorch')
        prefix: Data prefix (legacy, not used for CIFAR-10)
        epochs: Number of training epochs (default: 50)
        networks: Number of residual blocks, determines depth (default: 3)
        batch_size: Training batch size (default: 128)
        classes_list: Class names for CIFAR-10
        model_version: Model version number (default: '@auto-timestamp' - will be auto-generated by caller)
        deploy_to_watsonxai: Whether to deploy to Watson X.ai (default: False)
        kfp_output_path: KFP output path for artifact upload
    """
    import os
    import pathlib
    import shutil
    import json
    import numpy as np
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader
    from PIL import Image
    
    home = '/home/jovyan'
    
    raw_data_dir = os.path.join(home, "data")
    
    # Use KFP output path if provided (for artifact output)
    # Otherwise use local path (for local testing/runs)
    if kfp_output_path:
        output_model_dir = kfp_output_path
        print(f"Using KFP Output path: {output_model_dir}")
        print("Model will be saved as KFP artifact")
    else:
        # Save model directly under model_relative_path, not model_relative_path/model_name
        # Triton expects: /mnt/models/resnet_pytorch/, not /mnt/models/model/resnet_pytorch/
        output_model_dir = os.path.join(home, model_relative_path)
        print(f"Using local path: {output_model_dir}")
    
    print("PyTorch version:", torch.__version__)
    print("NumPy version:", np.__version__)
    print("Output directory:", output_model_dir)
    print("Model name:", model_name)
    print("Model version:", model_version)
    print("Classes:", classes_list)
    
    pathlib.Path(raw_data_dir).mkdir(parents=True, exist_ok=True)
    pathlib.Path(output_model_dir).mkdir(parents=True, exist_ok=True)
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # ResNet Block for CIFAR-10
    class BasicBlock(nn.Module):
        expansion = 1
        
        def __init__(self, in_planes, planes, stride=1):
            super(BasicBlock, self).__init__()
            self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(planes)
            
            self.shortcut = nn.Sequential()
            if stride != 1 or in_planes != self.expansion * planes:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(self.expansion * planes)
                )
        
        def forward(self, x):
            out = torch.relu(self.bn1(self.conv1(x)))
            out = self.bn2(self.conv2(out))
            out += self.shortcut(x)
            out = torch.relu(out)
            return out
    
    # ResNet Model
    class ResNet(nn.Module):
        def __init__(self, block, num_blocks, num_classes=10):
            super(ResNet, self).__init__()
            self.in_planes = 16
            
            self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(16)
            self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
            self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
            self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
            self.linear = nn.Linear(64 * block.expansion, num_classes)
        
        def _make_layer(self, block, planes, num_blocks, stride):
            strides = [stride] + [1] * (num_blocks - 1)
            layers = []
            for stride in strides:
                layers.append(block(self.in_planes, planes, stride))
                self.in_planes = planes * block.expansion
            return nn.Sequential(*layers)
        
        def forward(self, x):
            out = torch.relu(self.bn1(self.conv1(x)))
            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            # Use adaptive_avg_pool2d for TorchScript compatibility
            out = torch.nn.functional.adaptive_avg_pool2d(out, (1, 1))
            out = out.view(out.size(0), -1)
            out = self.linear(out)
            return out
    
    def ResNetCIFAR(num_blocks_per_layer, num_classes=10):
        """Create ResNet for CIFAR-10.
        
        Args:
            num_blocks_per_layer: Number of blocks per layer (e.g., 3 for ResNet-20)
            num_classes: Number of output classes
        """
        return ResNet(BasicBlock, [num_blocks_per_layer] * 3, num_classes)
    
    # Data loading
    print("Loading CIFAR-10 dataset...")
    
    # Custom ToTensor that avoids NumPy compatibility issues
    class PILToTensor:
        """Convert PIL Image to tensor without NumPy conversion."""
        def __call__(self, pic):
            if not isinstance(pic, Image.Image):
                raise TypeError(f'pic should be PIL Image. Got {type(pic)}')
            
            # Convert PIL Image to tensor manually
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
            img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
            # Put channels first and convert to float
            img = img.permute((2, 0, 1)).contiguous()
            return img.to(dtype=torch.float32).div(255)
    
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        PILToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        PILToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    trainset = torchvision.datasets.CIFAR10(
        root=raw_data_dir, train=True, download=True, transform=transform_train
    )
    # Use num_workers=0 to avoid multiprocessing issues with custom transforms
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
    
    testset = torchvision.datasets.CIFAR10(
        root=raw_data_dir, train=False, download=True, transform=transform_test
    )
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    print(f"Training samples: {len(trainset)}")
    print(f"Test samples: {len(testset)}")
    
    # Model setup
    n = int(networks)
    depth = n * 6 + 2  # For CIFAR-10: n=3 -> ResNet-20
    model_type = f'ResNet{depth}'
    print(f"Building model: {model_type}")
    
    model = ResNetCIFAR(num_blocks_per_layer=n, num_classes=10)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120, 160, 180], gamma=0.1)
    
    # Training function
    def train_epoch(epoch):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch} [{batch_idx}/{len(trainloader)}] '
                      f'Loss: {train_loss/(batch_idx+1):.3f} | '
                      f'Acc: {100.*correct/total:.2f}% ({correct}/{total})')
        
        return train_loss / len(trainloader), 100. * correct / total
    
    # Testing function
    def test_epoch(epoch):
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        acc = 100. * correct / total
        print(f'Test Epoch: {epoch} | Loss: {test_loss/len(testloader):.3f} | Acc: {acc:.2f}% ({correct}/{total})')
        return test_loss / len(testloader), acc
    
    # Training loop
    print(f"\nStarting training for {epochs} epochs...")
    best_acc = 0
    
    for epoch in range(epochs):
        print(f"\n=== Epoch {epoch+1}/{epochs} ===")
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        
        train_loss, train_acc = train_epoch(epoch)
        test_loss, test_acc = test_epoch(epoch)
        scheduler.step()
        
        print(f"epoch={epoch}")
        print(f"Train-Accuracy={train_acc:.6f}")
        print(f"Train-Loss={train_loss:.6f}")
        print(f"Validation-Accuracy={test_acc:.6f}")
        print(f"Validation-Loss={test_loss:.6f}")
        
        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            print(f"Saving best model (accuracy: {best_acc:.2f}%)...")
    
    print(f"\nTraining complete! Best accuracy: {best_acc:.2f}%")
    
    # Save model in Triton-compatible structure
    # Triton structure: <model_repository>/<model_name>/<version>/model.pt
    # where model_repository is output_model_dir
    model_root = os.path.join(output_model_dir, model_name)
    model_version_dir = os.path.join(model_root, model_version)
    
    # Clear existing version directory if it exists
    if os.path.isdir(model_version_dir):
        shutil.rmtree(model_version_dir)
    
    pathlib.Path(model_version_dir).mkdir(parents=True, exist_ok=True)
    
    # Save TorchScript model (Triton PyTorch backend requires TorchScript)
    print(f"\nSaving model in Triton format...")
    model.eval()
    example_input = torch.randn(1, 3, 32, 32).to(device)
    traced_model = torch.jit.trace(model, example_input)
    model_path = os.path.join(model_version_dir, 'model.pt')
    traced_model.save(model_path)
    print(f"Model saved to: {model_path}")
    
    # Create Triton config.pbtxt
    config_content = f'''name: "{model_name}"
platform: "pytorch_libtorch"
max_batch_size: {batch_size}

input [
  {{
    name: "INPUT__0"
    data_type: TYPE_FP32
    dims: [ 3, 32, 32 ]
  }}
]

output [
  {{
    name: "OUTPUT__0"
    data_type: TYPE_FP32
    dims: [ 10 ]
  }}
]

version_policy: {{ all {{ }} }}
'''
    
    config_path = os.path.join(model_root, 'config.pbtxt')
    with open(config_path, 'w') as f:
        f.write(config_content)
    print(f"Triton config saved to: {config_path}")
    
    # Save labels.txt at model root (for Triton)
    labels_path = os.path.join(model_root, 'labels.txt')
    with open(labels_path, 'w') as f:
        f.write('\n'.join(classes_list))
    print(f"Labels saved to: {labels_path}")
    
    # Save metadata.json for reference
    metadata = {
        'model_name': model_name,
        'model_type': model_type,
        'framework': 'pytorch',
        'platform': 'pytorch_libtorch',
        'num_classes': 10,
        'input_shape': [3, 32, 32],
        'epochs': epochs,
        'best_accuracy': float(best_acc),
        'classes': classes_list,
        'version': model_version,
    }
    
    metadata_path = os.path.join(model_root, 'metadata.json')
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    print(f"Metadata saved to: {metadata_path}")
    
    print(f"\nTriton model structure:")
    print(f"  {model_root}/")
    print(f"  ├── config.pbtxt")
    print(f"  ├── labels.txt")
    print(f"  ├── metadata.json")
    print(f"  └── {model_version}/")
    print(f"      └── model.pt")
    print(f"\nFull path: {output_model_dir}/{model_name}")
    
    print("\nDone!")

## Testing

Uncomment and run the cell below to test the trainOp function locally:

In [None]:
# Test trainOp locally (uncomment to run)
# trainOp(
#     model_relative_path='model',
#     model_name='resnet_pytorch',
#     prefix='testprefix',
#     epochs=2,  # Use 2 epochs for quick testing
#     networks=3,  # ResNet-20
#     batch_size=128,
# )