In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split
import os
from google.colab import drive # Import drive

# Mount Google Drive
drive.mount('/content/drive')

# Path for Colab
base_dir = "/content/drive/MyDrive/Gen_Ai_Lab/Cat Species"

# Data Augmentation (Crucial to avoid overfitting)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load all images from all folders
full_dataset = datasets.ImageFolder(root=base_dir)
num_classes = len(full_dataset.classes)
print(f"Detected {num_classes} classes.")

# 80/20 Split
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_data, test_data = random_split(full_dataset, [train_size, test_size])

# Assign specific transforms
train_data.dataset.transform = train_transform
test_data.dataset.transform = test_transform

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)

In [None]:
# Load Pretrained ResNet152
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet152(pretrained=True)

# Replace the last layer for your specific number of classes
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

# Loss and Optimizer (Low learning rate for fine-tuning)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)

In [None]:
def train_and_evaluate(epochs=15):
    for epoch in range(epochs):
        # --- Training Phase ---
        model.train()
        train_correct = 0
        train_total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (preds == labels).sum().item()

        # --- Testing Phase ---
        model.eval()
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                test_total += labels.size(0)
                test_correct += (preds == labels).sum().item()

        train_acc = 100 * train_correct / train_total
        test_acc = 100 * test_correct / test_total
        print(f"Epoch {epoch+1}: Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%")

train_and_evaluate()

Epoch 1: Train Acc: 19.05% | Test Acc: 26.19%
Epoch 2: Train Acc: 80.95% | Test Acc: 32.14%
Epoch 3: Train Acc: 94.64% | Test Acc: 55.95%


In [None]:
model.eval()
class_correct = [0.] * num_classes
class_total = [0.] * num_classes

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()

        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

print("\n--- Final Class-Wise Test Accuracy ---")
for i in range(num_classes):
    acc = 100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
    print(f"{full_dataset.classes[i]:<20}: {acc:.2f}%")