In [None]:
# Install and authenticate Kaggle API
!pip install -q kaggle

# Upload your kaggle.json (API token file)
from google.colab import files
files.upload()   # Upload kaggle.json from your Kaggle account

# Make a directory for Kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


In [None]:
# Download the dataset from Kaggle
!kaggle datasets download -d akhatova/pcb-defects

# Unzip dataset
!unzip -q pcb-defects.zip -d ./pcb_defects


In [None]:
#Importing the standard libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import patches
import seaborn as sns
import os
import random
import re
import shutil
sns.set_style('darkgrid')
sns.set_palette('pastel')

import warnings
warnings.filterwarnings('ignore')


# **Understanding the dataset**

In [None]:
#Defining the input
input_dir= "pcb_defects/PCB_DATASET"
os.listdir(input_dir)

*  **PCB USED folder**- contains the 12 template images we used in the dataset.
* **images folder**- contains the PCB images subclassed into different types.
* **rotation folder**- contains the rotated PCB images subclassed into different types as well as rotation angle
* **annotations folder**- contains the annotations for bounding box of each images





**Analyzing the PCB USED folder**

In [None]:
template_dir=os.path.join(input_dir,'PCB_USED')
template_dir

In [None]:
#Creating function to visualize images:
def visualize_img(dir_name,nos_):
    k=1
    plt.figure(figsize=(8,(nos_//2)*6))
    for filename in os.listdir(dir_name)[0:nos_]:
        if filename.lower().endswith(('.jpg','.jpeg','.png')):
            ax=plt.subplot((nos_//2)+1,2,k)
            img_path=os.path.join(dir_name,filename)
            img=plt.imread(img_path)
            ax.imshow(img)
            ax.set_xlabel(filename)
            ax.grid(False)
            ax.set_xticks([])
            ax.set_yticks([])
            k+=1
    plt.tight_layout()
    plt.show()

In [None]:
visualize_img(template_dir,nos_=4)

print(f'No of template images:{len(os.listdir(template_dir))}')

**Analyzing the images folder**

In [None]:
#Defining the image directory
img_dir=os.path.join(input_dir,'images')

#Listing the types of defects
os.listdir(img_dir)
types_defect=os.listdir(os.path.join(input_dir,'images'))
types_defect

In [None]:
#Creating an image path list for ready refernce
img_path_list=[]
#Creating img_path list
for sub_cat in types_defect:
    for file in os.listdir(os.path.join(img_dir,sub_cat)):

        img_path_list.append(os.path.join(img_dir,sub_cat,file))


In [None]:
#Vizualizing defect images for each type
df_defect=pd.DataFrame(columns=['No of defect']) #dataframe for counting no of defects
for sub_cat in types_defect:
    visualize_img(os.path.join(img_dir,sub_cat),nos_=2)  #Visualizing 2 types of defect for each type

    print(f'No of {sub_cat} images:{len(os.listdir(os.path.join(img_dir,sub_cat)))}')

    df_defect.loc[sub_cat]=len(os.listdir(os.path.join(img_dir,sub_cat)))


In [None]:
#No. of defects by type
df_defect

**Analyzing rotated folder**

In [None]:
rotated_dir=os.path.join(input_dir,'rotation')
os.listdir(rotated_dir)

In [None]:
rotated_angle_list=[j for j in os.listdir(rotated_dir) if j.endswith('.txt')]
rotated_angle_list

In [None]:
types_defect_rotated=[j for j in os.listdir(rotated_dir) if j.endswith('.txt')==False]
types_defect_rotated

In [None]:
#Vizualizing rotated defects
df_defect_rotated=pd.DataFrame(columns=['No of defect'])
for sub_cat in types_defect_rotated:
    visualize_img(os.path.join(rotated_dir,sub_cat),nos_=2)

    print(f'No of {sub_cat} images:{len(os.listdir(os.path.join(rotated_dir,sub_cat)))}')

    df_defect_rotated.loc[sub_cat]=len(os.listdir(os.path.join(rotated_dir,sub_cat)))

In [None]:
#No. of defects by type  rotated
df_defect_rotated

In [None]:
#Reading the rotation text files

df_rotation_angle=pd.DataFrame(columns=['Line','Angle'])
for filename in rotated_angle_list:
    with open(os.path.join(rotated_dir,filename),'r') as f:
        lines=f.readlines()
        for line in lines:
            text,angle=line.split()
            df_rotation_angle=pd.concat([df_rotation_angle,pd.DataFrame({'Line':[text],'Angle':[angle]})],axis=0)


In [None]:
df_rotation_angle

**Analyzing the annotate folder**

In [None]:
annote_dir=os.path.join(input_dir,'Annotations')
annote_dir

In [None]:
type_annot=os.listdir(annote_dir)
type_annot

In [None]:
df_annot_nos=pd.DataFrame(columns=['No of annotations'])
#Checking the length of annotation itms
for i in type_annot:
    df_annot_nos.loc[i]=len(os.listdir(os.path.join(annote_dir,i)))
df_annot_nos

We see that we have an annotation file for each image

In [None]:
#Checking the type of files
file_list=os.listdir(os.path.join(annote_dir,'Mouse_bite'))
file_list[0:5]

We see that all files are in XML format, so we have to parse the data

In [None]:
#importing xml ET to parse xml file
import xml.etree.ElementTree as ET



In [None]:
tree = ET.parse(os.path.join(os.path.join(annote_dir,'Mouse_bite'),'01_mouse_bite_11.xml'))
root = tree.getroot()

In [None]:
#getting the structure of XML file
print(ET.tostring(root, encoding='utf8').decode('utf8'))

In [None]:
#Parsing XML to return Bounding box dimensions
def parse_xml(xml_file):

    data=[]

    tree = ET.parse(xml_file)
    root = tree.getroot()

    filename = root.find('filename').text
    width = int(root.find('size/width').text)
    height = int(root.find('size/height').text)
    for obj in root.findall('object'):
        name = obj.find('name').text
        xmin = int(obj.find('bndbox/xmin').text)
        ymin = int(obj.find('bndbox/ymin').text)
        xmax = int(obj.find('bndbox/xmax').text)
        ymax = int(obj.find('bndbox/ymax').text)

        data.append({
            'filename': filename,
            'width': width,
            'height': height,
            'class': name,
            'xmin': xmin,
            'ymin': ymin,
            'xmax': xmax,
            'ymax': ymax
        })

    return data

In [None]:
#Retrieving data for all files
data=[]
all_data=[]

for x in type_annot:
    for file in os.listdir(os.path.join(annote_dir,x)):
        xml_file_path=os.path.join(os.path.join(annote_dir,x),file)
        data=parse_xml(xml_file_path)
        all_data.extend(data)


In [None]:
#Creating a dataframe to store the annotations
df_annot=pd.DataFrame(all_data)
df_annot

In [None]:
import os
import shutil

# Path to the main images folder in Colab (already uploaded)
img_dir = "/content/pcb_defects/PCB_DATASET/images"  # <-- Correct path

# Define the source directories for each category inside img_dir
source_dirs = [
    os.path.join(img_dir, "Missing_hole"),
    os.path.join(img_dir, "Mouse_bite"),
    os.path.join(img_dir, "Open_circuit"),
    os.path.join(img_dir, "Short"),
    os.path.join(img_dir, "Spur"),
    os.path.join(img_dir, "Spurious_copper")
]

# Define the destination directory for the combined images (Writable location in Colab)
destination_dir = "/content/images_combined"

# Create the destination directory if it doesn't exist
os.makedirs(destination_dir, exist_ok=True)

# Loop through each source directory and copy all files to the destination
for source_dir in source_dirs:
    if os.path.exists(source_dir):
        # Get all files in the current directory
        files = os.listdir(source_dir)

        # Copy each file to the destination directory
        for file in files:
            file_path = os.path.join(source_dir, file)
            if os.path.isfile(file_path):
                shutil.copy(file_path, destination_dir)
    else:
        print(f"Directory {source_dir} does not exist.")

# Now check how many files are in the destination folder
files_in_combined = os.listdir(destination_dir)
print(f"Number of files copied: {len(files_in_combined)}")


In [None]:
destination_dir

In [None]:
#Visualizing the no of defects in each pcb
df_multiple_defects=pd.DataFrame(df_annot['filename'].value_counts())
sns.countplot(df_multiple_defects,x='count')
plt.xlabel('No of defects in one PCB')

In [None]:
#Defining a function to view image along with bounding box

def draw_bounding_boxes(image_path, bounding_boxes,annotation):
    """
    Draws multiple bounding boxes on an image using Matplotlib.

    Args:
        image_path: The path to the image file.
        bounding_boxes: A list of bounding boxes, each represented as a tuple or list containing
                       (min_x, min_y, max_x, max_y).
    """

    # Load the image
    img = plt.imread(image_path)

    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(15,10))

    # Display the image
    ax.imshow(img)

    # Draw each bounding box
    for bbox in bounding_boxes:
        min_x, min_y, max_x, max_y = bbox
        width = max_x - min_x
        height = max_y - min_y
        rect = patches.Rectangle((min_x, min_y), width, height, linewidth=1, edgecolor='red', facecolor='none')
        ax.add_patch(rect)

        # Calculate the centroid of the bounding box
        centroid_x = (min_x + max_x) / 2
        centroid_y = (min_y + max_y) / 2

        # Add the annotation to the centroid
        ax.annotate( annotation,(centroid_x,centroid_y),(max_x+20,max_y+20),
            fontsize=10,color='white',
            horizontalalignment='right', verticalalignment='top')

    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    # Show the plot
    plt.show()


In [None]:
#Getting filename from filepath
filepath=img_path_list[0]
filename=re.sub(r'.+/([\w_]+\.jpg)',r'\1',filepath)
filename

In [None]:
def visualize_annotations(list_image_path, df):
    for i in list_image_path:
        filepath = i
        filename = re.sub(r'.+/([\w_]+\.jpg)', r'\1', filepath)
        df_selected = df[df['filename'] == filename]
        width  = df_selected['width'].values
        height = df_selected['height'].values

        # check if 'class' or 'class_name' exists
        if 'class_name' in df_selected.columns:
            class_name = df_selected['class_name'].values
        elif 'class' in df_selected.columns:
            class_name = df_selected['class'].values
        else:
            raise KeyError("No class column found in dataframe")

        xmin = df_selected['xmin'].values
        ymin = df_selected['ymin'].values
        xmax = df_selected['xmax'].values
        ymax = df_selected['ymax'].values

        bbox = zip(xmin, ymin, xmax, ymax)
        draw_bounding_boxes(filepath, bbox, class_name[0])


In [None]:
image_path_shuffle=img_path_list
random.shuffle(image_path_shuffle)

visualize_annotations(image_path_shuffle[0:5],df_annot)

In [None]:
# ===================================================================
# 1. IMPORTING LIBRARIES AND CONFIGURATION
# ===================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms, utils

from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import time
import copy
import os

# Basic configuration
print("PyTorch Version:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# ===================================================================
# 2. DATA PREPARATION (using df_annot and img_dir)
# ===================================================================
# ASSUMING THAT 'df_annot' (DataFrame) AND 'img_dir' (str) ALREADY EXIST.

img_dir = destination_dir

# To avoid conflicts with the Python keyword 'class', we rename the column
if 'class' in df_annot.columns:
    df_annot = df_annot.rename(columns={'class': 'class_name'})

print("Preview of the provided annotation DataFrame:")
print(df_annot.head())

# Get the list of classes and number of classes
class_names = sorted(df_annot['class_name'].unique())
num_classes = len(class_names)
print(f"\n{num_classes} classes detected in the DataFrame: {class_names}")

# --- Train/Validation/Test split based on FILE NAMES ---
# This is CRUCIAL to avoid data leakage.
unique_filenames = df_annot['filename'].unique()
train_files, test_val_files = train_test_split(unique_filenames, test_size=0.3, random_state=42)
val_files, test_files = train_test_split(test_val_files, test_size=0.5, random_state=42) # 0.3 * 0.5 = 0.15

# Create DataFrames for each dataset
train_df = df_annot[df_annot['filename'].isin(train_files)].reset_index(drop=True)
val_df = df_annot[df_annot['filename'].isin(val_files)].reset_index(drop=True)
test_df = df_annot[df_annot['filename'].isin(test_files)].reset_index(drop=True)

dataset_sizes = {'train': len(train_df), 'val': len(val_df), 'test': len(test_df)}
print(f"Number of defects - Training: {dataset_sizes['train']}, Validation: {dataset_sizes['val']}, Test: {dataset_sizes['test']}")

# --- Data Augmentation and Normalization ---
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)), # Standardize the size of cropped patches
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# --- Custom Dataset Class for Cropping Images ---
class PCBCropDataset(Dataset):
    def __init__(self, dataframe, image_dir, class_names, transform=None):
        self.df = dataframe
        self.image_dir = image_dir
        self.transform = transform
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(class_names)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_filename = row['filename']
        box = (row['xmin'], row['ymin'], row['xmax'], row['ymax'])
        label_idx = self.class_to_idx[row['class_name']]

        try:
            img_path = os.path.join(self.image_dir, img_filename)
            image = Image.open(img_path).convert('RGB')
            cropped_image = image.crop(box)
        except FileNotFoundError:
            print(f"Warning: File not found {img_path}. Returning an empty tensor.")
            return torch.zeros((3, 224, 224)), -1 # Handle error case

        if self.transform:
            cropped_image = self.transform(cropped_image)

        return cropped_image, label_idx

# --- Creating Datasets and DataLoaders ---
train_dataset = PCBCropDataset(train_df, img_dir, class_names, transform=data_transforms['train'])
val_dataset = PCBCropDataset(val_df, img_dir, class_names, transform=data_transforms['val'])
test_dataset = PCBCropDataset(test_df, img_dir, class_names, transform=data_transforms['val'])

# Use a smaller batch size to start and avoid memory errors
batch_size = 16

dataloaders = {
    'train': DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count()),
    'val': DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count()),
    'test': DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count())
}

