In [1]:
import os
from zipfile import ZipFile, Path
import csv
import pandas as pd
import numpy as np
from PIL import Image
# from skimage
from contextlib import contextmanager


import torch
from torch import nn
from torch.utils import data as tdata
import webdataset as wds

In [2]:
# path of zip file
zip_path = 'archive.zip'

# path of training dir in the zip file
train_path = 'asl_alphabet_train/asl_alphabet_train/'

# file where (path, target will be stored)
info_path = './info.csv'

def get_dirs_and_count(zip_path, dir_path, info_path):
  """Is useful for getting directory list, counts etc

  Args:
      zip_path (str): _description_
      dir_path (str): _description_

  Returns:
      `Tuple[List[str], Dict[str, int], Dict[str, int]]`: \n
      1. List of directories, 2. dictionary that maps those directories to the index in that list, 3. Dictionary that holds `(path, target)`
  """

  if os.path.exists(info_path):
    raise FileExistsError("Can't write to info file. It already exists. Delete it if you want it to be overwritten.")

  with ZipFile(zip_path, 'r') as f:
    path = Path(f, dir_path)

    # stores a `Path` object for each directory in the mentioned `dir_path`
    dirs = [x for x in path.iterdir() if x.is_dir()]

    # stores each file's path to it's target class
    path_to_class = {}

    for i, d in enumerate(dirs):
      # we need to iterate over the files in `d` and write `(file_path, target)`

      for f in d.iterdir():
        path_to_class[f"{train_path}{d.name}/{f.name}"] = i

    # we convert each path like object to the actual directory name
    dirs = [x.name for x in dirs]

    # is like classes_to_idx
    dir_to_i = {d: i for i, d in enumerate(dirs)}

  # below code writes to an info csv in the format `path,target`
  if info_path is not None:
    with open(info_path, 'w') as f:
      w = csv.writer(f)
      w.writerows(path_to_class.items())

  return dirs, dir_to_i, path_to_class


try: 
  print(str(get_dirs_and_count(zip_path, train_path, info_path))[:1000])

except FileExistsError:
  print("File exists")


File exists


In [9]:
class _ImageZipDataset(tdata.Dataset):

  def __init__(self, zip_file: ZipFile, samples, transform=None) -> None:
    self.zip_file = zip_file
    self.samples = samples

    self.transform = transform


  def __getitem__(self, index):
    path, target = self.samples[index]

    with self.zip_file.open(path) as f:
      sample = Image.open(f).convert('RGB')

    if self.transform is not None:
      sample = self.transform(sample)

    return sample, target

  def __len__(self):
    return len(self.samples)


In [4]:
class ImageZipDatasetWrapper(tdata.Dataset):

  def __init__(self, zip_path, dir_path, info_path, transform=None) -> None:

    if not os.path.exists(zip_path):
      raise FileNotFoundError("Provided zip file does not exist at path")

      
    self.zip_path = zip_path
    self.classes, self.classes_to_idx, _ = get_dirs_and_count(
        zip_path, dir_path, info_path)

    if not os.path.exists(info_path):
      raise FileNotFoundError("CSV File does not exist")
      
    self.samples = self.load_csv(info_path)
    self.transform = transform

  def load_csv(self, info_path):
    return pd.read_csv(info_path).to_numpy()

  @contextmanager
  def dataset(self):
    with ZipFile(self.zip_path, 'r') as z:
      res = _ImageZipDataset(
          zip_file=z,
          samples=self.samples,
          transform=self.transform
      )
      yield res

  def __len__(self):
    return len(self.samples)


In [11]:
train_dataset_wrapper = ImageZipDatasetWrapper(zip_path, train_path, info_path)


In [12]:
with train_dataset_wrapper.dataset() as dataset:

  print(dataset[0])


(<PIL.Image.Image image mode=RGB size=200x200 at 0x26C5A20A4F0>, 0)
