# Wav2Rec

Operating System: Ubuntu 18.04.5 LTS. <br>
Python Version: v3.7.11

---

### Fetch the FMA Datasets (Small & Medium) + Metadata

In [None]:
!apt-get install p7zip

In [None]:
%%shell

if [ ! -f /content/fma_metadata.zip ];
then
  echo "Getting Metadata..."
  wget https://os.unil.cloud.switch.ch/fma/fma_metadata.zip -q --show-progress
  echo "Unzipping..."
  7z x /content/fma_metadata.zip
  echo "Unzipping Complete"
  rm /content/fma_metadata.zip
else
  echo "Metadata ready."
fi

if [ ! -f /content/fma_small.zip ];
then
  echo "Getting Small FMA Audio Data..."
  wget https://os.unil.cloud.switch.ch/fma/fma_small.zip -q --show-progress
  echo "Unzipping..."
  7z x /content/fma_small.zip
  echo "Unzipping Complete"
  rm /content/fma_small.zip
else
  echo "Small FMA Audio Data Ready."
fi

if [ ! -f /content/fma_medium.zip ];
then
  echo "Getting Medium FMA Audio Data..."
  wget https://os.unil.cloud.switch.ch/fma/fma_medium.zip -q --show-progress
  echo "Unzipping..."
  7z x /content/fma_medium.zip
  echo "Unzipping Complete"
  rm /content/fma_medium.zip
else
  echo "Medium FMA Audio Data Ready."
fi

---

In [None]:
import torch
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

from datetime import datetime
from pathlib import Path
from multiprocessing import cpu_count
from torch.utils.data import DataLoader, Dataset, random_split
from typing import Any, Tuple

from experiments.fma.data.dataset import FmaDataset
from wav2rec.nn.lightening import Wav2RecNet
from wav2rec.nn.audionets import AudioResnet50

In [None]:
warnings.filterwarnings("ignore", category=DeprecationWarning)

**Note**: in order to import `experiments` and `wav2rec` you must either (a) run this notebook from the root of the repository or (b) set the working directory to the root of the repository.

### Globals

In [None]:
EFFECTIVE_CPU_COUNT = max(1, cpu_count() - 1)

In [None]:
METADATA_PATH = Path("/content/fma_metadata").absolute()
AUDIO_PATH = Path("/content").absolute()

MODEL_CHECKPOINT_PATH = Path("/content/checkpoints").absolute()
MODEL_CHECKPOINT_PATH.mkdir(exist_ok=True, parents=True)

## Data Preparation

In [None]:
fma_dataset = FmaDataset(
    audio_path=AUDIO_PATH,
    metadata_path=METADATA_PATH,
    duration=10,
    min_listens=0,
).scan()

print(f"\n\nCollected {len(fma_dataset.files)} files")

**Note**: by default, FmaDataset() only loads audio with a Public Domain license. See `experiments.fma.data.meta.FmaMetadata`.

In [None]:
fma_dataset.cache_all(lazy=True)

In [None]:
assert len(fma_dataset) > 0

print(f"Files in cache: {len(fma_dataset)}.")

In [None]:
def split_dataset(
    dataset: Dataset,
    train_prop: float = 0.95,
    seed: int = 42,
) -> Tuple[Dataset, Dataset]:
    train_size = int(len(dataset) * train_prop)

    train_dataset, val_dataset = random_split(
        dataset,
        lengths=[train_size, len(dataset) - train_size],
        generator=torch.Generator().manual_seed(seed),
    )
    return train_dataset, val_dataset


def dataset2loader(
    dataset: Dataset,
    batch_size: int = 16,
    shuffle: bool = False,
    **kwargs: Any,
) -> DataLoader:
    return DataLoader(
        dataset=dataset,
        shuffle=shuffle,
        batch_size=batch_size,
        num_workers=kwargs.pop("num_workers", EFFECTIVE_CPU_COUNT),
        **kwargs,
    )

#### Datasets & Dataloaders

In [None]:
# Split into train & test
train_dataset, val_dataset = split_dataset(fma_dataset, train_prop=0.95)

# Build dataloaders
train_dataloader = dataset2loader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = dataset2loader(val_dataset, batch_size=32, shuffle=False)

---

## Train

Now we can finally train a model!

In [None]:
def do_training(wav2rec_model: torch.nn.Module, max_epochs: int = 10) -> Tuple[Wav2RecNet, datetime]:
    start_time = str(datetime.utcnow().isoformat())
    print("Start time:", start_time)

    checkpoint = ModelCheckpoint(
      dirpath=str(MODEL_CHECKPOINT_PATH.joinpath(f'wav2rec/{start_time}')),
      monitor='loss',
      verbose=True,
      save_top_k=-1,
      filename='wav2rec-{epoch:02d}-{loss:.5f}',
      save_last=True,
    )
    early_stopping = EarlyStopping(monitor='val_loss')

    trainer = pl.Trainer(
        gpus=1, 
        max_epochs=max_epochs,
        callbacks=[checkpoint, early_stopping],
        stochastic_weight_avg=True,
        logger=TensorBoardLogger("logs", name=f"wav2rec"),
    )
    trainer.fit(wav2rec_model, train_dataloader, val_dataloader)
    return wav2rec_model, start_time

In [None]:
model, start_time = do_training(
    Wav2RecNet(encoder=AudioResnet50()),
    max_epochs=25
)

**Note**: try experimenting with other encoders such as `wav2rec.nn.audionets.AudioVit`, or building your own using the `wav2rec.nn.audionets.AudioImageNetwork` base model.

## Collect Projections

With the model trained, we can go ahead at get some predictions from the model. <br>

Specifically, the code below will use the model trained above to compute projections for all songs in the training and validation datasets.

In [None]:
def _to_dataframe(all_track_ids: List[int], arrays: List[np.ndarray]) -> pd.DataFrame:
    df = pd.DataFrame(np.concatenate(arrays, axis=0)).assign(track_id=all_track_ids)
    df.columns = [f"feature_{i}" if isinstance(i, int) else i for i in df.columns]
    return df


def _get_projection(x: torch.Tensor, net: Wav2RecNet) -> np.ndarray:
    with torch.no_grad():
        proj = net.learner.encoder(x.unsqueeze(0).cuda())
    return proj.detach().cpu().numpy()


def _get_projections_from_dataset(
    dataset: torch.utils.data.Dataset,
    net: Wav2RecNet,
) -> pd.DataFrame:
    all_track_ids = list()
    all_projections = list()
    for file, x in tqdm(dataset):  # Note: could be optimized by iterating over the dataloader instead of dataset.
        all_track_ids.append(Path(file).stem)
        all_projections.append(_get_projection(x, net=net))
    return _to_dataframe(all_track_ids, arrays=all_projections)


def get_all_projections(
    net: Wav2RecNet,
    train_dataset: Dataset, 
    val_dataset: Dataset,
    verbose: bool = True,
) -> pd.DataFrame:
    all_frames = list()
    for stage, dataset in (("val", val_dataset), ("train", train_dataset)):
        if verbose:
            print(f"Working on stage '{stage}'...")
        frame = _get_projections_from_dataset(dataset, net=net).assign(stage=stage)
        all_frames.append(frame)
    return pd.concat(all_frames, ignore_index=True)

In [None]:
proj_df = get_all_projections(
    net=model.cuda().eval(),
    train_dataset=train_dataset,
    val_dataset=val_dataset,
)

In [None]:
proj_df.to_csv('wav2rec_projections.csv', index=False)

## Next

If you are interested in seeing some example of how the projections computed above can be used, check out [inference.ipynb](inference.ipynb).