In [1]:
import torch
import torchvision
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import wandb
from pathlib import Path

In [2]:
#wandb.login()
import numpy as np
import os
import Cropper

# Get the current directory
notebook_dir = os.path.abspath(os.path.dirname(__file__) if "__file__" in locals() else ".")

# Move one level up to reach the parent directory (--> cd ..)
parent_dir = os.path.abspath(os.path.join(notebook_dir, '..'))

# Define the dataset directory
dataset_dir = os.path.join(parent_dir,'fgvc-aircraft-2013b/data/images')
resized_dir = os.path.join(parent_dir, 'fgvc-aircraft-2013b/data/Images_resized')
avg_height, avg_width = Cropper.images_sizes(dataset_dir)

In [5]:
Cropper.resize_with_padding(dataset_dir, resized_dir, int(avg_height), int(avg_width))

In [None]:
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = annotations_file
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.img_labels)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [None]:
def get_data(batch_size, test_batch_size = 256, num_workers = 4, mean = None, std = None, resized_dir = None):
    
    # Compute Mean and Standatd Deviation of the dataset if not provided
    if mean == None or std == None:
        transform = transforms.Compose([
            transforms.ToTensor()])
    
        full_train_data = datasets.FGVCAircraft('./data', split = 'trainval', download=True, transform = transform)

        

        images = torch.stack([image for image, _ in full_train_data], dim = 3)
        mean = torch.mean(images)
        std = torch.std(images)

    else:
        full_train_data = datasets.FGVCAircraft('./data', split = 'trainval', download=True, transform = transform)

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[mean], std=[std])])

    test_data = datasets.FGVCAircraft('./data', split = 'test', download=True, transform = transform)


    num_samples = len(full_train_data)
    train_samples = int(num_samples * 0.7 + 1)
    validation_samples = num_samples - train_samples
    train_data, validation_data = torch.utils.data.random_split(full_train_data, [train_samples, validation_samples])

    train_loader = torch.utils.data.DataLoader(train_data, batch_size, shuffle = True, num_workers = num_workers)
    val_loader = torch.utils.data.DataLoader(validation_data, test_batch_size, shuffle = False, num_workers = num_workers)
    test_loader = torch.utils.data.DataLoader(test_data, test_batch_size, shuffle = False, num_workers = num_workers)

    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = get_data(batch_size = 64, resized_dir = resized_dir)

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.pool3 = nn.Conv2d(16, 32, 5)

