# Lesson 4 - Transfer Learning with Fashion-MNIST

## Goal

In the previous lesson, we learned how to use a pretrained convolutional neural network to perform inference on images. We treated the model as a fixed function: an image goes in, a prediction comes out, and no learning happens during this process. This allowed us to focus on forward passes, logits, probabilities, and interpreting model outputs.

However, pretrained inference comes with an important limitation. The model can only predict categories it was originally trained on. For example, a ResNet trained on ImageNet can only choose from its fixed set of 1,000 labels, even if our real-world problem looks very different.

This lesson answers the natural next question: what if we want the model to learn **our own categories using our own data?**

In this notebook, we introduce **Transfer Learning**, a technique that allows us to reuse a pretrained model while retraining part of it for a new task. Using the **Fashion-MNIST** dataset, we build our first complete deep learning workflow—from data loading and model adaptation to training, validation, and evaluation. This marks the transition from using models as black boxes to actively shaping them for specific problems.

To run the code, experiment with the models, and explore outputs interactively, open this lesson on **Kaggle**: 

[Open in Kaggle](https://www.kaggle.com/code/niharikamatcha/04-transfer-learning-fashion-mnist)

## 1. What is Transfer Learning?

Transfer learning is the practice of starting from a model that has already learned useful representations and adapting it to a new task. Instead of training a neural network entirely from scratch, we reuse most of an existing model and only retrain the parts that need to change.

This idea works especially well for convolutional neural networks. The early layers of a CNN tend to learn general visual features such as edges, textures, and simple shapes. These features are useful across many image tasks, regardless of the specific dataset. By keeping these layers and modifying only the final classification layers, we can train models faster, use less data, and achieve better performance than starting from zero.

Transfer learning is not an advanced trick. It is the standard approach used in real-world deep learning systems.

Transfer learning means starting with a model that has already learned useful features
and adapting it to a new task. Instead of training everything from scratch, we reuse most
of the model and only train the parts that need to learn our new categories.

This is the standard approach in real-world deep learning.


## 2. Training vs Validation Data

In the inference-only lesson, we never worried about model performance over time because the model was already trained and fixed. We simply passed images through the network and observed the predictions. Once we start training a model ourselves, this changes. We now need a reliable way to measure whether the model is actually learning something useful.

To do this, we split our dataset into **two separate parts: Training data and Validation data.**

The **Training Data** is the portion of the dataset that the model actively learns from. These are the images that are passed through the network during training, where errors are computed, gradients are calculated, and model weights are updated. In short, this is the data the model uses to improve itself.

The **Validation Data** is kept separate and is never used to update the model’s weights. Instead, it is used only for evaluation. After or during training, we run the model on the validation data to see how well it performs on examples it has not learned from directly.

This separation is critical. If we evaluate the model on the same data it was trained on, we may get overly optimistic results. The model could simply memorize the training examples without learning patterns that generalize to new data. High performance on training data alone does not guarantee that the model will perform well in real-world scenarios.

Validation data acts as a reality check. It helps us understand whether improvements in training loss actually translate into better generalization. When training loss decreases but validation performance stops improving or gets worse, it is often a sign of overfitting. Throughout this lesson, we will interpret the model’s behavior by looking at both training and validation performance together, rather than relying on a single metric.


## 3. Dataset: Fashion-MNIST

To demonstrate transfer learning in a controlled and accessible way, we use the **Fashion-MNIST** dataset. This dataset is provided directly by `torchvision` and is designed as a drop-in replacement for the original MNIST digits dataset, but with more visually meaningful content.

Fashion-MNIST consists of 28×28 grayscale images, where each image contains a single clothing item such as a shirt, shoe, or bag. Every image belongs to exactly one of 10 clothing categories, and the labels correspond directly to what we want the model to predict. Unlike ImageNet, there is no mismatch between the dataset labels and our task, which keeps the learning objective clear and focused.

This simplicity makes Fashion-MNIST ideal for understanding the mechanics of transfer learning. The images are small, the categories are well-defined, and the problem setup avoids unnecessary complexity. As a result, the focus of this lesson is not on achieving state-of-the-art accuracy, but on clearly seeing how a pretrained model can be adapted, trained, and evaluated on a new dataset.


## 4. Setup

Now that we understand what transfer learning is and why we are using Fashion-MNIST, we begin by setting up the tools needed for this lesson. This includes importing PyTorch for building and training neural networks, torchvision for datasets, models, and transforms, and a few utility libraries for visualization and numerical operations.

We also import a pretrained ResNet-18 model and its associated weights. This connects directly back to the previous lesson, where we used a pretrained CNN purely for inference. In this lesson, we will reuse the same idea, but instead of stopping at prediction, we will modify and train part of the model.

Before doing any computation, we decide where the model and data will live. Deep learning workloads can run on either the CPU or the GPU. If a GPU is available, we use it to significantly speed up training. Otherwise, everything runs on the CPU. From this point forward, both the model and all tensors must be moved to this selected device to avoid runtime errors.

This setup step mirrors what you would see in real-world deep learning projects and establishes a consistent environment for the rest of the lesson.

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

from torchvision.datasets import FashionMNIST
from torchvision.transforms import v2
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import DataLoader

Choose CPU or GPU if available:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


## 5. Where the dataset lives and why it matters

Before loading any data, it is important to understand where datasets are stored and how PyTorch manages them. When we load Fashion-MNIST for the first time, PyTorch automatically *downloads the dataset* and saves it to disk in the directory we specify. In this case, the data is stored inside a local folder named `data/`.

On subsequent runs, PyTorch does not download the dataset again. Instead, it loads the files directly from disk. This behavior reflects how most real-world datasets are handled, where data is downloaded or collected once and reused across experiments.

Understanding this mechanism is important because it helps explain why datasets sometimes load instantly and other times take longer. It also prepares you for working with custom datasets later, where you will explicitly control how and where your data is stored.

With this in mind, we are now ready to look at the data itself.


## 6. Inspecting the Raw Data (before Transforms)

Before applying any preprocessing or transformations, it is useful to examine the raw dataset exactly as it is stored. To do this, we load the Fashion-MNIST training split without applying any transformations. This allows us to see the original image format and understand what the model will receive before preprocessing.

Each image in Fashion-MNIST is a 28×28 grayscale image. This means it has only one color channel and a very small spatial resolution. At this stage, the images are represented as PIL images, not tensors, and their pixel values have not been normalized or resized.

We also define the class names manually. Fashion-MNIST provides labels as integers from 0 to 9, but these numbers are not meaningful on their own. Mapping each numeric label to a human-readable class name allows us to interpret predictions later and makes visualizations much easier to understand.

By displaying a single raw image along with its label, we can clearly see what the dataset looks like before any preprocessing is applied. This step is intentional and highlights an important difference from the previous inference lesson.

In the inference lesson, we worked with a pretrained model whose input requirements were already fixed. Because of that, we moved directly to transformed images and focused only on running a forward pass to obtain predictions. There was no need to closely inspect the raw dataset, since we were not training or adapting the model in any way.

In this transfer learning lesson, the situation is different. We are adapting a pretrained CNN to a new dataset, which means we must first understand the data itself before deciding how to prepare it. Examining the raw images helps us understand their resolution, color format, and overall structure.


In [None]:
raw_dataset = FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=None
)

