# Expert Decision

## Imports

In [None]:
use_google_drive = False

try:
  import google.colab
  from google.colab import drive
  !pip install webdataset
  use_google_drive = True
except Exception:
  pass


In [None]:
import matplotlib.pyplot as plt
import webdataset as wds
from itertools import islice
import os


## Definitions and Parameters

In [None]:
write_dataset_4 = True
write_dataset_5 = True

write_fulldataset = True
write_splitdatasets = True

max_image_count = 0 #set 0 for all


In [None]:
if use_google_drive:
  dataset_file = "file:///content/gdrive/MyDrive/ColabData/amazon/shoes-224-full-3.tar"
  metadata_folder = "/content/gdrive/MyDrive/ColabData/amazon/"
  dataset_folder = "/content/gdrive/MyDrive/ColabData/amazon/"

  drive.mount("/content/gdrive")
else:
  dataset_file = f"file://{os.getcwd()}/dataset/shoes-224-full-3.tar".replace('\\', '/')
  metadata_folder = "metadata\\"
  dataset_folder = "dataset\\"


In [None]:
classes = ["black", "white", "gray", "red", "green", "blue", "orange", "purple", "yellow", "pink", "brown", "multicolor"]


## Helper Functions

In [None]:
def torch_imadd(index, f, img, label, color):
  f.add_subplot(10, 10, index)

  plt.imshow(img)
  plt.axis('off')
  plt.title(label, loc='left', fontdict={"color": color})


In [None]:
def draw_canvas(dataset, include=[], exclude=[], show_class=0):
  f = plt.figure(figsize=(10, 10), dpi=200)
  plt.rcParams.update({'font.size': 4})
  plt.rcParams.update({'lines.linewidth': 0.5})
  plt.subplots_adjust(wspace=-0.6, hspace=0.6)
  index = 1
  images = []

  for jpg, cls, key in dataset:
    if not show_class == int(cls):
      continue

    if key in exclude:
      continue

    if index <= 100 and (len(include) == 0 or (key in include)):
      color = "black"
      label = f"""{index}"""
      torch_imadd(index, f, jpg, label, color)
      index += 1
      images.append(key)

  return images


In [None]:
def addtolist(inputlist: list, *items):
  for item in items:
    if not list[item - 1] in inputlist:
      inputlist.append(list[item - 1])


def savelisttofile(inputlist: list, filename):
  with open(f'{metadata_folder}{filename}', 'w') as fp:
    for item in inputlist:
      fp.write(f"{item}\n")


def loadlistfromfile(outputlist: list, filename):
  with open(f'{metadata_folder}{filename}', 'r') as fp:
    for item in fp.readlines():
      outputlist.append(item.rstrip())


## Load Dataset

In [None]:
dataset = (wds.WebDataset(dataset_file)
           .decode("pil", only="jpg")
           .to_tuple("jpg", "cls", "__key__")
           )


## Draw On Canvas and Verify Classification

In [None]:
#1. reset on every beginning on every class
cls = 11
invalid = []
misclassified = []
list = []


In [None]:
#2. draw the new 100 images, if all of them is good go to 4, else goto 3
exclude = invalid + misclassified
list = draw_canvas(dataset, exclude=exclude, show_class=cls)


In [None]:
#3. remove the invalid or misclassified images by adding them to list, goto 2
#addtolist(invalid, 100)
#addtolist(misclassified,100)


addtolist(misclassified, 98, 100)


In [None]:
#4. persist list to disk and finish
savelisttofile(list, f"expert_{cls}_valid.txt")
savelisttofile(invalid, f"expert_{cls}_invalid.txt")
savelisttofile(misclassified, f"expert_{cls}_misclassified.txt")


In [None]:
#5. draw final lists to canvas
_ = draw_canvas(dataset, include=invalid, show_class=cls)
_ = draw_canvas(dataset, include=misclassified, show_class=cls)


## Combine Outputs

In [None]:
combined = {}
valids = []
invalids = []
misclassifieds = []

for cls in range(12):
  loadlistfromfile(valids, f"expert_{cls}_valid.txt")
  loadlistfromfile(invalids, f"expert_{cls}_invalid.txt")
  loadlistfromfile(misclassifieds, f"expert_{cls}_misclassified.txt")

  for item in valids:
    combined[item] = 'valid'

  for item in invalids:
    combined[item] = 'invalid'

  for item in misclassifieds:
    combined[item] = 'misclassified'


