In [None]:
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms as T

from GSVCitiesDataset import GSVCitiesDataset
from SanFranciscoDataset import SanFranciscoValidationDataset, SanFranciscoTestDataset
from TokyoDataset import TokyoTestDataset


from prettytable import PrettyTable

IMAGENET_MEAN_STD = {'mean': [0.485, 0.456, 0.406],
                     'std': [0.229, 0.224, 0.225]}

VIT_MEAN_STD = {'mean': [0.5, 0.5, 0.5],
                'std': [0.5, 0.5, 0.5]}

TRAIN_CITIES = [
    'Bangkok',
    'Buenosaires',
    'Losangeles',
    'Mexicocity',
    'Osl', # refers to Oslo
    'Rome',
    'Barcelona',
    'Chicago',
    'Madrid',
    'Miami',
    'Phoenix',
    'Trt', # refers to Toronto
    'Boston',
    'Lisbon',
    'Medellin',
    'Minneapolis',
    'Prg', # refers to Prague
    'Washingtondc',
    'Brussels',
    'London',
    'Melbourne',
    'Osaka',
    'Prs', # refers to Paris
]


class GSVCitiesDataModule(pl.LightningDataModule):
    def __init__(self,
                 batch_size=32,
                 img_per_place=4,
                 min_img_per_place=4,
                 shuffle_all=False,
                 image_size=(480, 640),
                 num_workers=4,
                 show_data_stats=True,
                 cities=TRAIN_CITIES,
                 mean_std=IMAGENET_MEAN_STD,
                 batch_sampler=None,
                 random_sample_from_each_place=True,
                 val_set_names=['SanFrancisco_xs'],
                 test_set_names=['SanFrancisco_xs', 'Tokyo_xs']
                 ):
        super().__init__()
        self.batch_size = batch_size
        self.img_per_place = img_per_place
        self.min_img_per_place = min_img_per_place
        self.shuffle_all = shuffle_all
        self.image_size = image_size
        self.num_workers = num_workers
        self.batch_sampler = batch_sampler
        self.show_data_stats = show_data_stats
        self.cities = cities
        self.mean_dataset = mean_std['mean']
        self.std_dataset = mean_std['std']
        self.random_sample_from_each_place = random_sample_from_each_place
        self.val_set_names = val_set_names
        self.test_set_names = test_set_names
        self.save_hyperparameters() # save hyperparameter with Pytorch Lightening

        self.train_transform = T.Compose([
            T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
            T.RandAugment(num_ops=3, interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            T.Normalize(mean=self.mean_dataset, std=self.std_dataset),
        ])

        self.valid_transform = T.Compose([
            T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            T.Normalize(mean=self.mean_dataset, std=self.std_dataset)])

        self.train_loader_config = {
            'batch_size': self.batch_size,
            'num_workers': self.num_workers,
            'drop_last': False,
            'pin_memory': True,
            'shuffle': self.shuffle_all}

        self.valid_loader_config = {
            'batch_size': self.batch_size,
            'num_workers': self.num_workers//2,
            'drop_last': False,
            'pin_memory': True,
            'shuffle': False}

    def setup(self, stage):
        if stage == 'fit':
            # load train dataloader with reload routine
            self.reload()

            # load validation sets (pitts_val, msls_val, ...etc)
            self.val_datasets = []
            for valid_set_name in self.val_set_names:
                if 'sanfrancisco' in valid_set_name.lower():
                    self.val_datasets.append(SanFranciscoValidationDataset(
                        input_transform=self.valid_transform))
                else:
                    print(
                        f'Validation set {valid_set_name} does not exist or has not been implemented yet')
                    raise NotImplementedError

            if self.show_data_stats:
                self.print_stats(stage)


        if stage == "test":
            self.test_datasets = []
            for valid_set_name in self.test_set_names:
                if 'sanfrancisco_xs' in valid_set_name.lower():
                    self.test_datasets.append(SanFranciscoTestDataset(input_transform=self.valid_transform))

                elif 'tokyo_xs' in valid_set_name.lower():
                    self.test_datasets.append(TokyoTestDataset(input_transform=self.valid_transform))

                else:
                    print(f'Test set {valid_set_name} does not exist or has not been implemented yet')
                    raise NotImplementedError

            if self.show_data_stats:
                self.print_stats(stage)

    def reload(self):
        self.train_dataset = GSVCitiesDataset(
            cities=self.cities,
            img_per_place=self.img_per_place,
            min_img_per_place=self.min_img_per_place,
            random_sample_from_each_place=self.random_sample_from_each_place,
            transform=self.train_transform)

    def train_dataloader(self):
        self.reload()
        return DataLoader(dataset=self.train_dataset, **self.train_loader_config)

    def val_dataloader(self):
        val_dataloaders = []
        for val_dataset in self.val_datasets:
            val_dataloaders.append(DataLoader(
                dataset=val_dataset, **self.valid_loader_config))
        return val_dataloaders

    def test_dataloader(self):
        test_dataloaders = []
        for test_dataset in self.test_datasets:
            test_dataloaders.append(DataLoader(
                dataset=test_dataset, **self.valid_loader_config))
        return test_dataloaders

    def print_stats(self, stage):
        if stage == 'fit':
            print()  # print a new line
            table = PrettyTable()
            table.field_names = ['Data', 'Value']
            table.align['Data'] = "l"
            table.align['Value'] = "l"
            table.header = False
            table.add_row(["# of cities", f"{len(TRAIN_CITIES)}"])
            table.add_row(["# of places", f'{self.train_dataset.__len__()}'])
            table.add_row(["# of images", f'{self.train_dataset.total_nb_images}'])
            print(table.get_string(title="Training Dataset"))
            print()

            table = PrettyTable()
            table.field_names = ['Data', 'Value']
            table.align['Data'] = "l"
            table.align['Value'] = "l"
            table.header = False
            for i, val_set_name in enumerate(self.val_set_names):
                table.add_row([f"Validation set {i+1}", f"{val_set_name}"])
            # table.add_row(["# of places", f'{self.train_dataset.__len__()}'])
            print(table.get_string(title="Validation Datasets"))
            print()

            table = PrettyTable()
            table.field_names = ['Data', 'Value']
            table.align['Data'] = "l"
            table.align['Value'] = "l"
            table.header = False
            table.add_row(
                ["Batch size (PxK)", f"{self.batch_size}x{self.img_per_place}"])
            table.add_row(
                ["# of iterations", f"{self.train_dataset.__len__()//self.batch_size}"])
            table.add_row(["Image size", f"{self.image_size}"])
            print(table.get_string(title="Training config"))

        if stage == 'test':
            table = PrettyTable()
            table.field_names = ['Data', 'Value']
            table.align['Data'] = "l"
            table.align['Value'] = "l"
            table.header = False
            for i, test_set_name in enumerate(self.test_set_names):
                table.add_row([f"Test set {i+1}", f"{test_set_name}"])
            print(table.get_string(title="Test Datasets"))
            print()