In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
!python main.py

In [None]:
import main

In [None]:
main.train_AAE()

In [None]:
main.encoding()

In [None]:
main.train_LDM()

# **Checking the performance of AAE in the SPECT domain**

In [None]:
import os, nibabel as nib, numpy as np, torch
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import config
from model import AAE
from utils import load_checkpoint
from dataset import center_crop

device = torch.device(config.device)
out_dir = os.path.join("result", str(config.exp), "aae_recon_test")
os.makedirs(out_dir, exist_ok=True)

latent_dir = config.latent_Abeta
sample_id = os.listdir(latent_dir)[0]  # o elige uno
latent_path = os.path.join(latent_dir, sample_id)

lat = nib.load(latent_path).get_fdata().astype(np.float32)
x_lat = torch.tensor(lat[None,None,...], device=device)
aae = AAE().to(device)
opt = torch.optim.Adam(aae.parameters(), lr=config.learning_rate)
load_checkpoint(config.CHECKPOINT_AAE, aae, opt, config.learning_rate)
aae.eval()

with torch.no_grad():
    recon = aae.decoder(x_lat)
recon = torch.clamp(recon, 0, 1).cpu().numpy().squeeze().astype(np.float32)

orig_id = os.path.splitext(sample_id)[0]
orig_path = os.path.join(config.whole_Abeta, orig_id + ".nii")
if not os.path.exists(orig_path):
    orig_path = orig_path + "gz"  # si es .nii.gz
orig_img = nib.load(orig_path)
orig = center_crop(orig_img.get_fdata().astype(np.float32), config.crop_size)
vmin, vmax = orig.min(), orig.max()
orig = np.zeros_like(orig) if vmax - vmin < 1e-8 else (orig - vmin)/(vmax - vmin)
rng = max(orig.max() - orig.min(), 1e-8)
print("PSNR", psnr(orig, recon, data_range=rng), "SSIM", ssim(orig, recon, data_range=rng))

nib.save(nib.Nifti1Image(recon, orig_img.affine),
         os.path.join(out_dir, f"{orig_id}_decoder_only.nii.gz"))

In [None]:
import numpy as np, nibabel as nib, csv, torch
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import config
from dataset import resolve_nifti, center_crop  # usa center_crop
from model import AAE

device = torch.device(config.device)
model = AAE().to(device)
opt = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
ckpt = torch.load(config.CHECKPOINT_AAE, map_location=device)
model.load_state_dict(ckpt["state_dict"])
model.eval()

ids = [i.strip() for i in open(config.test) if i.strip()]
psnr_sum = ssim_sum = 0

with open("result/exp_1/aae_recon_test.csv", "w", newline="") as f:
    writer = csv.writer(f); writer.writerow(["ID","PSNR","SSIM"])
    for bid in ids:
        path = resolve_nifti(config.whole_Abeta, bid)
        img = nib.load(path)
        data = img.get_fdata().astype(np.float32)
        data = center_crop(data, config.crop_size)
        vmin, vmax = data.min(), data.max()
        data = np.zeros_like(data) if vmax - vmin < 1e-8 else (data - vmin)/(vmax - vmin)

        x = torch.tensor(data[None,None,...], device=device)
        with torch.no_grad():
            recon = model(x)
        recon = torch.clamp(recon, 0, 1).cpu().numpy().squeeze().astype(np.float32)

        # Asegúrate de que coinciden las formas
        if recon.shape != data.shape:
            print(f"{bid} shapes mismatch: data {data.shape}, recon {recon.shape}")
            continue

        rng = max(data.max() - data.min(), 1e-8)
        ps = psnr(data, recon, data_range=rng)
        ss = ssim(data, recon, data_range=rng)
        writer.writerow([bid, ps, ss])
        psnr_sum += ps; ssim_sum += ss

print("Promedio PSNR", psnr_sum/len(ids), "SSIM", ssim_sum/len(ids))

In [None]:
!python aae_eval.py

In [None]:
import os, nibabel as nib, numpy as np, torch
gen_path = "/workspace/projects/T1-SPECT-translation/IL-CLDM/data/latent_SPECT"
files = os.listdir(gen_path)

