In [10]:
import os
import requests

# Download model file and checkpoints from HuggingFace
if not os.path.exists("fMRIVAE_Model.py"):
    print("Downloading model file...")
    response = requests.get("https://huggingface.co/cindyhfls/fcMRI-VAE/resolve/main/fMRIVAE_Model.py")
    response.raise_for_status()  # Raise an error if the request fails
    with open("fMRIVAE_Model.py", "wb") as f:
        f.write(response.content)
    print("Download complete.")
else:
    print(f"fMRIVAE_Model.py already exists. Skipping download.")

checkpoint_filenames = [
    "checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar", 
    "checkpoint49_2024-06-21_Zdim_4_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar", 
    "checkpoint49_2024-11-28_Zdim_3_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar"
]

for filename in checkpoint_filenames:
    if not os.path.exists("Checkpoint/" + filename):
        print(f"Downloading checkpoint: {filename}")
        url = "https://huggingface.co/cindyhfls/fcMRI-VAE/resolve/main/Checkpoint/" + filename
        response = requests.get(url)
        response.raise_for_status()  # Raise an error if the request fails
        with open("Checkpoint/" + filename, "wb") as f:
            f.write(response.content)

Downloading model file...
Download complete.
Downloading checkpoint: checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar
Downloading checkpoint: checkpoint49_2024-06-21_Zdim_4_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar
Downloading checkpoint: checkpoint49_2024-11-28_Zdim_3_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar


In [11]:
import torch
import scipy.io as io
# import argparse
import logging
from utils import load_dataset_test, save_image_mat
from fMRIVAE_Model import BetaVAE

In [None]:
# --- Parameters (set these manually or through UI) ---
batch_size = 16  # How many samples per saved file
seed = 42
zdim = 2 # latent dimension
data_path = '/content/fmri_data/'        # Customize this path
z_path = '/content/vae_output/latent/' 
resume = 'Checkpoint/checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar' # should correspond with zdim
img_path = '/content/vae_output/recon/'
mode = 'encode'  # 'encode', 'decode', or 'both'
debug = True

# --- Logging Setup ---
logging_level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Create directories if needed ---
os.makedirs(z_path, exist_ok=True)
if mode != 'encode':
    os.makedirs(img_path, exist_ok=True)

In [None]:
# set seed and device
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

In [None]:
# load model
model = BetaVAE(z_dim=zdim, nc=1).to(device)
if os.path.isfile(resume):
    checkpoint = torch.load(resume, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    logging.info("Checkpoint loaded successfully.")
else:
    raise RuntimeError(f"Checkpoint not found at {resume}")

In [None]:
if mode in ['encode', 'both']:
    logging.info("Starting encoding process...")
    test_loader = load_dataset_test(data_path, batch_size)
    for batch_idx, (xL, xR) in enumerate(test_loader):
        xL = xL.to(device)
        xR = xR.to(device)
        z_distribution = model._encode(xL, xR)
        save_data = {'z_distribution': z_distribution.detach().cpu().numpy()}
        io.savemat(os.path.join(z_path, f'save_z{batch_idx}.mat'), save_data)
        logging.debug(f"Encoded batch {batch_idx}")

if mode in ['decode', 'both']:
    logging.info("Starting decoding process...")
    filelist = sorted([f for f in os.listdir(z_path) if f.startswith('save') and f.endswith('.mat')])
    for batch_idx, filename in enumerate(filelist):
        logging.debug(f"Decoding file {filename}")
        z_dist = io.loadmat(os.path.join(z_path, filename))['z_distribution']
        mu = z_dist[:, :zdim]
        z = torch.tensor(mu).to(device)
        x_recon_L, x_recon_R = model._decode(z)
        save_image_mat(x_recon_R, x_recon_L, img_path, batch_idx)
        logging.debug(f"Decoded and saved batch {batch_idx}")