### Import libraries

In [None]:
import pandas as pd
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.nn import PyroModule, PyroSample
from pyro.infer.autoguide import AutoDiagonalNormal
import torch.nn.functional as F
from tqdm import tqdm

### Load data

In [None]:
# Set random seed
pyro.set_rng_seed(42)

# File paths
train_path = "/work3/s214806/working_chunk_train.csv"
val_path = "/work3/s214806/working_chunk_val.csv"
test_path = "/work3/s214806/working_chunk_test.csv"

LABEL_MAPPING = {'AML': 0, 'ALL': 1, 'Normal': 2}

In [None]:
# Load and preprocess data
def load_and_preprocess(path):
    df = pd.read_csv(path, index_col=0, low_memory=False).T
    print(f"Gene expression data BEFORE removing non-numeric columns: {df.shape[1]}")

    labels = df['Cancer'].map(LABEL_MAPPING).dropna()

    gene_expressions = df.drop(['Cancer', 'Preservation Method', 'Tissue Type', 'Tumor Descriptor', 'Specimen Type'], 
                               axis=1, errors='ignore')

    gene_expressions = gene_expressions.apply(pd.to_numeric, errors='coerce')
    gene_expressions = gene_expressions.dropna(axis=1)
    print(f"Gene expression data AFTER removing non-numeric columns: {gene_expressions.shape[1]}")

    gene_expressions = gene_expressions.astype(np.float32)
    return gene_expressions.values, labels.loc[gene_expressions.index].values

In [None]:
# Load data
# x_train_np, y_train_np = load_and_preprocess(train_path)
x_val_np, y_val_np = load_and_preprocess(val_path)

# # Standardize input features
# x_train_mean = x_train_np.mean(axis=1)
# x_train_std = x_train_np.std(axis=1)
# x_train_np = (x_train_np - x_train_mean) / x_train_std

x_val_mean = x_val_np.mean(axis=0)
x_val_std = x_val_np.std(axis=0)
x_val_np = (x_val_np - x_val_mean) / x_val_std


# # Convert to tensors
# x_train = torch.tensor(x_train_np, dtype=torch.float32)
# y_train = torch.tensor(y_train_np, dtype=torch.long)

x_val = torch.tensor(x_val_np, dtype=torch.float32)
y_val = torch.tensor(y_val_np, dtype=torch.long)

In [None]:
# print(f"x_train shape: ", x_train.shape)
# print(f"y_train shape: ", y_train.shape)

print(f"x_val shape: ", x_val.shape)
print(f"y_val shape: ", y_val.shape)

### Linear Classification

### Define Classifier

In [None]:
# class HDPMMClassifier(PyroModule):
#     def __init__(self, num_genes, num_classes=3, num_clusters=20):
#         super().__init__()
#         self.K = num_clusters
#         self.num_genes = num_genes
#         self.num_classes = num_classes

#         # Using lambdas to return a distribution
#         self.cluster_means = PyroSample(
#             lambda self: dist.Normal(0., 1.).expand([self.K, self.num_genes]).to_event(2)
#         )
#         self.cluster_scales = PyroSample(
#             lambda self: dist.HalfCauchy(1.0).expand([self.K, self.num_genes]).to_event(2)
#         )
#         self.classifier_weights = PyroSample(
#             lambda self: dist.Normal(0., 1.).expand([self.num_classes, self.K]).to_event(2)
#         )

#         self.alpha = 1.0  # Concentration parameter for stick-breaking

#     def model(self, x, y=None):
#         N = x.shape[0]

#         cluster_means = pyro.sample("cluster_means", self.cluster_means)
#         cluster_scales = pyro.sample("cluster_scales", self.cluster_scales)
#         classifier_weights = pyro.sample("classifier_weights", self.classifier_weights)

#         beta = pyro.sample("beta", dist.Beta(1, self.alpha).expand([self.K - 1]))
#         weights = self.stick_breaking(beta)  # shape: [K]

#         with pyro.plate("samples", N):
#             z = pyro.sample("z", dist.Categorical(probs=weights.expand([N, self.K])))
#             mu = cluster_means[z]
#             sigma = cluster_scales[z]

#             pyro.sample("obs", dist.Normal(mu, sigma).to_event(1), obs=x)

#             z_onehot = F.one_hot(z, num_classes=self.K).float()
#             logits = torch.matmul(z_onehot, classifier_weights.T)

#             pyro.sample("label", dist.Categorical(logits=logits), obs=y)

#     def guide(self, x, y=None):
#         # Variational distributions over global parameters
#         cluster_means_loc = pyro.param("cluster_means_loc", torch.randn(self.K, self.num_genes))
#         cluster_means_scale = pyro.param("cluster_means_scale", torch.ones(self.K, self.num_genes), constraint=dist.constraints.positive)

#         cluster_scales_loc = pyro.param("cluster_scales_loc", torch.ones(self.K, self.num_genes), constraint=dist.constraints.positive)

#         classifier_weights_loc = pyro.param("classifier_weights_loc", torch.randn(self.num_classes, self.K))
#         classifier_weights_scale = pyro.param("classifier_weights_scale", torch.ones(self.num_classes, self.K), constraint=dist.constraints.positive)

#         beta_conc1 = pyro.param("beta_conc1", torch.ones(self.K - 1), constraint=dist.constraints.positive)
#         beta_conc0 = pyro.param("beta_conc0", torch.ones(self.K - 1), constraint=dist.constraints.positive)

#         pyro.sample("cluster_means", dist.Normal(cluster_means_loc, cluster_means_scale).to_event(2))
#         pyro.sample("cluster_scales", dist.HalfCauchy(cluster_scales_loc).to_event(2))
#         pyro.sample("classifier_weights", dist.Normal(classifier_weights_loc, classifier_weights_scale).to_event(2))
#         pyro.sample("beta", dist.Beta(beta_conc1, beta_conc0))

#     def stick_breaking(self, beta):
#         remaining_stick = torch.cumprod(1 - beta, dim=-1)
#         remaining_stick = torch.cat([torch.tensor([1.0], device=beta.device), remaining_stick], dim=0)
#         weights = beta * remaining_stick[:-1]
#         final_cluster = 1 - weights.sum(dim=-1, keepdim=True)
#         weights = torch.cat([weights, final_cluster], dim=-1)
#         return weights

### Initialize model

In [None]:
# # Initialize model
# hdpmm_model = HDPMMClassifier(num_genes=x_train.shape[1])
# guide = AutoDiagonalNormal(hdpmm_model.model)
# svi = SVI(hdpmm_model.model, guide, Adam({"lr": 1e-2}), loss=Trace_ELBO())

### Train loop

In [None]:
# # Training loop
# num_steps = 500
# for step in tqdm(range(num_steps)):
#     loss = svi.step(x_train, y_train)
#     if step % 50 == 0:
#         print(f"Step {step} - Loss: {loss:.2f}")