In [1]:
%load_ext autoreload
%autoreload 2
import wandb as wb
from clap import Clap, ClapDataset, ClapTrainer, SymmetricCrossEntropyLoss, get_target_device, load_config
from torch.utils.data import DataLoader
from torch import optim

In [2]:
config_path = "clap/configs/clap_htsat-tiny_gpt2.yml"
config = load_config(config_path)
device = get_target_device()

In [9]:
# change these accordingly
seed = ClapTrainer.set_random_seed(None)
train_dataset = ClapDataset(config=config_path, kind="train")
val_dataset = ClapDataset(config=config_path, kind="val")
test_dataset = ClapDataset(config=config_path, kind="test")

Random seed set as 2831676980


In [4]:
wb.login()

[34m[1mwandb[0m: Currently logged in as: [33mleonakkad[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
wb.init(
    # set the wandb project where this run will be logged 
    project='Custom-CLAP',
    name="Test run with hdf5 file",
    # track hyperparameters
    config=config
)
config = wb.config

In [10]:
# define data loaders
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"])

In [11]:
# define model, optimize and loss function
model = Clap(config).to(device)
print(f"Number of parameters to train: {sum(p.numel() for p in model.parameters())}")
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
loss_fn = SymmetricCrossEntropyLoss()
trainer = ClapTrainer(train_loader, val_loader, test_loader, model, optimizer, loss_fn, config["epochs"])



Number of parameters to train: 167871888


In [12]:
metrics = trainer.train_and_eval("checkpoints/test.ckpt", None, False)


Starting to train Model


Training epoch 1 (lr=array([0.0001])):  31%|███▏      | 314/1002 [48:39<1:46:36,  9.30s/it]


KeyboardInterrupt: 

In [13]:
wb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/a2t/batch recall@1,▁▁▁▁▂▂▂▁▃▂▂▂▃▃▂▃▅▄▄▅▅▆▃▆▆▆▆▇▆▅▅▆▅▅▇█▄▇█▆
train/a2t/batch recall@10,▁▁▁▁▂▂▄▁▃▃▄▄▄▅▄▅▆▆▆▅▅▅▆▇▆▆▆▆▇▆▆▇▇▅▇▇█▇█▇
train/a2t/batch recall@5,▁▁▂▁▂▂▃▁▃▃▃▃▃▅▄▄▅▅▅▅▅▅▅▆▆▅▆▅▇▇▆▇▆▅▇▇▇▇█▇
train/batch loss,████▇▇▆█▆▇▆▅▆▅▅▅▄▄▄▅▄▄▄▄▃▄▃▃▂▃▃▂▃▅▂▃▂▂▁▂
train/step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/t2a/batch recall@1,▁▁▁▁▁▁▂▁▂▃▃▄▃▄▃▃▄▃▃▄▄▄▃▃▆▆▇▄▇▆▅▆▅▃▆▆▅▇▇█
train/t2a/batch recall@10,▁▂▁▁▂▃▄▁▃▄▄▄▄▅▅▄▅▆▅▅▆▅▆▆▆▆▇▅▇▇▇▇▇▅▇▆▇▇█▇
train/t2a/batch recall@5,▁▂▁▂▂▃▃▁▂▃▃▃▃▄▄▃▅▆▅▄▅▅▅▆▆▅▇▅▇▅▇▆▇▅▇▇▇▇█▇

0,1
train/a2t/batch recall@1,0.23438
train/a2t/batch recall@10,0.79688
train/a2t/batch recall@5,0.64062
train/batch loss,2.6378
train/step,313.0
train/t2a/batch recall@1,0.29688
train/t2a/batch recall@10,0.79688
train/t2a/batch recall@5,0.625
