# 1.Imports

In [None]:
import re
import yaml
import torch
import librosa
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
from scipy.io import wavfile
import matplotlib.pyplot as plt

from kaituoxu.conv_tasnet import ConvTasNet
from asteroid.data import MUSDB18Dataset

In [None]:
DATA_DIR = Path("musdb_data")

TEST_SONG = DATA_DIR/"test/Al James - Schoolboy Facination.stem.mp4"

TRAINING_DIR = Path('weights/training_20220308-170315')
cfg_path = TRAINING_DIR/"cfg.yaml"
history_path = TRAINING_DIR/"history.csv"
best_model_path = TRAINING_DIR/"model.pth"

try:
    last_model_path = sorted(list(TRAINING_DIR.glob("model_epoch*")), key=lambda x: int("".join(re.findall(r"\d", str(x)))))[-1]
except:
    last_model_path = None

with open(str(cfg_path), 'r') as file:
    CFG = yaml.load(file, Loader=yaml.FullLoader)

assert cfg_path.exists() and history_path.exists() and best_model_path.exists() and cfg_path.exists()

## Hyper parameters:

In [None]:
for key, value in CFG.items():
    print(f">>> {key.upper()} -> {value}")

# 2. Learning Curves

In [None]:
df = pd.read_csv(history_path)
epochs = df.index.values
train_loss = df.train_loss
val_loss = df.val_loss

In [None]:
plt.figure(figsize=(10, 7))
plt.plot(epochs, train_loss, label='train loss')
plt.plot(epochs, val_loss, label='val loss')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.plot();

## Observations :
- The model does not learn ! Instability ?
  - Solutions ?:
    - Reduce learning rate
    - Increase batch size
    - Store the learning rate at each epoch
    - Store other metrics
    - Store gradients norms

<br/>

- Maybe use smaller network ?

# 3. Load model

In [None]:
TARGETS = CFG["targets"]
SAMPLE_RATE = CFG["sample_rate"]

LR = CFG["learning_rate"]
N_EPOCHS = CFG["n_epochs"]

TRAIN_BATCH_SIZE = CFG["train_batch_size"]
TEST_BATCH_SIZE = CFG["test_batch_size"]

N_SRC = len(TARGETS)
X = CFG["X"]
R = CFG["R"]
B = CFG["B"]
H = CFG["H"]
Sc = CFG["Sc"]
P = CFG["P"]
L = CFG["L"]
N = CFG["N"]
STRIDE = CFG["stride"]
CLIP = CFG["gradient_clipping"]

model = ConvTasNet(
    C=N_SRC,
    X=X,
    R=R,
    B=B,
    H=H,
    P=P,
    L=L,
    N=N,
    mask_nonlinear="softmax"
)
model.load_state_dict(torch.load(best_model_path)["model_state_dict"])

In [None]:
TARGETS

# 4. Separation test

In [None]:
sound, sr = librosa.load(TEST_SONG.__str__(), sr=SAMPLE_RATE)

In [None]:
sound = torch.tensor(sound[:10 * SAMPLE_RATE])

In [None]:
pred = model(sound.view(1, -1))

In [None]:
pred.shape

In [None]:
for i, name in enumerate(TARGETS):
    wavfile.write(f"./{name}.wav", SAMPLE_RATE, pred[i].detach().numpy())