In [1]:
# Core Libraries
import numpy as np
import pandas as pd
import pickle

#Visualization
import matplotlib.pyplot as plt
import seaborn as sns

#Preprocessing
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
from sklearn.metrics import f1_score, roc_auc_score
from xgboost import XGBClassifier
from scipy.ndimage import label, rotate
from skimage import measure

#Medical Imaging
import SimpleITK as sitk
from monai.networks.nets import SwinUNETR
import nibabel as nib

#DL libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [2]:
from src.image_preprocessing import(
    preprocess_and_rotate_images_and_masks,
    keep_largest_blob, 
    process_test_set_masks, 
    crop_and_pad_mask, 
    process_and_stack_masks, 
    separate_blobs,
    MaskAlignerAllAngles)    
    
from src.utils import(
    extract_number, 
    get_file_paths, 
    load_and_sort_files, load_tables,
    process_files, 
    inference, 
    process_mismatched_data,
    generate_outputs, 
    extract_latent_representations,
    create_features_dataframe, 
    plot_correlation_matrix,
    process_metadata_and_filter)
    
from src.data_loader import(
    setup_training_pipeline, 
    extract_test_set, 
    create_dataloader_from_masks)    

from src.model_training import(
    train_and_validate,
    initialize_training, 
    train_one_epoch, 
    train_model, load_model, 
    cross_validation_model_evaluation)

from src.compute_geometry import(
    compute_angles_for_masks, 
    rotate_masks_sequentially, 
    compute_disc_dimensions)

from src.image_generation import process_latent_features_and_generate_images
from src.config import config
from src.autoencoder import Autoencoder
from src.metrics import calculate_iou_list, get_iou_train

## 1) Segmentation model

In [3]:
#Load and Prepare Data
images_dir, masks_dir, rad_grad_path, overview_path=get_file_paths(config["curr_path"])
sorted_files_img, sorted_files_msk, sorted_files_img_we, sorted_files_msk_we=load_and_sort_files(images_dir, masks_dir, extract_number)
overview_table, rad_grad_table=load_tables(overview_path, rad_grad_path)

#Process Files
img_t1, img_t2, msk_t1_disk, msk_t2_disk, filename_t1, filename_t2, num_disc_table_t1, num_disc_table_t2, num_labels_real_t1_list, num_labels_real_t2_list=\
     process_files(sorted_files_img, sorted_files_img_we, images_dir, masks_dir, overview_table, config["target_resolution"])


#Preprocess Images and Masks
img_t1_MRI, msk_t1_MRI=preprocess_and_rotate_images_and_masks(img_t1, msk_t1_disk)
img_t2_MRI, msk_t2_MRI=preprocess_and_rotate_images_and_masks(img_t2, msk_t2_disk)


#Setup Training Pipeline
training_loader, testing_loader, model, optimizer, loss_fn=setup_training_pipeline(
    img_t1_MRI, msk_t1_MRI,
    img_size=config["img_size"],
    in_channels=config["in_channels"],
    out_channels=config["out_channels"],
    feature_size=config["feature_size"],
    batch_size=config["batch_size"],
    lr=config["lr"],
    test_size=config["test_size"],
    shuffle=False
)

#Train or Load Model
training_results=train_and_validate(
    model=model,
    training_loader=training_loader,
    testing_loader=testing_loader,
    optimizer=optimizer,
    criterion=loss_fn,
    num_epochs=config["num_epochs"],
    device=config["device"],
    save_path=config["path_weights"],
    load_weights=True
)


#Inference
train_output, iou_list_train, dice_list_train=inference(
    training_loader, training_results["model"], config["batch_size"], img_t1_MRI, (1 -config["test_size"])
)
test_output, iou_list_test, dice_list_test=inference(
    testing_loader, training_results["model"], config["batch_size"], img_t1_MRI, config["test_size"]
)

#Process Mismatched Data
img_t1_MRI_new, msk_t1_MRI_new, filename_t1_new, num_disc_table_t1_new, num_labels_real_t1_list_new, rad_grad_table_fin=process_mismatched_data(
    img_t1_MRI=img_t1_MRI,
    msk_t1_MRI=msk_t1_MRI,
    filename_t1=filename_t1,
    num_labels_real_t1_list=num_labels_real_t1_list,
    num_disc_table_t1=num_disc_table_t1,
    rad_grad_table=rad_grad_table,
)

## 2) Feature extraction and disc narrowing predictions

In [4]:
#Data preprocessing for autoencoder training
msk_t1_MRI_new_test=extract_test_set(msk_t1_MRI_new, img_t1_MRI, config["test_size"])
msk_t1_MRI_single=process_test_set_masks(msk_t1_MRI_new_test)
msk_t1_MRI_single_def=process_and_stack_masks(msk_t1_MRI_single)
training_img_msk_real=create_dataloader_from_masks(msk_t1_MRI_single_def, config["batch_size_ae"])


#Initialize and train or load the autoencoder 
autoencoder=Autoencoder()
if config["load_weights_ae"]:
    ae_model=load_model(autoencoder, config["weights_path_ae"])
    print("Model loaded. Training skipped.")
else:
    ae_model=train_model(autoencoder, training_img_msk_real, num_epochs=config["num_epochs_ae"], lr=config["learning_rate_ae"], save_path=config["weights_path_ae"])
    print("Training complete.")

#Extract latent representations using the trained/pretrained autoencoder
output_num=generate_outputs(ae_model, training_img_msk_real, config["device"])
latent_representations=extract_latent_representations(ae_model, training_img_msk_real, config["device"])
iou_list=calculate_iou_list(torch.tensor(msk_t1_MRI_single_def), output_num, get_iou_train)

# Compute geometric and latent features for the masks
sag_angle, trasv_angle, front_angle=compute_angles_for_masks(msk_t1_MRI_single_def)
rotated_masks_t1_MRI=rotate_masks_sequentially(msk_t1_MRI_single_def)
disc_height_list, ap_width_list, lat_width_list=compute_disc_dimensions(rotated_masks_t1_MRI)

geom_features_df, latent_features_df, all_features_df=create_features_dataframe(
    disc_height_list, 
    ap_width_list, 
    lat_width_list, 
    sag_angle, 
    trasv_angle, 
    front_angle, 
    latent_representations
)


#Filter metadata and making disc narrowing predictions
filtered_metadata=process_metadata_and_filter('annotations.xlsx', filename_t1_new, test_size=42)
feature_columns=[filtered_metadata.columns[10]]

f1_scores_tot, accuracy_scores_tot=cross_validation_model_evaluation(
    geom_features_df,
    filtered_metadata,
    feature_columns,
    config["f1_score_type"],
    random_state=config["random_state"])


## 3) Feature interpretability

In [5]:
#Visualize the correlation matrix between latent and geometric features
plot_correlation_matrix(all_features_df)

#Generate synthetic images by varying latent features 
images_arr, latent_features_list_fin=process_latent_features_and_generate_images(
    latent_representations=latent_representations, 
    decoder=ae_model.decoder,
    numb_plots=6, 
    latent_dim=4,
    view="sagittal"  
)