# TRANSFER LEARNING FOR IMAGE CLASSIFICATION OF COSTA RICAN DISHES USING RESNET50

## Setup for Image Classification with Transfer Learning using ResNet50

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import os
import matplotlib.pyplot as plt

# transform for the train data
# Define transformations for training data:
# 1. RandomResizedCrop: Randomly crops and resizes images to 224x224 for consistency and data augmentation.
# 2. RandomHorizontalFlip: Horizontally flips images randomly for augmentation.
# 3. RandomRotation: Randomly rotates the images in a 15 degrees angle.
# 4. ColorJitter: Adjusts brightness and contrast for variety in image appearance.
# 5. ToTensor: Converts images to PyTorch tensors, the format required for model input.
# 6. Normalize: Normalizes images to have a specific mean and standard deviation, aligning with pre-trained model expectations.
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=25),
    transforms.ColorJitter(brightness=0.5, contrast=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# transform for the validation data
# Define transformations for validation data:
# 1. Resize: Increases the size of the image to 256x256 pixels.
# 2. CenterCrop: Crops the center part of the image to 224x224, ensuring it's the same size as the training images.
# 3. ToTensor: Converts the image to a PyTorch tensor, suitable for model input.
# 4. Normalize: Normalizes the image with the specified mean and standard deviation, matching the training data's normalization.
validation_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# load the datasets
path_traning_data = "../dataset/trainig-data"
path_validation_data = "../dataset/test-data"
train_dataset = ImageFolder(root=path_traning_data, transform=train_transforms)
validation_dataset = ImageFolder(root=path_validation_data, transform=validation_transforms)

# define the data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=False)

# load the pre-trained ResNet50 model
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# Get the number of input features for the final fully connected layer of the pre-trained model
num_ftrs = model.fc.in_features

# Replace the final fully connected layer with a new one tailored to the number of classes in the dataset
# This is necessary to adapt the pre-trained model for the specific classification task
model.fc = torch.nn.Linear(num_ftrs, len(train_dataset.classes))

# move the model to a GPU if it is available
# Check and set the device for Apple M1 (MPS)
#device = torch.device("cpu")
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model = model.to(device)

# define loss function and optimizer
#criterion = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Define loss function and optimizer with increased weight decay for L2 regularization
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# set numbers of epochs
num_epochs= 13

## Trainig phase and validation phase

In [None]:
training_losses = []
validation_losses = []
validation_accuracies = []
for epoch in range(num_epochs):

    # Set the model to training mode
    model.train()
    training_loss = 0.0

    # Iterate over training data
    for images, labels in train_loader:
        # Move images and labels to the device (GPU or CPU)
        images, labels = images.to(device), labels.to(device)

        # Reset gradients for new iteration
        optimizer.zero_grad()

        # Compute model output
        outputs = model(images)

        # Calculate loss between output and true labels
        loss = criterion(outputs, labels)

        # Backpropagate the error and update weights
        loss.backward()
        optimizer.step()

        # Accumulate loss over the batch
        training_loss += loss.item() * images.size(0)

    # Calculate average training loss for this epoch
    training_loss = training_loss / len(train_loader.dataset)

    # Validation phase
    model.eval() # Set model to evaluation mode
    validation_loss = 0.0
    correct_predictions = 0

    # No gradient update during validation to reduce memory usage
    with torch.no_grad():
        # Iterate over validation data
        for images, labels in validation_loader:

            # Move images and labels to the device
            images, labels = images.to(device), labels.to(device)

            # Compute model output
            outputs = model(images)

            # Calculate validation loss
            loss = criterion(outputs, labels)
            validation_loss += loss.item() * images.size(0)

            # Calculate the number of correct predictions
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()

    # Calculate average validation loss and accuracy for this epoch
    validation_loss = validation_loss / len(validation_loader.dataset)
    validation_accuracy = correct_predictions / len(validation_loader.dataset)

    # Store losses and accuracy for plotting
    training_losses.append(training_loss)
    validation_losses.append(validation_loss)
    validation_accuracies.append(validation_accuracy)

    # Print epoch summary
    print(f'Epoch: {epoch+1}/{num_epochs}, Training Loss: {training_loss:.4f}, Validation Loss: {validation_loss:.4f}, Validation Accuracy: {validation_accuracy:.4f}')

## Training Dynamics Visualization

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl

# to ignore warnings
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

# Adjust grid line style for transparency
mpl.rcParams['grid.alpha'] = 0.03  # Adjust grid transparency

# Define themed colors
training_loss_color = "#008080"
validation_loss_color = "#C0C0C0"
validation_accuracy_color = "#C0C0C0"

# Set the aesthetic style of the plots
sns.set_style("whitegrid")

plt.figure(figsize=(15, 5))

# Plot training and validation loss
plt.subplot(1, 3, 1)
sns.lineplot(x=range(num_epochs), y=training_losses, label='Training Loss', linewidth=3.5)
sns.lineplot(x=range(num_epochs), y=validation_losses, label='Validation Loss', linewidth=3.5)
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()



## Saving the model

In [None]:
torch.save(model.state_dict(), 'models-CNN/model.pth')

## Implementing ResNet50 model for Dish Recognition

In [None]:
# Load a pre-trained ResNet50 model
model = resnet50(pretrained=True)

# Get the number of input features in the final fully connected (fc) layer
num_ftrs = model.fc.in_features

# Replace the final fc layer with a new layer with outputs matching the number of classes in the dataset
model.fc = torch.nn.Linear(num_ftrs, len(train_dataset.classes))

# Load the model's saved weights
model.load_state_dict(torch.load('models-CNN/model.pth'))

# Set the model to evaluation mode
model.eval()

# Move the model to the appropriate device (GPU or CPU)
model = model.to(device)

from PIL import Image

# Prepare the image
# Define a function to preprocess the image
def process_image(image_path):
    # Define image transformations
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Open and transform the image
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0) # Add a batch dimension
    return image

# Make a prediction
idx2label = {0: 'arroz-con-pollo', 1: 'chifrijo', 2: 'gallo-pinto', 3: 'tamales'}
# Define a function to make a prediction on an image
def predict_image(image_path):
    # Process the image
    image = process_image(image_path)
    image = image.to(device)

    # Make a prediction
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
        predicted_class = idx2label[predicted.item()]

    return predicted_class

# Example of using the prediction function
image_path = 'your_image_path'
#image_path = '../dataset/dishes_example/IMG_20230610_172940.jpg'
predicted_class = predict_image(image_path)
print(f'Predicted class for Costa Rica dish: {predicted_class}')

# Save the predicted dish in a txt file
text_to_write = predicted_class
with open('predicted-dish.txt', 'w', encoding='utf-8') as file:
    file.write(text_to_write)