In [1]:
import os
import pandas as pd
import numpy as np
import torch
import torchvision.transforms as transforms
from tqdm.auto import tqdm
from PIL import Image
from collections import defaultdict
import math
from pathlib import Path

In [2]:
image_folder = "./"
csvfile = "./dev.csv"

df = pd.read_csv(csvfile)

In [3]:
id_to_labels = {str(row['fname']): row['labels'].split(',') for _, row in df.iterrows()}
unique_labels = sorted(set(label for labels in id_to_labels.values() for label in labels))
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
num_classes = len(label_to_idx)

In [4]:
def pad_num(n, width):
  return str(n).zfill(width)

In [5]:
file_paths = list(Path(image_folder).rglob("*.png"))
max_images = 1000
shard_size = 100
image_tensor = []
label_tensor = []
processed_examples = 0
processed_shards = 0
rel_idxs = []
max_images = min(max_images, len(file_paths))
num_shards = (max_images + shard_size - 1) // shard_size
pad_size = math.log10(num_shards) + 1

In [6]:
def save_chunk(imgs, lbls, idx):
  shard_data = {"imgs": imgs, "lbls": lbls}
  padded_num = pad_num(idx, int(pad_size))
  torch.save(shard_data, os.path.join(".", "processed", f"fsd_{padded_num}.pth"))

In [9]:
transform = transforms.Compose([
  transforms.Resize((224, 224)),
  transforms.ToTensor(),
])

In [10]:
# Given a shard, it returns 2 tensors, one for image and another for the label vector which it also creates 
def process_files(shard_paths):
  img_tensor = []
  lbls_tensor = []
  for filename in shard_paths:
    img_id = filename.stem # get file name
    img_path = os.path.join(image_folder, filename)

    if img_id in id_to_labels:
      img = Image.open(img_path).convert("RGB")
      img = transform(img)
      img_tensor.append(img)

      label_vector = torch.zeros(num_classes, dtype=torch.float32)
      for label in id_to_labels[img_id]:
        if label in label_to_idx:
          label_vector[label_to_idx[label]] = 1
      lbls_tensor.append(label_vector)
  img_tensor = torch.stack(img_tensor)
  lbls_tensor = torch.stack(lbls_tensor)

  class_idxs = [torch.where(lbls_tensor[:, c] == 1)[0] for c in range(num_classes)]
  
  return img_tensor, lbls_tensor, class_idxs

In [11]:
acc_shard_class_idxs = []
class_idxs = []

print(f"Processing {max_images} files")
for shard_idx in tqdm(range(0, max_images, shard_size), desc="Processing chunks"):
  shard_paths = file_paths[shard_idx: min(shard_idx + shard_size, max_images)]
  imgs, lbls, shard_class_idxs = process_files(shard_paths)
  save_chunk(imgs, lbls, processed_shards)
  acc_shard_class_idxs.append(shard_class_idxs) # store a list of precalculated indices which are part of each class
  
  processed_shards += 1 # for shard naming scheme
  
  processed_examples += len(imgs) # update total examples for tracking in rel_idxs
  rel_idxs.append(processed_examples - 1)  # store index of the latest example of this chunk so we can identify which chunk an index is from
  
class_idxs = [torch.cat(shard_class_idxs) for shard_class_idxs in zip(*acc_shard_class_idxs)]
torch.save({
  'rel_idxs': rel_idxs,
  'class_idxs': class_idxs
}, os.path.join("./processed", f"fsd_meta.pth"))

Processing 1000 files


Processing chunks:   0%|          | 0/10 [00:00<?, ?it/s]

In [12]:
print(unique_labels)

['Accelerating_and_revving_and_vroom', 'Accordion', 'Acoustic_guitar', 'Aircraft', 'Alarm', 'Animal', 'Applause', 'Bark', 'Bass_drum', 'Bass_guitar', 'Bathtub_(filling_or_washing)', 'Bell', 'Bicycle', 'Bicycle_bell', 'Bird', 'Bird_vocalization_and_bird_call_and_bird_song', 'Boat_and_Water_vehicle', 'Boiling', 'Boom', 'Bowed_string_instrument', 'Brass_instrument', 'Breathing', 'Burping_and_eructation', 'Bus', 'Buzz', 'Camera', 'Car', 'Car_passing_by', 'Cat', 'Chatter', 'Cheering', 'Chewing_and_mastication', 'Chicken_and_rooster', 'Child_speech_and_kid_speaking', 'Chime', 'Chink_and_clink', 'Chirp_and_tweet', 'Chuckle_and_chortle', 'Church_bell', 'Clapping', 'Clock', 'Coin_(dropping)', 'Computer_keyboard', 'Conversation', 'Cough', 'Cowbell', 'Crack', 'Crackle', 'Crash_cymbal', 'Cricket', 'Crow', 'Crowd', 'Crumpling_and_crinkling', 'Crushing', 'Crying_and_sobbing', 'Cupboard_open_or_close', 'Cutlery_and_silverware', 'Cymbal', 'Dishes_and_pots_and_pans', 'Dog', 'Domestic_animals_and_pets',

In [13]:
meta_file_test = torch.load("./processed/fsd_meta.pth")

  meta_file_test = torch.load("./processed/fsd_meta.pth")
