In [19]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50, ResNet50_Weights
import split_data

def train_image_classifier(
    data_dir='/Users/nataliamarko/Documents/GitHub/mnist-classifier-oop/task_2/data/',
    output_dir="./models/image_classifier",
    num_classes=10,
    batch_size=64,
    epochs=4,
    learning_rate=1e-3,
    val_ratio=0.3
):
    """
    Train a ResNet50 to classify 10 animal classes with resized images and enhanced data augmentation.
    """
    base_directory = '/Users/nataliamarko/Documents/GitHub/mnist-classifier-oop/task_2/'
    train_dir = os.path.join(base_directory, 'train')
    val_dir = os.path.join(base_directory, 'val')

    if not os.path.exists(train_dir) or not os.path.exists(val_dir):
        print("Splitting dataset into train/val sets...")
        split_data.split_dataset(source_dir=data_dir, base_dir=base_directory, val_ratio=val_ratio)

    # Enhanced train_transform with AutoAugment
    train_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy=torchvision.transforms.AutoAugmentPolicy.IMAGENET),  # AutoAugment with IMAGENET policy
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    val_transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    train_dataset = ImageFolder(train_dir, transform=train_transform)
    val_dataset = ImageFolder(val_dir, transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_acc = 0.0
    for epoch in range(epochs):
        model.train()
        total_loss, total_correct, total_samples = 0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            total_correct += torch.sum(preds == labels).item()
            total_samples += images.size(0)

        avg_train_loss = total_loss / total_samples
        train_acc = total_correct / total_samples

        model.eval()
        val_correct, val_samples = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                val_correct += torch.sum(preds == labels).item()
                val_samples += images.size(0)
        val_acc = val_correct / val_samples

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))

    print("Training complete. Best val accuracy:", best_val_acc)

if __name__ == "__main__":
    train_image_classifier()


Epoch [1/5], Loss: 1.0393, Train Acc: 0.6580, Val Acc: 0.7514
Epoch [2/5], Loss: 0.7130, Train Acc: 0.7647, Val Acc: 0.8276
Epoch [3/5], Loss: 0.5973, Train Acc: 0.8031, Val Acc: 0.8251
Epoch [4/5], Loss: 0.5055, Train Acc: 0.8332, Val Acc: 0.8542
Epoch [5/5], Loss: 0.4603, Train Acc: 0.8489, Val Acc: 0.8446
Training complete. Best val accuracy: 0.8541945346837378
