In [18]:
"""
1. Transform Configuration
"""
from torchvision import transforms

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [None]:
"""
2. Load Data & Analysis & Display
"""
from torchvision.datasets import CIFAR100
from torch.utils.data import Subset
from pathlib import Path
import matplotlib.pyplot as plt
import torch

project_root = Path.cwd().parent
data_root = project_root / "data"

train_ds = CIFAR100(
    root=data_root,
    train=True,     # create dataset from train set, otherwise will be test set
    transform=train_tf
)

# use small dataset to display image content
train_ds_small = Subset(train_ds, range(100, 200))

fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(32, 32))

# image in tensor form | class it belongs to
for idx, (img_tensor, cls_num) in enumerate(train_ds_small):
    img = img_tensor.permute(1, 2, 0)       # convert image to be plotted (H, W, C)
    img = img * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
       # denormalize all images back to normal
    row, col = idx // 10, idx % 10
    axes[row, col].imshow(img)
    axes[row, col].set_title(f"Class-{cls_num}")



In [None]:
"""
3. DataLoader for torch
"""

from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_ds,
    batch_size=128,         # IMPORTANT: must match what we use in the model params!!!
    shuffle=True,
    num_workers=0,          # MPS and multiprocessing don't work well together, might slow down
    pin_memory=False        # Important for MPS       TODO: understand this
)
print(f"[ALL]Training dataset size: {len(train_ds)}")

[ALL]Training dataset size: 50000


In [None]:
# Set up loggings
import logging, time, sys

format = '%(asctime)s - %(levelname)s - %(filename)s - PID:%(process)d - TID:%(thread)d - %(message)s'
logger = logging.getLogger(__name__ + str(time.time()))     # in jupyter notebook, avoid accumulating multiple loggers
logger.setLevel(logging.DEBUG)
logger.propagate = False  # create a logger with getLogger(__name__), it sends messages to both your custom handler AND the root logger by default. Setting propagate = False stops this behavior.
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(format))
logger.addHandler(handler)

logger.debug("Logger Initialization Completed")

2025-06-30 01:17:50,566 - DEBUG - 1454778220.py - PID:3660 - TID:8784977664 - Logger Initialization Completed


In [None]:
"""
4. Define Model
"""

from torch import nn, optim
import mlflow
from datetime import datetime
import torchvision
from torchvision.models import MobileNet_V2_Weights
from mlflow.models import infer_signature


