In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
device = 'cuda'
audio_encoder_pretrain = "facebook/wav2vec2-xls-r-300m"
clip_pretrain = "openai/clip-vit-base-patch32"

In [3]:
from audio import AudioEncoder
from text import TextEncoder

audio_encoder = AudioEncoder(pretrain_name=audio_encoder_pretrain).to(device)
text_encoder = TextEncoder(pretrain_name=clip_pretrain).to(device)
# audio_encoder = AudioEncoder(pretrain_name=audio_encoder_pretrain)
# text_encoder = TextEncoder(pretrain_name=clip_pretrain)

In [4]:
from dataset import AudiocapsDataset


val = AudiocapsDataset(
    dataset_dir='data/audiocaps',
    part='val',
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
)

train = AudiocapsDataset(
    dataset_dir='data/audiocaps',
    part='train',
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
)


Audio embeddings exists.
Text embeddings exists.
Audio embeddings exists.
Text embeddings exists.


In [23]:
from typing import Any

import torch
import numpy as np
from torch.utils.data import DataLoader


def collate(items: list[dict, Any]) -> dict[str, Any]:
    audio = [item['audio'] for item in items]
    text = [item['text'] for item in items]
    path = [item['path'] for item in items]
    audio_embedding_lenght = torch.tensor([item['audio_embedding'].shape[0] for item in items])
    audio_embedding = []
    for i, item in enumerate(items):
        audio_embedding.append(torch.nn.functional.pad(torch.tensor(item['audio_embedding']), (0, 0, 0, max(audio_embedding_lenght) - audio_embedding_lenght[i])))
    audio_embedding = torch.stack(audio_embedding)

    text_embedding = torch.tensor(np.array([item['text_embedding'] for item in items]))
    return dict(
        audio=audio,
        text=text,
        path=path,
        audio_embedding=audio_embedding,
        audio_embedding_lenght=audio_embedding_lenght,
        text_embedding=text_embedding
    )


batch_size = 64

train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=6, collate_fn=collate)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6, collate_fn=collate)

In [31]:
import torch
from torch.nn.modules.loss import _Loss
from torch.nn import functional as F


class ContrastiveLoss(_Loss):
    def __init__(self, t: float = 0.1, reduction: str = "mean") -> None:
        super().__init__(reduction=reduction)
        self.t = t

    def forward(self, audio: torch.Tensor, text: torch.Tensor) -> torch.Tensor:
        audio = F.normalize(audio, p=2, dim=-1)
        text = F.normalize(text, p=2, dim=-1)
        logits = (audio @ text.T) / self.t

        target = torch.arange(audio.size(0), device=audio.device)
        loss_audio = F.cross_entropy(logits, target, reduction=self.reduction)
        loss_text = F.cross_entropy(logits.T, target, reduction=self.reduction)
        return (loss_audio + loss_text) / 2

In [33]:
example_audio = torch.randn(batch_size, 512)
example_text = torch.randn(batch_size, 512)

loss = ContrastiveLoss()
print(loss(example_audio, example_text).item())
print(loss(example_text, example_text).item())

4.270759582519531
0.0031483927741646767


In [34]:
from torch.optim import Optimizer
from torch.utils.data import Dataset
from tqdm import tqdm

from audio import AudioProjector
from metric import KNNMetric


