In [1]:
import torch
from playgenie.model import VAE
from playgenie.loss import generative_loss
import config
from tqdm import tqdm
from playgenie.data.dataset import get_dataloader

In [2]:
DATA_PATH = '../data/spotify.safetensors'
MODEL_SAVE_PATH = '../data/model.pt'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCHS = 1000
model  = VAE(input_size=config.INPUT_DIM, hidden_size=config.HIDDEN_DIM,latent_size=config.LATENT_DIM)
optim = torch.optim.Adam(model.parameters())

In [3]:

train_dataloader = get_dataloader(path=DATA_PATH,mode='train',batch_size=64)
validation_dataloader = get_dataloader(path=DATA_PATH,mode='validation',batch_size=64)

best_loss = float('inf')

epoch_bar = tqdm(range(EPOCHS),desc=f'Training model',leave=False)
for epoch in epoch_bar:
    avg_loss = 0
    avg_val_loss = 0

    for inputs, targets in train_dataloader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        preds,mu,log_var = model(inputs)

        optim.zero_grad()
        loss = generative_loss(preds=preds,targets=targets,mu= mu, log_var= log_var)
        loss.backward()
        optim.step()

        avg_loss += loss.item()

    avg_loss /= len(train_dataloader)

    model.eval()
    with torch.no_grad():
        for val_inputs, val_targets in validation_dataloader:
            val_inputs, val_targets = val_inputs.to(DEVICE), val_targets.to(DEVICE)
            val_preds,val_mu,val_log_var = model(val_inputs)

            loss = generative_loss(preds=val_preds,targets=val_targets,mu= val_mu, log_var= val_log_var)
            avg_val_loss += loss.item()

    avg_val_loss /= len(validation_dataloader)
    epoch_bar.set_postfix(Epoch=(epoch+1),Train_Loss=avg_loss,Validation_loss=avg_val_loss)

    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
    if (epoch+1) % 100 == 0:
        tqdm.write(f"Epoch {epoch+1}: Train Loss={avg_loss:.4f}, Val Loss={avg_val_loss:.4f}")

Training model:  11%|█         | 108/1000 [00:02<00:20, 43.03it/s, Epoch=109, Train_Loss=0.205, Validation_loss=0.344]

Epoch 100: Train Loss=0.2085, Val Loss=0.3424


Training model:  20%|██        | 205/1000 [00:05<00:17, 45.22it/s, Epoch=209, Train_Loss=0.18, Validation_loss=0.35]  

Epoch 200: Train Loss=0.1813, Val Loss=0.3503


Training model:  31%|███       | 308/1000 [00:07<00:13, 49.99it/s, Epoch=309, Train_Loss=0.17, Validation_loss=0.355] 

Epoch 300: Train Loss=0.1709, Val Loss=0.3551


Training model:  40%|████      | 404/1000 [00:10<00:15, 38.87it/s, Epoch=407, Train_Loss=0.166, Validation_loss=0.356]

Epoch 400: Train Loss=0.1663, Val Loss=0.3562


Training model:  51%|█████     | 506/1000 [00:13<00:11, 43.23it/s, Epoch=508, Train_Loss=0.164, Validation_loss=0.353]

Epoch 500: Train Loss=0.1641, Val Loss=0.3529


Training model:  60%|██████    | 604/1000 [00:15<00:13, 28.98it/s, Epoch=605, Train_Loss=0.166, Validation_loss=0.355]

Epoch 600: Train Loss=0.1663, Val Loss=0.3488


Training model:  70%|███████   | 703/1000 [00:18<00:07, 40.31it/s, Epoch=708, Train_Loss=0.163, Validation_loss=0.34] 

Epoch 700: Train Loss=0.1686, Val Loss=0.3447


Training model:  81%|████████  | 808/1000 [00:21<00:04, 44.73it/s, Epoch=811, Train_Loss=0.162, Validation_loss=0.334]

Epoch 800: Train Loss=0.1617, Val Loss=0.3332


Training model:  91%|█████████ | 907/1000 [00:23<00:01, 46.64it/s, Epoch=909, Train_Loss=0.161, Validation_loss=0.336]

Epoch 900: Train Loss=0.1614, Val Loss=0.3349


                                                                                                                        

Epoch 1000: Train Loss=0.1612, Val Loss=0.3222


