In [None]:
from torchvision import transforms
from torch.utils import data
import numpy as np

In [None]:
DATASET_ROOT = './CUB/DATASET/CUB_200_2011'
DATASET_VALIDATION_RANDOM_SEED = 123
BATCH_SIZE = 6
DATASET_WORK_NUMBER = 8
DATASET_SPLIT_RATIO = 0.9

## Dataset Loading

In [None]:
from DatasetLoader.cub import CUB


trans_train = transforms.Compose([
    # A SCELTA NOSTRA AUGUMENTATION
    # transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomRotation(30),
    
    # ATTENZIONE QUA NON DOVREBBE SERIVIRE IL RESIZE (?)
    # transforms.RandomResizedCrop(224, scale=(0.7, 1), ratio=(3/4, 4/3)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

trans_test = transforms.Compose([
    # ATTENZIONE QUA NON DOVREBBE SERIVIRE IL RESIZE (?)
    # transforms.Resize(224),
    # transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# create dataset
train_data = CUB(DATASET_ROOT, 'train', DATASET_SPLIT_RATIO, DATASET_VALIDATION_RANDOM_SEED, transform=trans_train)
valid_data = CUB(DATASET_ROOT, 'valid', DATASET_SPLIT_RATIO, DATASET_VALIDATION_RANDOM_SEED, transform=trans_test)
test_data = CUB(DATASET_ROOT, 'test', 0, 0, transform=trans_test)

print("Train: {}".format(len(train_data)))
print("Valid: {}".format(len(valid_data)))
print("Test: {}".format(len(test_data)))

In [None]:
from matplotlib import pyplot as plt

def imshow(image, label, ax=None, normalize=True):
    """show single along with label on an ax"""
    
    if ax is None:
        fig, ax = plt.subplots()
    image = image.numpy().transpose((1, 2, 0))

    if normalize:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image + mean
        image = np.clip(image, 0, 1)

    ax.imshow(image)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', length=0)
    ax.set_xticklabels('')
    ax.set_yticklabels('')
    ax.set_title(label)

    return ax


def show_samples(images, labels, nrows=2, ncols=3, title=None, normalize=True):
    """ show multiple samples

    args:
        nrows (int, optional): number of row
        ncols (int, optional): number of column
        title (str, optional): title.
        normalize (bool, optional): whether the images are normalized
    """
    fig, axes = plt.subplots(nrows, ncols, facecolor='#ffffff', dpi=100)

    # .flat: to map samples to multi-dimensional axes
    for (ax, image, label) in zip(axes.flat, images, labels):
        ax = imshow(image, label, ax, normalize)

    fig.suptitle(title)
    fig.tight_layout = True
    fig.subplots_adjust(top=0.85, hspace=0.3)
    plt.show()
    

In [None]:
# show samples
train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
images, labels = next(iter(train_loader))
show_samples(images[0:6], labels[0:6], 2, 3, 'SHOW SOME SAMPLES')

In [None]:
# create dataloader
train_loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=DATASET_WORK_NUMBER)
valid_loader = data.DataLoader(valid_data, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=DATASET_WORK_NUMBER)
test_loader = data.DataLoader(test_data, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=DATASET_WORK_NUMBER)