class MyMobileNetV2(nn.Module):
    def __init__(self, num_epochs=10, batch_size=32, num_classes=100, learning_rate=5e-3, use_mlflow=True):
        super().__init__()
        self._num_classes = num_classes     # by default 100 for CIFAR100
        self._use_mlflow = use_mlflow
        self._best_acc = -1     # use to compare the current best accuracy

        # Model skeleton
        self._model = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)

        # IMAGENET classifier (input layer, output layer)
        # Sequential((0): Dropout(p=0.2, inplace=True)
        #            (1): Linear(in_features=1280, out_features=1000, bias=True))
        # replaces to 100 classes to fit CIFAR100
        self._model.classifier[1] = nn.Linear(1280, 100)

        self._device = self._set_device()
        # set it to device setup; by default, model weights are on CPU, and input data is on MPS device
        self._model.to(self._device)

        # Hyperparameters
        self._batch_size = batch_size
        self._num_epochs = num_epochs
        self._learning_rate = learning_rate

        # Loss function & Optimizer
        self._criterion = self._set_criterion()
        self._optimizer = self._set_optimizer()

    def _set_device(self):
        if torch.backends.mps.is_available():
            return torch.device("mps")
        elif torch.cuda.is_available():
            return torch.device("cuda")
        else:
            return torch.device("cpu")

    def _set_optimizer(self):
        # TODO: SGD understanding, along with Adam
        return optim.SGD(self._model.parameters(), lr=5e-3, momentum=0.9, weight_decay=4e-5)

    def _set_criterion(self):
        return nn.CrossEntropyLoss()

    @property
    def optimizer(self):
        return self._optimizer.__class__.__name__

    # Train Epoch for 1 Loop
    def train_epoch(self, data_loader, epoch_idx):
        """
        1. Prediction: tensor.shape = [1, 32], 32 samples
        Original tensor: [32, 100], 32 samples in one batch, 100 classes in each sample
        2. Mask tensor is boolean tenser in [32, 100], True/False
        3. Use argmax column direction, we get [32, 1], the index of the largest probability in each sample
        4. All the prediction, targets, loss, are tensor objects

        Args:
            data_loader: pytorch DataLoader instance for image iteration
            epoch_idx: current epoch number
        Return:
            epoch_loss: the average loss for current epoch: total loss / total batches
            epoch_acc: the average accuracy for current epoch: total correct prediction / total true labels
        """
        logger.info(f"Starting epoch {epoch_idx+1}, total batches: {len(data_loader)}")

        # Set model mode to train, so the Dropout, BatchNorm are up to work
        self._model.train()
        epoch_total_loss = 0.0  # Track full epoch
        running_loss = 0.0      # Track 10-batch average
        running_correct = 0
        running_total = 0

        for idx, (images, targets) in enumerate(data_loader):
            images = images.to(self._device)
            # labels represented in int, not one-hot encoding
            targets = targets.to(self._device)
            # Clears(Reset) accumulated gradients from the previous iteration
            self._optimizer.zero_grad()
            # inference result from the model, it will auto call forward(images)
            logits = self._model(images)
            # Calculate the loss, compare Inference Prediction vs. True Labels
            loss = self._criterion(logits, targets)
            # backtrack to update the gradients
            loss.backward()
            # Update weights and parameters using latest gradients from backward()
            self._optimizer.step()

            predictions = logits.argmax(dim=1)

            # how many predictions match true labels
            running_correct += (predictions == targets).sum().item()
            # total number of the true labels
            running_total += targets.size(0)
            running_loss += loss.item()
            epoch_total_loss += loss.item()

            # print out progress every 10 batches
            if idx % 10 == 9:
                avg_loss = running_loss / 10
                curr_acc = running_correct / running_total
                logger.info(f"Epoch {epoch_idx + 1}| Batch {idx + 1} | Avg Loss {avg_loss: .4f} | Accuracy: {curr_acc:.4f}")
                running_loss = 0.0      # reset for every 10 batches

        # return loss and accuracy for current training epoch
        epoch_loss = epoch_total_loss / len(data_loader)     # total loss / total batches
        epoch_acc = running_correct / running_total     # total correct prediction / total true labels

        logger.info(f"Epoch loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
        return epoch_loss, epoch_acc

    # MLFlow Setup
    def _setup_mlflow(self):
        """Initialize Mlflow experiment and log parameters"""
        mlflow.set_tracking_uri("http://127.0.0.1:5555")
        # setup name of experiment
        mlflow.set_experiment(f"{self._model.__class__.__name__}-{train_ds.__class__.__name__}")
        # this will start a global state for all the mflow.log_*()
        mlflow.start_run(run_name=datetime.now().strftime("%Y%m%d_%H%M%S"))

        mlflow.log_params({
            "model": self._model.__class__.__name__,
            "dataset": train_ds.__class__.__name__,
            "batch_size": self._batch_size,
            "learning_rate": self._learning_rate,
            "epochs": self._num_epochs,
            "device": self._device.type,
            "optimizer": self._optimizer.__class__.__name__,
            "num_classes": self._num_classes
        })

    def _log_metrics(self, epoch, epoch_loss, epoch_acc):
        """Log metrics to Mlflow"""
        if not self._use_mlflow:
            return

        metrics = {
            "train_loss": epoch_loss,
            "train_accuracy": epoch_acc,
            "learning_rate": self._optimizer.param_groups[0]["lr"]
        }

        mlflow.log_metrics(metrics, step=epoch)

    def _log_best_model(self, epoch_acc):
        """Log model if it's the best so far"""
        if not self._use_mlflow:
            return

        # Only log current model if the accuracy is better
        if epoch_acc <= self._best_acc:
            return
        # update to better accuracy
        self._best_acc = epoch_acc

        # TODO: understand this
        # FIX: Set to eval mode for signature creation
        self._model.eval()
        with torch.no_grad():
            sample_input = torch.randn(1, 3, 224, 224).to(self._device)
            signature = infer_signature(
                    sample_input.cpu().numpy(),
                    self._model(sample_input).detach().cpu().numpy()
            )
        self._model.train()   # Return to training mode

        # Log the model, the first one will always log, then keep overriding for better ones
        mlflow.pytorch.log_model(
            self._model,
            artifact_path="best_model",
            signature=signature
        )

    def train(self, train_loader):
        """Main training loop with Mlflow integration"""
        try:
            logger.info(f"Training {self._num_epochs} epochs on {self._device}")
            # init mlflow, new run & experiment
            if self._use_mlflow:
                self._setup_mlflow()

            for epoch in range(self._num_epochs):
                epoch_loss, epoch_acc = self.train_epoch(train_loader, epoch)
                logger.info(f"Epoch {epoch + 1}/{self._num_epochs}: Avg Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.4f}")

                # log metrics to mlflow
                if self._use_mlflow:
                    self._log_metrics(epoch, epoch_loss, epoch_acc)

                # update current best model to gcs
                self._log_best_model(epoch_acc)

        except Exception as e:
            logger.error(f"Train process failed, error: {e}")
        finally:
            if self._use_mlflow:
                logger.debug("Mlflow: Ending current run...")
                mlflow.end_run()
                logger.debug("Mlflow: current run ended successfully")


In [None]:
"""
5. Train
"""

m = MyMobileNetV2(batch_size=128, num_epochs=10)
m.train(train_loader)


2025-06-30 01:17:58,729 - INFO - 1561975117.py - PID:3660 - TID:8784977664 - Training 10 epochs on mps
2025-06-30 01:17:58,770 - INFO - 1561975117.py - PID:3660 - TID:8784977664 - Starting epoch 1, total batches: 391
2025-06-30 01:18:00,645 - INFO - 1561975117.py - PID:3660 - TID:8784977664 - Epoch 1| Batch 10 | Avg Loss  4.6819 | Accuracy: 0.0141
2025-06-30 01:18:01,100 - INFO - 1561975117.py - PID:3660 - TID:8784977664 - Epoch 1| Batch 20 | Avg Loss  4.6430 | Accuracy: 0.0141
2025-06-30 01:18:01,552 - INFO - 1561975117.py - PID:3660 - TID:8784977664 - Epoch 1| Batch 30 | Avg Loss  4.5995 | Accuracy: 0.0133
2025-06-30 01:18:02,009 - INFO - 1561975117.py - PID:3660 - TID:8784977664 - Epoch 1| Batch 40 | Avg Loss  4.5748 | Accuracy: 0.0146
2025-06-30 01:18:02,465 - INFO - 1561975117.py - PID:3660 - TID:8784977664 - Epoch 1| Batch 50 | Avg Loss  4.5199 | Accuracy: 0.0208
2025-06-30 01:18:02,919 - INFO - 1561975117.py - PID:3660 - TID:8784977664 - Epoch 1| Batch 60 | Avg Loss  4.5120 | Ac