In [3]:
import sys
sys.path.insert(0, "/home/jwiers/deeprisk/new_codebase/enf-biobank/")

import ml_collections
from ml_collections import config_flags
from absl import app

import jax
import jax.numpy as jnp
import optax
import logging
from tqdm import tqdm
import matplotlib.pyplot as plt
import itertools
import wandb
import time
import os
import pickle
# Custom imports
from experiments.datasets import get_dataloaders
from enf.model import EquivariantNeuralField
from enf.bi_invariants import TranslationBI
from enf.utils import create_coordinate_grid, initialize_latents

from experiments.downstream_models.transformer_enf import TransformerClassifier

jax.config.update("jax_default_matmul_precision", "highest")

In [5]:
# Define the reconstruction model
recon_enf = EquivariantNeuralField(
    num_hidden=128,
    att_dim=64,
    num_heads=3,
    num_out=1,
    emb_freq=(30.0, 60.0),
    nearest_k=4,
    bi_invariant=TranslationBI(),
    gaussian_window=True,
)

In [13]:
Z = 2 
T = 5 
H = 60 
W = 60

x = create_coordinate_grid(
        batch_size=4, 
        img_shape=(Z, T, H, W, 1),
        num_in=4  # 3D coordinates
    )

In [14]:
print(x.shape)

(4, 36000, 4)


In [18]:
config = ml_collections.ConfigDict()
config.seed = 68
config.debug = False
config.run_name = "biobank_reconstruction_3d_lvef_autodecoding"
config.exp_name = "test"

# Reconstruction model
config.recon_enf = ml_collections.ConfigDict()
config.recon_enf.num_hidden = 128
config.recon_enf.num_heads = 3
config.recon_enf.att_dim = 64
config.recon_enf.num_in = 4  
config.recon_enf.num_out = 1  
config.recon_enf.freq_mult = (30.0, 60.0)
config.recon_enf.k_nearest = 4
config.recon_enf.latent_noise = True

config.recon_enf.num_latents = 625 # 64 x z 
config.recon_enf.latent_dim = 32
config.recon_enf.z_positions = 2
config.recon_enf.even_sampling = True
config.recon_enf.gaussian_window = True

# Dataset config
config.dataset = ml_collections.ConfigDict()
config.dataset.num_workers = 0
config.dataset.num_patients_train = 50 
config.dataset.num_patients_test = 10
config.dataset.z_indices = (0, 1)  # Which z-slices to use

# Optimizer config
config.optim = ml_collections.ConfigDict()
config.optim.lr_enf = 5e-4 
config.optim.inner_lr = (0., 60., 0.) # (pose, context, window), orginally (2., 30., 0.) # NOTE: Try 1e-3 


# Training config
config.train = ml_collections.ConfigDict()
config.train.batch_size = 4
config.train.noise_scale = 1e-1  # Noise added to latents to prevent overfitting
config.train.num_epochs_train = 100
config.train.log_interval = 50
config.train.num_subsampled_points = None  # Will be set based on image shape

In [19]:
key = jax.random.PRNGKey(55)

# Create dummy latents for model init
key, subkey = jax.random.split(key)
temp_z = initialize_latents(
    batch_size=1,  # Only need one example for initialization
    num_latents=config.recon_enf.num_latents,
    latent_dim=config.recon_enf.latent_dim,
    data_dim=config.recon_enf.num_in,
    bi_invariant_cls=TranslationBI,
    key=subkey,
    noise_scale=config.train.noise_scale,
    even_sampling=config.recon_enf.even_sampling,
    latent_noise=config.recon_enf.latent_noise,
    z_positions=config.recon_enf.z_positions,
)

# Init the model
recon_enf_params = recon_enf.init(key, x, *temp_z)

# Define optimizer for the ENF backbone
enf_opt = optax.adam(learning_rate=config.optim.lr_enf)
recon_enf_opt_state = enf_opt.init(recon_enf_params)

In [22]:
z = initialize_latents(
    batch_size=4, 
    num_latents=config.recon_enf.num_latents,
    latent_dim=config.recon_enf.latent_dim,
    data_dim=config.recon_enf.num_in,
    bi_invariant_cls=TranslationBI,
    key=subkey,
    noise_scale=config.train.noise_scale,
    even_sampling=config.recon_enf.even_sampling,
    latent_noise=config.recon_enf.latent_noise,
    z_positions=config.recon_enf.z_positions,
)

img_r = recon_enf.apply(recon_enf_params, x, *z).reshape((4, Z, T, H, W, 1))
img_r.shape

(4, 2, 5, 60, 60, 1)