## Data

In [None]:
import numpy as np
from spect_cardiac.src.tools.manip.manip import normalize_volume

# data fetching and handling
from spect_cardiac.data.check_database import load_remote_data
from spect_cardiac.data.fetch_data import fetch_data
from spect_cardiac.src.tools.data.loadvolumes import LoadVolumes

import matplotlib.pyplot as plt
import torch

import pandas as pd
import nrrd

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:1") if use_cuda else "cpu"

In [None]:
dicom_loader = LoadVolumes()

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# read all filenames from the url
from bs4 import BeautifulSoup
import requests
page = requests.get(url + '/recon/' + 'spie_2024/' + 'misc/' + 'label/')
soup = BeautifulSoup(page.content, 'html.parser')
label_names = []
for label_ref in soup.find_all('a'):
    label_names.append(label_ref.get('href'))

page = requests.get(url + '/recon/' + 'spie_2024/' + 'misc/' + 'data/')
soup = BeautifulSoup(page.content, 'html.parser')
data_names = []
for label_ref in soup.find_all('a'):
    data_names.append(label_ref.get('href'))

subjects = []
subjects_data = []

# fetch specific patient data
for index in range(len(data_names)):

    dicom_name = data_names[index]
    label_name = label_names[index]
    data_url = url + '/recon/' + 'spie_2024/' + 'misc/' + 'data/' + dicom_name
    label_url = url + '/recon/' + 'spie_2024/' + 'misc/' + 'label/' + label_name
    
    # fetch the data from remote
    data = fetch_data(data_url)
    lab = fetch_data(label_url)
    
    # load data with the dicom loader
    volume, data_loaded = dicom_loader.LoadSinglePatient(data)
    header = nrrd.read_header(lab)
    labels = nrrd.read_data(header, lab)
    
    # looks like the label export is a bit tricky so loading shall be updated
    prob_val_1 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0))
    prob_val_2 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0))
    
    if prob_val_1 > prob_val_2:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0)
    else:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0)
        
    subject = {
        'spect' : volume,
        'left_ventricle' : labels
    }
    subjects.append(subject)
    
    age, gender, weight, height = dicom_loader.CalculatePatientStatistics()
    subject_data = {
        'age' : age,
        'gender' : gender,
        'weight' : weight,
        'height' : height
    }
    subjects_data.append(subject_data)

    # print("Volume shape: ", volume.shape, "Labels shape:", labels.shape)

    # normalizing the frame values
    normalize_volume(volume)

assert (data_loaded)

In [None]:
dicom_loader = LoadVolumes()

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# read all filenames from the url
from bs4 import BeautifulSoup
import requests
page = requests.get(url + '/recon/' + 'spie_2024/' + 'bela/' + 'label/')
soup = BeautifulSoup(page.content, 'html.parser')
label_names = []
for label_ref in soup.find_all('a'):
    label_names.append(label_ref.get('href'))

page = requests.get(url + '/recon/' + 'spie_2024/' + 'bela/' + 'data/')
soup = BeautifulSoup(page.content, 'html.parser')
data_names = []
for label_ref in soup.find_all('a'):
    data_names.append(label_ref.get('href'))

subjects_bela = []
subjects_bela_data = []

# fetch specific patient data
for index in range(len(data_names)):

    dicom_name = data_names[index]
    label_name = label_names[index]
    data_url = url + '/recon/' + 'spie_2024/' + 'bela/' + 'data/' + dicom_name
    label_url = url + '/recon/' + 'spie_2024/' + 'bela/' + 'label/' + label_name
    
    # fetch the data from remote
    data = fetch_data(data_url)
    lab = fetch_data(label_url)
    
    # load data with the dicom loader
    volume, data_loaded = dicom_loader.LoadSinglePatient(data)
    header = nrrd.read_header(lab)
    labels = nrrd.read_data(header, lab)
    
    # looks like the label export is a bit tricky so loading shall be updated
    prob_val_1 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0))
    prob_val_2 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0))
    
    if prob_val_1 > prob_val_2:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0)
    else:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0)
    
    subject = {
        'spect' : volume,
        'left_ventricle' : labels
    }
    subjects_bela.append(subject)
    
    age, gender, weight, height = dicom_loader.CalculatePatientStatistics()
    subject_data = {
        'age' : age,
        'gender' : gender,
        'weight' : weight,
        'height' : height
    }
    subjects_bela_data.append(subject_data)

    # print("Volume shape: ", volume.shape, "Labels shape:", labels.shape)

    # normalizing the frame values
    normalize_volume(volume)

