In [172]:
import sys

# Step 1: Data Pipeline, Preprocessing the data
from torchvision import transforms
from torchvision.datasets import CIFAR100
from pathlib import Path
from torch.utils.data import DataLoader

project_root = Path.cwd().parent
data_root = project_root / "data"

# Transform of data, resize, normalize, etc
# Compose accept a list of instance
train_tf = transforms.Compose([
    transforms.Resize(224),  # upsample the CIFAR from 32x32 to 224x224, the EfficientNet expects this
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet normalization
])

# Dataset and pytorch Dataloader
train_ds = CIFAR100(
    root=data_root,         # the path stores actual data
    train=True,             # create dataset from training set, otherwise from testset
    transform=train_tf
)

from torch.utils.data import Subset
import random
indices = random.sample(range(len(train_ds)), 10000)
train_ds_small = Subset(train_ds, indices)

train_loader = DataLoader(
    # train_ds,
    train_ds_small,
    batch_size=8,
    shuffle=True,
    num_workers=0,       # MPS and multiprocessing don't work well together, might slow down, use 0
    pin_memory=False     # Important for MPS       TODO: understand this
)

print(f"[ALL]Training dataset size: {len(train_ds)}")
print(f"[SMALL]Training dataset size: {len(train_ds_small)}")

[ALL]Training dataset size: 50000
[SMALL]Training dataset size: 10000


In [173]:
# Set up loggings
import logging, time

format = '%(asctime)s - %(levelname)s - %(filename)s - PID:%(process)d - TID:%(thread)d - %(message)s'
logger = logging.getLogger(__name__ + str(time.time()))     # in jupyter notebook, avoid accumulating multiple loggers
logger.setLevel(logging.DEBUG)
logger.propagate = False  # create a logger with getLogger(__name__), it sends messages to both your custom handler AND the root logger by default. Setting propagate = False stops this behavior.
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter(format))
logger.addHandler(handler)

logger.debug("Logger Initialization Completed")

2025-06-22 21:57:17,036 - DEBUG - 2310805002.py - PID:61797 - TID:8386830080 - Logger Initialization Completed


In [179]:
import torch.backends.mps

# Step 2: Model definition
import torchvision.models as models
from torch import nn, optim
from torchvision.models import EfficientNet_B0_Weights
import mlflow
from datetime import datetime
from mlflow.models import infer_signature

