<a href="https://colab.research.google.com/github/aleks-tu/XAI4CV-Projekt/blob/main/Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Base: https://medium.com/@engr.akhtar.awan/how-to-fine-tune-the-resnet-50-model-on-your-target-dataset-using-pytorch-187abdb9beeb

import torch
import torchvision
from torchvision import datasets, transforms
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
import torch.optim as optim

In [2]:
# ----------------------------
# Download Dataset
# ----------------------------

# Dataset page: https://www.kaggle.com/datasets/kritikseth/fruit-and-vegetable-image-recognition?resource=download
# 36 classes

import kagglehub

# Download latest version
path = kagglehub.dataset_download("kritikseth/fruit-and-vegetable-image-recognition")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/fruit-and-vegetable-image-recognition


In [3]:
# ----------------------------
# Data Preparation
# ----------------------------

# Define the transformation
# transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.485, 0.456, 0.406],
#         std=[0.229, 0.224, 0.225]
#     )
# ])

# See https://docs.pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
# Load transformations:
weights = ResNet50_Weights.IMAGENET1K_V2
transform = weights.transforms()

# Define a specific image loader to deal with images containing transparent pixels (Make them white)
from PIL import Image
from typing import Union
from pathlib import Path
def rgba_pil_loader(path: Union[str, Path]) -> Image.Image:
  # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)

  # Some images are corrupt
  corrupt_images_replacement = {
      "/kaggle/input/fruit-and-vegetable-image-recognition/train/bell pepper/Image_56.jpg": "/kaggle/input/fruit-and-vegetable-image-recognition/train/bell pepper/Image_55.jpg",
      "/kaggle/input/fruit-and-vegetable-image-recognition/train/potato/Image_69.png": "/kaggle/input/fruit-and-vegetable-image-recognition/train/potato/Image_68.png",
      "/kaggle/input/fruit-and-vegetable-image-recognition/train/carrot/Image_68.png": "/kaggle/input/fruit-and-vegetable-image-recognition/train/carrot/Image_67.png",
      "/kaggle/input/fruit-and-vegetable-image-recognition/train/soy beans/Image_13.png": "/kaggle/input/fruit-and-vegetable-image-recognition/train/soy beans/Image_12.png",
      "/kaggle/input/fruit-and-vegetable-image-recognition/train/paprika/Image_26.png": "/kaggle/input/fruit-and-vegetable-image-recognition/train/paprika/Image_25.png",
  }

#   with open(path, "rb") as f:
#       img = Image.open(f)
#       if (img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info)) and len(img.split())>3:
#       #   print("TRANSPARENT!")
#       #   alpha = img.convert("RGBA").split()[-1]
#       #   print("Converted")
#       #   bg = Image.new("RGB", img.size, (255, 255, 255))
#       #   print("made bg")
#       #   bg.paste(img, mask=alpha)
#       #   print("pasted img")
#       #   return bg

#         # print("TRANSPARENT")
#         # img = img.convert("RGBA")
#         # datas = img.getdata()
#         # newData = []
#         # for item in datas:
#         #     if item[0] == 255 and item[1] == 255 and item[2] == 255:
#         #         newData.append((255, 255, 255, 0))
#         #     else:
#         #         newData.append(item)
#         # img.putdata(newData)

#         background = Image.new("RGB", img.size, (255, 255, 255))
#         # print("here")
#         print(f"length: {len(img.split())}")
#         background.paste(img, mask=img.split()[3]) # 3 is the alpha channel

#         # Print value of image at pixel 50, 50:
#         # print(img.getpixel((50, 50)))

#         # print("RETURNING MODIFIED IMG")
#         return img.convert('RGB')
#       elif len(img.split())<3:
#         print("CORRUPT IMAGE -------------------------")
#         print(path)
#         return Image.new("RGB", img.size, (255,255,255))
#       else:
#         # print("NOT TRANSPARENT")
#         return img.convert("RGB")  # Ensure all images are RGB



# # Load the data
# train_data = torchvision.datasets.ImageFolder(root="/kaggle/input/fruit-and-vegetable-image-recognition/train", loader=rgba_pil_loader, transform=transform)
# test_data = torchvision.datasets.ImageFolder(root="/kaggle/input/fruit-and-vegetable-image-recognition/validation", loader=rgba_pil_loader, transform=transform)