In [None]:
class_names = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

In [None]:
raw_image, raw_label = raw_dataset[0]

plt.figure(figsize=(2, 2))
plt.imshow(raw_image, cmap="gray")
plt.title(f"Original image (label={raw_label})")
plt.axis("off")
plt.show()

From the displayed image, we can see that Fashion-MNIST consists of low-resolution 28×28 grayscale images. Pretrained convolutional neural networks are typically trained on larger, three-channel RGB images. This difference in format explains why additional preprocessing steps such as resizing and channel conversion will be required before we can use a pretrained model.

Now that we understand what the raw data looks like, we can move on to preparing it so that it is compatible with a pretrained CNN.



## 7. Image transforms

In the previous sections, we examined the raw Fashion-MNIST data and identified a clear mismatch between the dataset format and the input requirements of a pretrained CNN. At this point, the goal is no longer to understand what the data looks like, but to systematically adapt it so that it can flow correctly through a pretrained ResNet during training.

Pretrained ResNet models are trained on ImageNet, which consists of RGB images with three color channels and a spatial resolution of 224×224. Fashion-MNIST images, on the other hand, are much smaller 28×28 grayscale images with only a single channel. If we were to feed these raw images directly into a pretrained ResNet, the model would fail because the input shape would be incompatible.

Image transforms are the bridge between our dataset and the pretrained model. They control how raw images are resized, converted into tensors, normalized, and, in the case of training data, augmented. These transformations ensure that every image entering the network has the correct shape, data type, and statistical properties.

