In [None]:
# %load_ext autoreload
# %autoreload 2
import wandb
from torch.utils.data import DataLoader
from torch import optim
from clap import Clap
from clap.datasets import ClapDataset
from clap.training import ClapTrainer, create_scheduler, SymmetricCrossEntropyLoss
from clap.utils import get_target_device, load_clap_config, set_random_seed

# Stage 1: Train CLAP on audio captioning datasets AudioCaps and ClothoV2

In [None]:
# Load config for audio processing and get target device
audio_encoder = "htsat-tiny"
text_encoder = "gpt2"
cfg_version = "TestDistillation"
config = load_clap_config(audio_encoder=audio_encoder, text_encoder=text_encoder, version=cfg_version)
set_random_seed(config["training"]["seed"])
device = get_target_device()

In [None]:
# Load Datasets
train_dataset = ClapDataset(config=config, kinds=["train"], datasets=["Clotho"])
val_dataset = ClapDataset(config=config, kinds=["val"], datasets=["Clotho"])
test_dataset = ClapDataset(config=config, kinds=["test"], datasets=["Clotho"])

In [None]:
# Use wandb logging (just skip and set enable_wandb_logging to False if not wanted)
wandb.login()

In [None]:
wandb.init(
    # Set the wandb project where this run will be logged 
    project='CLAP-Training',
    name="Stage 1 only on Clotho",
    # Track hyperparameters
    config=config
)
config = wandb.config

In [None]:
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=config["training"]["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["training"]["batch_size"])
test_loader = DataLoader(test_dataset, batch_size=config["training"]["batch_size"])

In [None]:
# Define model
clap = Clap(config).to(device)
print(f"Number of parameters to train: {sum(p.numel() for p in clap.parameters())}")

In [None]:
# Define optimizer, scheduler and loss function
optimizer = optim.AdamW(clap.parameters(), lr=config["training"]["learning_rate"], betas=(0.9, 0.999), weight_decay=0)
scheduler = create_scheduler(optimizer, warmup_steps=300, T_max=len(train_loader)*config["training"]["stage1_epochs"], milestones=[300])
loss_fn = SymmetricCrossEntropyLoss()

In [None]:
# Define trainer
stage1_trainer = ClapTrainer(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    model=clap,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    epochs=config["training"]["stage1_epochs"],
    enable_wandb_logging=True
)

In [None]:
stage1_train_metrics, stage1_val_metrics, stage1_test_metrics = stage1_trainer.train_and_eval(audio_encoder=audio_encoder, text_encoder=text_encoder, version="Stage1_Clotho_test", early_stopping=False)

In [None]:
wandb.finish()

# Stage 2: Continue training by distilling soft-targets from pre-trained CLAP models

In [None]:
# Load config for audio processing and get target device
audio_encoder = "htsat-tiny"
text_encoder = "gpt2"
cfg_version = "TestDistillation"
config = load_clap_config(audio_encoder=audio_encoder, text_encoder=text_encoder, version=cfg_version)
set_random_seed(config["training"]["seed"])
device = get_target_device()

In [None]:
# Load Datasets
train_dataset = ClapDataset(config=config, kinds=["train"], datasets=["Clotho"])
val_dataset = ClapDataset(config=config, kinds=["val"], datasets=["Clotho"])
test_dataset = ClapDataset(config=config, kinds=["test"], datasets=["Clotho"])

In [None]:
# Use wandb logging (just skip and set enable_wandb_logging to False if not wanted)
wandb.login()

In [None]:
wandb.init(
    # Set the wandb project where this run will be logged 
    project='CLAP-Training',
    name="Stage 2 with distillation",
    # Track hyperparameters
    config=config
)
config = wandb.config

In [None]:
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=config["training"]["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["training"]["batch_size"])
test_loader = DataLoader(test_dataset, batch_size=config["training"]["batch_size"])

In [None]:
# Define model
clap = Clap.from_ckpt(audio_encoder=audio_encoder, text_encoder=text_encoder, cfg_version=cfg_version, ckpt_version="Stage1_Clotho").to(device)

In [None]:
# Define optimizer, scheduler and loss function
optimizer = optim.AdamW(clap.parameters(), lr=config["training"]["learning_rate"], betas=(0.9, 0.999), weight_decay=0)
scheduler = create_scheduler(optimizer, warmup_steps=300, T_max=len(train_loader)*config["training"]["stage2_epochs"], milestones=[300])
loss_fn = SymmetricCrossEntropyLoss()

In [None]:
# Define distillation models and loss weight
distill_models = []
distill_model1 = Clap.from_ckpt(audio_encoder=audio_encoder, text_encoder=text_encoder, cfg_version=cfg_version, ckpt_version="Stage1_Clotho").to(device)
distill_model1.freeze_encoders()
distill_model1.eval()
distill_models.append(distill_model1)

distill_from = distill_models
distill_weight = 1

In [None]:
stage2_trainer = ClapTrainer(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    model=clap,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    epochs=config["training"]["stage2_epochs"],
    enable_wandb_logging=True,
    distill_from=distill_from,
    distill_weight=distill_weight
)

In [None]:
stage2_train_metrics, stage2_val_metrics, stage2_test_metrics = stage2_trainer.train_and_eval(audio_encoder=audio_encoder, text_encoder=text_encoder, version="Stage2_Distillation_New", early_stopping=True)

In [None]:
wandb.finish()