<a href="https://colab.research.google.com/github/AlexHeyman/FewShotGANTraining/blob/main/dataloader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torch==1.10.0 torchvision

In [None]:
!pip install pillow



In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [1]:
import os
import zipfile
import urllib.request
from typing import Any, Dict

import torch
import torchvision
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from copy import deepcopy
import shutil
import json

In [2]:
def copy_Generated_parameters(model):
    flatten = deepcopy(list(p.data for p in model.parameters()))
    return flatten

In [4]:
def load_parameters(model, new_parameter):
    for p, new_param in zip(model.parameters(), new_parameter):
        p.data.copy_(new_param)

In [5]:
def get_directory(args):
    folder_name = 'train_results/' + args.name
    saved_model_folder = os.path.join( folder_name, 'models')
    saved_image_folder = os.path.join( folder_name, 'images')
    
    os.makedirs(saved_model_folder, exist_ok=True)
    os.makedirs(saved_image_folder, exist_ok=True)

    for f in os.listdir('./'):
        if '.py' in f:
            shutil.copy(f, folder_name+'/'+f)
    
    with open( os.path.join(saved_model_folder, '../args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    return saved_model_folder, saved_image_folder

In [None]:
dataset = {
    'url': '/content/gdrive/MyDrive/DATASETS/few-shot-image-datasets.zip',
    'archive': 'few-shot-image-datasets.zip',
    'destination': 'few-shot-image-datasets',
}

In [None]:
def extract_dataset(root: str,
                    url: str,
                    archive: str,
                    destination: str):
    destination_path = os.path.join(root, destination)
    archive_path = os.path.join(root, archive)

    if not os.path.isdir(destination_path):
        urllib.request.urlretrieve(url, archive_path)

        if archive_path.endswith('.zip'):
            with zipfile.ZipFile(archive_path, 'r') as zip:
                zip.extractall(root)

In [None]:
class MergeFewShotImageDatasets:

    def __init__(self, root: str):
        extract_dataset(
                root=root,
                url=_dataset['url'],
                archive=_dataset['archive'],
                destination=_dataset['destination'],
            )


In [None]:
class FewShotImageDataset(MergeFewShotImageDatasets, Dataset):


    def __init__(self, root: str,
                       subdirectory: str):
        super().__init__(root)
        self._root = os.path.join(root, subdirectory)
        self._files = os.listdir(self._root)

        self._transforms = torchvision.transforms.Compose([
                torchvision.transforms.Resize((1024, 1024)),
                torchvision.transforms.RandomHorizontalFlip(p=0.5),
            ])

In [None]:
def _length_(self):
        return len(self._files)

In [None]:
def retrieveimage(self, index: int) -> Dict[str, Any]:
        image_path = os.path.join(self._root, self._files[index])
        image = Image.open(image_path).convert('RGB')
        image = torch.from_numpy(np.array(image))
        image = image.permute(2, 0, 1)
        image = self._transforms(image)
        return {'image': image}

In [None]:
class NoiseDataset(Dataset):

    def __init__(self, size: int, channels: int):
        self._size = size
        self._channels = channels

    def _length_(self):
        return self._size

    def retrieveimage(self, index: int) -> torch.Tensor:
        return torch.zeros(self._channels, 1, 1).normal_(0.0, 1.0)

Reference:https://github.com/silentz/Towards-Faster-And-Stabilized-GAN-Training-For-High-Fidelity-Few-Shot-Image-Synthesis
