In [None]:
%load_ext autoreload
%autoreload 2
import wandb as wb
from clap import Clap, ClapDataset, ClapTrainer, SymmetricCrossEntropyLoss, get_target_device, load_config
from torch.utils.data import DataLoader
from torch import optim

In [None]:
config_path = "clap/configs/clap_htsat-tiny_gpt2.yml"
config = load_config(config_path)
device = get_target_device()

In [None]:
# change these accordingly
seed = ClapTrainer.set_random_seed(None)
train_dataset = ClapDataset(config=config_path, kind="train")
val_dataset = ClapDataset(config=config_path, kind="val")
test_dataset = ClapDataset(config=config_path, kind="test")

In [None]:
wb.login()

In [None]:
wb.init(
    # set the wandb project where this run will be logged 
    project='Custom-CLAP',
    name="Test run with hdf5 file",
    # track hyperparameters
    config=config
)
config = wb.config

In [None]:
# define data loaders
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"])

In [None]:
# define model, optimize and loss function
model = Clap(config).to(device)
print(f"Number of parameters to train: {sum(p.numel() for p in model.parameters())}")
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
loss_fn = SymmetricCrossEntropyLoss()
trainer = ClapTrainer(train_loader, val_loader, test_loader, model, optimizer, loss_fn, config["epochs"])

In [None]:
metrics = trainer.train_and_eval("checkpoints/test.ckpt", None, False)

In [None]:
wb.finish()