In [None]:
import wandb
import torch
from torch import optim
from torch.utils.data import DataLoader
from clap import Clap, ClapAudioClassifier
from clap.training import create_scheduler, ClapFinetuner
from clap.datasets import ClapDataset
from clap.utils import get_target_device, load_clap_config, set_random_seed

# Fine-tune ClapAudioClassifier on ESC-50

In [None]:
# Load config for audio processing and get target device
config_path = r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\configs\clap_htsat-tiny_gpt2_v1.yml"
clap_ckpt_path = r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\checkpoints\clap_htsat-tiny_gpt2_v1.ckpt"
clf_ckpt_path = r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\checkpoints\clf_htsat-tiny_gpt2_v1.ckpt"
config = load_clap_config(config_path=config_path)
device = get_target_device()

In [None]:
# Load Datasets
seed = set_random_seed(None)
dataset_paths = [r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\datasets\esc50"]
train_dataset = ClapDataset(config=config, kinds=["train"], datasets=["ESC50"], datasets_paths=dataset_paths)
val_dataset = ClapDataset(config=config, kinds=["val"], datasets=["ESC50"], datasets_paths=dataset_paths)
test_dataset = ClapDataset(config=config, kinds=["test"], datasets=["ESC50"], datasets_paths=dataset_paths)

In [None]:
# Use wandb for logging (just skip and set enable_wandb_logging to False if not wanted)
wandb.login()

In [None]:
wandb.init(
    # Set the wandb project where this run will be logged 
    project='CLAP-Fine-tuning',
    name="First fine-tuning run",
    # Track hyperparameters
    config=config
)
config = wandb.config

In [None]:
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=config["fine-tuning"]["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["fine-tuning"]["batch_size"])
test_loader = DataLoader(test_dataset, batch_size=config["fine-tuning"]["batch_size"])

In [None]:
# Define model, optimizer, scheduler and loss function
clap = Clap.from_ckpt(config_path=config_path, ckpt_path=clap_ckpt_path)
clap_clf = ClapAudioClassifier(clap=clap, config=config).to(device)
print(f"Number of parameters to train: {sum(p.numel() for p in clap_clf.parameters())}")
optimizer = optim.Adam(clap.parameters(), lr=config["fine-tuning"]["learning_rate"])
scheduler = create_scheduler(optimizer, warmup_steps=31, T_max=len(train_loader)*config["fine-tuning"]["epochs"], milestones=[31])
loss_fn = torch.nn.CrossEntropyLoss()
trainer = ClapFinetuner(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    model=clap_clf,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    epochs=config["fine-tuning"]["epochs"],
    enable_wandb_logging=True
)

In [None]:
train_metrics, val_metrics, test_metrics = trainer.finetune_and_eval(ckpt_path=clf_ckpt_path, early_stopping=False)

In [None]:
wandb.finish()