In [None]:
# Training function for CANM and TransformerVAE
def fit_model(model_class, traindata, N=1, beta=0.1, batch_size=128, epochs=450, learning_rate=1e-5, prior_sdy=0.95, update_sdy=True, verbose=False):
    model = model_class(N).to('cuda' if torch.cuda.is_available() else 'cpu')
    traindata = torch.from_numpy(traindata).float()
    train_loader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, shuffle=True)

    # Set up `sdy` to be on the correct device
    sdy = torch.tensor([prior_sdy], requires_grad=True, device=next(model.parameters()).device)
    optimizer = optim.AdamW([{'params': model.parameters()}, {'params': [sdy]}] if update_sdy else model.parameters(), lr=learning_rate)

    score = []
    for epoch in range(1, epochs + 1):
        train_loss = 0
        for data in train_loader:
            optimizer.zero_grad()
            y, yhat, mu, logvar = data[:, 1].view(-1, 1), *model(data)
            loss = loss_function(y, yhat, mu, logvar, sdy, beta)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        avg_loss = -train_loss / len(train_loader)
        score.append(avg_loss)
        if verbose:
            print(f'Epoch {epoch} - Avg Loss: {avg_loss:.4f}')
    return model, score