In [1]:
import dataloader as dl
import torch
import GPVarInf as GPVI

In [2]:
patients = dl.load_patients_from_csv_files("Data/data_table.csv", "Data/drug_mapping_table.csv", "Data/drug_condition_atc_table.csv")



For now only get a subset

In [3]:
torch.stack([patient.covariates for patient in patients], dim=0)

tensor([[46.0476,  1.0000],
        [63.4286,  1.0000],
        [44.0000,  1.0000],
        ...,
        [59.5000,  1.0000],
        [64.0000,  1.0000],
        [66.0000,  1.0000]])

Now get the covariates X and the A matrix. First clean the drugs

In [4]:
for patient in patients:
    for visit in patient.visits:
        # Reassign the drugs list to only those drugs that have a non-empty atcs attribute.
        visit["drugs"] = [drug for drug in visit["drugs"] 
                          if hasattr(drug, "atcs") and drug.atcs is not None and len(drug.atcs) > 0]


In [5]:
A, cond_to_index = dl.construct_A_matrix(patients)

X_cov = dl.construct_covariate_matrix(patients)

In [9]:
X_cov

tensor([[46.0476,  1.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [63.4286,  1.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [44.0000,  1.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [59.5000,  1.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [64.0000,  1.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [66.0000,  1.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

In [6]:
condition_list = dl.get_all_conditions_from_drugs(patients)

In [9]:
A.shape

torch.Size([6491, 43, 31])

In [7]:
# Create a grid of condition indices and continuous features
n_conditions = len(condition_list)
n_patients = X_cov.shape[0]

# Create condition indices grid
condition_indices = torch.arange(n_conditions).repeat(n_patients, 1).T.reshape(-1, 1).float()
continuous_features = X_cov[:,0].repeat(n_conditions, 1).reshape(-1, 1)

# Combine into input tensor x with shape (n_conditions * n_patients, 2)
x = torch.cat([condition_indices, continuous_features], dim=1)

# Create inducing points by randomly selecting a subset of x
num_inducing = 50 # Choose a reasonable number of inducing points
inducing_indices = torch.randperm(x.shape[0])[:num_inducing]
inducing_points = x[inducing_indices]


In [8]:
# Initialize the variational GP model
model = GPVI.KNNVariationalGP(inducing_points, condition_list, k_neighbors=5)

# Train the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

n_epochs = 100
# A has shape [6491, 43, 31] based on earlier output
# A[:,:,1] selects all patients (6491), all conditions (43), and the 2nd feature (index 1)
# .view(-1,1) reshapes it into a single column vector of length 6491*43
y = A[:,:,1].view(-1,1)  # Shape: [279113, 1]

# Track training progress
best_elbo = float('-inf')
patience = 10
patience_counter = 0

print("Starting training...")
initial_elbo = model.elbo(x, y).item()

for epoch in range(n_epochs):
    optimizer.zero_grad()
    
    # Compute ELBO
    current_elbo = model.elbo(x, y)
    loss = -current_elbo
    
    # Backprop and optimize
    loss.backward()
    optimizer.step()
    
    # Early stopping check
    if current_elbo.item() > best_elbo:
        best_elbo = current_elbo.item()
        patience_counter = 0
    else:
        patience_counter += 1
        
    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch}")
        break
        
    if epoch % 5 == 0:
        print(f'Epoch {epoch}: ELBO = {current_elbo.item():.3f}')

final_elbo = model.elbo(x, y).item()
print(f"ELBO improved from {initial_elbo:.3f} to {final_elbo:.3f}")

# Get predictions
with torch.no_grad():
    f_pred = model(x)

Starting training...


KeyboardInterrupt: 

In [10]:
x.size()

torch.Size([220374000, 2])

In [None]:
X_cov