# --- Visualizing a batch of cropped defect patches ---
print("\nVisualizing a batch of cropped defect patches...")
inputs, classes_idx = next(iter(dataloaders['train']))
out = utils.make_grid(inputs)

# Reverse normalization for display
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean; inp = np.clip(inp, 0, 1)
    plt.figure(figsize=(15, 8)); plt.imshow(inp)
    if title is not None: plt.title(title)
    plt.axis('off'); plt.show()

imshow(out, title=[class_names[x] for x in classes_idx])


# ===================================================================
# 3. PRETRAINED MODEL (ResNet18 - most performant approach)
# ===================================================================
def get_pretrained_model(num_classes):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    # Freeze the weights of the pretrained layers
    for param in model.parameters():
        param.requires_grad = False

    # Replace the final layer to adapt it to our problem
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)

    return model.to(device)



# ===================================================================
# 4. TRAINING AND VALIDATION FUNCTION
# ===================================================================
def train_model(model, criterion, optimizer, num_epochs=25, patience=5):
    since = time.time()
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}' + ' | ' + '-'*10)
        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()
            running_loss, running_corrects = 0.0, 0

            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    if phase == 'train':
                        loss.backward(); optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print(f'{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())

            if phase == 'val':
                if epoch_loss < best_loss:
                    print(f"Validation loss improved ({best_loss:.4f} -> {epoch_loss:.4f}). Saving model...")
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"\nEarly stopping triggered after {patience} epochs with no improvement.")
            break
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed//60:.0f}m {time_elapsed%60:.0f}s')
    print(f'Best Validation Loss: {best_loss:4f}')

    model.load_state_dict(best_model_wts)
    return model, history


