In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm
import random

# ===========================
# 1. Set Random Seeds for Reproducibility
# ===========================
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [None]:
# ===========================
# 2. Define Custom Dataset Class
# ===========================
class ImageDataset(torch.utils.data.Dataset):
    """
    Custom Dataset for loading images from 'real' and 'synthetic' directories.
    Assigns label 0 for 'real' and 1 for 'synthetic'.
    """
    def __init__(self, real_dir, synthetic_dir, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform

        supported_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.gif')

        for root, _, files in os.walk(real_dir):
            for file in files:
                if file.lower().endswith(supported_extensions):
                    self.image_paths.append(os.path.join(root, file))
                    self.labels.append(0)

        for root, _, files in os.walk(synthetic_dir):
            for file in files:
                if file.lower().endswith(supported_extensions):
                    self.image_paths.append(os.path.join(root, file))
                    self.labels.append(1)

        assert len(self.image_paths) == len(self.labels), "Mismatch between images and labels"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            image = Image.new('RGB', (224, 224))
            label = -1

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
# ===========================
# 3. Configuration and Hyperparameters
# ===========================
data_dir = "data/stanford-cars/split_random/"  # Update this path
real_test_dir = os.path.join(data_dir, 'real', 'test')
synthetic_test_dir = os.path.join(data_dir, 'synthetic', 'test')

for dir_path in [real_test_dir, synthetic_test_dir]:
    if not os.path.isdir(dir_path):
        raise ValueError(f"Directory does not exist: {dir_path}")

batch_size = 32
num_classes = 2
device = torch.device("mps" if torch.backends.mps.is_available() else \
                    "cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# ===========================
# 4. Data Transforms
# ===========================
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

In [None]:
# ===========================
# 5. Load Test Dataset and DataLoader
# ===========================
test_dataset = ImageDataset(real_dir=real_test_dir, 
                            synthetic_dir=synthetic_test_dir, 
                            transform=test_transform)

dataloader_test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
print(f"Test dataset size: {len(test_dataset)}")

class_names = ['real', 'synthetic']

In [None]:
# ===========================
# 6. Initialize the Model and Load Saved Weights
# ===========================
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

model_path = 'models/resnet18_finetuned.pth'
if not os.path.isfile(model_path):
    raise FileNotFoundError(f"Saved model not found at {model_path}")

model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()
print("Model loaded successfully.")

In [None]:
# ===========================
# 8. Evaluation on Test Set
# ===========================
criterion = nn.CrossEntropyLoss()

running_loss = 0.0
running_corrects = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in tqdm(dataloader_test, desc="Testing Phase"):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_loss = running_loss / len(test_dataset)
test_acc = running_corrects.float() / len(test_dataset)
print(f'Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}')

In [None]:
# ===========================
# 9. Confusion Matrix and Classification Report
# ===========================
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix on Test Set')
plt.tight_layout()
plt.show()

report = classification_report(all_labels, all_preds, target_names=class_names)
print("Classification Report:")
print(report)

In [None]:
# ===========================
# 10. Visualize Sample Predictions
# ===========================
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp * np.array(imagenet_std) + np.array(imagenet_mean), 0, 1)
    plt.imshow(inp)
    if title:
        plt.title(title)
    plt.pause(0.001)

test_iter = iter(dataloader_test)
inputs, classes = next(test_iter)
inputs = inputs.to(device)
classes = classes.to(device)

with torch.no_grad():
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)

plt.figure(figsize=(12, 8))
for i in range(min(6, inputs.size(0))):
    ax = plt.subplot(2, 3, i+1)
    inp = inputs.cpu().data[i]
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp * np.array(imagenet_std) + np.array(imagenet_mean), 0, 1)
    plt.imshow(inp)
    ax.set_title(f'Predicted: {class_names[preds[i]]}\nTrue: {class_names[classes[i]]}')
    ax.axis('off')
plt.tight_layout()
plt.show()