# Steps to Implement Transfer Learning for Image Classification in PyTorch

Transfer learning for image classification is essentially reusing a pre-trained neural network to improve the result on a different dataset. Follow the steps to implement Transfer Learning for Image Classification.

1. Choose a pre-trained model (ResNet, VGG, etc.) based on your task.
2. Modify the model by potentially replacing the final classification layer to match the number of classes in your new dataset.
3. Freeze the pre-trained layers (make their weights non-trainable) to prevent them from being updated during training on the new dataset. This is especially useful when you have a small dataset.
4. Preprocess your data, including resizing images and normalization.
5. Optionally, perform data augmentation to increase the size and diversity of your dataset.
6. Define the new model architecture by adding the new classifier on top of the pre-trained model.
7. Compile the model by specifying the loss function, optimizer, and metrics.
8. Train the model on your new dataset. Freezing the pre-trained layers might require fewer training epochs compared to training from scratch.
9. Fine-tuning: You can further train the model by unfreezing some or all of the pre-trained layers.
10. Evaluate the model’s performance on a validation or test dataset to assess its accuracy and generalization capabilities.


In [None]:
#Import neccessary libraries

import torch
from torchsummary import summary
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Normalize
from torchvision.datasets import MNIST, CIFAR10
from torch.utils.data import DataLoader

In [None]:
import torchvision.models as models

#### Step 1: Choose a pre-trained model

In [None]:
# Load the pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)

#### Step 2: Modify the pre-trained model

In [None]:
class ModifiedResNet(nn.Module):
    def __init__(self):
        super(ModifiedResNet, self).__init__()
        self.resnet = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
        num_classes = 10  # MNIST has 10 classes
        self.resnet.fc = nn.Linear(pretrained_model.fc.in_features, num_classes)  # Change the final fully connected layer for 10 classes

    def forward(self, x):
        return self.resnet(x)

model = ModifiedResNet()

#### Freeze the pre-trained models' weights

In [None]:
for param in model.parameters():
    param.requires_grad = False

#### Step 4: Data preprocessing

In [None]:
from torchvision.transforms.functional import pad

transform_train = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.Grayscale(num_output_channels=3),  # Convert to RGB format
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize as before
])

transform_test = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to RGB format
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

#### Step 5: Data augmentation (optional)

In [None]:
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

#### Step 6: Define the model architecture
The code underneath is the same that was used in the creation of a ResNet model for transfer learning,here’s a breakdown:

1. `super(CustomResNet, self).__init__()`: Particularly in this particular line, there is meticulous care taken where the constructor of the parent class (the parent class here is `nn.Module`) is invoked to iteritialize the `CustomResNet` class.

2. `self.resnet = torch.hub.load(‘pytorch/vision’, ‘resnet50’, pretrained=True)`: This implies that we incorporate the model through the `torch.hub.load(‘pytorch/vision:v0.1’, ‘resnet50’)`. The `pretrained` argument depicts that the mode is initialized with the precondition that we must give the files loaded with the pre-trained weights.

3. `self.features = nn.Sequential(*list(pretrained_model.children())[:((fn.seq( [nn.Sequential( io.layer.(nn.select(-1)))] ))`: Meanwhile, the layer incl. the last layer (partially except last densely connected layer of the ResNet) is sequentialized. It is achieved by developing the framework from all the kids (which are filters of pretrained_models) and using them as parameters for nn.Sequential category of objects.
4. `self.classifier = nn.Linear(pretrained_model.fc.in_features, 10)`: As a result, it generates a new fully connected layer (`nn.Linear`) which is titled as `classifier` that has the inputs being equal to the number of features in output of the final fully connected layer (`pretrained_model.fc.in_features`) of the ResNet model and the outputs that are ten (assuming the model takes ten classes to classify).

In general, an architecture model is constructed in the design of ResNet transfer learning where the initial layer is pre-trained and having been transferred to the ResNet feature layer the new layer is added for doing the specific classification task at the end.

In [None]:
class CustomResNet(nn.Module):
    def __init__(self, pretrained_model):
        super(CustomResNet, self).__init__()
        self.resnet = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
        self.features = nn.Sequential(*list(pretrained_model.children())[:-1])
        self.classifier = nn.Linear(pretrained_model.fc.in_features, 10)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

model = CustomResNet(pretrained_model)

#### Step 7: Compile the model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
summary(model,(3,32,32))

#### Train the model

In [None]:
# Enable gradient computation for the last few layers
for param in model.resnet.layer4.parameters():
    param.requires_grad = True

# Train the model
model.train()


#### Fine-tune the model

In [None]:
# Fine-tuning
num_epochs = 10
train_losses = []
train_correct = 0
train_total = 0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    train_losses.append(running_loss / len(train_loader.dataset))
    train_accuracy = train_correct / train_total
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader)}')

print(f'Finished fine-tuning with {train_accuracy} accuracy')

#### Evaluate the model

In [None]:
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)  # Don't shuffle test data

# Initialize variables for tracking performance
test_losses = []
correct = 0
total = 0

# Loop through epochs for testing
for epoch in range(num_epochs):
    with torch.no_grad():
        running_loss = 0.0

        # Evaluate the model on the test set
        for images, labels in test_loader:
            # Forward pass (no need for gradients during testing)
            outputs = model(images)

            # Calculate loss (assuming your loss function is defined)
            loss = criterion(outputs, labels)

            # Update running loss
            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)  # Get the index of the maximum value
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        # Calculate average loss for the epoch
        test_loss = running_loss / len(test_loader.dataset)
        test_losses.append(test_loss)

    # Print epoch-wise performance (optional)
    print(f'Epoch {epoch+1} - Test Loss: {test_loss:.4f}')

# Calculate and print overall test accuracy
test_accuracy = correct / total
print(f'Accuracy of the model on the test set: {test_accuracy:.4f}')