This section also marks an important structural change in the workflow. From here onward, preprocessing is no longer uniform across all data. We deliberately introduce two separate transform pipelines: one for training and one for validation. The training pipeline is allowed to modify images stochastically, while the validation pipeline must remain stable and repeatable. This separation ensures that learning and evaluation serve different purposes without interfering with each other.

### Training Transforms (with data augmentation)

The **Training Transform** pipeline does more than simply prepare images for the model. It intentionally introduces controlled randomness through data augmentation. Operations such as random horizontal flips, small rotations, and brightness or contrast changes create slightly different versions of the same image every time it is loaded.

This randomness helps prevent the model from memorizing exact pixel patterns in the training data. Instead, the model is encouraged to learn more general and robust visual features that transfer better to unseen data. Although Fashion-MNIST images are simple, this step becomes increasingly important when working with small datasets or complex models like ResNet.

Several transformations here are essential rather than optional. Resizing converts the *28×28 images to 224×224* so they match the ResNet input size. Converting images to tensors and scaling their values ensures numerical stability during training. Repeating the single grayscale channel three times creates a synthetic RGB image so that it matches the expected input format of ImageNet-pretrained weights. Normalization aligns the input statistics with the distribution the pretrained model was originally trained on, which significantly improves transfer learning performance.


In [None]:
train_transform = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Lambda(lambda x: x.repeat(3, 1, 1)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=20),
    v2.ColorJitter(brightness=0.3, contrast=0.3),
    v2.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

### Validation Transforms (no augmentation)

**Validation Transforms** intentionally exclude any form of randomness. During validation, our goal is not to help the model learn but to measure how well it performs on data it has not seen before. If we applied random augmentations to validation images, the evaluation would become noisy and inconsistent.

The validation pipeline still includes resizing, tensor conversion, channel expansion, and normalization, because the model’s input requirements do not change. The key difference is that every validation image is processed in exactly the same way every time. This ensures that changes in validation performance reflect actual learning rather than randomness introduced by preprocessing.

This separation between training and validation transforms is a major conceptual shift from the inference lesson. Once we start training models, preprocessing becomes part of the learning process rather than just a formatting step.

In [None]:
val_transform = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Lambda(lambda x: x.repeat(3, 1, 1)),
    v2.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])


## Data Augmentation: Transforms as Part of Training

When training a neural network, our goal is not just to fit the training images, but to help the model learn general visual patterns that will transfer to new, unseen data. If the model sees the exact same images in the exact same form every time, it can begin to memorize details instead of learning robust features.

**Data Augmentation** is one way we encourage this kind of generalization. Instead of always showing the model the same image with the same orientation and appearance, we apply small random transformations each time the image is used during training. These transformations simulate natural variations that might occur in real-world data, such as slight rotations, changes in brightness, or horizontal flips.

In this code cell, we repeatedly apply the training transform to the same raw image and visualize the results. Each call to the transform produces a slightly different version of the image, even though the underlying label remains the same. This demonstrates that augmentation is not about changing what the image represents, but about exposing the model to many plausible variations of the same object.

The plotting logic simply arranges multiple augmented versions of the image in a grid so we can compare them side by side. The key idea to focus on is not the plotting itself, but the fact that a single input image can generate many distinct training examples through augmentation.

In [None]:
plt.figure(figsize=(8, 4))
for i in range(6):
    augmented = train_transform(raw_image)
    plt.subplot(2, 3, i + 1)
    plt.imshow(augmented.permute(1, 2, 0))
    plt.title(f"Augmented #{i+1}")
    plt.axis("off")

