In [55]:
## code based on https://sidthoviti.com/fine-tuning-resnet50-pretrained-on-imagenet-for-CINIC-10/

In [56]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import random


In [57]:
import albumentations as A
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np


class CustomAugmentations:
    '''
    This class is used to apply custom augmentations to the images using the albumentations library
    Parameters:
    image: np.ndarray
        The image to be augmented
    p_dict: dict
        The dictionary containing the probability of each augmentation to be applied
    '''
    def __init__(self, p_dict):
        self.p_dict = p_dict
        self.mapping = {
            "flip": A.Flip,
            "transpose": A.Transpose,
            "gauss_noise": A.GaussNoise,
            "blur": A.OneOf,
            "shift_scale_rotate": A.ShiftScaleRotate,
            "distortion": A.OneOf,
            "brightness_contrast": A.OneOf,
            "hue_saturation_value": A.HueSaturationValue,
            "perspective": A.Perspective,
            "rotate": A.Rotate
        }
        if not self.__check_p_dict__():
            raise ValueError("The p_dict is not valid. Please check the values")
    
        transform = []
        for k, v in self.p_dict.items():
            if v > 0:
                if k == "blur":
                    transform.append(self.mapping[k]([
                        A.MotionBlur(p=1),
                        A.MedianBlur(blur_limit=3, p=1),
                        A.Blur(blur_limit=3, p=1)
                    ], p=v))
                elif k == "distortion":
                    transform.append(self.mapping[k]([
                        A.OpticalDistortion(p=1),
                        A.GridDistortion(p=1)
                    ], p=v))
                elif k == "brightness_contrast":
                    transform.append(self.mapping[k]([
                        A.CLAHE(clip_limit=2, p=1),
                        A.RandomBrightnessContrast(p=1)
                    ], p=v))
                else:
                    transform.append(self.mapping[k](p=v))
        self.transform = transform
    
    def __check_p_dict__(self):
        '''
        This method is used to check if the p_dict is valid

        Returns:
        bool
            True if the p_dict is valid, False otherwise
        '''
        for k, v in self.p_dict.items():
            if v < 0 or v > 1 or k not in self.mapping.keys():
                return False
        return True
    
    def augment(self):
        '''
        This method is used to apply the augmentations to the image

        Returns:
        np.ndarray
            The augmented image
        '''
        transform = self.__transform__()
        return transform(image=self.image)['image']

In [58]:
p_dict = {
    "flip": 0.2,
    "transpose": 0.0,
    "gauss_noise": 0.0,
    "blur": 0.0,
    "shift_scale_rotate": 0.0,
    "distortion": 0.0,
    "brightness_contrast": 0.0,
    "hue_saturation_value": 0.0,
    "perspective": 0.0,
    "rotate": 0.1
}

In [59]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classes = []

In [60]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

train_path = "../data/CINIC10/train"
test_path = "../data/CINIC10/test"
val_path = "../data/CINIC10/valid"

def load_dataset():
    # Set dataset path
    dataset_path = '../data/CINIC10/'

    augmentation = CustomAugmentations(p_dict).transform
    print(augmentation)
    transform_train = transforms.Compose([
        transforms.ToTensor(),  # Convert images to PyTorch tensors
    ])
    transform_test = transforms.Compose(
    [
        transforms.ToTensor(),  # Convert images to PyTorch tensors
    ]
)

    train_dataset = ImageFolder(root=train_path, transform=transform_train)
    test_dataset = ImageFolder(root=test_path, transform=transform_test)
    # val_dataset = ImageFolder(root=val_path, transform=transform)

   

    # Define dataloaders
    batch_size = 64
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    # Class names for CINIC-10 dataset
    classes = train_dataset.classes

    return train_dataset, train_loader, test_dataset, test_loader, classes

In [61]:
def train(model, trainloader, criterion, optimizer, device):
    train_loss = 0.0
    train_total = 0
    train_correct = 0

    # Switch to train mode
    model.train()

    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update training loss
        train_loss += loss.item() * inputs.size(0)

        # Compute training accuracy
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    # Compute average training loss and accuracy
    train_loss = train_loss / len(trainloader.dataset)
    train_accuracy = 100.0 * train_correct / train_total

    return model, train_loss, train_accuracy

