In [None]:
import torch
import numpy as np
# ^^^ pyforest auto-imports - don't write above this line
sys.path.insert(0, str(Path("../../").resolve()))

%load_ext autoreload
%autoreload 2

In [None]:
from data_reader import CINC2022Reader, CINC2016Reader, EPHNOGRAMReader
from dataset import CinC2022Dataset
from models import (
    CRNN_CINC2022,
    SEQ_LAB_NET_CINC2022,
    UNET_CINC2022,
    Wav2Vec2_CINC2022,
    HFWav2Vec2_CINC2022,
)
from cfg import TrainCfg, ModelCfg
from trainer import CINC2022Trainer, _MODEL_MAP, _set_task, collate_fn
from utils.plot import plot_spectrogram

from tqdm.auto import tqdm
import torchaudio
from copy import deepcopy

from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP
from torch.utils.data import DataLoader

CRNN_CINC2022.__DEBUG__ = False
Wav2Vec2_CINC2022.__DEBUG__ = False
HFWav2Vec2_CINC2022.__DEBUG__ = False
CinC2022Dataset.__DEBUG__ = False

%load_ext autoreload
%autoreload 2

In [None]:
db_dir = "/data1/Jupyter-Data/CinC2022/"  # replace with the data directory

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if ModelCfg.torch_dtype == torch.float64:
    torch.set_default_tensor_type(torch.DoubleTensor)
    DTYPE = np.float64
else:
    DTYPE = np.float32

In [None]:
# task = "classification"
task = "multi_task"

train_config = deepcopy(TrainCfg)
# train_config.db_dir = data_folder
# train_config.model_dir = model_folder
# train_config.final_model_filename = _ModelFilename
train_config.debug = True

train_config.db_dir = db_dir

# train_config.n_epochs = 100
# train_config.batch_size = 24  # 16G (Tesla T4)
# train_config.log_step = 20
# # train_config.max_lr = 1.5e-3
# train_config.early_stopping.patience = 20

train_config[task].model_name = "crnn"  # "wav2vec2_hf"

train_config[task].cnn_name = "tresnetF"  # "resnet_nature_comm_bottle_neck_se"
# train_config[task].rnn_name = "none"  # "none", "lstm"
# train_config[task].attn_name = "se"  # "none", "se", "gc", "nl"

_set_task(task, train_config)

model_config = deepcopy(ModelCfg[task])

# adjust model choices if needed
model_config.model_name = train_config[task].model_name
# print(model_name)
if "cnn" in model_config[model_config.model_name]:
    model_config[model_config.model_name].cnn.name = train_config[task].cnn_name
if "rnn" in model_config[model_config.model_name]:
    model_config[model_config.model_name].rnn.name = train_config[task].rnn_name
if "attn" in model_config[model_config.model_name]:
    model_config[model_config.model_name].attn.name = train_config[task].attn_name

# model_config.wav2vec2.cnn.name = "resnet_nature_comm_bottle_neck_se"
# model_config.wav2vec2.encoder.name = "wav2vec2_nano"

In [None]:
model_cls = _MODEL_MAP[model_config.model_name]
model_cls.__DEBUG__ = False

In [None]:
model = model_cls(config=model_config)
if torch.cuda.device_count() > 1:
    model = DP(model)
    # model = DDP(model)
model.to(device=DEVICE);

In [None]:
model.module.module_size, model.module.module_size_

In [None]:
model

In [None]:
ds_train = CinC2022Dataset(train_config, task, training=True, lazy=True)
ds_test = CinC2022Dataset(train_config, task, training=False, lazy=True)

In [None]:
ds_train._load_all_data()

In [None]:
ds_test._load_all_data()

In [None]:
trainer = CINC2022Trainer(
    model=model,
    model_config=model_config,
    train_config=train_config,
    device=DEVICE,
    lazy=True,
)

In [None]:
trainer._setup_dataloaders(ds_train, ds_test)

In [None]:
best_state_dict = trainer.train()

## Inspect trained models

In [None]:
from models import Wav2Vec2_CINC2022, CRNN_CINC2022

%load_ext autoreload
%autoreload 2

In [None]:
ckpt = CRNN_CINC2022.from_checkpoint(
    "./saved_models/BestModel_task-multi_task_CRNN_CINC2022_epoch41_08-11_02-38_metric_-16272.44.pth.tar"
    # replace with a saved model
)

In [None]:
ckpt[0].config

In [None]:
best_model = ckpt[0]

In [None]:
best_model = best_model.to("cpu")

In [None]:
dl = DataLoader(
    dataset=ds_train,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
    collate_fn=collate_fn,
)

In [None]:
for batch in dl:
    labels = batch
    waveforms = labels.pop("waveforms")
    break

In [None]:
best_model(waveforms, labels)