# Preprocess Dataset

## Imports

In [None]:
use_google_drive = False

try:
  import google.colab
  from google.colab import drive
  !pip install webdataset
  use_google_drive = True
except Exception:
  pass


In [None]:
import matplotlib.pyplot as plt
import webdataset as wds
from torchvision import transforms
import csv
from torchvision.transforms.functional import pad
import os


## Definitions and Parameters

In [None]:
write_fulldataset = True
write_splitdatasets = True
max_image_count = 0 #set 0 for all


In [None]:
if use_google_drive:
  metadata_source = "/content/gdrive/MyDrive/ColabData/amazon/metadata.csv"
  tar_dataset = "file:///content/gdrive/MyDrive/ColabData/amazon/abo-images-small.tar"
  dataset_folder = "/content/gdrive/MyDrive/ColabData/amazon/"

  drive.mount("/content/gdrive")
else:
  metadata_source = "metadata\\metadata.csv"
  tar_dataset = f"file://{os.getcwd()}/dataset/abo-images-small.tar".replace('\\', '/')
  dataset_folder = "dataset\\"


## Helper Functions

In [None]:
classes = ["black", "white", "gray", "red", "green", "blue", "orange", "purple", "yellow", "pink", "brown", "multicolor"]

def fromcls(cls):
  return classes[cls]

def tocls(cls):
  if cls in classes:
    return classes.index(cls)
  else:
    return None


print(tocls("green"))
print(fromcls(tocls("green")))
print(tocls("red"))
print(fromcls(tocls("red")))


In [None]:
def match_metadata(sample: dict):
  key = sample["__key__"]
  if key in metadata.keys():
    sample["__item_id__"] = metadata[key][0]
    sample["__color__"] = metadata[key][1]
    sample["__extracted_color__"] = metadata[key][2]
    sample["__valid_color__"] = metadata[key][3]
    sample["__product_type__"] = metadata[key][4]
    sample["__image_id__"] = metadata[key][5]
    sample["__country__"] = metadata[key][6]
    sample["sort"] = int(metadata[key][7])
    sample["cls"] = tocls(metadata[key][3])
  else:
    sample["__item_id__"] = None
    sample["__color__"] = None
    sample["__extracted_color__"] = None
    sample["__valid_color__"] = 'not-valid'
    sample["__product_type__"] = None
    sample["__image_id__"] = None
    sample["__country__"] = None
    sample["sort"] = 99
    sample["cls"] = None

  return sample


In [None]:
class SquarePad(object):
  def __init__(self, fill=0, padding_mode='constant'):
    self.fill = fill
    self.padding_mode = padding_mode

  def __call__(self, img):
    if self.fill == 'auto':
      fill = img.getpixel((0, 0))
    else:
      fill = self.fill
    return pad(img, self.get_padding(img), fill, self.padding_mode)

  def __repr__(self):
    return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.format(self.fill, self.padding_mode)

  def get_padding(self, image):
    (w, h) = image.size
    if w >= h:
      h_padding = w - h
      w_padding = 0
    else:
      h_padding = 0
      w_padding = h - w
    return (int(w_padding / 2), int(h_padding / 2))


## Load Metadata and Dataset

In [None]:
metadata = {}

with open(metadata_source, newline='') as metadatacsv:
  temp = csv.reader(metadatacsv, delimiter=',')
  for line in temp:
    metadata[line[5]] = line


In [None]:
preproc = transforms.Compose([
    SquarePad(padding_mode='constant', fill='auto'),
    transforms.Resize(224)
])

dataset = (wds.WebDataset(tar_dataset)
           .map(match_metadata)
           .select(predicate=lambda r: r["__product_type__"] == 'SHOES' and r["__valid_color__"] != 'not-valid')
           .decode("pil", only="jpg")
           .to_tuple("__key__", "jpg", "cls", "sort")
           .map_tuple(lambda a: a, preproc,)
           )


In [None]:
for key, jpg, cls, sort in dataset:
  break

print(classes[cls])
plt.imshow(jpg)


## Prepare Dataset(s)

In [None]:
def write2TARs(dataset, folder):

  dataset_id = 3
  filefull = os.path.join(folder, f"shoes-224-full-{dataset_id}.tar")
  filetraining = os.path.join(folder, f"shoes-224-training-{dataset_id}.tar")
  filevalidation = os.path.join(folder, f"shoes-224-validation-{dataset_id}.tar")
  filetest = os.path.join(folder, f"shoes-224-test-{dataset_id}.tar")

  full_size, training_size, validation_size, test_size = 0, 0, 0, 0
  with wds.TarWriter(filefull) as full, wds.TarWriter(filetraining) as train, wds.TarWriter(filevalidation) as validation, wds.TarWriter(filetest) as test:
    i = 0
    for item in dataset:

      item = {
          "__key__": item[0],
          "jpg": item[1],
          "cls": item[2],
          "sort": bin(item[3])
      }

      if write_fulldataset:
        full.write(item)
        full_size += 1

      if write_splitdatasets:
        if i % 10 < 6:
          train.write(item)
          training_size += 1
        elif i % 10 < 8:
          validation.write(item)
          validation_size += 1
        else:
          test.write(item)
          test_size += 1

      i += 1
      if max_image_count > 0 and i >= max_image_count:
        break

  return full_size, training_size, validation_size, test_size


In [None]:
results = write2TARs(dataset, dataset_folder)

print(f"""Dataset is written to files:

# of Total images: {results[0]}

# of Training images: {results[1]}
# of Validation images: {results[2]}
# of Test images: {results[3]}""")