In [62]:
def test(model, testloader, criterion, device):
    test_loss = 0.0
    test_total = 0
    test_correct = 0
    class_correct = [0] * len(testloader.dataset.classes)
    class_total = [0] * len(testloader.dataset.classes)

    # Switch to evaluation mode
    model.eval()

    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Update test loss
            test_loss += loss.item() * inputs.size(0)

            # Compute test accuracy
            _, predicted = torch.max(outputs, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

            # Compute accuracy per class
            for i in range(len(labels)):
                label = labels[i]
                prediction = predicted[i]
                if label == prediction:
                    class_correct[label] += 1
                class_total[label] += 1

    # Compute average test loss and accuracy
    test_loss = test_loss / len(testloader.dataset)
    test_accuracy = 100.0 * test_correct / test_total

    # Compute accuracy per class
    class_accuracy = [100.0 * class_correct[i] / class_total[i] for i in range(len(class_correct))]

    return test_loss, test_accuracy, class_accuracy


In [63]:
def train_epochs(model, trainloader, testloader, criterion, optimizer, device, num_epochs, save_interval=5):
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        model, train_loss, train_accuracy = train(model, trainloader, criterion, optimizer, device)
        test_loss, test_accuracy = test(model, testloader, criterion, device)

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)

        print(f'Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.2f}%')
        print(f'Test Loss: {test_loss:.4f} - Test Accuracy: {test_accuracy:.2f}%')
        print()

        if (epoch + 1) % save_interval == 0:
          # Save the model and variables
          torch.save(model.state_dict(), f'resnet50_CINIC10_{epoch+1}.pth')
          checkpoint = {
              'epoch': epoch + 1,
              'train_losses': train_losses,
              'train_accuracies': train_accuracies,
              'test_losses': test_losses,
              'test_accuracies': test_accuracies,
              'classes': classes
          }
          torch.save(checkpoint, f'resnet50_CINIC10_variables_{epoch+1}.pth')

    return model, train_losses, train_accuracies, test_losses, test_accuracies


In [64]:
def plot_loss(train_losses, test_losses):
    plt.figure()
    plt.plot(range(len(train_losses)), train_losses, label='Training Loss')
    plt.plot(range(len(test_losses)), test_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('loss_plot.png')
    plt.show()

def plot_accuracy(train_accuracies, test_accuracies):
    plt.figure()
    plt.plot(range(len(train_accuracies)), train_accuracies, label='Training Accuracy')
    plt.plot(range(len(test_accuracies)), test_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig('accuracy_plot.png')
    plt.show()

def plot_image(dataset, model, classes):
    idx = random.randint(0, len(dataset))
    label = dataset[idx][1]
    img = dataset[idx][0].unsqueeze(0).to(device)  # Move the input image tensor to the GPU
    model.eval()
    #model.to(device)  # Move the model to the GPU
    output = model(img)
    _, predicted = torch.max(output.data, 1)
    # Convert the image and show it
    img = img.squeeze().permute(1, 2, 0).cpu()  # Move the image tensor back to the CPU and adjust dimensions
    plt.imshow(img)
    plt.axis('off')
    plt.title(f'Predicted: {classes[predicted]}, True: {classes[label]}')
    plt.savefig('predicted_image.png')
    plt.show()
    print("Predicted label: ", classes[predicted[0].item()])
    print("Actual label: ", classes[label])



In [65]:
# Flag to control whether to run training or use saved fine-tuned model.
train_model = True

# Set random seed for reproducibility
random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

# Number of classes
num_classes = 10

# Import ResNet50 model pretrained on ImageNet
model = models.resnet50(pretrained=True)
# print("Network before modifying conv1:")
# print(model)

#Modify conv1 to suit CINIC-10
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

# Modify the final fully connected layer according to the number of classes
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
# print("Network after modifying conv1:")
# print(model)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Load the dataset
trainset, trainloader, testset, testloader, classes = load_dataset()

if train_model:
    # Train the model for 20 epochs, saving every 5 epochs
    num_epochs = 60
    save_interval = 5
    model, train_losses, train_accuracies, test_losses, test_accuracies = train_epochs(
        model, trainloader, testloader, criterion, optimizer, device,
        num_epochs, save_interval)

    # Save the final trained model
    torch.save(model.state_dict(), f'resnet50_CINIC10_final_model_epochs_{num_epochs}.pth')

    # Plot and save the loss and accuracy plots
    plot_loss(train_losses, test_losses)
    plot_accuracy(train_accuracies, test_accuracies)
else:
    # Load the pre-trained model
    model.load_state_dict(torch.load('resnet50_CINIC10_final_model_epochs_50.pth'))
    # Load the variables
    checkpoint = torch.load("resnet50_CINIC10_variables.pth")
    epoch = checkpoint['epoch']
    train_losses = checkpoint['train_losses']
    train_accuracies = checkpoint['train_accuracies']
    test_losses = checkpoint['test_losses']
    test_accuracies = checkpoint['test_accuracies']
    classes = checkpoint['classes']
    model.to(device)
    model.eval()

# Plot and save an example image
plot_image(testset, model, classes)

[Flip(always_apply=False, p=0.2), Rotate(always_apply=False, p=0.1, limit=(-90, 90), interpolation=1, border_mode=4, value=None, mask_value=None, rotate_method='largest_box', crop_border=False)]
Epoch 1/60
