In [0]:
%pip install -r requirements.txt
%restart_python

In [0]:
%run ../setup/00_setup

In [0]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader
from accelerate import Accelerator
from accelerate.utils import set_seed
import mlflow
import logging
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from IPython.display import display, HTML
import json 
import tempfile

os.environ['HF_DATASETS_CACHE'] = cifar_cache

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set seed for reproducibility
set_seed(42)

# Display environment information
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [0]:
# Configuration
config = {
    "batch_size": 128,
    "num_epochs": 1,  # Reduced for demonstration
    "learning_rate": 0.001,
    "weight_decay": 1e-4,
    "num_classes": 10,
    "save_every": 1,
    "experiment_name": "resnet50_cifar10"
}

# CIFAR-10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

# Set up MLflow - only on the main process
def is_main_process():
    """Check if this is the main process (rank 0)"""
    # For Accelerate
    if hasattr(accelerator, "is_main_process"):
        return accelerator.is_main_process
    # For distributed training with torch.distributed
    if torch.distributed.is_initialized():
        return torch.distributed.get_rank() == 0
    # Default case (not distributed)
    return True

# Initialize accelerator first to access its properties
accelerator = Accelerator()

# # Set up MLflow only on the main process
# if is_main_process():
#     mlflow.set_experiment(experiment_path)
#     print(f"MLflow experiment set to: {experiment_path}")
# else:
#     print("This is not the main process. MLflow logging will be disabled.")

# # Save the run id
# run_id = None

# if is_main_process():
#     mlflow_run = mlflow.start_run()
#     run_id = mlflow_run.info.run_id

In [0]:
%sh nvidia-smi


The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class.

## Data Splits

#### Total Rows: 60000


| Split       | # of examples |
|-------------|---------------|
| Train       | 50,000    |
| Validation  | 10,000       |

In [0]:
from utils import hf_dataset_utilities as hf_util

cifar_dataset = hf_util.hfds_download_volume(
  hf_cache = os.environ['HF_DATASETS_CACHE'],
  dataset_path= 'uoft-cs/cifar10',
  trust_remote_code = True, 
  disable_progress = False, 
)

In [0]:
CIFARDataset = hf_util.create_torch_image_dataset(
  image_key="img",
  label_key="label"
)

In [0]:
ds_train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

ds_test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [0]:
train_dataset = CIFARDataset(cifar_dataset['train'], transform=ds_train_transforms)
test_dataset = CIFARDataset(cifar_dataset['test'], transform=ds_test_transforms)

In [0]:
num_cores = spark.sparkContext.defaultParallelism
num_cores

# Update your number of workers for proper subprocess count (spark executor instances may not work)

In [0]:
# num_executors = int(spark.sparkContext.getConf().get("spark.executor.instances", "1"))
num_executors = 2

num_workers = int(num_cores/num_executors)

In [0]:
train_loader = DataLoader(
    train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=num_workers, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=num_workers, pin_memory=True
)

In [0]:
# Display some sample images
def show_images(dataset, num_images=5):
    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
    for i in range(num_images):
        img, label = dataset[i]
        # Denormalize the image
        img = img.numpy().transpose((1, 2, 0))
        mean = np.array([0.4914, 0.4822, 0.4465])
        std = np.array([0.2023, 0.1994, 0.2010])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        
        axes[i].imshow(img)
        axes[i].set_title(f"{classes[label]}")
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

show_images(train_dataset)

In [0]:
print("Initializing ResNet50 model...")
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# We create the final linear layer for the classification 
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, config["num_classes"])

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(), 
    lr=config["learning_rate"], 
    weight_decay=config["weight_decay"]
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["num_epochs"])

# Enable MLflow autologging for PyTorch only on the main process
if is_main_process():
    mlflow.pytorch.autolog(log_models=True, log_every_n_epoch=1)

# Configure accelerator with MLflow tracking, but make it log only from the main process
accelerator = Accelerator(log_with="mlflow" if is_main_process() else None)

# Check if run_id exists, if not set to None
run_id = mlflow.active_run().info.run_id if mlflow.active_run() else None

# Initialize trackers only on the main process (experiment path from setup run)
if is_main_process() and not run_id:
    accelerator.init_trackers(experiment_path, config=config)

model, optimizer, train_loader, test_loader, scheduler = accelerator.prepare(
    model, optimizer, train_loader, test_loader, scheduler
)