plt.tight_layout()
plt.show()

Looking at the Augmented images, we can see that each version still clearly represents the same object, but with small differences in orientation, position, and appearance. Some images are slightly rotated, some are flipped, and others have minor brightness or contrast changes. These variations are intentional and carefully controlled so that they do not alter the semantic meaning of the image.

This is an important contrast with the previous inference lesson. During inference, consistency is critical: the same input should always produce the same output. During training, controlled randomness is beneficial. By seeing many slightly different versions of the same object, the model learns to focus on essential visual features rather than memorizing exact pixel patterns.

It is also important to note that these random transformations are applied only to the training data. Validation images are processed deterministically, without augmentation. This ensures that validation performance reflects the model’s true ability to generalize, rather than its ability to adapt to randomness during training.

At this point, we have completed the full preparation of our image data. We understand what the raw images look like, how they are transformed to match a pretrained CNN, and how augmentation helps improve generalization during training. With this foundation in place, we can now move on to organizing the data into datasets and batches so it can be efficiently fed into the model.


## 8. Datasets (what is happening here?)

A **Dataset** in PyTorch represents a collection of examples and defines how to *retrieve a single example at a time*. It does not load all images into memory upfront. Instead, images are loaded from disk and transformed only when they are requested.

When we create the training and validation dataset objects, we attach different transform pipelines to each one. This means the same underlying Fashion-MNIST images can behave differently depending on whether they are used for training or validation. The **training dataset** applies *random augmentations*, while the **validation dataset** applies only *deterministic preprocessing*.

This design is intentional and efficient. It allows us to reuse the same dataset files while enforcing different behaviors depending on context. It also ensures that data augmentation is applied dynamically during training rather than being permanently stored.

This level of control did not matter in the inference lesson because the dataset was only used to feed images into a trained model. In training, however, datasets become an active component of the learning process.

A **Dataset** represents a collection of examples and knows how to return **one example
at a time**.

When we create these objects each image is loaded and transformed **on demand**, not all at once.

In [None]:
train_dataset = FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=train_transform
)

val_dataset = FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=val_transform
)

## 9. Batches and DataLoaders

Neural networks are not trained by processing one image at a time. Instead, they operate on batches of images. A **Batch** is a fixed number of examples that are processed together during a single forward and backward pass. After each batch, the model updates its weights and moves on to the next batch.

The **DataLoader** is responsible for creating these batches and serving them to the training loop. It handles batching, optional shuffling, and efficient iteration over the dataset. *Shuffling* is enabled for the training DataLoader so that the model does not see images in the same order every epoch, which helps prevent learning order-specific patterns. For validation, shuffling is disabled to ensure consistent evaluation.

By grouping images into batches, training becomes significantly more *memory-efficient* and computationally practical. This is especially important when working with large models like ResNet and when using GPUs.

The calculation of the number of batches gives us insight into how many weight updates occur during one epoch. In this case, multiplying the number of batches by the batch size confirms that one full pass over the training dataset corresponds to approximately 60,000 images. This reinforces the concept of an epoch as one complete pass through the training data.

In [None]:
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
num_batches = len(train_loader)
num_batches

In [None]:
938*64  # 938 batches, 64 images in each batch gives 60k images total (one epoch)

## 10. Load a pretrained ResNet and Adapt it

Up to this point, we have focused on preparing the data so that it can be consumed by a pretrained CNN. Now we shift our attention to the model itself. Transfer learning works by reusing a network that has already learned rich visual features and adapting it to a new task.

Here, we load a **ResNet-18** model that was pretrained on ImageNet. This pretrained network has already learned how to detect edges, textures, shapes, and object parts from millions of natural images. Rather than training an entire CNN from scratch, we take advantage of this prior knowledge.

The first important step in the code is freezing the pretrained layers. This is done by setting `requires_grad = False` for all existing parameters in the model. In PyTorch, `requires_grad` controls whether a parameter should be updated during backpropagation. By setting it to `False`, we tell PyTorch to keep these weights fixed and not compute gradients for them. This means the early and middle layers of ResNet will act purely as a feature extractor.

