In [19]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, f1_score
from PIL import Image
import numpy as np
import os

import sys
sys.path.append("..")
from models.models_v1 import SimpleCNN


In [20]:
test_dir = "../data/test_set"

img_size = 128

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

test_data = datasets.ImageFolder(test_dir, transform=transform)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

class_names = test_data.classes
class_names


['cats', 'dogs']

In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model = SimpleCNN().to(device)

model_path = "models/final_model.pth"
model.load_state_dict(torch.load(model_path, map_location=device))

model.eval()
print("Model loaded successfully!")


Using device: cpu
Model loaded successfully!


In [22]:
all_preds = []
all_labels = []

correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = correct / total
f1 = f1_score(all_labels, all_preds, average="weighted")

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test F1 Score: {f1:.4f}")


Test Accuracy: 0.7756
Test F1 Score: 0.7754


In [None]:
import json

results = {
    "accuracy": float(accuracy),
    "f1_score": float(f1),
    "total_samples": int(total)
}

with open("../results/test_v1_user2.json", "w") as f:
    json.dump(results, f, indent=4)

print("Saved → results/test_results.json")


Saved → results/test_results.json
