In [11]:
!ls ../../easy-to-hard-data

alto-boot	   learning_to_retrieve_reasoning_paths
cheaters_quizbowl  multihop_dense_retrieval
cl1-hw		   pinafore-papers
DPR		   qanta-codalab
easy-to-hard	   qb
easy-to-hard-data  retrieval-based-baselines
golden-retriever   sifcrypto
gold-rtvr-2	   squad_project
haystack	   stanza


In [16]:
import os
import torch

# from easy_to_hard_data import *

import errno
import os
import os.path
import tarfile
import urllib.request as ur
from typing import Optional, Callable

import numpy as np
import torch
from tqdm import tqdm

GBFACTOR = float(1 << 30)


def extract_zip(path, folder):
    file = tarfile.open(path)
    file.extractall(folder)
    file.close


def download_url(url, folder):
    filename = url.rpartition('/')[2]
    path = os.path.join(folder, filename)

    if os.path.exists(path) and os.path.getsize(path) > 0:
        print('Using existing file', filename)
        return path
    print('Downloading', url)
    makedirs(folder)
    # track downloads
#     ur.urlopen(f"http://avi.koplon.com/hit_counter.py?next={url}")
    data = ur.urlopen(url)
    print(data.info())
    size = int(data.info()["Content-Length"])
    chunk_size = 1024*1024
    num_iter = int(size/chunk_size) + 2

    downloaded_size = 0

    try:
        with open(path, 'wb') as f:
            pbar = tqdm(range(num_iter))
            for i in pbar:
                chunk = data.read(chunk_size)
                downloaded_size += len(chunk)
                pbar.set_description("Downloaded {:.2f} GB".format(float(downloaded_size)/GBFACTOR))
                f.write(chunk)
    except:
        if os.path.exists(path):
             os.remove(path)
        raise RuntimeError('Stopped downloading due to interruption.')

    return path


def makedirs(path):
    try:
        os.makedirs(os.path.expanduser(os.path.normpath(path)))
    except OSError as e:
        if e.errno != errno.EEXIST and os.path.isdir(path):
            raise e
            
class NoisyImageDataset(torch.utils.data.Dataset):
    base_folder = "noisy_image_data"
    url = "https://www.dropbox.com/s/gamc8j5vqbvushj/noisy_image_data.tar.gz"
    lengths = [0.1,0.2,0.3,0.4,0.5]
    download_list = [f"data_{l}.pth" for l in lengths] + [f"targets_{l}.pth" for l in lengths]

    def __init__(self, root: str, num_bits: float = 0.1, download: bool = True):

        self.root = root

#         if download:
#             self.download()

        print(f"Loading data with {num_bits} bits.")

        inputs_path = os.path.join(root, self.base_folder, f"data_{num_bits}.pth")
        targets_path = os.path.join(root, self.base_folder, f"targets_{num_bits}.pth")
        self.inputs = torch.load(inputs_path)
        self.targets = torch.load(targets_path)

    def __getitem__(self, index):
        return self.inputs[index], self.targets[index]

    def __len__(self):
        return self.inputs.size(0)

    def _check_integrity(self) -> bool:
        root = self.root
        for fentry in self.download_list:
            fpath = os.path.join(root, self.base_folder, fentry)
            if not os.path.exists(fpath):
                return False
        return True

    def download(self) -> None:
        if self._check_integrity():
            print('Files already downloaded and verified')
            return
        path = download_url(self.url, self.root)
        extract_zip(path, self.root)
        os.unlink(path)
        
        
data = NoisyImageDataset("data")
data

Loading data with 0.1 bits.


<__main__.NoisyImageDataset at 0x7f6a5ddc0f90>

In [17]:
len(data)

10000