def train_projector(
    model: AudioProjector,
    optimizer: Optimizer,
    criterion: ContrastiveLoss,
    train_loader: DataLoader,
    val_loader: DataLoader,
    metric_data: Dataset,
    metric: KNNMetric,
    epochs: int,
    device: str
) -> None:
    for epoch in range(1, epochs + 1):
        model.train()
        avg_train_loss = 0.
        avg_val_loss = 0.
        for batch in tqdm(train_loader):
            audio_embedding = batch['audio_embedding'].to(device)
            audio_embedding_lenght = batch['audio_embedding_lenght'].to(device)
            text_embedding = batch['text_embedding'].to(device)
            optimizer.zero_grad()
            audio_emb = model(audio_embedding, audio_embedding_lenght)
            loss = criterion(audio_emb, text_embedding)
            loss.backward()
            optimizer.step()
            avg_train_loss += loss.item()
        avg_train_loss /= len(train_loader)
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                audio_embedding = batch['audio_embedding'].to(device)
                audio_embedding_lenght = batch['audio_embedding_lenght'].to(device)
                text_embedding = batch['text_embedding'].to(device)
                audio_emb = model(audio_embedding, audio_embedding_lenght)
                loss = criterion(audio_emb, text_embedding)
                avg_val_loss += loss.item()
            avg_val_loss /= len(val_loader)
        
        with torch.no_grad():
            query: list[tuple[int, torch.Tensor]] = []
            key: list[tuple[int, torch.Tensor]] = []
            for item in metric_data:
                audio_embedding = torch.tensor(item['audio_embedding'], device=device).unsqueeze(0)
                audio_embedding_lenght = torch.tensor([audio_embedding.shape[1]], device=device)
                audio_emb = model(audio_embedding, audio_embedding_lenght).squeeze(0)

                audio_id = int(item['path'].split('/')[-1].split('.')[0])
                query.append(
                    (
                        audio_id,
                        audio_emb.detach().cpu()
                    )
                )
                key.append(
                    (
                        audio_id,
                        torch.tensor(item['text_embedding']).squeeze(0)
                    )
                )
            metric_value = metric(query, key)
        print(f'Epoch {epoch} train loss: {avg_train_loss:.5f}')
        print(f'Epoch {epoch} val loss: {avg_val_loss:.5f}')
        print(f'Epoch {epoch} metric: {metric_value:.5f}')

In [35]:
small_val = AudiocapsDataset(
    dataset_dir='data/audiocaps',
    part='val',
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
)

small_val.df = small_val.df[small_val.df['uniq_id'] <= 100]

metric = KNNMetric(k=4)

Audio embeddings exists.
Text embeddings exists.


In [37]:
projector = AudioProjector(in_features=512, out_features=512).to(device)
optimizer = torch.optim.Adam(projector.parameters(), lr=3e-4)
loss = ContrastiveLoss().to(device=device)

train_projector(projector, optimizer, loss, train_loader, val_loader, small_val, metric, epochs=60, device=device)

  0%|          | 0/774 [00:00<?, ?it/s]

100%|██████████| 774/774 [01:11<00:00, 10.78it/s]


Epoch 1 train loss: 3.22229
Epoch 1 val loss: 2.98363
Epoch 1 metric: 0.18069


100%|██████████| 774/774 [01:12<00:00, 10.72it/s]


Epoch 2 train loss: 2.96706
Epoch 2 val loss: 2.86507
Epoch 2 metric: 0.21040


100%|██████████| 774/774 [01:12<00:00, 10.70it/s]


Epoch 3 train loss: 2.86555
Epoch 3 val loss: 2.81672
Epoch 3 metric: 0.23762


100%|██████████| 774/774 [01:11<00:00, 10.78it/s]


Epoch 4 train loss: 2.79790
Epoch 4 val loss: 2.78240
Epoch 4 metric: 0.24505


100%|██████████| 774/774 [01:12<00:00, 10.72it/s]


Epoch 5 train loss: 2.74765
Epoch 5 val loss: 2.77954
Epoch 5 metric: 0.23267


100%|██████████| 774/774 [01:11<00:00, 10.80it/s]


Epoch 6 train loss: 2.71030
Epoch 6 val loss: 2.76489
Epoch 6 metric: 0.22525


100%|██████████| 774/774 [01:10<00:00, 10.99it/s]


Epoch 7 train loss: 2.67243
Epoch 7 val loss: 2.74403
Epoch 7 metric: 0.23020


100%|██████████| 774/774 [01:11<00:00, 10.79it/s]


Epoch 8 train loss: 2.64227
Epoch 8 val loss: 2.74259
Epoch 8 metric: 0.24257


100%|██████████| 774/774 [01:11<00:00, 10.85it/s]


Epoch 9 train loss: 2.61148
Epoch 9 val loss: 2.72603
Epoch 9 metric: 0.25000


100%|██████████| 774/774 [01:12<00:00, 10.74it/s]


Epoch 10 train loss: 2.58825
Epoch 10 val loss: 2.71967
Epoch 10 metric: 0.24257


100%|██████████| 774/774 [01:12<00:00, 10.74it/s]


Epoch 11 train loss: 2.56394
Epoch 11 val loss: 2.73639
Epoch 11 metric: 0.25495


100%|██████████| 774/774 [01:10<00:00, 10.94it/s]


Epoch 12 train loss: 2.54293
Epoch 12 val loss: 2.75354
Epoch 12 metric: 0.26980


