In [None]:
### IMPORT LIBRARIES ###


# Import necessary libraries
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from IPython.display import display, clear_output
import seaborn as sns

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.transforms import Resize, CenterCrop, ToTensor, Normalize
from torchvision import models
from torchvision.models import efficientnet_b0
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
### DATASET AND DATALOADER SETUP FUNCTION###


# Custom dataset class for loading and transforming images
class CustomDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = self.dataframe.iloc[idx]['Image path']
        label = self.dataframe.iloc[idx]['Label']
        image = Image.open(image_path).convert('RGB')  # Convert images to RGB format

        if self.transform:
            image = self.transform(image)

        return image, label
    

# Load the CSV File    
df = pd.read_csv('dino.csv', index_col=[0])
type_to_label = {'ankylosaurus': 0, 'brontosaurus': 1, 'pterodactyl': 2, 'trex': 3, 'triceratops': 4}
df.head()
    
# Data transformations
transform = transforms.Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize your Dataset and DataLoader
custom_dataset = CustomDataset(dataframe=df, transform=transform)
_, random_indices = train_test_split(df.index.tolist(),
                                                       test_size=0.2, stratify=df['Label'].values)
random_test_dataset = Subset(custom_dataset, random_indices)

# Creating DataLoaders
batch_size = 4
dataloader = DataLoader(random_test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
### LOADS THE BEST MODEL ###

# Define the path where you expect the model to be saved
model_save_path = './best_model.pth'

# Check if the file exists
if os.path.exists(model_save_path):
    print(f"Model file found: {model_save_path}")
else:
    print(f"Model file not found: {model_save_path}")
    
# Try to load the model state dictionary to confirm it's saved correctly
try:
    state_dict = torch.load(model_save_path)
    print("Model state dictionary loaded successfully.")
    
    # Optionally, load the state dictionary into a new model instance to fully confirm
    model_clone = models.resnet18(pretrained=False)  # Assuming you're using ResNet-34
    num_ftrs = model_clone.fc.in_features
    model_clone.fc = nn.Linear(num_ftrs, 5)  # Adjust this for your number of classes
    model_clone.load_state_dict(state_dict)
    print("Model loaded successfully with the state dictionary.")
except Exception as e:
    print(f"Error loading the model: {e}")

In [None]:
### RANDOM TESTING TO SEE MODEL ACCURACY ###


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_clone = model_clone.to(device)

# Ensure model_clone is in evaluation mode before making predictions
model_clone.eval()

all_predictions, all_labels = [], []

with torch.no_grad():
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model_clone(inputs)
        _, predicted = torch.max(outputs.data, 1)
        
        # Collect all predictions and labels for further analysis
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate accuracy
accuracy = 100 * np.sum(np.array(all_predictions) == np.array(all_labels)) / len(all_labels)
print(f'Accuracy on the dataset: {accuracy:.2f}%')

# Calculate additional metrics
print("\nDetailed classification report:")
print(classification_report(all_labels, all_predictions, target_names=list(type_to_label.keys())))

# Plot confusion matrix
conf_matrix = confusion_matrix(all_labels, all_predictions)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=list(type_to_label.keys()), yticklabels=list(type_to_label.keys()))
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()