# Extract class labels for PlacesAudio from Places205 image paths

The audio-visual embeddings models of Harwath et al. that use PlacesAudio

- [NIPS 2016 model](https://papers.nips.cc/paper/6186-unsupervised-learning-of-spoken-language-with-visual-context.pdf)
- [ACL 2017 model](https://arxiv.org/pdf/1701.07481.pdf)
- [DAVEnet model](https://github.com/dharwath/DAVEnet-pytorch)

are difficult to train from scratch. Initial warm-up could help, but the audio data has no labels. However, the images paired with audio captions have been organized to classes that are visible in their path (e.g. `c/cottage_garden/gsun_c43911d6f8ff4efb5e99dc6ac7e47a8e.jpg`). In this notebook the aim is to extract the classes from the image paths to get classification labels for the audio. Each audio caption has one corresponding image and thus one label.

First we define the function that can do the extraction.

In [2]:
import json
import re           # Regexps for pattern matching
import warnings     # Place warnings if any anomalies encountered

# Function that does the extraction from an input file and writes the result to an output file
# The 205 classes are supplied in the classes parameter
def extract_classes(input_file_path, output_file_path, classes): 
    
    print("Input json is {}".format(input_file_path))
    with open(input_file_path) as f:
        inputs = json.load(f)
        
    for i in range(len(inputs["data"])):
        # Replace key "image" with "label"
        inputs["data"][i]["label"] = inputs["data"][i].pop("image")
        # Find the class from the image path using regexp
        match = re.search('[a-z]?\/(.+)\/gsun', inputs["data"][i]["label"])
        if match:
            if match.group(1) in classes:
                # Use an index number instead of the word label
                inputs["data"][i]["label"] = classes.index(match.group(1))
            else:
                warnings.warn("Did not find label '%s' among Places205 classes.".format(match.group(1)))
        else:
            warnings.warn("Matching regexp to '%s' failed".format(inputs["data"][i]["label"]))
    
    print("Writing output to {}\n".format(output_file_path))
    with open(output_file_path, 'w') as f:
        json.dump(inputs, f, indent=4)

Then we perform the extraction for different subsets of the PlacesAudio400k dataset. Each paper has one plus there is a shared validation dataset:
 - NIPS, ~116k samples, `nips_train.json`
 - ACL, ~214k samples, `acl_train.json`
 - DAVEnet, full PlacesAudio dataset of ~402k samples, `train.json` 
 - Validation data, 1k samples), `val.json`

In [3]:
data_path = "/teamwork/t40511_asr/c/PlacesAudio400k/PlacesAudio_400k_distro/"
#input_and_output_files = [("metadata/nips_train.json", "metadata/classification_nips_train.json"),
#                          ("metadata/nips_val.json", "metadata/classification_nips_val.json"),
#                          ("metadata/acl_train.json", "metadata/classification_acl_train.json"),
#                          ("metadata/acl_val.json", "metadata/classification_acl_val.json"),
#                          ("metadata/train.json", "metadata/classification_train.json"),
#                          ("metadata/val.json", "metadata/classification_val.json")]
input_and_output_files = [("metadata/train1k.json", "metadata/classification_train1k.json")]


# Get the 205 class names from a file, one class per line
classes_file = "metadata/Places205_classes.txt"

with open(data_path + classes_file) as f:
    classes = f.read().splitlines()

for (input_file, output_file) in input_and_output_files:
    extract_classes(data_path + input_file,
                    data_path + output_file,
                    classes)

Input json is /teamwork/t40511_asr/c/PlacesAudio400k/PlacesAudio_400k_distro/metadata/train1k.json
Writing output to /teamwork/t40511_asr/c/PlacesAudio400k/PlacesAudio_400k_distro/metadata/classification_train1k.json