# ===================================================================
# 5. TRAINING EXECUTION
# ===================================================================
# Clear GPU cache to ensure memory is free
if torch.cuda.is_available():
    #torch.cuda.empty_cache()
    pass

# Instantiate the model
resnet_model = get_pretrained_model(num_classes)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
# Only optimize parameters of the new layer (those that are not frozen)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, resnet_model.parameters()), lr=0.001)

# Start training
best_resnet_model, history = train_model(resnet_model, criterion, optimizer, num_epochs=20, patience=5)


# ===================================================================
# 6. ANALYSIS AND EVALUATION OF RESULTS
# ===================================================================
# --- Display learning curves ---
def plot_history(history, model_name):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
    fig.suptitle(f"Learning curves for {model_name}", fontsize=16)
    ax1.plot(history['train_acc'], label='Train Acc'); ax1.plot(history['val_acc'], label='Val Acc')
    ax1.set_title('Accuracy'); ax1.set_xlabel('Epoch'); ax1.legend(); ax1.grid(True)
    ax2.plot(history['train_loss'], label='Train Loss'); ax2.plot(history['val_loss'], label='Val Loss')
    ax2.set_title('Loss'); ax2.set_xlabel('Epoch'); ax2.legend(); ax2.grid(True)
    plt.show()

plot_history(history, "ResNet18 on cropped patches")

