# AML - 2025 : Feather in Focus - The Baseline

In [1]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from google.colab import drive

In [2]:
# 1. SETUP: Mount Google Drive
# ---------------------------------------------------------
drive.mount('/content/drive')

BASE_PATH = "/content/drive/MyDrive/AML2025"

DATA_PATH = os.path.join(BASE_PATH, "Dataset")

Mounted at /content/drive


In [None]:
# 2. DEFINE THE DATASET CLASS
# ---------------------------------------------------------
class BirdDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

        # FIX: The CSV labels are 1-200, but PyTorch needs 0-199.
        # We subtract 1 from every label.
        self.data['label'] = self.data['label'] - 1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get path from CSV (e.g., "/train_images/1.jpg")
        img_path = self.data.iloc[idx, 0]

        # Remove leading slash if present to join paths correctly
        if img_path.startswith("/"):
            img_path = img_path[1:]

        # Full path: /content/drive/.../train_images/1.jpg
        full_path = os.path.join(self.root_dir, img_path)

        # Load Image
        try:
            image = Image.open(full_path).convert("RGB")
        except FileNotFoundError:
            print(f"MISSING IMAGE: {full_path}")
            # Return a black image if file is missing (prevents crash)
            image = Image.new('RGB', (224, 224))

        label = self.data.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

# 3. CREATE DATA LOADERS
# ---------------------------------------------------------
# Define standard formatting (Resize to 224x224)
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize Dataset
# Note: Pointing to where 'train_images.csv' is located
dataset = BirdDataset(
    csv_file=f'{DATA_PATH}/train_images.csv',
    root_dir=DATA_PATH,
    transform=data_transforms
)

# Split: 80% Train, 20% Validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create the Loaders (The final delivery)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f" Data Loaded Successfully!")
print(f"Training Images: {len(train_dataset)}")
print(f"Validation Images: {len(val_dataset)}")