In [None]:
import os
import pandas as pd

from torchvision.io import read_image
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

torch.manual_seed(42)

class BirdDataset(): 
    def __init__(self, annotations_file, img_dir, transform = None, target_transform = None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        csv_path = self.img_labels.iloc[idx, 0]

        if csv_path.startswith("/") or csv_path.startswith("\\"):
            csv_path = csv_path[1:]

        img_path = os.path.join(self.img_dir, csv_path)
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

standard_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)),
    transforms.ToTensor()
])