100%|██████████| 774/774 [01:11<00:00, 10.77it/s]


Epoch 13 train loss: 2.52424
Epoch 13 val loss: 2.75189
Epoch 13 metric: 0.25000


100%|██████████| 774/774 [01:11<00:00, 10.82it/s]


Epoch 14 train loss: 2.50320
Epoch 14 val loss: 2.74787
Epoch 14 metric: 0.25000


100%|██████████| 774/774 [01:11<00:00, 10.75it/s]


Epoch 15 train loss: 2.48541
Epoch 15 val loss: 2.76978
Epoch 15 metric: 0.25248


100%|██████████| 774/774 [01:12<00:00, 10.75it/s]


Epoch 16 train loss: 2.46785
Epoch 16 val loss: 2.73419
Epoch 16 metric: 0.25495


100%|██████████| 774/774 [01:11<00:00, 10.86it/s]


Epoch 17 train loss: 2.45368
Epoch 17 val loss: 2.76071
Epoch 17 metric: 0.24752


100%|██████████| 774/774 [01:11<00:00, 10.75it/s]


Epoch 18 train loss: 2.43637
Epoch 18 val loss: 2.74535
Epoch 18 metric: 0.25495


100%|██████████| 774/774 [01:12<00:00, 10.70it/s]


Epoch 19 train loss: 2.42232
Epoch 19 val loss: 2.74528
Epoch 19 metric: 0.26733


100%|██████████| 774/774 [01:12<00:00, 10.67it/s]


Epoch 20 train loss: 2.40952
Epoch 20 val loss: 2.75399
Epoch 20 metric: 0.26238


100%|██████████| 774/774 [01:11<00:00, 10.87it/s]


Epoch 21 train loss: 2.39459
Epoch 21 val loss: 2.74914
Epoch 21 metric: 0.26238


100%|██████████| 774/774 [01:11<00:00, 10.89it/s]


Epoch 22 train loss: 2.37925
Epoch 22 val loss: 2.76878
Epoch 22 metric: 0.26980


100%|██████████| 774/774 [01:09<00:00, 11.17it/s]


Epoch 23 train loss: 2.36860
Epoch 23 val loss: 2.75646
Epoch 23 metric: 0.25743


100%|██████████| 774/774 [01:11<00:00, 10.76it/s]


Epoch 24 train loss: 2.35459
Epoch 24 val loss: 2.76626
Epoch 24 metric: 0.24010


100%|██████████| 774/774 [01:11<00:00, 10.83it/s]


Epoch 25 train loss: 2.34439
Epoch 25 val loss: 2.76217
Epoch 25 metric: 0.25743


100%|██████████| 774/774 [01:11<00:00, 10.80it/s]


Epoch 26 train loss: 2.33408
Epoch 26 val loss: 2.78210
Epoch 26 metric: 0.25990


100%|██████████| 774/774 [01:11<00:00, 10.85it/s]


Epoch 27 train loss: 2.32150
Epoch 27 val loss: 2.76912
Epoch 27 metric: 0.24010


100%|██████████| 774/774 [01:12<00:00, 10.64it/s]


Epoch 28 train loss: 2.31224
Epoch 28 val loss: 2.78944
Epoch 28 metric: 0.26733


100%|██████████| 774/774 [01:12<00:00, 10.63it/s]


Epoch 29 train loss: 2.29977
Epoch 29 val loss: 2.78950
Epoch 29 metric: 0.25248


100%|██████████| 774/774 [01:13<00:00, 10.59it/s]


