In [1]:
import pickle
import random
import numpy as np
import torch
import sklearn.model_selection
from tqdm import tqdm

In [2]:
random.seed(0)
SKLEARN_RANDOM_SEED = 0
torch.manual_seed(0)

device = torch.device("cpu")  # torch.device("mps")
BATCH_SIZE = 128
WEIGHT_DECAY = 1e-2
NUM_EPOCHS = 10

# Process data

## Load and sample data from disk

In [3]:
with open('audio_embeddings.pickle', 'rb') as aud_embedding_picklefile:
    unsorted_aud_embeddings = pickle.load(aud_embedding_picklefile)
with open('image_embeddings.pickle', 'rb') as img_embedding_picklefile:
    unsorted_img_embeddings = pickle.load(img_embedding_picklefile)

In [4]:
sample_aud_embedding = unsorted_aud_embeddings[list(unsorted_aud_embeddings.keys())[1]]
sample_img_embedding = unsorted_img_embeddings[list(unsorted_img_embeddings.keys())[1]]
aud_embed_dim = sample_aud_embedding.shape[0]
img_embed_dim = sample_img_embedding.shape[0]

In [20]:
len(unsorted_img_embeddings)

7129

## Sort data

In [5]:
class Character:
    def __init__(self):
        self.audios = np.ndarray((0, sample_aud_embedding.shape[0]))
        self.imgs = np.ndarray((0, sample_img_embedding.shape[0]))

get_name = lambda key: key.split('/')[0]
aud_names = set(get_name(key) for key in unsorted_aud_embeddings)
img_names = set(get_name(key) for key in unsorted_img_embeddings)
characters = {name: Character() for name in aud_names if name in img_names}

train_char_names, test_char_names = sklearn.model_selection.train_test_split(
    list(characters.keys()),
    test_size=0.15, 
    train_size=1-0.15, 
    random_state=SKLEARN_RANDOM_SEED
)
train_characters = {key: val for key, val in characters.items() if key in train_char_names}
test_characters = {key: val for key, val in characters.items() if key in test_char_names}
# Use "not in test_characters" instead of "in train_characters" to allow for images with no matching audio
train_imgs = {key: val for key, val in unsorted_img_embeddings.items() if get_name(key) not in test_characters}
test_imgs = {key: val for key, val in unsorted_img_embeddings.items() if get_name(key) in test_characters}

In [6]:
def np_append(arr1, arr2):
    return np.concatenate((arr1, arr2[np.newaxis, :]), axis=0)

for key, embedding in unsorted_aud_embeddings.items():
    try:
        characters[get_name(key)].audios = np_append(characters[get_name(key)].audios, embedding)
    except KeyError:  # Character for whom there are no images
        pass

for key, embedding in unsorted_img_embeddings.items():
    try:
        characters[get_name(key)].imgs = np_append(characters[get_name(key)].imgs, embedding)
    except KeyError:  # Character for whom there are no audios
        pass

 # Modelling

In [7]:
class NegativeImgFactory:
    def __init__(self, img_embeddings):
        self.img_titles = list(img_embeddings.keys())
        self.img_embeddings = img_embeddings

    def __call__(self, name):
        img_name = random.choice(self.img_titles)
        while get_name(img_name) == name:
            img_name = random.choice(self.img_titles)
        return self.img_embeddings[img_name]

## Data

In [8]:
class TripletIterator:
    def __init__(self, characters, get_negative_img):
        self.characters = characters
        self.character_names = iter(characters.keys())
        self.get_negative_img = get_negative_img

    def __iter__(self):
        return self

    def __next__(self):
        character_name = next(self.character_names)
        character = self.characters[character_name]
        char_aud = character.audios[torch.randint(character.audios.shape[0], (1,))[0]]
        char_img = character.imgs[torch.randint(character.imgs.shape[0], (1,))[0]]
        neg_img = self.get_negative_img(character_name)
        pos_is_first = torch.randint(2, (1,))[0]
        triplet = (
            char_aud, 
            char_img if pos_is_first else neg_img, 
            neg_img if pos_is_first else char_img
        )
        return tuple(map(torch.tensor, triplet)), (1.0-pos_is_first)

class VoiceImgsTripletDataset(torch.utils.data.IterableDataset):
    def __init__(self, characters, get_negative_img):
        self.characters = characters
        self.get_negative_img = get_negative_img
    
    def __iter__(self):
        return TripletIterator(self.characters, self.get_negative_img)

    def __len__(self):
        return len(self.characters.keys())

    @staticmethod
    def get_sample_size():
        aud_size = sample_aud_embedding.shape[0]
        img_size = sample_img_embedding.shape[0]
        return aud_size + img_size*2

In [9]:
train_dataset = VoiceImgsTripletDataset(characters=train_characters, get_negative_img=NegativeImgFactory(train_imgs))
test_dataset = VoiceImgsTripletDataset(characters=test_characters, get_negative_img=NegativeImgFactory(test_imgs))

In [10]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

## Model

In [11]:
class VoiceToImageClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.aud_embedder = torch.nn.Linear(aud_embed_dim, img_embed_dim)

    def forward(self, triplet):
        aud, img1, img2 = triplet
        aud_embed = self.aud_embedder(aud)

        return torch.norm(aud_embed-img1, dim=1) - torch.norm(aud_embed-img2, dim=1)

model = VoiceToImageClassifier().to(device)

In [12]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.NAdam(model.parameters(), weight_decay=WEIGHT_DECAY, decoupled_weight_decay=True)

## Training

In [13]:
def to_mps_device(tensor):
    # For compatibility with Apple Metal MPS devices, we quantize signal to float32
    return tensor.to(torch.float32).to(device)

In [14]:
def eval_model(model, test_dl):
    num_tests = 0
    num_correct = 0
    
    with torch.no_grad():
        for batch in train_dl:
            triplet_cpu, target_cpu = batch
            triplet, target = tuple(map(to_mps_device, triplet_cpu)), to_mps_device(target_cpu)
            num_tests += len(target)
            pred = torch.nn.functional.sigmoid(model(triplet)).round()
            num_correct += torch.sum(pred == target)
        return num_correct / num_tests

In [15]:
for epoch in (pbar := tqdm(range(NUM_EPOCHS))):
    for batch in train_dl:
        triplet_cpu, target_cpu = batch
        triplet, target = tuple(map(to_mps_device, triplet_cpu)), to_mps_device(target_cpu)
        prediction = model(triplet)
        loss = loss_fn(prediction, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        acc = round(eval_model(model, test_dl).item(), 2)
    pbar.set_description(f"acc={round(acc, 2)}, loss={round(loss.item(), 3)}")

acc=0.91, loss=0.56: 100%|██████████████████████| 10/10 [00:17<00:00,  1.74s/it]


In [17]:
print(f"Final identification accuracy: {acc}")

Final identification accuracy: 0.91
