In [None]:
# %load_ext autoreload
# %autoreload 2
import wandb
from torch.utils.data import DataLoader
from torch import optim
from clap import Clap
from clap.datasets import ClapDataset
from clap.training import ClapTrainer, create_scheduler, SymmetricCrossEntropyLoss
from clap.utils import get_target_device, load_clap_config, set_random_seed

# Stage 1: Train CLAP on audio captioning datasets AudioCaps and ClothoV2

In [None]:
# Load config for audio processing and get target device
config_path = r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\configs\clap_cnn14_distilroberta-base_vTestDistillation.yml"
ckpt_path = r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\checkpoints\clap_cnn14_distilroberta-base_vStage1_Clotho.ckpt"
config = load_clap_config(config_path)
device = get_target_device()

In [None]:
# Use wandb logging (just skip and set enable_wandb_logging to False if not wanted)
wandb.login()

In [None]:
wandb.init(
    # Set the wandb project where this run will be logged 
    project='CLAP-Training',
    name="Stage 1 only on Clotho",
    # Track hyperparameters
    config=config
)
config = wandb.config

config_stage1 = config["training"]["stage1"]
set_random_seed(config_stage1["seed"])

In [None]:
# Load Datasets
dataset_paths = [r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\datasets\clotho"]
train_dataset = ClapDataset(config=config, kinds=["train"], datasets=["Clotho"], datasets_paths=dataset_paths)
val_dataset = ClapDataset(config=config, kinds=["val"], datasets=["Clotho"], datasets_paths=dataset_paths)
test_dataset = ClapDataset(config=config, kinds=["test"], datasets=["Clotho"], datasets_paths=dataset_paths)

In [None]:
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=config_stage1["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config_stage1["batch_size"])
test_loader = DataLoader(test_dataset, batch_size=config_stage1["batch_size"])

In [None]:
# Define model
clap = Clap(config).to(device)
print(f"Number of parameters to train: {sum(p.numel() for p in clap.parameters())}")

In [None]:
# Define optimizer, scheduler and loss function
optimizer = optim.AdamW(clap.parameters(), lr=config_stage1["learning_rate"], betas=config_stage1["betas"], weight_decay=config_stage1["weight_decay"])
scheduler = create_scheduler(optimizer, warmup_steps=len(train_loader)*config_stage1["warmup_epochs"], T_max=len(train_loader)*config_stage1["annealing_epochs"]-1, min_lr=1e-6)
loss_fn = SymmetricCrossEntropyLoss()

In [None]:
# Define trainer
stage1_trainer = ClapTrainer(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    model=clap,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    epochs=config_stage1["epochs"],
    enable_wandb_logging=True
)

In [None]:
stage1_train_metrics, stage1_val_metrics, stage1_test_metrics = stage1_trainer.train_and_eval(ckpt_path, early_stopping=False)

In [None]:
wandb.finish()

# Stage 2: Continue training by distilling soft-targets from pre-trained CLAP models

In [None]:
# Load config for audio processing and get target device
config_path = r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\configs\clap_cnn14_distilroberta-base_vTestDistillation.yml"
ckpt_path = r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\checkpoints\clap_cnn14_distilroberta-base_vStage1_Clotho.ckpt"
config = load_clap_config(config_path)
device = get_target_device()

In [None]:
# Use wandb logging (just skip and set enable_wandb_logging to False if not wanted)
wandb.login()

In [None]:
wandb.init(
    # Set the wandb project where this run will be logged 
    project='CLAP-Training',
    name="Stage 2 with distillation",
    # Track hyperparameters
    config=config
)
config = wandb.config

config_stage2 = config["training"]["stage2"]
set_random_seed(config_stage2["seed"])

In [None]:
# Load Datasets
dataset_paths = [r"C:\Users\leon\Documents\ML_Projects\Custom-CLAP\clap\datasets\clotho"]
train_dataset = ClapDataset(config=config, kinds=["train"], datasets=["Clotho"], datasets_paths=dataset_paths)
val_dataset = ClapDataset(config=config, kinds=["val"], datasets=["Clotho"], datasets_paths=dataset_paths)
test_dataset = ClapDataset(config=config, kinds=["test"], datasets=["Clotho"], datasets_paths=dataset_paths)

In [None]:
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=config_stage2["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config_stage2["batch_size"])
test_loader = DataLoader(test_dataset, batch_size=config_stage2["batch_size"])

In [None]:
# Define model
clap = Clap.from_ckpt(config_path=config_path, ckpt_path=ckpt_path).to(device)

In [None]:
# Define optimizer, scheduler and loss function
optimizer = optim.AdamW(clap.parameters(), lr=config_stage2["learning_rate"], betas=config_stage2["betas"], weight_decay=config_stage2["weight_decay"])
scheduler = create_scheduler(optimizer, warmup_steps=len(train_loader)*config_stage2["warmup_epochs"], T_max=len(train_loader)*config_stage2["annealing_epochs"]-1, min_lr=1e-6)
loss_fn = SymmetricCrossEntropyLoss()

In [None]:
# Define distillation models and loss weight
distill_models = []
distill_model1 = Clap.from_ckpt(config_path=config_path, ckpt_path=ckpt_path).to(device)
distill_model1.freeze_encoders()
distill_model1.eval()
distill_models.append(distill_model1)

distill_from = distill_models
distill_weight = 1

In [None]:
stage2_trainer = ClapTrainer(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    model=clap,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    epochs=config_stage2["epochs"],
    enable_wandb_logging=True,
    distill_from=distill_from,
    distill_weight=distill_weight
)

In [None]:
stage2_train_metrics, stage2_val_metrics, stage2_test_metrics = stage2_trainer.train_and_eval(ckpt_path, early_stopping=True)

In [None]:
wandb.finish()