In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt

In [2]:
!wget https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip

--2021-06-05 16:10:02--  https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip
Resolving github.com (github.com)... 192.30.255.112
Connecting to github.com (github.com)|192.30.255.112|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip [following]
--2021-06-05 16:10:02--  https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9464212 (9.0M) [application/zip]
Saving to: ‘images_background.zip’


2021-06-05 16:10:04 (28.2 MB/s) - ‘images_background.zip’ saved [9464212/9464212]



In [3]:
!wget https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip

--2021-06-05 16:10:25--  https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip
Resolving github.com (github.com)... 192.30.255.113
Connecting to github.com (github.com)|192.30.255.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip [following]
--2021-06-05 16:10:25--  https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6462886 (6.2M) [application/zip]
Saving to: ‘images_evaluation.zip’


2021-06-05 16:10:26 (23.9 MB/s) - ‘images_evaluation.zip’ saved [6462886/6462886]



In [26]:
from torchvision.datasets import ImageFolder, DatasetFolder, Omniglot

In [58]:
from torch.utils.data import DataLoader

In [60]:
import torchvision.transforms as transforms

In [61]:
#mean, std = (tensor(0.9221), tensor(0.2681))
train = Omniglot(root="", download=True, transform=transforms.ToTensor())

Files already downloaded and verified


In [62]:
train

Dataset Omniglot
    Number of datapoints: 19280
    Root location: omniglot-py
    StandardTransform
Transform: ToTensor()

In [63]:
#mean, std =  (tensor(0.9157), tensor(0.2778))
test = Omniglot(root="", background=False, download=True, transform=transforms.ToTensor())

Files already downloaded and verified


In [64]:
test

Dataset Omniglot
    Number of datapoints: 13180
    Root location: omniglot-py
    StandardTransform
Transform: ToTensor()

In [7]:
from PIL import Image
from os.path import join
from typing import Any, Callable, List, Optional, Tuple

In [72]:
from torchvision.datasets import VisionDataset

In [6]:
from torchvision.datasets.utils import download_and_extract_archive, check_integrity, list_dir, list_files

In [None]:
class Omniglot(VisionDataset):
    """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``omniglot-py`` exists.
        background (bool, optional): If True, creates dataset from the "background" set, otherwise
            creates from the "evaluation" set. This terminology is defined by the authors.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset zip files from the internet and
            puts it in root directory. If the zip files are already downloaded, they are not
            downloaded again.
    """
    folder = 'omniglot-py'
    download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python'
    zips_md5 = {
        'images_background': '68d2efa1b9178cc56df9314c21c6e718',
        'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
    }

    def __init__(
            self,
            root: str,
            background: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:
        super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
                                       target_transform=target_transform)
        self.background = background

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        self.target_folder = join(self.root, self._get_target_folder())
        self._alphabets = list_dir(self.target_folder)
        self._characters: List[str] = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
                                           for a in self._alphabets], [])
        self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
                                  for idx, character in enumerate(self._characters)]
        self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])

    def __len__(self) -> int:
        return len(self._flat_character_images)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target character class.
        """
        image_name, character_class = self._flat_character_images[index]
        image_path = join(self.target_folder, self._characters[character_class], image_name)
        image = Image.open(image_path, mode='r').convert('L')

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            character_class = self.target_transform(character_class)

        return image, character_class

    def _check_integrity(self) -> bool:
        zip_filename = self._get_target_folder()
        if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
            return False
        return True

    def download(self) -> None:
        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        filename = self._get_target_folder()
        zip_filename = filename + '.zip'
        url = self.download_url_prefix + '/' + zip_filename
        download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])

    def _get_target_folder(self) -> str:
        return 'images_background' if self.background else 'images_evaluation'

In [None]:
class OmniglotDataset(VisionDataset):

    download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python'
    zips_md5 = {
        'images_background': '68d2efa1b9178cc56df9314c21c6e718',
        'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
    }

    def __init__(
        self,
        root: str,
        background: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        super(OmniglotDataset, self).__init__(root, transform=transform, target_transform=target_transform)

        self.background = background

        if download:
            self.download()
        
        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')
        
        self.target_folder = join(self.root, self._get_target_folder())
        self._alphabets = list_dir(self.target_folder)
        self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets], [])
        self._character_images = 
        [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
                                  for idx, character in enumerate(self._characters)]
        self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])

    def _check_integrity(self) -> bool:
        zip_filename = self._get_target_folder()
        if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
            return False
        return True

    def download(self) -> None:
        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        filename = self._get_target_folder()
        zip_filename = filename + '.zip'
        url = self.download_url_prefix + '/' + zip_filename
        download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])

    def _get_target_folder(self) -> str:
        return 'images_background' if self.background else 'images_evaluation'

In [19]:
a= sum([[1,2,3],[3,4,5],[6,7,8]], [])

In [20]:
a

[1, 2, 3, 3, 4, 5, 6, 7, 8]