assert (data_loaded)

In [None]:
import torch
import matplotlib.pyplot as plt
from monai.losses import DiceCELoss
from monai.data import DataLoader, Dataset
from monai.config import print_config
from monai.transforms import (
    Compose,
    CopyItemsd,
    SpatialPadd,
    EnsureChannelFirstd,
    Spacingd
)

In [None]:
for subject in subjects_bela:
    subjects.append(subject)

In [None]:
import random

random.seed(42)
shuffled_subjects = random.sample(subjects, len(subjects))

In [None]:
from monai.utils import set_determinism, first
train_data = []
val_data = []
test_data = []

train_size = 60
val_size = 14
test_size = 10

for i in range(train_size):
    train_data.append(shuffled_subjects[i])

for j in range(train_size, train_size + val_size, 1):
    val_data.append(shuffled_subjects[j])

# for j in range(train_size + val_size, train_size + val_size + test_size, 1):
#     test_data.append(shuffled_subjects[j])

# Set Determinism
set_determinism(seed=123)

In [None]:
# Define Training Transforms
train_transforms = Compose(
    [
        EnsureChannelFirstd(keys=["spect", 'left_ventricle'], channel_dim='no_channel'),
        Spacingd(keys=["spect", 'left_ventricle'], pixdim=(2.0, 2.0, 2.0), mode=("bilinear")),
        SpatialPadd(keys=["spect", 'left_ventricle'], spatial_size=(64, 64, 64)),
        # RandSpatialCropSamplesd(keys=["spect", 'left_ventricle'], roi_size=(64, 64, 64), random_size=False, num_samples=2),
        CopyItemsd(keys=["spect", 'left_ventricle'], allow_missing_keys=False),
        ]
)

## nnFormer

In [None]:
from nnFormer.nnformer.network_architecture.nnFormer_synapse import nnFormer
from torch import nn

In [None]:
nnformer_model = nnFormer(crop_size=[64,64,64],
                embedding_dim=192,
                input_channels=1, 
                num_classes=2, 
                conv_op=nn.Conv3d, 
                depths=[2,2,2,2],
                num_heads=[6, 12, 24, 48],
                patch_size=[2,4,4],
                window_size=[4,4,8,4],
                deep_supervision=True)

In [None]:
model = nnformer_model.to(device)

# Define Hyper-paramters for training loop
max_epochs = 100 # it was 50
val_interval = 1
batch_size = 6
# gradient_accumulation_steps = 4
lr = 1e-5
epoch_loss_values = []
step_loss_values = []
val_loss_values = []
best_val_loss = 1000.0

# Loss function
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Define DataLoader using MONAI, CacheDataset needs to be used
train_ds = Dataset(data=train_data, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

val_ds = Dataset(data=val_data, transform=train_transforms)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=True, num_workers=4)