# Load the data
train_data = torchvision.datasets.ImageFolder(root="/kaggle/input/fruit-and-vegetable-image-recognition/train", transform=transform)
test_data = torchvision.datasets.ImageFolder(root="/kaggle/input/fruit-and-vegetable-image-recognition/validation", transform=transform)

# Define the dataloaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

In [4]:
# ----------------------------
# Modify Model
# ----------------------------

# Define the model
model = resnet50(weights=weights)

# # Freeze all parameters:
# for param in model.parameters():
#     param.requires_grad = False

# Freeze all layers but the fourth (last) sequential block
for name, param in model.named_parameters():
  # print(f"{name}:")
  if not 'layer4' in name:
      param.requires_grad = False
  # print(param.requires_grad)

# Replace the last layer
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(train_data.classes))
print(f"Number of classes: {len(train_data.classes)}")

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 189MB/s]


Number of classes: 36


In [5]:
# ----------------------------
# Prepare Training
# ----------------------------

# import warnings
# warnings.filterwarnings('ignore', '.*UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images',)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Move the model to the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# ----------------------------
# Train
# ----------------------------

# Define the number of epochs
num_epochs = 10

# Train the model
for epoch in range(num_epochs):
    # Train the model on the training set
    model.train()
    train_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        # Move the data to the device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Update the training loss
        train_loss += loss.item() * inputs.size(0)

    # Evaluate the model on the test set
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(test_loader):
            # Move the data to the device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Update the test loss and accuracy
            test_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            test_acc += torch.sum(preds == labels.data)

    # Print the training and test loss and accuracy
    train_loss /= len(train_data)
    test_loss /= len(test_data)
    test_acc = test_acc.double() / len(test_data)
    print(f"Epoch [{epoch + 1}/{num_epochs}] Train Loss: {train_loss:.4f} Test Loss: {test_loss:.4f} Test Acc: {test_acc:.4f}")


torch.save(model.state_dict(), "fine-tuned-resnet.pt")



Epoch [1/10] Train Loss: 3.3280 Test Loss: 2.8332 Test Acc: 0.6980




Epoch [2/10] Train Loss: 2.4415 Test Loss: 1.5215 Test Acc: 0.8063




Epoch [3/10] Train Loss: 1.4620 Test Loss: 0.7536 Test Acc: 0.8632




Epoch [4/10] Train Loss: 0.9428 Test Loss: 0.4841 Test Acc: 0.8917




Epoch [5/10] Train Loss: 0.7214 Test Loss: 0.3698 Test Acc: 0.9060




Epoch [6/10] Train Loss: 0.5970 Test Loss: 0.2978 Test Acc: 0.9117




Epoch [7/10] Train Loss: 0.5031 Test Loss: 0.2674 Test Acc: 0.9174




Epoch [8/10] Train Loss: 0.4360 Test Loss: 0.2392 Test Acc: 0.9231




Epoch [9/10] Train Loss: 0.3891 Test Loss: 0.2130 Test Acc: 0.9259




Epoch [10/10] Train Loss: 0.3608 Test Loss: 0.2028 Test Acc: 0.9288


All layers:
Epoch [9/40] Train Loss: 0.2780 Test Loss: 0.1734 Test Acc: 0.9487
Epoch [11/40] Train Loss: 0.2054 Test Loss: 0.1482 Test Acc: 0.9487

Training only last layer:
Epoch [15/40] Train Loss: 0.8744 Test Loss: 0.6045 Test Acc: 0.8832
Epoch [16/40] Train Loss: 0.8436 Test Loss: 0.5920 Test Acc: 0.8917
Epoch [17/40] Train Loss: 0.8195 Test Loss: 0.5519 Test Acc: 0.9060
Epoch [18/40] Train Loss: 0.7959 Test Loss: 0.5375 Test Acc: 0.8974

Training last sequential block and FC layer:
Epoch [8/10] Train Loss: 0.4360 Test Loss: 0.2392 Test Acc: 0.9231
Epoch [9/10] Train Loss: 0.3891 Test Loss: 0.2130 Test Acc: 0.9259
Epoch [10/10] Train Loss: 0.3608 Test Loss: 0.2028 Test Acc: 0.9288