# --- Evaluation on the test set ---
def evaluate_model(model, dataloader, model_name):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs.to(device))
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.numpy())
            y_pred.extend(predicted.cpu().numpy())

    print(f"\n--- Final evaluation of model '{model_name}' on the Test Set ---\n")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix - {model_name}', fontsize=16)
    plt.xlabel('Predictions'); plt.ylabel('True Labels')
    plt.show()

evaluate_model(best_resnet_model, dataloaders['test'], "ResNet18")


In [None]:
# ===================================================================
# 7. SAVE THE TRAINED RESNET18 MODEL (SAFE WAY)
# ===================================================================

# Path to save inside Google Colab /content directory
RESNET_MODEL_PATH = "/content/resnet18_trained.pth"

# Save only the weights (state_dict) - safest method
torch.save(best_resnet_model.state_dict(), RESNET_MODEL_PATH)

print(f"✅ Trained ResNet18 weights saved at: {RESNET_MODEL_PATH}")


In [None]:
# ===================================================================
# 5. SAVE THE TRAINED CNN MODEL "FROM SCRATCH"
# ===================================================================

# Path to save inside Google Colab /content directory
MODEL_PATH = "/content/cnn_scratch_model.pth"

# Save the model's state_dict (recommended way in PyTorch)
torch.save(best_cnn_scratch_model.state_dict(), MODEL_PATH)