Next, we replace the final fully connected layer of the network. In ResNet, this layer is stored as `model.fc`. The original layer outputs 1,000 values, one for each ImageNet class. Since Fashion-MNIST has only 10 categories, we create a new nn.Linear layer with 10 outputs and assign it to `model.fc`. This new layer is randomly initialized.

By default, parameters of newly created layers have `requires_grad = True`. This means the new fully connected layer is trainable, while the rest of the network remains frozen. During training, only this final layer will learn, adjusting how the extracted features are mapped to Fashion-MNIST classes.

This combination: **freezing the pretrained backbone and training a new classifier head** is the *core idea of Transfer Learning*. We keep what the model already knows about visual structure and retrain only the part that is specific to our new task.

In [None]:
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)

Once the model architecture is defined and adapted, we need a way to measure how well it is performing and a mechanism to improve it during training.


## 11. Loss function and Optimizer

Training a neural network requires two essential components: a **Loss Function** and an **Optimizer**. The loss function tells us how far the model’s predictions are from the true labels, while the optimizer decides how the model’s parameters should be updated to reduce that error.

Since this is a multi-class classification problem, we use **Cross-Entropy loss**, which is standard for tasks where exactly one class is correct. For optimization, we use **Adam**, a popular adaptive optimizer that works well in practice with minimal tuning.

Notice that the optimizer is configured to update only the parameters of the newly added fully connected layer. This ensures that the pretrained backbone remains frozen while the classifier learns to adapt to Fashion-MNIST.

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

With the model, loss function, and optimizer in place, we are now ready to train the network and observe how learning actually happens.

## 12. Training and validation loop

Training does not happen on the entire dataset at once. Instead, the dataset is loaded using a DataLoader, which splits the data into small batches. The training loop runs once per batch, not once per dataset. Each iteration of the loop processes a single batch of images and labels, computes predictions, calculates the loss, and updates the model’s trainable parameters.

An epoch represents one full pass through the entire training dataset. Because we use batches, the model does not see the whole dataset at once. Instead, the dataset is divided into many batches, and the model processes them one batch at a time until all batches have been seen. Only after this full pass is completed do we move on to the next epoch.

During each epoch, the model alternates between **training mode** and **evaluation mode**. In training mode, gradients are computed and weights are updated. In evaluation mode, the model is used only for forward passes so that we can measure validation performance without affecting learned weights.

Separating training and validation is crucial. The training loss tells us how well the model is fitting the data it sees during learning, while the validation loss indicates how well the model generalizes to unseen data. 

This comparison helps us identify two common problems: **Underfitting** and **Overfitting**. Underfitting occurs when the *model is too simple or not trained enough to capture the patterns in the data*. In this case, both training loss and validation loss remain **high**, indicating that the model is struggling even on the training data. Overfitting occurs when the *model learns the training data too well, including noise or details that do not generalize*. In this situation, training loss continues to **decrease**, but validation loss stops improving or starts increasing.

By tracking both losses across epochs, we can observe how learning progresses and whether the model is improving in a healthy way. A good training process typically shows both training and validation loss decreasing together, or at least staying close to each other. Large gaps between them are a warning sign that the model may not generalize well.

In [None]:
%%time
epochs = 3
display_period = 100

train_losses, val_losses = [], []

# epoch: one full pass through entire training dataset
for epoch in range(epochs):
    print(f"Starting epoch {epoch+1}/{epochs}")
    model.train()
    train_loss = 0.0
    total = 0

    for batch_num, (images, labels) in enumerate(train_loader):
        # Each iteration processes one batch and updates the model
        if np.mod(batch_num, display_period) == 0:
            print(f"\tRunning batch {batch_num}/{num_batches}")
            
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        total += labels.size(0)

    train_losses.append(train_loss / total)

    model.eval()
    val_loss = 0.0
    val_total = 0

    with torch.inference_mode():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            logits = model(images)
            loss = loss_fn(logits, labels)

            val_loss += loss.item() * images.size(0)
            val_total += labels.size(0)

    val_losses.append(val_loss / val_total)

print("Done training!")

Once training is complete, raw numbers alone are not very informative. Visualizing how the losses evolve over time makes the learning process much easier to interpret.


## 13. Inspecting training vs validation loss