for file in files:
    #getting the shape of each file
    latent = nib.load(os.path.join(gen_path, file)).get_fdata()
    if latent.shape != (40, 48, 40):
        print(f"File: {file}, Shape: {latent.shape}")

# **Checking the latent MRI representation**

In [None]:
import os, nibabel as nib, torch
from dataset import resolve_nifti, center_crop, z_score_norm
from model import UNet
import config
import numpy as np

device = torch.device(config.device)

# Carga el UNet con el image_size correcto
unet = UNet(in_channel=2, out_channel=1, image_size=config.latent_shape[0]).to(device)
unet.eval()

ids = [i.strip() for i in open(config.train) if i.strip()][:3]  # prueba con 3 IDs
out_dir = "result/exp_1/mri_latent_cond"
os.makedirs(out_dir, exist_ok=True)

with torch.no_grad():
    for bid in ids:
        mri_path = resolve_nifti(config.whole_MRI, bid)
        img = nib.load(mri_path)
        mri = img.get_fdata().astype(np.float32)
        # mismo preprocesado que en TwoDataset
        from dataset import resample_to_shape
        mri = resample_to_shape(mri, config.target_shape)
        mri = center_crop(mri, config.crop_size)
        mri = z_score_norm(mri)

        x = torch.tensor(mri[None, None, ...], device=device)
        cond_latent = unet.cond(x).cpu().numpy().squeeze()  # esperado: 40x48x40 con config actual

        print(bid, "cond shape", cond_latent.shape, "min/max", cond_latent.min(), cond_latent.max())
        nib.save(nib.Nifti1Image(cond_latent, img.affine),
                 os.path.join(out_dir, f"{bid}_mri_cond.nii.gz"))


# **Checking the data input**

In [None]:
import nibabel as nib
import numpy as np
import os

base_path = "/workspace/projects/T1-SPECT-translation/IL-CLDM/data/whole_SPECT"

files = os.listdir(base_path)
nan_values = []

for file in files:
    if file.endswith(".nii"):
        filepath = os.path.join(base_path, file)
        img = nib.load(filepath)
        data = img.get_fdata()
        print("dimansions of the file:", data.shape)
        if np.isnan(data).any():
            print(f"El archivo {file} contiene NaNs.")
            #cuantos NaNs hay
            num_nans = np.isnan(data).sum()
            print(f"Número de NaNs en el archivo {file}: {num_nans}")
            nan_values.append((file, num_nans))
        else:
            print(f"El archivo {file} no contiene NaNs.")

#obtener el file con menos NaNs y el de mas NaNs
if nan_values:
    nan_values.sort(key=lambda x: x[1])
    print(f"Archivo con menos NaNs: {nan_values[0][0]} con {nan_values[0][1]} NaNs")
    print(f"Archivo con más NaNs: {nan_values[-1][0]} con {nan_values[-1][1]} NaNs")

In [None]:
file = "3863.nii"
filepath = os.path.join(base_path, file)
img = nib.load(filepath)
data = img.get_fdata()
print("dimansions of the file:", data.shape)

In [None]:
import nibabel as nib, numpy as np
from pathlib import Path

def resolve(root, bid):
    for ext in ('.nii', '.nii.gz'):
        p = Path(root)/f"{bid}{ext}"
        if p.exists(): return p
    raise FileNotFoundError(bid)

def stats(path):
    data = nib.load(str(path)).get_fdata()
    data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
    return data.min(), data.max(), data.mean(), data.std()

root_abeta = Path('data/whole_SPECT')
root_mri = Path('data/whole_MRI')
train_ids = Path('data_info/train.txt').read_text().splitlines()[:-10]  # cambia [:3] por más
print('IDs', train_ids)
for bid in train_ids:
    pA = resolve(root_abeta, bid)
    pM = resolve(root_mri, bid)
    amin, amax, amean, astd = stats(pA)
    mmin, mmax, mmean, mstd = stats(pM)
    print(f"{bid} Abeta: min {amin:.4f} max {amax:.4f} mean {amean:.4f} std {astd:.4f}")
    print(f"{bid} MRI  : min {mmin:.4f} max {mmax:.4f} mean {mmean:.4f} std {mstd:.4f}")

