In [2]:

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
import timm
import torch.nn as nn
import torch.optim as optim
import time
import copy

#device agnostic code and trying to use my own gpu
if torch.backends.mps.is_available():
    device = torch.device("mps")
    # For MPS, num_workers > 0 might not provide benefit and can sometimes cause issues.
    # It's often recommended to stick to 0 for MPS DataLoaders.
    # If you experience hangs later, ensure this remains 0.
    num_workers_dataloader = 0
elif torch.cuda.is_available(): # For external NVIDIA GPUs (unlikely on M4 Air)
    device = torch.device("cuda:0")
    num_workers_dataloader = 4 # Or adjust based on your external GPU setup
else:
    device = torch.device("cpu")
    # For CPU, num_workers=0 is most robust for avoiding hangs.
    # For a 10-core CPU, you could try 2-4 workers if you experience long data loading times
    # AFTER the initial batches start moving, but stick to 0 for robustness.
    num_workers_dataloader = 0

print(f"Using device: {device}")
print(f"DataLoader num_workers set to: {num_workers_dataloader}")

# --- Step 3: Data Preparation (CIFAR-10) ---
# Define CIFAR-10 specific normalization values (mean and std)
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2023, 0.1994, 0.2010)

# Define transformations for training and validation sets
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std)
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std)
    ]),
}

# Load CIFAR-10 Dataset
data_root_path = './data' # Define path to data directory
os.makedirs(data_root_path, exist_ok=True) # Ensure data directory exists

print("Loading train_dataset...")
train_dataset = datasets.CIFAR10(root=data_root_path, train=True, download=True, transform=data_transforms['train'])
print("train_dataset loaded.")

print("Loading val_dataset...")
val_dataset = datasets.CIFAR10(root=data_root_path, train=False, download=True, transform=data_transforms['val'])
print("val_dataset loaded.")

# Create a dictionary of datasets
image_datasets = {'train': train_dataset, 'val': val_dataset}

# Create DataLoaders
print("Creating DataLoaders...")
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32,
                             shuffle=True if x == 'train' else False,
                             num_workers=num_workers_dataloader,
                             pin_memory=True if device.type == 'mps' else False # Pin memory for MPS/CUDA if available
                            )
               for x in ['train', 'val']}
print("DataLoaders created.")

# Get dataset sizes and class names
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = train_dataset.classes

print(f"Dataset sizes: {dataset_sizes}")
print(f"Class names: {class_names}")
num_classes = len(class_names)

# --- Step 4: Model Loading and Modification (timm) ---
print("Creating model...")
model_name = 'resnet18'
model = timm.create_model(model_name, pretrained=True)

# Modify the final classification layer
if hasattr(model, 'fc'):
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
elif hasattr(model, 'classifier'):
    num_ftrs = model.classifier.in_features
    model.classifier = nn.Linear(num_ftrs, num_classes)
elif hasattr(model, 'head'):
    if isinstance(model.head, nn.Sequential) and isinstance(model.head[-1], nn.Linear):
        num_ftrs = model.head[-1].in_features
        model.head[-1] = nn.Linear(num_ftrs, num_classes)
    elif isinstance(model.head, nn.Linear):
        num_ftrs = model.head.in_features
        model.head = nn.Linear(num_ftrs, num_classes)
else:
    raise AttributeError(f"Couldn't find a common classification head for model: {model_name}. Please inspect its structure or choose another model.")
print("Model created and modified.")

# Move the model to the chosen device (MPS or CPU)
model = model.to(device)

# Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# --- Step 5: Training Loop ---
def train_model(model, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            print(f"Starting {phase} phase for Epoch {epoch}...")
            total_batches = len(dataloaders[phase])

            for i, (inputs, labels) in enumerate(dataloaders[phase]):
                # Print progress every 100 batches, or at the start/end of phase
                if (i + 1) % 100 == 0 or (i + 1) == total_batches or (i + 1) == 1:
                    print(f"   {phase.capitalize()} Batch [{i+1}/{total_batches}]")

                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            print(f"Finished {phase} phase for Epoch {epoch}.")

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.float() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model

# This ensures the training function is called when the notebook is run directly
if __name__ == '__main__':
    model_trained = train_model(model, criterion, optimizer, scheduler, num_epochs=10)
    # --- Step 6: Saving the Trained Model ---
    # Save the state_dict (weights) of your best performing model.
    model_save_path = 'app/image_classification_model.pth'
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True) # Ensure 'app' directory exists
    torch.save(model_trained.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

Using device: mps
DataLoader num_workers set to: 0
Loading train_dataset...
train_dataset loaded.
Loading val_dataset...
val_dataset loaded.
Creating DataLoaders...
DataLoaders created.
Dataset sizes: {'train': 50000, 'val': 10000}
Class names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Creating model...
Model created and modified.

Epoch 0/9
----------
Starting train phase for Epoch 0...
   Train Batch [1/1563]
   Train Batch [100/1563]
   Train Batch [200/1563]
   Train Batch [300/1563]
   Train Batch [400/1563]
   Train Batch [500/1563]
   Train Batch [600/1563]
   Train Batch [700/1563]
   Train Batch [800/1563]
   Train Batch [900/1563]
   Train Batch [1000/1563]
   Train Batch [1100/1563]
   Train Batch [1200/1563]
   Train Batch [1300/1563]
   Train Batch [1400/1563]
   Train Batch [1500/1563]
   Train Batch [1563/1563]
Finished train phase for Epoch 0.
train Loss: 0.9862 Acc: 0.6561
Starting val phase for Epoch 0...
   Val Batch [