Here, we plot the training and validation losses across epochs. A decreasing training loss indicates that the model is learning from the data. A validation loss that decreases alongside it suggests that the model is generalizing well.

If the training loss continues to drop while the validation loss stagnates or increases, it may indicate overfitting. In this example, both curves decrease steadily, which suggests that freezing the backbone and training only the classifier is an effective strategy for this dataset.

This visualization provides a simple but powerful diagnostic tool for evaluating training behavior. 

In [None]:
plt.plot(train_losses, marker=".", label="Training Loss")
plt.plot(val_losses, marker=".", label="Validation Loss")
plt.legend()
plt.show()

After plotting the losses, it is important to pause and interpret what this visualization is telling us. The training loss consistently decreases across epochs, which means the model is successfully learning patterns from the training data. At the same time, the validation loss also decreases, indicating that the model is not merely memorizing the training images but is learning features that generalize to unseen data.

The gap between training and validation loss remains relatively small. This is a good sign and suggests that freezing the pretrained backbone and training only the final classification layer is an appropriate strategy for this task. If the validation loss had started increasing while the training loss kept decreasing, it would have been a warning sign of overfitting.

Loss curves like this act as a diagnostic tool. They help us decide whether to train longer, stop early, adjust learning rates, or change model capacity. In our case, the behavior of both curves suggests that training is stable and effective. With a trained model and reasonable validation behavior, the next step is to preserve this work so it can be reused later.

Now that training is complete and we have verified that the model learned successfully, we save the trained model to disk.


## 14. Saving the trained model

Once training is finished, we save the model’s learned parameters using `torch.save`. Specifically, we save the model’s `state_dict`, which contains all the trainable weights and biases learned during training.

Saving the `state_dict` instead of the entire model is a common PyTorch practice. It keeps the file lightweight and flexible, allowing the same architecture to be recreated later and loaded with these trained weights. The filename `"fashion_mnist_resnet.pth"` simply serves as a checkpoint that we can load again for inference or further training.

This step is important because training deep learning models can be time-consuming. By saving the model, we avoid having to retrain from scratch and can directly move to evaluation or deployment.


In [None]:
torch.save(model.state_dict(), "fashion_mnist_resnet.pth")

## 15. Inference on new data  ("test" data)

After training, we want to verify that the model can make sensible predictions on unseen images. To do this, we run inference on samples from the validation dataset, which were never used to update the model’s weights.

The helper function `predict_fashion_image` takes a dataset and an index, retrieves a single image and its true label, and prepares the image for the model. The call to `unsqueeze(0)` adds a batch dimension, since PyTorch models expect inputs in batch form, even if the batch contains only one image.

Inference is performed inside `torch.inference_mode()`, which disables gradient computation. This makes inference faster and ensures that no gradients are accidentally stored. The model outputs logits, which are raw, unnormalized scores for each class. These logits are then passed through `torch.softmax` to convert them into probabilities that are easier to interpret.

The `topk` function is used to retrieve the top 5 predicted classes along with their probabilities. This allows us to inspect not only the most confident prediction but also alternative guesses the model considered.

Finally, we switch the model to evaluation mode using `model.eval()` and randomly sample a few images from the validation set. This mirrors real-world usage, where a trained model is loaded once and then used repeatedly to make predictions. This final step completes the full transfer learning pipeline: preparing data, adapting a pretrained model, training a new classifier, evaluating performance, saving the model, and running inference on new data.

In [None]:
def predict_fashion_image(dataset, index):
    image, true_label = dataset[index]

    image_batch = image.unsqueeze(0).to(device)

    with torch.inference_mode():
        logits = model(image_batch)
        probs = torch.softmax(logits, dim=1)

    top_probs, top_idxs = probs.topk(5, dim=1)

    return (
        image.cpu(),
        true_label,
        top_probs.squeeze().cpu(),
        top_idxs.squeeze().cpu(),
    )

In [None]:
import random

In [None]:
model.eval()

random_indices = random.sample(range(len(val_dataset)), 5)

