In [8]:
import torch
from torchvision import models, datasets
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

from PIL import Image
import pandas as pd
import os

In [53]:
class WaterbirdsFullData(Dataset):
    def __init__(self, root, metadata_csv, split, transform=None):

        self.root = root
        self.transform = transform
        self.split = 0 if split == 'train' else 1 if split == 'val' else 2

        self.metadata = pd.read_csv(metadata_csv)

        self.metadata = self.metadata[self.metadata['split'] == self.split]

        self.classes = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
        self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
        self.idx_to_class = {i: self.classes[i] for i in range(len(self.classes))}

        # Build the full samples list
        # Map from image path to metadata row
        self.samples = []
        for class_name, class_idx in self.class_to_idx.items():
            class_folder = os.path.join(root, class_name)
            for img_filename in os.listdir(class_folder):
                full_img_path = os.path.join(class_folder, img_filename)
                metadata_img_filename = os.path.join(class_name, img_filename)

                metadata_row = self.metadata[self.metadata['img_filename'] == metadata_img_filename]
                if not metadata_row.empty:
                    label = class_idx
                    strata = metadata_row['place'].values[0]
                    self.samples.append((full_img_path, label, strata))
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        """
        Returns:
            tuple: (image, class_label, strata)
        """
        img_path, label, strata = self.samples[idx]

        # Load the image
        img = Image.open(img_path).convert('RGB')

        # Apply transformations if provided
        if self.transform is not None:
            img = self.transform(img)
        
        return img, label, strata

In [55]:
trainset = WaterbirdsFullData('waterbird_complete95_forest2water2', 'waterbird_complete95_forest2water2/metadata.csv', 'train')
valset = WaterbirdsFullData('waterbird_complete95_forest2water2', 'waterbird_complete95_forest2water2/metadata.csv', 'val')
testset = WaterbirdsFullData('waterbird_complete95_forest2water2', 'waterbird_complete95_forest2water2/metadata.csv', 'test')