savelisttofile(valids, f"expert_all_valid.txt")
savelisttofile(invalids, f"expert_all_invalid.txt")
savelisttofile(misclassifieds, f"expert_all_misclassified.txt")


## Load Previously Saved Decisions

In [None]:
combined = {}
valids = []
invalids = []
misclassifieds = []

loadlistfromfile(valids, f"expert_all_valid.txt")
loadlistfromfile(invalids, f"expert_all_invalid.txt")
loadlistfromfile(misclassifieds, f"expert_all_misclassified.txt")


for item in valids:
  combined[item] = 'valid'

for item in invalids:
  combined[item] = 'invalid'

for item in misclassifieds:
  combined[item] = 'misclassified'


## Prepare Dataset(s)

In [None]:
def match_metadata(sample: dict):
  key = sample["__key__"]
  if key in combined.keys():
    sample["__decision__"] = combined[key]
  else:
    sample["__decision__"] = None

  return sample

def convert_decision_to_cls(sample: dict):
  if sample["__decision__"] == 'invalid':
    sample["cls"] = 0
  elif sample["__decision__"] == 'valid':
    sample["cls"] = 1
  else:
    raise Exception("decision not defined")

  return sample


In [None]:
dataset_4 = (wds.WebDataset(dataset_file)
             .map(match_metadata)
             .select(predicate=lambda r: r["__decision__"] == 'valid')
             .decode("pil", only="jpg")
             .to_tuple("__key__", "jpg", "cls", "sort")
             .shuffle(100)
             )

dataset_5 = (wds.WebDataset(dataset_file)
             .map(match_metadata)
             .select(predicate=lambda r: r["__decision__"] == 'invalid' or r["__decision__"] == 'valid')
             .map(convert_decision_to_cls)
             .decode("pil", only="jpg")
             .to_tuple("__key__", "jpg", "cls", "sort")
             .shuffle(100)
             )


In [None]:
for key, jpg, cls, sort in dataset_4:
  break

print(classes[int(cls)])
print(f"sort:{int(sort, 2)}")
plt.imshow(jpg)


In [None]:
for key, jpg, cls, sort in dataset_5:
  break

print("valid" if int(cls) == 1 else "invalid")
plt.imshow(jpg)


## Prepare Dataset(s)

In [None]:
def write2TARs(dataset_id, dataset, folder):

  filefull = os.path.join(folder, f"shoes-224-full-{dataset_id}.tar")
  filetraining = os.path.join(folder, f"shoes-224-training-{dataset_id}.tar")
  filevalidation = os.path.join(folder, f"shoes-224-validation-{dataset_id}.tar")
  filetest = os.path.join(folder, f"shoes-224-test-{dataset_id}.tar")

  full_size, training_size, validation_size, test_size = 0, 0, 0, 0
  with wds.TarWriter(filefull) as full, wds.TarWriter(filetraining) as train, wds.TarWriter(filevalidation) as validation, wds.TarWriter(filetest) as test:
    i = 0
    for item in dataset:

      item = {
          "__key__": item[0],
          "jpg": item[1],
          "cls": item[2],
          "sort": item[3]
      }

      if write_fulldataset and dataset_id != 5: #dataset 5 is only used for CNN
        full.write(item)
        full_size += 1

      if write_splitdatasets:
        if i % 10 < 6:
          train.write(item)
          training_size += 1
        elif i % 10 < 8:
          validation.write(item)
          validation_size += 1
        else:
          test.write(item)
          test_size += 1

      i += 1
      if max_image_count > 0 and i >= max_image_count:
        break

  return full_size, training_size, validation_size, test_size


In [None]:
if write_dataset_4:
  results = write2TARs(4, dataset_4, dataset_folder)

  print(f"""Dataset 4 is written to files:

  # of Total images: {results[0]}

  # of Training images: {results[1]}
  # of Validation images: {results[2]}
  # of Test images: {results[3]}
  """)


if write_dataset_5:
  results = write2TARs(5, dataset_5, dataset_folder)

  print(f"""Dataset 5 is written to files:

  # of Training images: {results[1]}
  # of Validation images: {results[2]}
  # of Test images: {results[3]}
  """)