for idx in random_indices:
    image, true_label, top_probs, top_idxs = predict_fashion_image(val_dataset, idx)

    predicted_labels = [class_names[i] for i in top_idxs]
    probabilities = top_probs.tolist()

    plt.figure(figsize=(6, 3))

    plt.subplot(1, 2, 1)
    plt.imshow(image.permute(1, 2, 0))
    plt.title(f"True label: {class_names[true_label]}")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.barh(predicted_labels[::-1], probabilities[::-1])
    plt.xlabel("Probability")
    plt.title("Top 5 Predictions")

    plt.tight_layout()
    plt.show()


## Interpretation of the final model output

This section visualizes the final performance of the trained transfer learning model on unseen validation images. As mentioned earlier, the model is switched to evaluation mode using `model.eval()`. At this stage, it is important to understand what this change means internally. Evaluation mode disables training-specific behaviors such as dropout and ensures that batch normalization layers use learned statistics instead of updating them. This guarantees that the predictions shown here reflect the model’s true inference behavior rather than its training behavior.

A small subset of images is randomly selected from the validation dataset to qualitatively evaluate how well the model generalizes. For each selected image, the model predicts class probabilities and identifies the top five most likely classes. These predictions help us understand not only whether the model is correct, but also how confident it is and which alternative classes it considers plausible.

For each sample, the left plot displays the input image along with its true label. The right plot shows a horizontal bar chart of the top five predicted classes and their corresponding probabilities. When the correct class appears as the top prediction with a high probability, this indicates strong model confidence and good learning. In cases where the correct label is not ranked first or has a lower probability, it reveals areas where the model is uncertain or confused between visually similar classes, such as shirts versus coats or sneakers versus ankle boots.

The probability distribution itself is very informative. Sharp distributions, where one class dominates, indicate confident predictions. Flatter distributions suggest ambiguity, often caused by overlapping visual features between clothing categories in Fashion MNIST. This qualitative analysis complements numerical metrics like accuracy by showing *how* the model thinks, not just whether it is right or wrong.

During visualization, a warning appears stating that the input data is being clipped to a valid range for display. This happens because the images were normalized during preprocessing using mean and standard deviation values expected by the pretrained model. As a result, pixel values no longer lie in the standard `[0, 1]` range required by `imshow`. Matplotlib automatically clips these values for visualization. This warning does **not** indicate an error in training or prediction and does **not** affect model performance. It only impacts how the image is displayed, not how the model processes it.

Overall, the visual results show that the **transfer learning model has successfully learned meaningful representations from the data**. The model correctly identifies many validation samples with high confidence, demonstrating effective feature reuse from the pretrained backbone. Occasional misclassifications are expected due to class similarity and limited image resolution, but the probability rankings still reflect reasonable semantic understanding.

## Reflection: what would you try next?

This reflection step encourages us to think beyond a single training run and understand that building good Machine Learning models is an **iterative process**. A model rarely performs perfectly on the first attempt, and small changes in training strategy can have a significant impact on performance.

If the validation loss stops improving, it usually means the model has reached the limit of what it can learn under the current setup. Reducing the number of epochs can help prevent overfitting, especially if the model begins to memorize training data instead of learning general patterns. Adjusting data augmentation can either make the task slightly easier (by reducing randomness) or help the model generalize better (by adding more realistic variation). 

Changing the batch size or learning rate affects how the model updates its weights. A smaller batch size can introduce more variability that sometimes improves generalization, while a different learning rate can help the model converge more smoothly or escape poor solutions. Fine-tuning more layers of the pretrained network allows the model to adapt deeper feature representations to the new dataset, but this should be done carefully to avoid overfitting.

This reflection also opens the door to broader computer vision tasks. In this lesson, we focused on **Image Classification**, where the model predicts a single label for an entire image. In **Object Detection**, the model would locate and classify multiple objects within an image. In **Image Segmentation**, the model would assign a class label to each individual pixel. Generative models such as GANs take a different approach entirely, learning to generate new images that resemble the training data rather than simply labeling them.

Together, these ideas highlight that transfer learning is not a one-off trick but a foundation that can be extended to many real-world vision problems. By understanding how to interpret results, reflect on model behavior, and decide what to try next, you complete the transition from *running a model* to *thinking like a machine learning engineer*.