# **Checking the PSNR AND SSIM metric**

In [None]:
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from dataset import OneDataset
import config

ds = OneDataset(root_Abeta=config.whole_Abeta, task=config.validation, stage="validation")
for i in range(3):  # primeros 3
    vol, name = ds[i]
    vol = vol.astype(np.float32)
    rng = max(vol.max() - vol.min(), 1e-8)
    print(name, "range", vol.min(), vol.max(),
          "PSNR self", psnr(vol, vol, data_range=rng),
          "SSIM self", ssim(vol, vol, data_range=rng))
    # Volumen con ruido para ver caída
    noise = np.random.normal(0, 0.01, size=vol.shape).astype(np.float32)
    vol_noisy = np.clip(vol + noise, 0, 1)
    print("  noisy PSNR", psnr(vol, vol_noisy, data_range=rng),
          "noisy SSIM", ssim(vol, vol_noisy, data_range=rng))

In [None]:
import os, nibabel as nib, numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import config

ids = open(config.validation).read().split()
for bid in ids[:10]:  # revisa los primeros 10
    path = None
    for ext in (".nii", ".nii.gz"):
        p = os.path.join(config.whole_Abeta, bid + ext)
        if os.path.exists(p):
            path = p
            break
    if not path:
        print("no file for", bid)
        continue
    img = nib.load(path).get_fdata().astype(np.float32)
    vol = img[10:170, 18:210, 10:170]  # mismo crop
    vmin, vmax = vol.min(), vol.max()
    vol = np.zeros_like(vol) if vmax - vmin < 1e-8 else (vol - vmin) / (vmax - vmin)
    rng = max(vol.max() - vol.min(), 1e-8)
    psnr_self = psnr(vol, vol, data_range=rng)
    ssim_self = ssim(vol, vol, data_range=rng)
    print(bid, "range", vol.min(), vol.max(), "psnr_self", psnr_self, "ssim_self", ssim_self)
    if np.isfinite(ssim_self) and ssim_self > 0.9:
        break


# **Plotting curves of the results**

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
gen_path = "/workspace/projects/T1-SPECT-translation/IL-CLDM/result/exp_1"
file1 = "loss_curve.csv"
file2 = "validation.csv"

loss_df = pd.read_csv(os.path.join(gen_path, file1))
similarity_df = pd.read_csv(os.path.join(gen_path, file2))

In [None]:
loss_df.head()

In [None]:
similarity_df.head()

In [None]:
#plotting loss curve
plt.figure(figsize=(10,5))
plt.plot(loss_df['Epoch'], loss_df['recon_loss'], label='Reconstruction Loss')
plt.plot(loss_df['Epoch'], loss_df['disc_loss_epoch'], label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Loss Curve')
plt.legend()
plt.grid()
plt.show()

In [None]:
#plotting the similarity metrics
plt.figure(figsize=(10,5))
plt.plot(similarity_df['Epoch'], similarity_df['SSIM'], label='SSIM')
plt.plot(similarity_df['Epoch'], similarity_df['PSNR'], label='PSNR')
plt.xlabel('Epoch')
plt.ylabel('Similarity Metrics')
plt.title('Validation Similarity Metrics')
plt.legend()
plt.grid()
plt.show()

# **Converting the nii.gz into nii files**

In [None]:
import nibabel as nib
import gzip
import shutil
from pathlib import Path

In [None]:
input_folder = "/workspace/projects/T1-SPECT-translation/IL-CLDM/data/whole_SPECT"
output_folder = "/workspace/projects/T1-SPECT-translation/IL-CLDM/data/whole_SPECT2"
Path(output_folder).mkdir(parents=True, exist_ok=True)

for f in Path(input_folder).glob("*.nii.gz"):
    img = nib.load(str(f))
    out_path = Path(output_folder) / f.with_suffix('').name  # quita .gz → .nii
    nib.save(img, str(out_path))
    print("Guardado:", out_path)

# **Testing the trained model**

In [None]:
!python infer.py