In [1]:
import torch
from playgenie.model import VAE
from playgenie.loss import generative_loss
import config
from tqdm import tqdm
from playgenie.data.dataset import get_dataloader

In [3]:
model  = VAE(input_size=384, hidden_size=10,latent_size=2,encoder_n_heads=2,decoder_n_la=2)
optim = torch.optim.Adam(model.parameters())

In [4]:
model

VAE(
  (encoder): Encoder(
    (attention_block): AttentionBlock(
      (mha): MultiHeadAttention(
        (q_proj): Linear(in_features=384, out_features=384, bias=True)
        (k_proj): Linear(in_features=384, out_features=384, bias=True)
        (v_proj): Linear(in_features=384, out_features=384, bias=True)
        (out_proj): Linear(in_features=384, out_features=384, bias=True)
      )
      (ff): Linear(in_features=384, out_features=384, bias=True)
      (relu): ReLU()
    )
    (to_hidden): Linear(in_features=384, out_features=10, bias=True)
    (output_mean): Linear(in_features=10, out_features=2, bias=True)
    (output_logvar): Linear(in_features=10, out_features=2, bias=True)
    (relu): ReLU()
  )
  (decoder): Decoder(
    (input_layer): Linear(in_features=2, out_features=10, bias=True)
    (mid_layers): Sequential(
      (0): Linear(in_features=10, out_features=10, bias=True)
      (1): ReLU()
      (2): Linear(in_features=10, out_features=10, bias=True)
      (3): ReLU()
  

In [5]:
preds,mu,log_var=model(torch.zeros(1,10,384))

In [7]:
mu.shape

torch.Size([1, 2])

In [6]:

train_dataloader = get_dataloader(path=config.DATA_PATH,mode='train',batch_size=64)
validation_dataloader = get_dataloader(path=config.DATA_PATH,mode='validation',batch_size=64)

best_loss = float('inf')

epoch_bar = tqdm(range(config.EPOCHS),desc=f'Training model',leave=False)
for epoch in epoch_bar:
    avg_loss = 0
    avg_val_loss = 0

    for inputs, targets in train_dataloader:
        inputs, targets = inputs.to(config.DEVICE), targets.to(config.DEVICE)
        preds,mu,log_var = model(inputs)

        optim.zero_grad()
        loss = generative_loss(preds=preds,targets=targets,mu= mu, log_var= log_var)
        loss.backward()
        optim.step()

        avg_loss += loss.item()

    avg_loss /= len(train_dataloader)

    model.eval()
    with torch.no_grad():
        for val_inputs, val_targets in validation_dataloader:
            val_inputs, val_targets = val_inputs.to(config.DEVICE), val_targets.to(config.DEVICE)
            val_preds,val_mu,val_log_var = model(val_inputs)

            loss = generative_loss(preds=val_preds,targets=val_targets,mu= val_mu, log_var= val_log_var)
            avg_val_loss += loss.item()

    avg_val_loss /= len(validation_dataloader)
    epoch_bar.set_postfix(Epoch=(epoch+1),Train_Loss=avg_loss,Validation_loss=avg_val_loss)

    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save(model.state_dict(), config.MODEL_SAVE_PATH)
    if (epoch+1) % 100 == 0:
        tqdm.write(f"Epoch {epoch+1}: Train Loss={avg_loss:.4f}, Val Loss={avg_val_loss:.4f}")

FileNotFoundError: No such file or directory: ./data/spotify.safetensors