In [None]:
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1

        inputs, labels = (
            batch_data["spect"].to(device),
            batch_data["left_ventricle"].to(device)
        )
        
        print(inputs.shape, labels.shape)
        
        optimizer.zero_grad()
        outputs = model(inputs)

        # Adjust the CL loss by Recon Loss
        total_loss = loss_fn(outputs[0], labels)

        total_loss.backward()
        optimizer.step()
        epoch_loss += total_loss.item()
        step_loss_values.append(total_loss.item())

        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {total_loss.item():.4f}, "
        )

    epoch_loss /= step

    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if epoch % val_interval == 0:
        print("Entering Validation for epoch: {}".format(epoch + 1))
        total_val_loss = 0
        val_step = 0
        model.eval()
        for val_batch in val_loader:
            val_step += 1
            inputs, labels = (
                val_batch["spect"].to(device),
                val_batch['left_ventricle'].to(device),
            )
            print("Input shape: {}".format(inputs.shape))
            outputs = model(inputs)
            val_loss = loss_fn(outputs[0], labels)
            total_val_loss += val_loss.item()

        total_val_loss /= val_step
        val_loss_values.append(total_val_loss)
        print(f"epoch {epoch + 1} Validation avg loss: {total_val_loss:.4f}")

        if total_val_loss < best_val_loss:
            print(f"Saving new model based on validation loss {total_val_loss:.4f}")
            best_val_loss = total_val_loss
            checkpoint = {"epoch": max_epochs, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
            torch.save(checkpoint, "spect_cardiac/nbs/batch_study_files/nnSynapse_100E_30D_6B.pt")
print("Done")

In [None]:
plt.plot(epoch_loss_values)
plt.title("Training loss")
plt.xlabel("Epochs")

In [None]:
plt.plot(val_loss_values)
plt.title("Validation loss")
plt.xlabel("Epochs")
plt.show()

## Prior

In [None]:
import copy as cp
import numpy as np
from scipy.spatial.distance import cdist
import torch
import point_cloud_utils as pcu

from scipy.spatial.transform import Rotation as R
import mcubes
from simpleicp import PointCloud, SimpleICP

from spect_cardiac.src.algs.arm import lv_indicator
from spect_cardiac.src.tools.recon.projector import forward_projector, backward_projector
from spect_cardiac.src.tools.manip.manip import normalize_volume

# data fetching and handling
from spect_cardiac.data.check_database import load_remote_data
from spect_cardiac.data.fetch_data import fetch_data
from spect_cardiac.src.tools.data.loadvolumes import LoadVolumes

In [None]:
dicom_loader = LoadVolumes()

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# fetch specific patient data
dicom_name = datasets['raw/']['turkey_par/'][10]
data_url = url + '/raw/' + 'turkey_par/' + dicom_name

# fetch the data from remote
data = fetch_data(data_url)

# load data with the dicom loader
frames, data_loaded = dicom_loader.LoadSinglePatient(data)

# normalizing the frame values
normalize_volume(frames)
frames = frames + 1

assert (data_loaded)

In [None]:
num_frames, width, height = frames.shape

bprojectpor = backward_projector()
lv_volume = bprojectpor(frames)

In [None]:
# lv_volume = np.random.rand(64, 64, 64)
normalize_volume(lv_volume)

num_prior = 9
shape_priors = np.zeros([num_prior, *lv_volume.shape])

wall_thickness = np.random.uniform(0.3, 1.0, num_prior)
rot_angles = np.random.uniform(0, 2 * np.pi, num_prior)
curvature = np.random.uniform(1.5, 3, num_prior)
sigmas = np.random.uniform(-0.5, -1, num_prior)

for i in range(num_prior):
    volume = np.zeros([*lv_volume.shape])
    params = dict(a=wall_thickness[i], c=curvature[i], sigma=sigmas[i])
    rot_mx = R.from_quat([0, 0, np.sin(rot_angles[i]), np.cos(rot_angles[i])])

    transform_params = [np.eye(3, 3), [16, 16, 0], 1.5]
    shape_priors[i] = lv_indicator(volume, params, transform_params, a_plot=False)

    recon_mode = 'basic'
    fprojector = forward_projector(recon_mode)

    frames = fprojector(shape_priors[i])
    
# lv_volume = shape_priors[3]

In [None]:
def nonlinear_shape_prior(shape_priors, kernel, sigma, centering_point):
    """
    Nonlinear statistics shape prior based on kernel density estimation in the feature space
        [1] Shape statistics in kernel space for variational image segmentation - Daniel Cremers, Timo Kohlberger,
                                                                                  Christoph Schnoerr
        [2] Active Shape Models - Their Training and Application - T. F. Cootes, C. J. Taylor, D. H. Cooper, J. Graham

    Args:
        z:
        z_i:
        sigma:

    Returns:
        energy:
    """
    m = shape_priors.shape[0]
    
    E = (1 / m) * torch.ones([m, m], dtype=torch.float64)
    K = torch.zeros([m, m], dtype=torch.float64)
    
    height, width, depth = shape_priors[0].shape
    z_i = []
    shape_face_count = torch.zeros([m], dtype=torch.int32)
    shape_faces = []
    for i in range(m):
        verts_shape, tri_shape = mcubes.marching_cubes(shape_priors[i], 0.0)
        cur_prior_shape = verts_shape / depth
        
        # set mesh size to 1 and move it to the centering point
        verts_dist = cdist(cur_prior_shape, cur_prior_shape, 'euclidean')
        verts_scaled = cur_prior_shape * 1.0 / verts_dist.max()
        verts_scaled_translation = centering_point - verts_scaled.mean(axis=0)
        verts_translated = verts_scaled + verts_scaled_translation 

        z_i.append(torch.from_numpy(verts_translated))
        shape_faces.append(tri_shape)
        shape_face_count[i] = tri_shape.shape[0]
     
    min_shape_face_count = shape_face_count.min()
    # if k_til is wrongfully implemented or slow, or numerically unstable, 
    # then one can use K_til̃ = K − KE − EK + EKE
    mean_shape = z_i[ int(m / 2) ] # try it with Wasserstein barycenter here compute the mean shape
    mean_shape_face = shape_faces[ int(m / 2) ] # save the faces as well
    
    
    for i in range(m):
        for j in range(m):
            K[i, j] = kernel(z_i[i], z_i[j], sigma)
    
    K_til= K - K @ E - E @ K + E @ K @ E
    
    # keep only real eigenvalues and eigenvectors
    L, V = torch.linalg.eigh(K_til)
    L = torch.flip(L, [0])
    V = torch.fliplr(V)
    
    limit_val = 1e-6
    if (L <= limit_val).any():
        first_cplx = torch.where(L <= limit_val)[0][0]
        sigma_ort = L[first_cplx - 1] / 2.0
        
        L[first_cplx:] = 0.0
        V[:, first_cplx:] = 0.0
        reg_mx = torch.eye(K.shape[0])
        
        Sigma_ort = V @ torch.diag(L) @ V.t() + sigma_ort * (reg_mx - V @ V.t())
    else:  # bad bad things happen
        first_cplx = -1
        sigma_ort = 1
        Sigma_ort = V @ torch.diag(L) @ V.t()
    
    return z_i, torch.linalg.inv(Sigma_ort), L, V, sigma_ort, sigma, first_cplx, min_shape_face_count, mean_shape, mean_shape_face, K.sum(), K

In [None]:
def k_til(k, sigma, x, y, z_i, m):
    sum = 0
    for i in range(m):
        sum -= (1 / m) * (k(x, z_i[i], sigma) + k(y, z_i[i], sigma))

    sum += k(x, y, sigma)

    for i in range(m):
        for j in range(m):
                sum += (1 / (m ** 2)) * k(z_i[i], z_i[j], sigma)

    return sum

In [None]:
# alpha_i needs some rescaling, named V[k, i] here
def E_phi_grad_opt(V, kernel, k_m, k_matrix_sum, sigma, z_i, z, L, L_ort, r, m):
    loss = torch.zeros(z.shape)
    
    # lightspeed optimized gradient computation
    par_z = torch.zeros([m, *z.shape])
    kernel_ = torch.zeros([m])
    for i in range(m):
        par_z[i] = torch.autograd.grad(kernel(z_i[i], z, sigma), [z])[0]
        kernel_[i] = kernel(z_i[i], z, sigma)

    k_til = kernel_ - (1/m) * kernel_.sum(dim=0) - (1 /m) * k_m.sum(dim=1) + (1 / (m ** 2)) * k_matrix_sum
    
    par_z_sum = (1 / m) * par_z.sum(dim = 0)
    kernel_til = lambda par_z, index : par_z[index] - par_z_sum
    
    alpha = cp.copy(V)
    alpha[:, :r] *= (torch.sqrt(L[:r])[:, None]).t()
    
    for k in range(r):
        for i in range(m):
            loss += (alpha[i, k] * k_til[i])  * (alpha[i, k] * kernel_til(par_z, i)) * (L[k] ** (-1) - L_ort ** (-1))
            
    par_zz = torch.zeros([*z.shape])
    for k in range(m):
        par_zz -= (1/m) * par_z[k]
    loss += (L_ort ** (-1)) * par_zz
    
    return 2.0 * loss, k_til

In [None]:
# might need to try with different models from CV, e.g.: MS, Potts
from geomloss import SamplesLoss
eps = 5 * 1e-3
loss_unbalanced = SamplesLoss(loss='sinkhorn', p=2, blur=eps, scaling=0.95)
sigma = 5 * 1e0
# k = lambda x, y, sigma : torch.exp(-loss(x, y) ** 2 / (2 * sigma ** 2))
k = lambda x, y, sigma : torch.exp(-sigma * loss_unbalanced(x, y))
centering_point = np.array([0.45, 0.45, 0.45])

z_i, sigma_inv, L, V, sigma_ort, sigma, first_cplx, min_shape_face_count, mean_shape, mean_shape_face, k_matrix_sum, k_matrix  = nonlinear_shape_prior(shape_priors, kernel=k, sigma=sigma, centering_point=centering_point)

In [None]:
def sample_prior(frames, min_shape_face_count, V, k, k_matrix, k_matrix_sum, sigma, z_i, L, sigma_ort, first_cplx):
    
    num_frames, width, height = frames.shape

    bprojectpor = backward_projector()
    lv_volume = bprojectpor(frames)
    
    normalize_volume(lv_volume)
    verts, faces = mcubes.marching_cubes(lv_volume, 0.5)
    min_shape_face_count = min(min_shape_face_count, faces.shape[0])
    v_decimate, f_decimate, v_correspondence, f_correspondence = pcu.decimate_triangle_mesh(verts, faces.astype(np.int32), min_shape_face_count)
    input_shape = torch.from_numpy(v_decimate / height)

    m = len(z_i)
    a_volume=torch.from_numpy(lv_volume)
    rows, cols, height = a_volume.shape
    z = v_decimate / cols
    # z.requires_grad = True

    # renorm to size 1 and translate it to center_point
    z_dist = cdist(z, z, 'euclidean')
    max_real_size = z_dist.max() * cols

    z_scaled = torch.from_numpy(z * (1.0 / (z_dist.max())))
    z_translation = z_scaled.mean(axis=0) - torch.from_numpy(centering_point)

    # Project current shape on the mean shape as in [1]
    pc_fix = PointCloud(mean_shape.detach().numpy(), columns=["x", "y", "z"])
    pc_mov = PointCloud((z_scaled - z_translation).detach().numpy(), columns=["x", "y", "z"])
    icp = SimpleICP()
    icp.add_point_clouds(pc_fix, pc_mov)
    H, proj_mean_icp, rigid_body_transformation_params, distance_residuals = icp.run(max_overlap_distance=1)

    proj_mean = torch.from_numpy(proj_mean_icp)
    proj_mean.requires_grad = True
    grad_E, k_til = E_phi_grad_opt(V, k, k_matrix, k_matrix_sum, sigma, z_i, proj_mean, L, sigma_ort, first_cplx, m)
    return grad_E, k_til

In [None]:
grad_E, k_til = sample_prior(frames, min_shape_face_count, V, k, k_matrix, k_matrix_sum, sigma, z_i, L, sigma_ort, first_cplx)

## nnFormer + SP

In [None]:
from nnFormer.nnformer.network_architecture.nnFormer_synapse import nnFormer
import torch.nn as nn

In [None]:
nnformer_model = nnFormer(crop_size=[64,64,64],
                embedding_dim=192,
                input_channels=1, 
                num_classes=2, 
                conv_op=nn.Conv3d, 
                depths=[2,2,2,2],
                num_heads=[6, 12, 24, 48],
                patch_size=[2,4,4],
                window_size=[4,4,8,4],
                deep_supervision=True)

In [None]:
model = nnformer_model.to(device)
# Define Hyper-paramters for training loop
max_epochs = 50 # it was 50
val_interval = 1
batch_size = 1
# gradient_accumulation_steps = 4
lr = 1e-5
epoch_loss_values = []
step_loss_values = []
val_loss_values = []
best_val_loss = 1000.0

# Loss function
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Define DataLoader using MONAI, CacheDataset needs to be used
train_ds = Dataset(data=train_data, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

val_ds = Dataset(data=val_data, transform=train_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4)


In [None]:
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1

        frames = batch_data['spect'].squeeze(0).squeeze(0)

        grad_E, _ = sample_prior(frames, min_shape_face_count, V, k, k_matrix, k_matrix_sum, sigma, z_i, L, sigma_ort, first_cplx)

        inputs, labels = (
            batch_data['spect'].to(device),
            batch_data["left_ventricle"].to(device)
        )
        grad_E = grad_E.to(device)

        print(inputs.shape, labels.shape)

        optimizer.zero_grad()
        outputs = model(inputs, grad_E)

        # Adjust the CL loss by Recon Loss
        total_loss = loss_fn(outputs[0], labels)

        total_loss.backward()
        optimizer.step()
        epoch_loss += total_loss.item()
        step_loss_values.append(total_loss.item())

        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {total_loss.item():.4f}, "
        )

    epoch_loss /= step

    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if epoch % val_interval == 0:
        print("Entering Validation for epoch: {}".format(epoch + 1))
        total_val_loss = 0
        val_step = 0
        # model.eval()
        for val_batch in val_loader:
            val_step += 1
            frames = val_batch['spect'].squeeze(0).squeeze(0)
            grad_E, _ = sample_prior(frames, min_shape_face_count, V, k, k_matrix, k_matrix_sum, sigma, z_i, L, sigma_ort, first_cplx)


            inputs, labels = (
                val_batch["spect"].to(device),
                val_batch['left_ventricle'].to(device),
            )
            print("Input shape: {}".format(inputs.shape))
            outputs = model(inputs, grad_E.to(device))
            val_loss = loss_fn(outputs[0], labels)
            total_val_loss += val_loss.item()

        total_val_loss /= val_step
        val_loss_values.append(total_val_loss)
        print(f"epoch {epoch + 1} Validation avg loss: {total_val_loss:.4f}")

        if total_val_loss < best_val_loss:
            print(f"Saving new model based on validation loss {total_val_loss:.4f}")
            best_val_loss = total_val_loss
            checkpoint = {"epoch": max_epochs, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
            torch.save(checkpoint, "spect_cardiac/nbs/batch_study_files/nnSynapse_SP_50_60_1.pt")
print("Done")

In [None]:
plt.plot(epoch_loss_values)
plt.title("Training loss")
plt.xlabel("Epochs")

In [None]:
plt.plot(val_loss_values)
plt.title("Validation loss")
plt.xlabel("Epochs")
plt.show()

## SwinUNETR

In [None]:
from monai.networks.nets import SwinUNETR

In [None]:
swin_model = SwinUNETR(
    img_size=(64,64,64),
    in_channels=1,            # Number of input channels
    out_channels=2,           # Number of output channels (e.g., for binary segmentation)
    feature_size=48,          # Embedding size
)


In [None]:
model = swin_model.to(device)

# Define Hyper-paramters for training loop
max_epochs = 50 # it was 50
val_interval = 1
batch_size = 2
# gradient_accumulation_steps = 4
lr = 1e-5
epoch_loss_values = []
step_loss_values = []
val_loss_values = []
best_val_loss = 1000.0

# Loss function
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Define DataLoader using MONAI, CacheDataset needs to be used
train_ds = Dataset(data=train_data, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

val_ds = Dataset(data=val_data, transform=train_transforms)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=True, num_workers=4)

In [None]:
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1

        inputs, labels = (
            batch_data["spect"].to(device),
            batch_data["left_ventricle"].to(device)
        )
        
        print(inputs.shape, labels.shape)
        
        optimizer.zero_grad()
        outputs = model(inputs)

        # Adjust the CL loss by Recon Loss
        total_loss = loss_fn(outputs, labels)

        total_loss.backward()
        optimizer.step()
        epoch_loss += total_loss.item()
        step_loss_values.append(total_loss.item())

        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"train_loss: {total_loss.item():.4f}, "
        )

    epoch_loss /= step

    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if epoch % val_interval == 0:
        print("Entering Validation for epoch: {}".format(epoch + 1))
        total_val_loss = 0
        val_step = 0
        model.eval()
        for val_batch in val_loader:
            val_step += 1
            inputs, labels = (
                val_batch["spect"].to(device),
                val_batch['left_ventricle'].to(device),
            )
            print("Input shape: {}".format(inputs.shape))
            outputs = model(inputs)
            val_loss = loss_fn(outputs, labels)
            total_val_loss += val_loss.item()

        total_val_loss /= val_step
        val_loss_values.append(total_val_loss)
        print(f"epoch {epoch + 1} Validation avg loss: {total_val_loss:.4f}")

        if total_val_loss < best_val_loss:
            print(f"Saving new model based on validation loss {total_val_loss:.4f}")
            best_val_loss = total_val_loss
            checkpoint = {"epoch": max_epochs, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
            torch.save(checkpoint, "spect_cardiac/nbs/batch_study_files/SwinUNetR_50E_60D_2B.pt")
print("Done")

In [None]:
plt.plot(epoch_loss_values)
plt.title("Training loss")
plt.xlabel("Epochs")

In [None]:
plt.plot(val_loss_values)
plt.title("Validation loss")
plt.xlabel("Epochs")
plt.show()