In [None]:
# %load_ext autoreload
# %autoreload 2
import wandb
from torch.utils.data import DataLoader
from torch import optim
from clap import Clap
from clap.datasets import ClapDataset
from clap.training import ClapTrainer, create_scheduler, SymmetricCrossEntropyLoss
from clap.utils import get_target_device, load_clap_config, set_random_seed

# Train CLAP on audio captioning datasets AudioCaps and Clotho

In [None]:
# Load config for audio processing and get target device
audio_encoder = "htsat-tiny"
text_encoder = "gpt2"
cfg_version = 1
ckpt_version = 1
config = load_clap_config(audio_encoder=audio_encoder, text_encoder=text_encoder, version=cfg_version)
device = get_target_device()

In [None]:
# Load Datasets
seed = set_random_seed(None)
train_dataset = ClapDataset(config=config, kinds=["train"], datasets=["AudioCaps", "Clotho"])
val_dataset = ClapDataset(config=config, kinds=["val"], datasets=["AudioCaps", "Clotho"])
test_dataset = ClapDataset(config=config, kinds=["test"], datasets=["AudioCaps", "Clotho"])

In [None]:
wandb.login()

In [None]:
wandb.init(
    # Set the wandb project where this run will be logged 
    project='CLAP-Training',
    name="First correct run",
    # Track hyperparameters
    config=config
)
config = wandb.config

In [None]:
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=config["training"]["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["training"]["batch_size"])
test_loader = DataLoader(test_dataset, batch_size=config["training"]["batch_size"])

In [None]:
# Define model, optimizer, scheduler and loss function
clap = Clap(config).to(device)
print(f"Number of parameters to train: {sum(p.numel() for p in clap.parameters())}")
optimizer = optim.Adam(clap.parameters(), lr=config["training"]["learning_rate"])
scheduler = create_scheduler(optimizer, warmup_steps=1000, T_max=len(train_loader)*config["training"]["epochs"], milestones=[1000])
loss_fn = SymmetricCrossEntropyLoss()
trainer = ClapTrainer(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    model=clap,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    epochs=config["training"]["epochs"]
)

In [None]:
train_metrics, val_metrics, test_metrics = trainer.train_and_eval(audio_encoder=audio_encoder, text_encoder=text_encoder, version=1, early_stopping=False)

In [None]:
wandb.finish()