In [None]:
from torchvision.datasets import FGVCAircraft
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# remove copyright banner
class RemoveCopyrightBanner(object):
    def __call__(self, img):
        width, height = img.size
        return img.crop((0, 0, width, height - 20))

transform = transforms.Compose([
    RemoveCopyrightBanner(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create the FGVC Aircraft dataset instance
train_dataset = FGVCAircraft(
    root='./data', 
    split='train',              # Options: 'train', 'val', 'trainval', 'test'
    annotation_level='variant',    # Options: 'variant', 'family', 'manufacturer'
    transform=transform, 
    download=True
)

val_dataset = FGVCAircraft(
    root='./data', 
    split='val',           
    annotation_level='variant', 
    transform=transform, 
    download=True
)

test_dataset = FGVCAircraft(
    root='./data', 
    split='test',             
    annotation_level='variant',  
    transform=transform, 
    download=True
)

In [None]:
# function to show images
def show_images(train_dataset, num_images=5):
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
    for i in range(num_images):
        image, label = train_dataset[i]
        image = image.permute(1, 2, 0)  # convert from CxHxW to HxWxC
        axes[i].imshow(image)
        axes[i].set_title(f'Label: {label}')
        axes[i].axis('off')
    plt.show()

show_images(train_dataset, num_images=5)

# Create Dataset

In [None]:
from collections import defaultdict
import torch

def group_task_indices(dataset):
    """
    Task 0: 0-9, Task 1: 10-19, ..., Task 9: 90-99
    Output a dictionary where keys are task indices and values are lists of image indices.
    For example, task_dict[0] will contain indices of images with labels 0-9.
    """
    task_dict = defaultdict(list)
    for idx, (_, label) in enumerate(dataset):
        for i in range((label // 10) + 1):
            task_dict[i].append(idx)
    return task_dict
    
train_task_idxs = group_task_indices(train_dataset)
val_task_idxs = group_task_indices(val_dataset)
test__idxs = group_task_indices(test_dataset)

In [None]:
from torch.utils.data import Subset
train_subset = Subset(train_dataset, train_task_idxs[0])

# initialize dataloaders with task 0
train_loader = torch.utils.data.DataLoader(
    train_subset, batch_size=32, shuffle=True, num_workers=4
)
# initalize val_loader with task 0
val_subset = Subset(val_dataset, val_task_idxs[0])
val_loader = torch.utils.data.DataLoader(
    val_subset, batch_size=32, shuffle=False, num_workers=4
)
# initalize test_loader with task 0
test_subset = Subset(test_dataset, test__idxs[0])
test_loader = torch.utils.data.DataLoader(
    test_subset, batch_size=32, shuffle=False, num_workers=4
)