In [None]:
def fit(traindata, model_type='canm', testdata=None, N=1, beta=0.1, batch_size=64,
        epochs=100, logpx=None, cuda=False, seed=0, log_interval=10,
        learning_rate=1e-4, prior_sdy=0.80, update_sdy=True,
        preload=False, warming_up=False, verbose=False, debug=False):

    torch.set_num_threads(1)
    cuda = cuda and torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    if model_type == 'canm':
        model = CANM(N).to(device)
        loss_fn = canm_loss_function
    else:
        model = TransformerVAE(latent_dim=N, confounding_dim=1).to(device)
        loss_fn = transformer_loss_function

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)

    kwargs = {'num_workers': 1, 'pin_memory': False}
    if logpx is None:
        pde = gaussian_kde(traindata[0,:])
        logpx = np.log(pde(traindata[0,:])).mean()

    if preload:
        train_loader = traindata
        test_loader = testdata

    if not preload:
        traindata = torch.from_numpy(traindata).float()
        train_loader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, shuffle=True, **kwargs)

    if testdata is not None and not preload:
        testdata = torch.from_numpy(testdata).float()
        test_loader = torch.utils.data.DataLoader(testdata, batch_size=batch_size, shuffle=True, **kwargs)

    if update_sdy:
        sdy = torch.tensor([prior_sdy], device=device, dtype=torch.float, requires_grad=True)
        optimizer = optim.Adam([{'params': model.parameters()}, {'params': sdy}], lr=learning_rate)
    else:
        sdy = torch.tensor([prior_sdy], device=device, dtype=torch.float, requires_grad=False)
        optimizer = optim.Adam([{'params': model.parameters()}], lr=learning_rate)

    score = []
    score_test = []
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0
        wu_beta = beta / epoch if warming_up else beta

        for batch_idx, data in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            y = data[:, 1].view(-1, 1)

            if model_type == 'canm':
               # Updated unpacking for CANM
               yhat, mu, logvar, conf_mu, conf_logvar = model(data)

               # Compute loss with conf_mu and conf_logvar
               loss = canm_loss_function(y, yhat, mu, logvar, sdy, wu_beta, conf_mu, conf_logvar) - logpx * len(data)
            else:
                yhat, mu, logvar, conf_mu, conf_logvar = model(data)
                loss = loss_fn(y, yhat, mu, logvar, conf_mu, conf_logvar, sdy, wu_beta) - logpx * len(data)

            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            if update_sdy and sdy < 0.05:
                sdy = sdy + 0.05

            if verbose and batch_idx % log_interval == 0:
                print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')

        train_loss /= len(train_loader.dataset)
        score.append(-train_loss)

        if verbose:
            print(f'====> Epoch: {epoch} Average loss: {train_loss:.4f}')

        if testdata is not None:
            model.eval()
            test_loss = 0
            with torch.no_grad():
                for i, data in enumerate(test_loader):
                    data = data.to(device)

                    if model_type == 'canm':
                        yhat, mu, logvar, conf_z = model(data)
                        y = data[:, 1].view(-1, 1)
                        test_loss += loss_fn(y, yhat, mu, logvar, sdy, wu_beta, conf_z).item() - logpx * len(data)
                    else:
                        yhat, mu, logvar, conf_mu, conf_logvar = model(data)
                        y = data[:, 1].view(-1, 1)
                        test_loss += loss_fn(y, yhat, mu, logvar, conf_mu, conf_logvar, sdy, wu_beta).item() - logpx * len(data)

                test_loss /= len(test_loader.dataset)
                score_test.append(-test_loss)

                if verbose:
                    print(f'====> Test set loss: {test_loss:.4f}')

    output = {
        'train_likelihood': -float(train_loss),
        'train_score': score,
        'sdy': sdy.detach().numpy()
    }

    if testdata is not None:
        output.update({
            'test_likelihood': -float(test_loss),
            'test_score': score_test
        })

    if debug:
        output['model'] = model

    return output