print("Model and training environment prepared!")
print(f"Using {accelerator.num_processes} GPU(s)")
print(f"This process has rank: {accelerator.process_index}")
print(f"Is main process: {accelerator.is_main_process}")

In [0]:
run_id = mlflow.active_run().info.run_id if mlflow.active_run() else None
run_id

In [0]:
def train_model(run_id=None):
    """
    Train the model and log artifacts with MLflow
    
    Args:
        run_id (str, optional): Existing MLflow run ID to continue tracking.
                               If None, a new run will be created.
    
    Returns:
        tuple: (history dict, mlflow run_id)
    """
    # Training loop
    print("Starting training...")
    best_accuracy = 0.0
    history = {
        'train_loss': [], 'train_acc': [],
        'test_loss': [], 'test_acc': [],
        'lr': []
    }
    
    # Start or resume an MLflow run only on the main process
    if accelerator.is_main_process:
        if run_id:
            # Log model parameters (only for new runs)
            mlflow.log_params({
                "model_type": "ResNet50",
                "batch_size": config["batch_size"],
                "epochs": config["num_epochs"],
                "learning_rate": config["learning_rate"],
                "weight_decay": config["weight_decay"],
                "optimizer": "Adam",
                "scheduler": "CosineAnnealingLR",
                "num_gpus": accelerator.num_processes
            })
            print(f"Resuming MLflow run with ID: {run_id}")
        else:
            # Start new run
            mlflow_run = mlflow.start_run()
            run_id = mlflow_run.info.run_id
            
            # Log model parameters (only for new runs)
            mlflow.log_params({
                "model_type": "ResNet50",
                "batch_size": config["batch_size"],
                "epochs": config["num_epochs"],
                "learning_rate": config["learning_rate"],
                "weight_decay": config["weight_decay"],
                "optimizer": "Adam",
                "scheduler": "CosineAnnealingLR",
                "num_gpus": accelerator.num_processes
            })
            
            print(f"New MLflow run started with ID: {run_id}")
    
    # Make sure all processes are synced
    if accelerator.num_processes > 1:
        torch.distributed.barrier()
    
    for epoch in range(config["num_epochs"]):
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']}")
        for inputs, targets in progress_bar:
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            accelerator.backward(loss)
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            progress_bar.set_postfix({
                'loss': train_loss / (progress_bar.n + 1),
                'acc': 100. * correct / total
            })
        
        # Gather metrics from all processes
        train_loss = accelerator.gather(torch.tensor(train_loss).to(accelerator.device)).sum().item()
        total = accelerator.gather(torch.tensor(total).to(accelerator.device)).sum().item()
        correct = accelerator.gather(torch.tensor(correct).to(accelerator.device)).sum().item()
        
        train_accuracy = 100. * correct / total
        train_loss = train_loss / len(train_loader.dataset) * accelerator.num_processes
        
        # Evaluation phase
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in test_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        # Gather test metrics from all processes
        test_loss = accelerator.gather(torch.tensor(test_loss).to(accelerator.device)).sum().item()
        total = accelerator.gather(torch.tensor(total).to(accelerator.device)).sum().item()
        correct = accelerator.gather(torch.tensor(correct).to(accelerator.device)).sum().item()
        
        test_accuracy = 100. * correct / total
        test_loss = test_loss / len(test_loader.dataset) * accelerator.num_processes
        current_lr = scheduler.get_last_lr()[0]
        
        # Update learning rate
        scheduler.step()
        
        # Store history (only on main process to avoid duplicates)
        if accelerator.is_main_process:
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_accuracy)
            history['test_loss'].append(test_loss)
            history['test_acc'].append(test_accuracy)
            history['lr'].append(current_lr)
        
            # Log metrics with MLflow
            mlflow.log_metrics({
                "train_loss": train_loss,
                "train_accuracy": train_accuracy,
                "test_loss": test_loss,
                "test_accuracy": test_accuracy,
                "learning_rate": current_lr
            }, step=epoch)
        
        # Log with accelerator (it handles the main process check internally)
        accelerator.log({
            "train_loss": train_loss,
            "train_accuracy": train_accuracy,
            "test_loss": test_loss,
            "test_accuracy": test_accuracy,
            "learning_rate": current_lr
        })
        
        # Print progress (from all processes for debugging, but could limit to main)
        print(f"[Rank {accelerator.process_index}] Epoch {epoch+1}/{config['num_epochs']} - "
              f"Train Loss: {train_loss:.4f}, "
              f"Train Acc: {train_accuracy:.2f}%, "
              f"Test Loss: {test_loss:.4f}, "
              f"Test Acc: {test_accuracy:.2f}%")
        
        # Make sure all processes are synced before saving checkpoints
        if accelerator.num_processes > 1:
            torch.distributed.barrier()
        
        # Save checkpoint using MLflow (only on main process)
        if accelerator.is_main_process and ((epoch + 1) % config["save_every"] == 0 or epoch == config["num_epochs"] - 1):
            unwrapped_model = accelerator.unwrap_model(model)
            
            # In Databricks, you can directly log the model without temporary files
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': unwrapped_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'test_accuracy': test_accuracy,
            }
            
            # Log directly to MLflow in Databricks
            mlflow.pytorch.log_state_dict(checkpoint, f"checkpoints/epoch_{epoch+1}")
            print(f"Checkpoint saved to MLflow run {run_id}, epoch {epoch+1}")
        
        
        # Save best model using MLflow (only on main process)
        if accelerator.is_main_process and test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            unwrapped_model = accelerator.unwrap_model(model)
            
            # Log best model directly to MLflow in Databricks
            mlflow.pytorch.log_model(
                unwrapped_model, 
                "best_model"
            )
            
            # Also log model metadata
            model_info = {
                'epoch': epoch + 1,
                'test_accuracy': test_accuracy,
                'config': config,
                'classes': classes
            }
            mlflow.log_dict(model_info, "best_model/metadata.json")
            
            print(f"New best model saved with accuracy: {best_accuracy:.2f}%")
    
    # End the MLflow run (only on main process)
    if accelerator.is_main_process:
        # Log the final history directly as a dictionary
        mlflow.log_dict(history, "training_history.json")
        print(f"Training completed. Best accuracy: {best_accuracy:.2f}%")

    # End training
    accelerator.end_training()
    
    # Make sure all processes are synced before returning
    if accelerator.num_processes > 1:
        torch.distributed.barrier()
    
    # Broadcast run_id from rank 0 to all processes
    if accelerator.num_processes > 1:
        if accelerator.is_main_process:
            run_id_tensor = torch.tensor([ord(c) for c in run_id], dtype=torch.long, device=accelerator.device)
            run_id_length = torch.tensor([len(run_id)], dtype=torch.long, device=accelerator.device)
        else:
            run_id_tensor = torch.zeros(100, dtype=torch.long, device=accelerator.device)  # Assume max length 100
            run_id_length = torch.zeros(1, dtype=torch.long, device=accelerator.device)
        
        # Broadcast length first
        torch.distributed.broadcast(run_id_length, src=0)
        # Then broadcast the actual run_id
        torch.distributed.broadcast(run_id_tensor[:run_id_length.item()], src=0)
        
        if not accelerator.is_main_process:
            run_id = ''.join([chr(i) for i in run_id_tensor[:run_id_length.item()].cpu().numpy()])

    if accelerator.is_main_process:
        # Log the final history directly as a dictionary
        mlflow.end_run()
    
    # Return the history and run_id to all processes
    return history, run_id

In [0]:
# Run the training process
history, run_id = train_model(run_id)

print(f"MLflow run ID: {run_id}")

In [0]:
import mlflow
import mlflow.pytorch

# Set the registry URI to ensure correct path resolution
mlflow.set_registry_uri("databricks-uc")

# Load the model using the run_id
model_uri = f"runs:/{run_id}/best_model"
model = mlflow.pytorch.load_model(model_uri)

# Display the model architecture
print(model)

In [0]:
import torch
import matplotlib.pyplot as plt

# Get one image from the test set
test_iter = iter(test_loader)
images, labels = next(test_iter)  # Use next() function
image = images[0].unsqueeze(0)  # Add batch dimension

# Run inference
model.eval()
with torch.no_grad():
    output = model(image)
    _, predicted = torch.max(output, 1)

# Display the image and the prediction
plt.imshow(image.squeeze().permute(1, 2, 0).cpu().numpy())
plt.title(f"Predicted: {predicted.item()}, Actual: {labels[0]}")
plt.axis('off')
plt.show()

### Clear GPU Memory

In [0]:
%restart_python