class MyEfficientNet(nn.Module):
    """
    Complexity is lowest for B0
    "Model","ϕ\phiϕ","Resolution","Parameters (M)","FLOPs (B)","Top-1 Accuracy (ImageNet)"
    "EfficientNet-B0","0","224x224","5.3","0.39","~77.1%"
    "EfficientNet-B1","1","240x240","7.8","0.70","~79.1%"
    "EfficientNet-B2","2","260x260","9.2","1.0","~80.1%"
    "EfficientNet-B3","3","300x300","12","1.8","~81.6%"
    "EfficientNet-B4","4","380x380","19","4.2","~82.9%"
    "EfficientNet-B5","5","456x456","30","9.9","~83.7%"
    "EfficientNet-B6","6","528x528","43","19","~84.0%"
    "EfficientNet-B7","7","600x600","66","37","~84.3%"
    # fine-tuning with pretrained weights, use lr=1e-4, too big is not good for fine-tuning
    """
    def __init__(self, batch_size=32, num_epochs=10, num_classes=100, learning_rate=1e-4, use_mlflow=True):
        super().__init__()
        self._num_classes = num_classes
        self._use_mlflow = use_mlflow
        self._best_acc = -1

        # Model Skeleton
        # train from scratch: no pretrained weights
        # models.efficientnet_b0(num_classes=self._num_classes, weights=None)
        # Here we use pretrained weights as it's too costly and slow to train this, no num classes needed
        self._model = models.efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)

        # IMAGENET classifier
        # Sequential((0): Dropout(p=0.2, inplace=True)
        #            (1): Linear(in_features=1280, out_features=1000, bias=True))
        # replace to 100 classes to fit CIFAR100
        self._model.classifier[1] = nn.Linear(1280, 100)
        self._device = self._set_device()
        # set it to device setup; by default, model weights are on CPU, and input data is on MPS device
        self._model.to(self._device)

        # Hyperparameters
        self._batch_size = batch_size
        self._num_epochs = num_epochs
        self._learning_rate = learning_rate

        # Loss function & Optimizer
        self._criterion = self._set_criterion()
        self._optimizer = self._set_optimizer()

    def _set_device(self):
        if torch.backends.mps.is_available():
            return torch.device("mps")
        elif torch.cuda.is_available():
            return torch.device("cuda")
        else:
            return torch.device("cpu")

    def _set_optimizer(self):
        return optim.Adam(self._model.parameters(), lr=self._learning_rate)

    def _set_criterion(self):
        return nn.CrossEntropyLoss()

    # Train Epoch for 1 Loop
    def train_epoch(self, data_loader, epoch_idx):
        """
        1. Prediction: tensor([20, 58, 76, 61, 44, 44, 92, 60, 51, 87, 18, 51, 60, 87, 12, 75, 64, 10, 63, 26, 64, 70, 76, 66, 81, 20, 90, 76, 93, 42, 23, 13], device='mps:0')
        Original tensor: [32, 100], 32 samples in one batch, 100 classes in each sample
        2. Mask: tensor([False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], device='mps:0')
        3. Mask tensor is boolean tenser in [32, 100], True/False
        4. Use argmax column direction, we get [32, 1], the index of the largest probability in each sample
        5. All the prediction, targets, loss, are tensor objects

        Args:
            loader: pytorch DataLoader instance for image iteration
        Return:
            epoch_loss: the average loss for current epoch: total loss / total batches
            epoch_acc: the average accuracy for current epoch: total correct prediction / total true labels
        """
        logger.info(f"Starting epoch {epoch_idx+1}, total batches: {len(data_loader)}")
        # print(f"Starting epoch {epoch_idx+1}, total batches: {len(data_loader)}")

        # Set model mode to train, so the Dropout, BatchNorm are up to work
        self._model.train()
        epoch_total_loss = 0.0  # Track full epoch
        running_loss = 0.0      # Track 10-batch average
        running_correct = 0
        running_total = 0

        for idx, (images, targets) in enumerate(data_loader):
            images = images.to(self._device)
            # labels represented in int, not one-hot encoding
            targets = targets.to(self._device)
            # Clears(Reset) accumulated gradients from the previous iteration
            self._optimizer.zero_grad()
            # inference result from the model, it will auto call forward(images)
            logits = self._model(images)
            # Calculate the loss, compare Inference Prediction vs. True Labels
            loss = self._criterion(logits, targets)
            # backtrack to update the gradients
            loss.backward()
            # Update weights and parameters using latest gradients from backward()
            self._optimizer.step()

            predictions = logits.argmax(dim=1)

            # how many predictions match true labels
            running_correct += (predictions == targets).sum().item()
            # total number of the true labels
            running_total += targets.size(0)
            running_loss += loss.item()
            epoch_total_loss += loss.item()

            # print out progress every 10 batches
            if idx % 10 == 9:
                avg_loss = running_loss / 10
                curr_acc = running_correct / running_total
                logger.info(f"Epoch {epoch_idx + 1}| Batch {idx + 1} | Avg Loss {avg_loss: .4f} | Accuracy: {curr_acc:.4f}")
                running_loss = 0.0      # reset for every 10 batches

        # return loss and accuracy for current training epoch
        epoch_loss = epoch_total_loss / len(data_loader)     # total loss / total batches
        epoch_acc = running_correct / running_total     # total correct prediction / total true labels

        logger.info(f"Epoch loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
        return epoch_loss, epoch_acc

    # MLFlow Setup
    def _setup_mlflow(self):
        """Initialize Mlflow experiment and log parameters"""
        mlflow.set_tracking_uri("http://127.0.0.1:5555")
        # setup name of experiment
        mlflow.set_experiment("EfficientNet-B0-CIFAR100")
        # this will start a global state for all the mflow.log_*()
        mlflow.start_run(run_name=datetime.now().strftime("%Y%m%d_%H%M%S"))

        mlflow.log_params({
            "model": "EfficientNet-B0",
            "dataset": "CIFAR-100",
            "batch_size": self._batch_size,
            "learning_rate": self._learning_rate,
            "epochs": self._num_epochs,
            "device": self._device.type,
            "optimizer": "Adam",
            "num_classes": self._num_classes
        })

    def _log_metrics(self, epoch, epoch_loss, epoch_acc):
        """Log metrics to Mlflow"""
        if not self._use_mlflow:
            return

        metrics = {
            "train_loss": epoch_loss,
            "train_accuracy": epoch_acc,
            "learning_rate": self._optimizer.param_groups[0]["lr"]
        }

        mlflow.log_metrics(metrics, step=epoch)

    def _log_best_model(self, epoch_acc):
        """Log model if it's the best so far"""
        if not self._use_mlflow:
            return

        # Only log current model if the accuracy is better
        if epoch_acc <= self._best_acc:
            return
        # update to better accuracy
        self._best_acc = epoch_acc

        # TODO: understand this
        sample_input = torch.randn(1, 3, 224, 224).to(self._device)
        signature = infer_signature(
                sample_input.cpu().numpy(),
                self._model(sample_input).detach().cpu().numpy()
        )

        # Log the model, the first one will always log, then keep overriding for better ones
        mlflow.pytorch.log_model(
            self._model,
            artifact_path="best_model",
            signature=signature
        )

    def train(self, train_loader):
        """Main training loop with Mlflow integration"""
        try:
            logger.info(f"Training {self._num_epochs} epochs on {self._device}")
            # init mlflow, new run & experiment
            if self._use_mlflow:
                self._setup_mlflow()

            for epoch in range(self._num_epochs):
                epoch_loss, epoch_acc = self.train_epoch(train_loader, epoch)
                logger.info(f"Epoch {epoch + 1}/{self._num_epochs}: Avg Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.4f}")

                # log metrics to mlflow
                if self._use_mlflow:
                    self._log_metrics(epoch, epoch_loss, epoch_acc)

                # update current best model to gcs
                self._log_best_model(epoch_acc)

        except Exception as e:
            logger.error(f"Train process failed, error: {e}")
        finally:
            if self._use_mlflow:
                logger.debug("Mlflow: Ending current run...")
                mlflow.end_run()
                logger.debug("Mlflow: current run ended successfully")


In [180]:
m = MyEfficientNet(num_epochs=15)
m.train(train_loader)

# TODO: frontend shows batchsize = 32, incorrect, fix it

2025-06-22 22:01:23,460 - INFO - 2456644747.py - PID:61797 - TID:8386830080 - Training 15 epochs on mps
2025-06-22 22:01:23,500 - INFO - 2456644747.py - PID:61797 - TID:8386830080 - Starting epoch 1, total batches: 1250
2025-06-22 22:02:16,209 - INFO - 2456644747.py - PID:61797 - TID:8386830080 - Epoch 1| Batch 10 | Avg Loss  4.5875 | Accuracy: 0.0000
2025-06-22 22:02:17,404 - INFO - 2456644747.py - PID:61797 - TID:8386830080 - Epoch 1| Batch 20 | Avg Loss  4.5722 | Accuracy: 0.0063
2025-06-22 22:02:18,606 - INFO - 2456644747.py - PID:61797 - TID:8386830080 - Epoch 1| Batch 30 | Avg Loss  4.5749 | Accuracy: 0.0125
2025-06-22 22:02:19,792 - INFO - 2456644747.py - PID:61797 - TID:8386830080 - Epoch 1| Batch 40 | Avg Loss  4.5565 | Accuracy: 0.0094
2025-06-22 22:02:20,971 - INFO - 2456644747.py - PID:61797 - TID:8386830080 - Epoch 1| Batch 50 | Avg Loss  4.5723 | Accuracy: 0.0100
2025-06-22 22:02:22,123 - INFO - 2456644747.py - PID:61797 - TID:8386830080 - Epoch 1| Batch 60 | Avg Loss  4.