Epoch 30 train loss: 2.28861
Epoch 30 val loss: 2.79972
Epoch 30 metric: 0.27475


 82%|████████▏ | 636/774 [00:59<00:12, 10.86it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e544c1c8ea0>
Traceback (most recent call last):
  File "/home/anuiel/Remote/Anuiel/multimodal-dz1/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/anuiel/Remote/Anuiel/multimodal-dz1/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
 82%|████████▏ | 638/774 [00:59<00:11, 11.82it/s]  ^^^^^^^^^^^^
  File "/home/anuiel/.pyenv/versions/3.12.7/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e544c1c8ea0>
Traceback (most recent call last):
  File "/home/anuiel/R

Epoch 31 train loss: 2.28197
Epoch 31 val loss: 2.79704
Epoch 31 metric: 0.26733


  2%|▏         | 14/774 [00:02<01:05, 11.54it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e544c1c8ea0>
Traceback (most recent call last):
  File "/home/anuiel/Remote/Anuiel/multimodal-dz1/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/anuiel/Remote/Anuiel/multimodal-dz1/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/anuiel/.pyenv/versions/3.12.7/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
  2%|▏         | 16/774 [00:02<01:01, 12.33it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7e544c1c8ea0>
Traceback (most recent call last):
  File "/home/anuie

Epoch 32 train loss: 2.27214
Epoch 32 val loss: 2.80397
Epoch 32 metric: 0.27228


100%|██████████| 774/774 [01:10<00:00, 11.00it/s]


Epoch 33 train loss: 2.26378
Epoch 33 val loss: 2.81001
Epoch 33 metric: 0.22525


100%|██████████| 774/774 [01:10<00:00, 10.99it/s]


Epoch 34 train loss: 2.25277
Epoch 34 val loss: 2.82002
Epoch 34 metric: 0.22030


100%|██████████| 774/774 [01:10<00:00, 11.00it/s]


Epoch 35 train loss: 2.24274
Epoch 35 val loss: 2.83186
Epoch 35 metric: 0.23267


100%|██████████| 774/774 [01:09<00:00, 11.13it/s]


Epoch 36 train loss: 2.23673
Epoch 36 val loss: 2.83014
Epoch 36 metric: 0.23515


100%|██████████| 774/774 [01:09<00:00, 11.17it/s]


Epoch 37 train loss: 2.23083
Epoch 37 val loss: 2.82729
Epoch 37 metric: 0.23762


100%|██████████| 774/774 [01:09<00:00, 11.12it/s]


Epoch 38 train loss: 2.22261
Epoch 38 val loss: 2.82005
Epoch 38 metric: 0.23020


100%|██████████| 774/774 [01:08<00:00, 11.25it/s]


Epoch 39 train loss: 2.20895
Epoch 39 val loss: 2.84312
Epoch 39 metric: 0.21535


100%|██████████| 774/774 [01:09<00:00, 11.09it/s]


Epoch 40 train loss: 2.20151
Epoch 40 val loss: 2.84216
Epoch 40 metric: 0.23267


100%|██████████| 774/774 [01:09<00:00, 11.19it/s]


Epoch 41 train loss: 2.19491
Epoch 41 val loss: 2.83950
Epoch 41 metric: 0.26238


100%|██████████| 774/774 [01:09<00:00, 11.17it/s]


Epoch 42 train loss: 2.19080
Epoch 42 val loss: 2.83329
Epoch 42 metric: 0.24010


100%|██████████| 774/774 [01:08<00:00, 11.28it/s]


Epoch 43 train loss: 2.18157
Epoch 43 val loss: 2.85607
Epoch 43 metric: 0.18069


100%|██████████| 774/774 [01:09<00:00, 11.10it/s]


Epoch 44 train loss: 2.17642
Epoch 44 val loss: 2.82367
Epoch 44 metric: 0.23515


100%|██████████| 774/774 [01:09<00:00, 11.06it/s]


Epoch 45 train loss: 2.16770
Epoch 45 val loss: 2.83741
Epoch 45 metric: 0.23020


100%|██████████| 774/774 [01:09<00:00, 11.14it/s]


Epoch 46 train loss: 2.15845
Epoch 46 val loss: 2.83600
Epoch 46 metric: 0.22525


100%|██████████| 774/774 [01:09<00:00, 11.16it/s]


Epoch 47 train loss: 2.15405
Epoch 47 val loss: 2.85254
Epoch 47 metric: 0.23762


100%|██████████| 774/774 [01:09<00:00, 11.17it/s]


Epoch 48 train loss: 2.14804
Epoch 48 val loss: 2.85434
Epoch 48 metric: 0.21782


100%|██████████| 774/774 [01:08<00:00, 11.24it/s]


Epoch 49 train loss: 2.13635
Epoch 49 val loss: 2.85807
Epoch 49 metric: 0.23020


100%|██████████| 774/774 [01:08<00:00, 11.23it/s]


Epoch 50 train loss: 2.13799
Epoch 50 val loss: 2.84906
Epoch 50 metric: 0.23762


100%|██████████| 774/774 [01:08<00:00, 11.25it/s]


Epoch 51 train loss: 2.12632
Epoch 51 val loss: 2.84486
Epoch 51 metric: 0.22030


100%|██████████| 774/774 [01:09<00:00, 11.10it/s]


Epoch 52 train loss: 2.12281
Epoch 52 val loss: 2.86263
Epoch 52 metric: 0.23515


100%|██████████| 774/774 [01:08<00:00, 11.27it/s]


Epoch 53 train loss: 2.11831
Epoch 53 val loss: 2.87189
Epoch 53 metric: 0.21535


100%|██████████| 774/774 [01:09<00:00, 11.20it/s]


Epoch 54 train loss: 2.11196
Epoch 54 val loss: 2.89498
Epoch 54 metric: 0.21040


100%|██████████| 774/774 [01:10<00:00, 10.99it/s]


Epoch 55 train loss: 2.10463
Epoch 55 val loss: 2.88399
Epoch 55 metric: 0.21287


100%|██████████| 774/774 [01:11<00:00, 10.85it/s]


Epoch 56 train loss: 2.09869
Epoch 56 val loss: 2.89213
Epoch 56 metric: 0.23020


100%|██████████| 774/774 [01:10<00:00, 10.93it/s]


Epoch 57 train loss: 2.09511
Epoch 57 val loss: 2.88719
Epoch 57 metric: 0.23515


100%|██████████| 774/774 [01:10<00:00, 11.00it/s]


Epoch 58 train loss: 2.08712
Epoch 58 val loss: 2.88628
Epoch 58 metric: 0.22525


100%|██████████| 774/774 [01:10<00:00, 10.96it/s]


Epoch 59 train loss: 2.08982
Epoch 59 val loss: 2.88453
Epoch 59 metric: 0.22277


100%|██████████| 774/774 [01:09<00:00, 11.11it/s]


Epoch 60 train loss: 2.08140
Epoch 60 val loss: 2.89051
Epoch 60 metric: 0.20792


In [3]:
from queue import PriorityQueue

import torch
from IPython.display import Audio

from dataset import AudiocapsDataset
from text import TextEncoder
from audio import AudioProjector, AudioEncoder

device = 'cuda' if torch.cuda.is_available() else 'cpu'
audio_encoder_pretrain = "facebook/wav2vec2-xls-r-300m"
clip_pretrain = "openai/clip-vit-base-patch32"


text_encoder = TextEncoder(clip_pretrain).to(device)
audio_encoder = AudioEncoder(audio_encoder_pretrain)
projector = AudioProjector(in_features=512, out_features=512).to(device)
projector.load_state_dict(torch.load('projector.pth'))

val = AudiocapsDataset(
    dataset_dir='data/audiocaps',
    part='val',
    audio_encoder=audio_encoder,
    text_encoder=text_encoder,
)

Audio embeddings exists.
Text embeddings exists.


  projector.load_state_dict(torch.load('projector.pth'))


In [6]:
from torch.nn import functional as F

def search_for_audio(text: str, data: AudiocapsDataset, top_k: int = 4):
    text_embedding = text_encoder(text)
    text_embedding = text_embedding.squeeze(0)
    text_embedding = F.normalize(text_embedding, p=2, dim=-1).detach().cpu()

    queue = PriorityQueue(maxsize=top_k)

    # Fix for many audio in data
    path_set = set()
    for item in data:
        if item['path'] in path_set:
            continue
        path_set.add(item['path'])
        audio_embedding = torch.tensor(item['audio_embedding']).to(device).unsqueeze(0)
        audio_embedding_lenght = torch.tensor([audio_embedding.shape[1]]).to(device)
        with torch.no_grad():
            audio_emb = projector(audio_embedding, audio_embedding_lenght)
        audio_emb = F.normalize(audio_emb, p=2, dim=-1).squeeze(0).detach().cpu()
        score = torch.dot(audio_emb, text_embedding).item()
        if queue.full():
            if score > queue.queue[0][0]:
                queue.get()
                queue.put((score, item['path']))
        else:
            queue.put((score, item['path']))

    print(f'For text "{text}" closest audio:')
    for _ in range(queue.qsize()):
        score, path = queue.get()
        # Show audio
        print(f"Simularity score: {score:.4f}")
        display(Audio(path, autoplay=True))


search_for_audio("Bird", val, top_k=3)

For text "Bird" closest audio:
Simularity score: 0.2634


Simularity score: 0.2688


Simularity score: 0.2743
