# Model development

In [1]:
import logging
import os
import sys
import wandb

sys.path.append("../")  # go to parent dir

from pprint import pprint
from types import SimpleNamespace

import torch

from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from lightning.pytorch.loggers import WandbLogger
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary

from core.config import (
    BATCH_SIZE,
    LR,
    N_EPOCHS,
    RANDOM_STATE,
    SAMPLE_RATE,
    WEIGHT_DECAY,
)
from core.data import load_data
from core.downstream import LightningMLP
from core.transforms import load_transforms
from core.upstream import load_feature_extractor, load_processor

logger = logging.getLogger(__name__)

torch.set_float32_matmul_precision("high")

## Training script

We will dissect the ```core.train``` script using the tiny Whisper model as an example

### Parameters

In [3]:
args = SimpleNamespace(
    # Name of the dataset
    dataset = "nort3160",
    # Name of the backbone model
    model_id = "whisper_tiny",
    # Data type
    dtype = "speech",
    # Max audio duration
    max_duration = 10,
    # Whether to transform the data using Voice Activity Detection or not
    transform = True,
    # Device on which a torch.Tensor is or will be allocated
    device = "cuda:0",
    # Learning rate
    lr = LR,
    # Weight decay
    weight_decay = WEIGHT_DECAY,
    # Batch size
    batch_size = BATCH_SIZE,
    # Number of training epochs
    n_epochs = N_EPOCHS,
    # Path to main data folder path
    data_dir = "data", # Change that folder if needed!
    # Path where cached models are stored
    cache_dir = "/home/common/speech_phylo/models",  # Change that folder if needed!
    # Random seed
    seed = RANDOM_STATE
)

print("Configuration:")
for k, v in vars(args).items():
    print(f"\t{k}: {v}")

if not (torch.cuda.is_available() and "cuda" in args.device):
    device = "cpu"

Configuration:
	dataset: nort3160
	model_id: whisper_tiny
	dtype: speech
	max_duration: 10
	transform: True
	device: cuda:0
	lr: 0.00025
	weight_decay: 0.001
	batch_size: 64
	n_epochs: 3
	data_dir: data
	cache_dir: /home/common/speech_phylo/models
	seed: 42


In [4]:
seed_everything(args.seed, workers=True)

Seed set to 42


42

In [5]:
print("Loading feature extractor...")
feature_extractor = load_feature_extractor(
    model_id=args.model_id,
    cache_dir=args.cache_dir,
    device=args.device,
)

summary(feature_extractor, input_size=(1, SAMPLE_RATE))

Loading feature extractor...


Layer (type:depth-idx)                             Output Shape              Param #
WhisperFeatureExtractor                            [1, 1, 384]               --
├─Whisper: 1-1                                     --                        --
│    └─AudioEncoder: 2-1                           --                        --
│    │    └─Conv1d: 3-1                            [1, 384, 3000]            92,544
│    │    └─Conv1d: 3-2                            [1, 384, 1500]            442,752
│    │    └─ModuleList: 3-3                        --                        7,096,320
│    │    └─LayerNorm: 3-4                         [1, 1500, 384]            768
│    └─TextDecoder: 2-2                            --                        172,032
│    │    └─Embedding: 3-5                         [1, 1, 384]               19,916,160
│    │    └─ModuleList: 3-6                        --                        9,463,296
│    │    └─LayerNorm: 3-7                         [1, 1, 384]               7

In [6]:
print("Preparing data...")
dataset_args = {
    "dataset": args.dataset,
    "data_dir": args.data_dir,
}

loader_args = {
    "num_workers": 4,
    "batch_size": args.batch_size,
    "pin_memory": True,
}

if args.dtype == "speech":
    print("Loading processor...")
    processor = load_processor(
        model_id=args.model_id,
        sr=SAMPLE_RATE,
        cache_dir=args.cache_dir,
    )

    transform = load_transforms(
        sr=SAMPLE_RATE, max_duration=args.max_duration, vad=args.transform
    )

    dataset_args = {
        **dataset_args,
        "processor": processor,
        "max_duration": args.max_duration,
        "transform": transform,
    }
else:
    raise NotImplementedError

print("Dataset args:")
pprint(dataset_args)
print("Data loader args:")
pprint(loader_args)

