In this section of the notebook, we will be fine-tuning a pre-trained ResNet model using our specific dataset. The concept of fine-tuning involves training a pre-existing model, which has been trained on a large dataset, on a new and usually smaller dataset. The idea behind fine-tuning is to harness the features that the pre-existing model has learned, and adapt these to the new task. For this purpose, we are going to use a ResNet model which has been pre-trained on the ImageNet dataset. The ImageNet dataset is a large and diverse dataset which covers a wide variety of categories, and therefore it serves as a good starting point for many vision tasks. During the fine-tuning, we will adjust the final layers of our ResNet model so it can work well with our specific data. We will also decide whether to freeze the initial layers or allow some or all of them to change. By doing this, our model will be able to learn patterns that are specific to our dataset, which could potentially improve its performance.

In [None]:
# Import torch for PyTorch functionality
import torch

# Import PIL's Image module for image manipulation
from PIL import Image

# Import transforms and models from torchvision
# Transforms module provides common image transformations, 
# Models provides access to a variety of pre-trained models 
import torchvision.transforms as transforms
import torchvision.models as models

# Import DataLoader and Dataset from torch.utils.data
# DataLoader wraps an iterable around the Dataset to enable easy access to the samples
# Dataset is an abstract class representing a dataset which other datasets should subclass
from torch.utils.data import DataLoader, Dataset

# Import optim module from torch for optimization algorithms
# Optim module has various optimization algorithm implementations like Adam, SGD etc.
from torch import optim

# Import nn module from torch for all neural network modules
import torch.nn as nn

In [None]:
# Define a custom dataset for handling images
class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        """
        Constructor for the ImageDataset class.
        Params:
            image_paths: A list of paths to images.
            transform: Optional torchvision transforms to be applied to the images.
        """
        self.image_paths = image_paths  # Store the list of image paths
        self.transform = transform  # Store the transform (if any)

    def __len__(self):
        """
        Returns the length of the dataset, i.e., the number of images.
        """
        return len(self.image_paths)  # Return the number of images

    def __getitem__(self, idx):
        """
        Allows indexing into the dataset to get an image.
        Params:
            idx: An index into the image list.
        Returns:
            img: A PIL image loaded from the disk and optionally transformed.
        """
        img_path = self.image_paths[idx]  # Get the path to the image at the given index
        img = Image.open(img_path).convert("RGB")  # Open the image and convert it to the RGB color space
        if self.transform:  # If a transform was provided,
            img = self.transform(img)  # apply it to the image
        return img  # Return the image


# Define the image transformations
transform = transforms.Compose([
    # Resize the image to 224x224 pixels
    transforms.Resize((224, 224)),
    
    # Convert the image (numpy array) to PyTorch tensor
    transforms.ToTensor(),
    
    # Normalize the image by setting its mean and standard deviation to the given values
    # These values are the means and standard deviations of the ImageNet dataset 
    # on which many pre-trained models are trained
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


# Define the list of image paths you want to process. 
# Replace the '' with the actual paths to your images. You can add multiple paths separated by commas.
image_paths = ['']

# Create the custom dataset using the list of image paths and the defined transformation.
# The ImageDataset will apply the transformation to each image when it is loaded.
dataset = ImageDataset(image_paths, transform=transform)

# Create the DataLoader, which allows you to load data in batches and shuffle the data.
# batch_size determines how many samples per batch to load.
# shuffle set to True means the data will be reshuffled at every epoch.
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Define the device. If CUDA (a GPU acceleration library) is available, use it. Otherwise, use the CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained ResNet50 model from torchvision.models.
# ResNet50 is a deep convolutional neural network model, widely used for image classification tasks.
model = models.resnet50(pretrained=True)


# Replace the final fully connected layer of the pre-trained model with a new fully connected layer.
# The output features of the new layer are 3 * 224 * 224, assuming the output is an image with dimensions [3, 224, 224].
# 'num_ftrs' holds the number of input features for the last layer of the model.
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 3 * 224 * 224)

# Set the model to training mode.
# This is an important step as certain layers such as Dropout or BatchNorm behave differently in training and evaluation modes.
model.train()

# Move the model to the device (GPU or CPU). 
# This step ensures that all computations and model parameters are on the specified device for increased efficiency.
model = model.to(device)

# Define the loss function as Mean Squared Error (MSE) Loss. 
# MSE Loss is a popular choice for regression problems.
criterion = nn.MSELoss()

# Define the optimizer as Stochastic Gradient Descent (SGD). 
# The parameters of the model are passed to the optimizer so that it knows which parameters to update. 
# Learning rate (lr) is set to 0.001 and momentum is set to 0.9. 
# Both of these are hyperparameters that might need to be tuned for different datasets or models.
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


# Start the training process
for epoch in range(2):  # loop over the entire dataset twice
    # Initialize running_loss to 0.0
    running_loss = 0.0
    # Enumerate over the dataloader which provides batches of images (inputs)
    for i, inputs in enumerate(dataloader):
        # Move the inputs to the device (GPU or CPU)
        inputs = inputs.to(device)

        # Reset the gradients of all optimized variables to zero. 
        # This is done because by default, gradients are accumulated in buffers
        optimizer.zero_grad()

        # Forward pass: compute predicted outputs by passing inputs to the model
        outputs = model(inputs)
        # Reshape the outputs to match the input image shape, [batch_size, 3, 224, 224]
        outputs = outputs.view(-1, 3, 224, 224)

        # Compute the loss. Here the loss is computed between the model's outputs and the original inputs 
        # since we're aiming for the model to regenerate the input image itself (for an autoencoder like model).
        loss = criterion(outputs, inputs)

        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # Perform a single optimization step (parameter update)
        optimizer.step()

        # Update running loss
        running_loss += loss.item()
        # Print loss statistics every 2000 mini-batches.
        if i % 2000 == 1999:  
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            # Reset running loss
            running_loss = 0.0


# Print out that we have finished the training process.
print('Finished Training')

# After training, we save the learned parameters of the model. 
# model.state_dict() contains the learned values of the weights and biases for all layers of the model.
# We store these parameters in a file 'fine_tuned_model.pth'. 
# This file can be later loaded to continue training or for inference without having to retrain the model.
torch.save(model.state_dict(), 'fine_tuned_model.pth')