In [None]:
# Training function
def fit_model(model_class, traindata, testdata=None, N=1, beta=1, batch_size=128, epochs=450, logpx=None, cuda=False, seed=0,
              log_interval=10, learning_rate=1e-4, prior_sdy=0.90, update_sdy=True, preload=False, warming_up=False, verbose=False, debug=False):

    torch.manual_seed(seed)
    np.random.seed(seed)
    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    # DataLoader setup
    kwargs = {'num_workers': 1, 'pin_memory': False}
    if logpx is None:
        pde = gaussian_kde(traindata[0, :])
        logpx = np.log(pde(traindata[0, :])).mean()

    traindata = torch.from_numpy(traindata).float()
    train_loader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, shuffle=True, **kwargs)
    if testdata is not None:
        testdata = torch.from_numpy(testdata).float()
        test_loader = torch.utils.data.DataLoader(testdata, batch_size=batch_size, shuffle=True, **kwargs)

    model = model_class(N).to(device)
    sdy = torch.tensor([prior_sdy], device=device, dtype=torch.float, requires_grad=True)
    optimizer = optim.Adam([{'params': model.parameters()}, {'params': [sdy]}], lr=learning_rate)

    train_scores, test_scores = [], []

    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0
        beta_epoch = min(beta, beta * (epoch / 100))  # Gradually increase beta

        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            y = data[:, 1].view(-1, 1)
            yhat, mu, logvar = model(data)
            loss = loss_function(y, yhat, mu, logvar, sdy, beta_epoch, logpx)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = -train_loss / len(train_loader.dataset)
        train_scores.append(avg_train_loss)
        if verbose and epoch % log_interval == 0:
            print(f'Epoch {epoch}, Avg Train Loss: {avg_train_loss:.4f}')

        if testdata is not None:
            model.eval()
            test_loss = 0
            with torch.no_grad():
                for data in test_loader:
                    data = data.to(device)
                    y = data[:, 1].view(-1, 1)
                    yhat, mu, logvar = model(data)
                    test_loss += loss_function(y, yhat, mu, logvar, sdy, beta_epoch, logpx).item()
            avg_test_loss = -test_loss / len(test_loader.dataset)
            test_scores.append(avg_test_loss)
            if verbose:
                print(f'Epoch {epoch}, Avg Test Loss: {avg_test_loss:.4f}')

    output = {
        'train_likelihood': avg_train_loss,
        'train_score': train_scores,
        'sdy': sdy.detach().cpu().numpy()
    }
    if testdata is not None:
        output['test_likelihood'] = avg_test_loss
        output['test_score'] = test_scores
    if debug:
        output['model'] = model

    return output