print("Loading data...")
train_dataset, valid_dataset, test_dataset = load_data(**dataset_args)
train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
valid_loader = DataLoader(valid_dataset, shuffle=False, **loader_args)
test_loader = DataLoader(test_dataset, shuffle=False, **loader_args)

print("Done")

Preparing data...
Loading processor...
Dataset args:
{'data_dir': 'data',
 'dataset': 'nort3160',
 'max_duration': 10,
 'processor': <core.upstream.processing.WhisperProcessor object at 0x7faffc288dc0>,
 'transform': Sequential(
  (0): Vad()
  (1): Trim()
  (2): Pad()
)}
Data loader args:
{'batch_size': 64, 'num_workers': 4, 'pin_memory': True}
Loading data...
Done


In [7]:
print("Preparing downstream classifier...")
lit_mlp = LightningMLP(
    feature_extractor=feature_extractor,
    num_classes=len(train_dataset.label_encoder),
    loss_fn=nn.NLLLoss(),
    lr=args.lr,
    weight_decay=args.weight_decay,
)

torch.compile(lit_mlp)

print(lit_mlp.classifier)

Preparing downstream classifier...


MLP(
  (classifier): Linear(in_features=384, out_features=5, bias=True)
  (log_softmax): LogSoftmax(dim=-1)
)


In [8]:
eval_dir = f"{args.data_dir}/eval"
os.makedirs(eval_dir, exist_ok=True)
wandb_logger = WandbLogger(
    project=f"mlops_project_eval_{args.dataset}", save_dir=eval_dir
)

wandb_logger.experiment.config.update(
    {
        "model_id": args.model_id,
        "max_duration": args.max_duration,
        "transform": args.transform,
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mneclow[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
if "cuda" in args.device:
    accelerator = "gpu"
    devices = [int(args.device.split(":")[-1])]
else:
    accelerator = "cpu"
    devices = "auto"

print("Start training!")
trainer = Trainer(
    accelerator=accelerator,
    devices=devices,
    max_epochs=args.n_epochs,
    enable_model_summary=True,
    callbacks=[
        ModelCheckpoint(monitor="valid_loss", mode="min", save_last=True),
        TQDMProgressBar(),
    ],
    logger=wandb_logger,
)

trainer.fit(
    model=lit_mlp,
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name              | Type                    | Params | Mode 
----------------------------------------------------------------------
0 | feature_extractor | WhisperFeatureExtractor | 37.2 M | eval 
1 | classifier        | MLP                     | 1.9 K  | train
2 | loss_fn           | NLLLoss                 | 0      | train
3 | train_metric      | MulticlassAccuracy      | 0      | train
4 | valid_metric      | MulticlassAccuracy      | 0      | train
5 | test_metric       | MulticlassAccuracy      | 0      | train
----------------------------------------------------------------------
37.2 M    Trainable params
0         Non-trainable params
37.2 M    Total params
148.746   Total estimated model params size (MB)
7         Modules in train mode
131       Modules in eval mode


Start training!


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [10]:
trainer.test(
    model=lit_mlp,
    dataloaders=test_loader,
    ckpt_path="best",
)

wandb.finish()

Restoring states from the checkpoint path at data/eval/mlops_project_eval_nort3160/7zi2am3w/checkpoints/epoch=2-step=1386.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from the checkpoint at data/eval/mlops_project_eval_nort3160/7zi2am3w/checkpoints/epoch=2-step=1386.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▃▆▆▆▆▆▆▆▆▆▆▆█
test_metric,▁
train_loss_epoch,█▂▁
train_loss_step,█▆▃▄▃▂▄▂▂▃▃▁▂▂▃▂▂▂▃▂▂▁▂▂▁▂▁
train_metric_epoch,▁▇█
train_metric_step,▁▄▅▄▆▆▅█▇▆▆█▆▇▆▇▇▆▇▇▆▇▇▇█▅█
trainer/global_step,▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇█████
valid_loss,█▄▁
valid_metric,▁▂█

0,1
epoch,3.0
test_metric,0.87911
train_loss_epoch,0.25082
train_loss_step,0.12757
train_metric_epoch,0.90983
train_metric_step,0.95312
trainer/global_step,1386.0
valid_loss,0.39403
valid_metric,0.86456


## Logging Artifacts

In [None]:
api = wandb.api()

run = api.run()