In [None]:
# default_exp dirty_mnist

# DirtyMNIST DataLoader

> Ready to go

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# exports
import os
from typing import IO, Any, Callable, Dict, List, Optional, Tuple, Union
from urllib.error import URLError

import torch
from torchvision.datasets.mnist import MNIST, VisionDataset
from torchvision.datasets.utils import (
    download_and_extract_archive,
    download_url,
    extract_archive,
    verify_str_arg,
)
from torchvision.transforms import Compose, Normalize, ToTensor

# based on torchvision.datasets.mnist.py (https://github.com/pytorch/vision/blob/37eb37a836fbc2c26197dfaf76d2a3f4f39f15df/torchvision/datasets/mnist.py)

MNIST_NORMALIZATION = Normalize((0.1307,), (0.3081,))


class AmbiguousMNIST(VisionDataset):
    mirrors = ["http://github.com/BlackHC/ddu_dirty_mnist/releases/download/data-v0.5.0/"]

    resources = dict(data=("amnist_labels.pt", None), targets=("amnist_samples.pt", None))

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
        device=None,
    ):
        super().__init__(root, transform=transform, target_transform=target_transform)

        self.train = train  # training set or test set

        if download:
            self.download()

        data_range = slice(None, 60000) if self.train else slice(60000, None)

        self.data = torch.load(self.resource_path("data"), map_location=device)[data_range]
        self.targets = torch.load(self.resource_path("targets"), map_location="cpu")[data_range]

        num_multi_labels = self.targets.shape[1]
        self.data = self.data.expand(-1, num_multi_labels, 28, 28).reshape(-1, 1, 28, 28)
        self.targets = self.targets.reshape(-1)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index].unsqueeze, int(self.targets[index])

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.data)

    @property
    def data_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__)

    def resource_path(self, name):
        print(name)
        return os.path.join(self.data_folder, self.resources[name][0])

    def _check_exists(self) -> bool:
        return all(os.path.exists(self.resource_path(name)) for name in self.resources)

    def download(self) -> None:
        """Download the data if it doesn't exist in data_folder already."""

        print(0)
        if self._check_exists():
            return
        print(0)

        os.makedirs(self.data_folder, exist_ok=True)

        # download files
        for filename, md5 in self.resources.values():
            for mirror in self.mirrors:
                url = "{}{}".format(mirror, filename)
                try:
                    print("Downloading {}".format(url))
                    download_and_extract_archive(url, download_root=self.data_folder, filename=filename, md5=md5)
                except URLError as error:
                    print("Failed to download (trying next):\n{}".format(error))
                    continue
                except:
                    raise
                finally:
                    print()
                break
            else:
                raise RuntimeError("Error downloading {}".format(filename))

        print("Done!")


def DirtyMNIST(
    root: str,
    train: bool = True,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    download: bool = False,
):
    mnist_transform = Compose([ToTensor(), transform]) if transform else ToTensor()
    mnist_dataset = MNIST(
        root=root, train=train, transform=mnist_transform, target_transform=target_transform, download=download
    )
    amnist_dataset = AmbiguousMNIST(
        root=root, train=train, transform=transform, target_transform=target_transform, download=download
    )

    return torch.utils.data.ConcatDataset(mnist_dataset, amnist_dataset)

Let's look at the dataset:

In [None]:
dirty_mnist_train = DirtyMNIST(".", train=True, download=True, transform=MNIST_NORMALIZATION)

0
data
0
Downloading http://github.com/BlackHC/ddu_dirty_mnist/releases/download/data-v0.5.0/amnist_labels.pt
Failed to download (trying next):
HTTP Error 404: Not Found



RuntimeError: Error downloading amnist_labels.pt