print(f"✅ Model saved successfully at: {MODEL_PATH}")

# If you want to save the entire model (not just weights)
MODEL_FULL_PATH = "/content/cnn_scratch_model_full.pth"
torch.save(best_cnn_scratch_model, MODEL_FULL_PATH)

print(f"✅ Full model saved at: {MODEL_FULL_PATH}")


In [None]:
print("\n--- Start of Fine-Tuning ResNet18 ---")

# Clear the GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# 1. Reload the pre-trained model
finetune_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# 2. Unfreeze the last layers. For example, the last two blocks (layer3 and layer4)
# Parameters are frozen by default, we selectively unfreeze them.
for name, param in finetune_model.named_parameters():
    if "layer3" in name or "layer4" in name or "fc" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

# 3. Replace the classification head (as before)
num_ftrs = finetune_model.fc.in_features
finetune_model.fc = nn.Linear(num_ftrs, num_classes)
finetune_model = finetune_model.to(device)

# 4. Create an optimizer with different learning rates (best practice)
# A higher lr for the new layer, a very low lr for the unfrozen layers.
optimizer_finetune = optim.Adam([
    {'params': finetune_model.fc.parameters(), 'lr': 1e-3},
    {'params': finetune_model.layer4.parameters(), 'lr': 1e-5}, # Very low lr
    {'params': finetune_model.layer3.parameters(), 'lr': 1e-5}  # Very low lr
])

print("ResNet18 model prepared for fine-tuning.")

# 5. Launch training with the same function
# The variables criterion, train_model, dataloaders, etc. already exist
best_finetune_model, history_finetune = train_model(
    finetune_model,
    criterion,
    optimizer_finetune,
    num_epochs=5, # 20 epochs should be enough for fine-tuning
    patience=5
)

# 6. Evaluate the fine-tuned model
print("\n--- Evaluation of the Fine-Tuned ResNet18 Model ---")
evaluate_model(best_finetune_model, dataloaders['test'], "ResNet18 (Fine-Tuned)")


In [None]:
# ===================================================================
# 7. SAVE THE FINE-TUNED RESNET18 MODEL
# ===================================================================

# Path to save inside Google Colab /content directory
FINETUNE_MODEL_PATH = "/content/resnet18_finetuned.pth"

# Save the model's state_dict (recommended)
torch.save(best_finetune_model.state_dict(), FINETUNE_MODEL_PATH)
print(f"✅ Fine-tuned ResNet18 model weights saved at: {FINETUNE_MODEL_PATH}")

