# Preprocess Dataset

## Imports

In [None]:
use_google_drive = False

try:
  import google.colab
  from google.colab import drive
  !pip install webdataset
  use_google_drive = True
except Exception:
  pass


In [None]:
import matplotlib.pyplot as plt
import webdataset as wds
from torchvision import transforms
import os
import torch


## Definitions and Parameters

In [None]:
max_image_count = 0 #set 0 for all
write_fulldataset = True
write_splitdatasets = True

model_name = "alexnet" #resnet18, vgg11_bn


In [None]:
if use_google_drive:
  model_load_file = "/content/gdrive/MyDrive/ColabData/amazon/shoes-model-5-singlecolor-rgb.model"
  dataset_file = "file:///content/gdrive/MyDrive/ColabData/amazon/shoes-224-full-3.tar"
  dataset_folder = "/content/gdrive/MyDrive/ColabData/amazon/"

  drive.mount("/content/gdrive")
else:
  model_load_file = "model/shoes-model-5-singlecolor-rgb.model"
  dataset_file = f"file://{os.getcwd()}/dataset/shoes-224-full-3.tar".replace('\\', '/')
  dataset_folder = "dataset\\"

device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
classes = ["black", "white", "gray", "red", "green", "blue", "orange", "purple", "yellow", "pink", "brown", "multicolor"]


## Load Model and Dataset


In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', model_name, pretrained=False, num_classes=2)
model.load_state_dict(torch.load(model_load_file))
model.eval()
convnet = model.to(device)
test_input = torch.randn(1, 3, 224, 224, device=device)

test_output = convnet(test_input)
print(test_output.shape)


In [None]:
def match_metadata(sample):
  image_tensor = transforms.ToTensor()(sample['jpg']).unsqueeze_(0).reshape(1, 3, 224, 224).to(device)

  with torch.no_grad():
    image_output = convnet(image_tensor)
    image_validation = image_output.argmax(dim=1).cpu().numpy()[0]
    validation_perc = image_output.softmax(dim=1).max().cpu().numpy()
    sample["validation"] = bin(image_validation)
    sample["validation_perc"] = bin(round(validation_perc * 100))
  return sample


In [None]:
dataset = (wds.WebDataset(dataset_file)
           .decode("pil", only="jpg")
           .map(match_metadata)
           .shuffle(100)
           )


In [None]:
for item in dataset:
  break

print(f"""{item['__key__']}
class: {classes[int(item['cls'])]}
validation: {'valid' if int(item['validation'],2)==1 else 'invalid'} {int(item['validation_perc'],2)}%
image sort: {int(item['sort'],2)}""")

plt.imshow(item['jpg'])


## Prepare Dataset(s)

In [None]:
def write2TARs(dataset, folder):
  i = 0

  dataset_id = 6
  filefull = os.path.join(folder, f"shoes-224-full-{dataset_id}.tar")
  filetraining = os.path.join(folder, f"shoes-224-training-{dataset_id}.tar")
  filevalidation = os.path.join(folder, f"shoes-224-validation-{dataset_id}.tar")
  filetest = os.path.join(folder, f"shoes-224-test-{dataset_id}.tar")

  invalids, valids, training_size, validation_size, test_size = 0, 0, 0, 0, 0
  with wds.TarWriter(filefull) as full, wds.TarWriter(filetraining) as train, wds.TarWriter(filevalidation) as validation, wds.TarWriter(filetest) as test:
    for item in dataset:
      result = int(item['validation'], 2)

      if result == 0:
        invalids += 1

      elif result == 1:
        if write_fulldataset:
          full.write(item)
          valids += 1

        if write_splitdatasets:
          if i % 10 < 6:
            train.write(item)
            training_size += 1
          elif i % 10 < 8:
            validation.write(item)
            validation_size += 1
          else:
            test.write(item)
            test_size += 1

        i += 1
        if max_image_count > 0 and i >= max_image_count:
          break

  return invalids, valids, training_size, validation_size, test_size


In [None]:
results = write2TARs(dataset, dataset_folder)

print(f"""Validated dataset is written to files:
# of Invalid images: {results[0]}

# of Valid images: {results[1]}
# of Training images: {results[2]}
# of Validation images: {results[3]}
# of Test images: {results[4]}
""")
