In [None]:
import os
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image

from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy
from torchvision import models, transforms

In [None]:
def get_abs_path(n_parent: int = 0):
    return Path('../' * n_parent).resolve()

In [None]:
path = get_abs_path(1)
data_path = path / 'data'
images_paths = data_path.glob('**/*.png')
images_paths = list(images_paths)

class_names = [d.name for d in data_path.iterdir() if d.is_dir()]
class_labels = {value:key for (key,value) in enumerate(class_names)}
print('Labels:', class_labels)

In [None]:
train_paths, test_paths = train_test_split(images_paths, test_size=0.1, shuffle=True)
val_paths, test_paths = train_test_split(test_paths, test_size=0.5, shuffle=True)
print('train len: %d val len: %d test len: %d' % (len(train_paths), len(val_paths), len(test_paths)))

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((400, 400)),
        transforms.RandomCrop((300, 300)),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'validate': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [None]:
def get_label_from_filename(class_labels, filename):
    for country_name, country_label in class_labels.items():
        if country_name in filename:
            return country_label


class CountryDataset(Dataset):

    def __init__(self, images_paths, class_labels, transform):

        self.images_paths = images_paths
        self.transform = transform

        self.labels = []
        for image_path in self.images_paths:
            image_path = str(image_path)
            label = get_label_from_filename(class_labels, image_path)
            self.labels.append(label)


    def __len__(self):
        return len(self.images_paths)


    def __getitem__(self, idx):

        filename = self.images_paths[idx]
        x = Image.open(filename)
        x = self.transform(x)
        label = self.labels[idx]
        return x, label, filename

In [None]:
batch_size = 16
train_dataset = CountryDataset(train_paths, class_labels, data_transforms['train'])
val_dataset = CountryDataset(val_paths, class_labels, data_transforms['validate'])
test_dataset = CountryDataset(test_paths, class_labels, data_transforms['test'])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
val_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
test_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

In [None]:
for i in range(1):
    x, label, filename = train_dataset[i]
    name = class_names[label]
    img = mpimg.imread(filename)
    plt.gca().clear()
    plt.title(name)
    plt.imshow(img)
    plt.show()

In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print('Using {0} device'.format(device))