# Optionally, save the entire model (architecture + weights)
FINETUNE_FULL_MODEL_PATH = "/content/resnet18_finetuned_full.pth"
torch.save(best_finetune_model, FINETUNE_FULL_MODEL_PATH)
print(f"✅ Full fine-tuned ResNet18 model saved at: {FINETUNE_FULL_MODEL_PATH}")


In [None]:
# ===================================================================
# 7. SAVE THE FINE-TUNED RESNET18 MODEL (RECOMMENDED WAY)
# ===================================================================

# Path to save inside Google Colab /content directory
FINETUNE_MODEL_PATH = "/content/resnet18_finetuned.pth"

# Save only the model weights (state_dict)
torch.save(best_finetune_model.state_dict(), FINETUNE_MODEL_PATH)
print(f"✅ Fine-tuned ResNet18 model weights saved at: {FINETUNE_MODEL_PATH}")


In [None]:
# ===================================================================
# 1. RETRIEVAL OF METRICS FOR THE 3 MODELS
# ===================================================================
# This function is a standalone version of the evaluation that returns the metrics.
# It's cleaner than relying on the function from the previous cell.


def get_final_metrics(model, dataloader, device, class_names):
    """Evaluate a model and return its accuracy and weighted F1-score."""
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            # Handle potentially empty batches due to error filtering
            if inputs.size(0) == 0:
                continue

            outputs = model(inputs.to(device))
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    # Generate the classification report as a dictionary
    report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True, zero_division=0)

    accuracy = report['accuracy']
    f1_score = report['weighted avg']['f1-score']

    return accuracy, f1_score

print("Retrieving final metrics for each model on the test set...")

# Model 1: ResNet (Feature Extraction)
resnet_acc, resnet_f1 = get_final_metrics(best_resnet_model, dataloaders['test'], device, class_names)

# Model 2: CNN (From Scratch)
cnn_acc, cnn_f1 = get_final_metrics(best_cnn_scratch_model, dataloaders['test'], device, class_names)

# Model 3: ResNet (Fine-Tuned)
finetune_acc, finetune_f1 = get_final_metrics(best_finetune_model, dataloaders['test'], device, class_names)

print("\n--- FINAL RESULTS ON THE TEST SET ---")
print(f"ResNet (Feature Extract): Accuracy = {resnet_acc:.4f}, F1-Score = {resnet_f1:.4f}")
print(f"CNN (From Scratch):       Accuracy = {cnn_acc:.4f}, F1-Score = {cnn_f1:.4f}")
print(f"ResNet (Fine-Tuned):      Accuracy = {finetune_acc:.4f}, F1-Score = {finetune_f1:.4f}")


# ===================================================================
# 2. COMPARATIVE VISUALIZATION WITH BAR CHARTS
# ===================================================================

model_labels = [
    'ResNet\n(Feature Extract)',
    'CNN\n(From Scratch)',
    'ResNet\n(Fine-Tuned)'
]

accuracies = [resnet_acc, cnn_acc, finetune_acc]
f1_scores = [resnet_f1, cnn_f1, finetune_f1]

x = np.arange(len(model_labels))  # Positions of labels on the x-axis
width = 0.35  # Width of the bars

fig, ax = plt.subplots(figsize=(14, 8))

# Create bars for accuracy and F1-score
rects1 = ax.bar(x - width/2, accuracies, width, label='Accuracy', color='cornflowerblue')
rects2 = ax.bar(x + width/2, f1_scores, width, label='Weighted F1-Score', color='lightcoral')

# Add text, titles, and labels
ax.set_ylabel('Scores', fontsize=14)
ax.set_title('Final Comparison of Model Performances', fontsize=18)
ax.set_xticks(x)
ax.set_xticklabels(model_labels, fontsize=12)
ax.legend(fontsize=12)
ax.grid(axis='y', linestyle='--', alpha=0.7)

# Add value labels on each bar
ax.bar_label(rects1, padding=3, fmt='%.4f')
ax.bar_label(rects2, padding=3, fmt='%.4f')

# Adjust y-axis limit to leave space for labels
ax.set_ylim(0, max(max(accuracies), max(f1_scores)) * 1.15)

fig.tight_layout()
plt.show()
