In [None]:
import os
import torch
import pandas as pd
import numpy as np
from sklearn.preprocessing import RobustScaler, OneHotEncoder, StandardScaler, MultiLabelBinarizer
from sklearn.metrics.pairwise import cosine_similarity 
#from sklearn.neighbors import kneighbors_graph
import torch_geometric.data as geo_data
from tqdm import tqdm 

# --- File Paths (as provided) ---
clinical_csv_path = '../kyanh/data/clinical_features_tmb.csv'
radiology_csv_path = '../kyanh/data/radiology_node_features.parquet'
pathology_glcm_csv_path= '../kyanh/data/patient_glcm_features.csv'

output_multi_view_data_path = 'multi_view_pdl1_data_lesions_366_thresh_08_robust.pt' # Modified output name

# --- Feature Definitions (as provided) ---
patient_core_features = [
    'albumin', 'smoking_status', 'pack_years', 'dnlr', 'age', 'sex', 'TMB',
    'histo', 'ecog', 'io_drug', 'pdl1_tiss_site', 'clinical_pdl1_score',
    'tumor_burden', 
]
numerical_cols_patient = [
    'albumin', 'pack_years', 'dnlr', 'age', 'TMB', 'clinical_pdl1_score', 'tumor_burden', 
]
categorical_cols_patient = [
    'smoking_status', 'sex', 'histo', 'ecog', 'pdl1_tiss_site', 'io_drug'
]

# --- Patient Similarity Parameters ---
SIMILARITY_THRESHOLD = 0.7

# --- 1. Load Data ---
print("--- 1. Loading Data ---")
df_clinical_raw = pd.read_csv(clinical_csv_path, index_col='main_index')
#df_clinical_raw = df_clinical_raw[df_clinical_raw['TMB'].notna()].copy()

df_radiology_lesions_raw = pd.read_parquet(radiology_csv_path)
df_radiology_lesions_raw = df_radiology_lesions_raw.reset_index()
if 'main_index' in df_radiology_lesions_raw.columns and df_radiology_lesions_raw.index.name != 'main_index':
    df_radiology_lesions_raw = df_radiology_lesions_raw.set_index('main_index')
elif df_radiology_lesions_raw.index.name != 'main_index':
    potential_pid_cols = [col for col in ['main_index', 'patient_id', 'dmp_pt_id'] if col in df_radiology_lesions_raw.columns]
    if potential_pid_cols:
        print(f"   Radiology: Setting index to '{potential_pid_cols[0]}'. Ensure this is patient ID.")
        df_radiology_lesions_raw = df_radiology_lesions_raw.set_index(potential_pid_cols[0])
    else:
        raise ValueError("Radiology data needs a 'main_index' or identifiable patient ID column to be set as index.")


df_path_glcm_raw = pd.read_csv(pathology_glcm_csv_path, index_col='main_index')

if df_clinical_raw.index.is_unique:
    patient_order = df_clinical_raw.index.tolist()
else:
    print("Warning: Clinical data index is not unique. Using unique values for order.")
    patient_order = df_clinical_raw.index.unique().tolist()
df_clinical = df_clinical_raw.loc[patient_order].copy() # Main df_clinical is ordered and unique


def process_features_dataframe(df_features, numerical_cols, categorical_cols, patient_order_ref, scaler_type=RobustScaler):
    df_processed_list = []
    df_copy = df_features.copy()

    # Imputation: Numerical
    for col in numerical_cols:
        if col in df_copy.columns:
            if df_copy[col].isnull().any():
                df_copy[col].fillna(df_copy[col].median(), inplace=True)

    # Imputation: Categorical (single-label)
    single_label_categorical_cols = list(dict.fromkeys([c for c in categorical_cols if c != 'io_drug']))

    for col in single_label_categorical_cols:
        if col in df_copy.columns:
            df_copy[col] = df_copy[col].astype(str) 
            valid_entries_mask = ~((df_copy[col].str.lower() == 'nan') | (df_copy[col].str.strip() == ''))
            if valid_entries_mask.sum() > 0: 
                mode_val = df_copy[col][valid_entries_mask].mode()
                fill_val = mode_val[0] if not mode_val.empty else 'Unknown'
            else: 
                fill_val = 'Unknown'


            df_copy.loc[~valid_entries_mask, col] = fill_val


    unique_numerical_cols = list(dict.fromkeys(numerical_cols))
    df_num_subset = df_copy[[col for col in unique_numerical_cols if col in df_copy.columns]]
    if not df_num_subset.empty:
        scaler = scaler_type()
        scaled_num = scaler.fit_transform(df_num_subset)
        df_scaled_num = pd.DataFrame(scaled_num, columns=df_num_subset.columns, index=df_copy.index)
        df_processed_list.append(df_scaled_num)

    if 'io_drug' in categorical_cols and 'io_drug' in df_copy.columns:
        def split_drugs(drug_value):
            if pd.isna(drug_value): return []
            s_drug_value = str(drug_value).strip()
            if not s_drug_value or s_drug_value.lower() == 'nan': return []
            return [drug.strip() for drug in s_drug_value.split(',')]
        io_drug_labels = df_copy['io_drug'].apply(split_drugs)
        if any(io_drug_labels):
            mlb = MultiLabelBinarizer(sparse_output=False)
            encoded_io_drug = mlb.fit_transform(io_drug_labels)
            if mlb.classes_.size > 0:
                sanitized_classes = [cls.replace(" ", "_").replace(",", "").replace("(", "").replace(")", "") for cls in mlb.classes_]
                df_encoded_io_drug = pd.DataFrame(encoded_io_drug, columns=[f'io_drug_{cls}' for cls in sanitized_classes], index=df_copy.index)
                df_processed_list.append(df_encoded_io_drug)
            else: print("   Warning: 'io_drug' binarization no features.")
        else: print("   Warning: 'io_drug' column no parsable drug names.")

    cols_for_ohe = [col for col in single_label_categorical_cols if col in df_copy.columns]
    if cols_for_ohe:
        df_other_cat_subset = df_copy[cols_for_ohe]
        encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        encoded_cat = encoder.fit_transform(df_other_cat_subset)
        try: feature_names = encoder.get_feature_names_out(df_other_cat_subset.columns)
        except AttributeError: feature_names = encoder.get_feature_names(df_other_cat_subset.columns)
        df_encoded_cat = pd.DataFrame(encoded_cat, columns=feature_names, index=df_other_cat_subset.index)
        df_processed_list.append(df_encoded_cat)

    if df_processed_list:
        df_final_processed = pd.concat(df_processed_list, axis=1)
        return df_final_processed.reindex(patient_order_ref).fillna(0)
    return pd.DataFrame(index=patient_order_ref)

def create_similarity_edges(df_features_for_sim, patient_idx_map, threshold, edge_type_name, data_obj):
    if df_features_for_sim.empty or df_features_for_sim.shape[1] == 0:
        print(f"   Skipping {edge_type_name} similarity: no features.")
        return

    valid_rows_mask = ~df_features_for_sim.isnull().all(axis=1)
    df_filtered_for_nan_rows = df_features_for_sim[valid_rows_mask]

    relevant_patient_orig_ids = [pid for pid in df_filtered_for_nan_rows.index if pid in patient_idx_map]

    if not relevant_patient_orig_ids:
        print(f"   Skipping {edge_type_name} similarity: no relevant patients with non-NaN features after filtering.")
        return

    df_subset_features = df_filtered_for_nan_rows.loc[relevant_patient_orig_ids].fillna(0)

    if df_subset_features.shape[0] <= 1: # Need at least 2 patients to form an edge
        print(f"   Skipping {edge_type_name} similarity: not enough samples ({df_subset_features.shape[0]}) to form edges.")
        return

    subset_global_indices_map = {orig_id: patient_idx_map[orig_id] for orig_id in relevant_patient_orig_ids}
    
    print(f"   Calculating {edge_type_name} patient similarity (Cosine Sim > {threshold}) for {df_subset_features.shape[0]} patients with {df_subset_features.shape[1]} features...")
    try:
        features_values = df_subset_features.values
        
        pairwise_sim_matrix = cosine_similarity(features_values) # Returns a NumPy array
        

        np.fill_diagonal(pairwise_sim_matrix, -np.inf) 
        
        src_nodes_local_indices, dst_nodes_local_indices = np.where(pairwise_sim_matrix > threshold)
        
        if src_nodes_local_indices.size == 0:
            print(f"   No {edge_type_name} edges found above similarity threshold {threshold}.")
            return
            
        # Get the similarity scores for these edges
        edge_attr_scores = pairwise_sim_matrix[src_nodes_local_indices, dst_nodes_local_indices]

        original_ids_in_subset_order = df_subset_features.index
        
        global_src_indices = torch.tensor(
            [subset_global_indices_map[original_ids_in_subset_order[local_idx]] for local_idx in src_nodes_local_indices],
            dtype=torch.long
        )
        global_dst_indices = torch.tensor(
            [subset_global_indices_map[original_ids_in_subset_order[local_idx]] for local_idx in dst_nodes_local_indices],
            dtype=torch.long
        )
        
        filtered_edge_index_global = torch.stack([global_src_indices, global_dst_indices], dim=0)
        filtered_edge_attr_subset = torch.tensor(edge_attr_scores, dtype=torch.float32).unsqueeze(1)
        
        num_kept_edges = filtered_edge_index_global.shape[1]
        print(f"   Created {num_kept_edges} {edge_type_name} edges based on similarity > {threshold}.")

        if num_kept_edges > 0:
            data_obj['patient', edge_type_name, 'patient'].edge_index = filtered_edge_index_global
            data_obj['patient', edge_type_name, 'patient'].edge_attr = filtered_edge_attr_subset
            
    except Exception as e:
        print(f"   Error calculating similarity for {edge_type_name}: {e}")
        import traceback
        traceback.print_exc()


# --- Global Patient Index Mapping ---
patient_to_global_idx = {pid: i for i, pid in enumerate(patient_order)}
num_total_patients = len(patient_order)

# --- Initialize HeteroData ---
data = geo_data.HeteroData()
data['patient'].num_nodes = num_total_patients

# --- 2. Clinical View ---
print("\n--- 2. Preparing Clinical View ---")
available_patient_core_features = [col for col in patient_core_features if col in df_clinical.columns]
current_numerical_cols_patient = [col for col in numerical_cols_patient if col in available_patient_core_features]
current_categorical_cols_patient = [col for col in categorical_cols_patient if col in available_patient_core_features]
df_patient_clinical_processed = process_features_dataframe(
    df_clinical[available_patient_core_features].copy(), 
    current_numerical_cols_patient,
    current_categorical_cols_patient,
    patient_order
)
if not df_patient_clinical_processed.empty and df_patient_clinical_processed.shape[1] > 0:
    data['patient'].x_clinical = torch.tensor(df_patient_clinical_processed.values, dtype=torch.float32)
    print(f"   Assigned 'patient.x_clinical' features: {data['patient'].x_clinical.shape}")
    create_similarity_edges(df_patient_clinical_processed, patient_to_global_idx, SIMILARITY_THRESHOLD, 'similar_to_clinical', data)
else:
    print("   No clinical features. Clinical view features will be empty.")
    data['patient'].x_clinical = torch.empty((num_total_patients, 0), dtype=torch.float32)


# --- 3. Pathology View ---
print("\n--- 3. Preparing Pathology View ---")
df_path_glcm_subset = df_path_glcm_raw.drop(columns=['dmp_pt_id'], errors='ignore').copy()
df_path_glcm_subset = df_path_glcm_subset[df_path_glcm_subset.index.isin(patient_order)]

if not df_path_glcm_subset.empty:
    glcm_numerical_cols = df_path_glcm_subset.columns.tolist()
    df_path_glcm_processed_full = process_features_dataframe(
        df_path_glcm_subset.copy(),
        glcm_numerical_cols, [], patient_order 
    )
    data['patient'].x_pathology = torch.tensor(df_path_glcm_processed_full.values, dtype=torch.float32)
    pathology_present_mask = torch.zeros(num_total_patients, dtype=torch.bool)
    for orig_id in df_path_glcm_subset.index:
        if orig_id in patient_to_global_idx:
            pathology_present_mask[patient_to_global_idx[orig_id]] = True
    data['patient'].pathology_mask = pathology_present_mask
    print(f"   Assigned 'patient.x_pathology' (zero-padded for all patients): {data['patient'].x_pathology.shape}")
    print(f"   Number of patients with actual pathology data: {pathology_present_mask.sum().item()}")

    df_path_glcm_scaled_for_sim = process_features_dataframe(
        df_path_glcm_subset.copy(), # patients with pathology data
        glcm_numerical_cols, [],
        df_path_glcm_subset.index.tolist()
    )
    create_similarity_edges(df_path_glcm_scaled_for_sim, patient_to_global_idx, SIMILARITY_THRESHOLD, 'similar_to_pathology', data)
else:
    print("   No Pathology GLCM features.")
    data['patient'].x_pathology = torch.zeros((num_total_patients, 0), dtype=torch.float32)
    data['patient'].pathology_mask = torch.zeros(num_total_patients, dtype=torch.bool)


# --- 4. Radiology View (Lesion-Level) ---
print("\n--- 4. Preparing Radiology View (Lesion-Level) ---")
exclude_rad_cols = ['radiology_accession_number', 'job_tag', 'dmp_pt_id']
all_radiomics_feature_cols = [col for col in df_radiology_lesions_raw.columns if col not in exclude_rad_cols]

df_radiology_lesions_filtered = df_radiology_lesions_raw[df_radiology_lesions_raw.index.isin(patient_order)].copy()

if not df_radiology_lesions_filtered.empty and all_radiomics_feature_cols:
    df_lesion_features_to_scale = df_radiology_lesions_filtered[all_radiomics_feature_cols].copy()
    for col in all_radiomics_feature_cols:
        df_lesion_features_to_scale[col] = pd.to_numeric(df_lesion_features_to_scale[col], errors='coerce')

    for col in all_radiomics_feature_cols:
        if df_lesion_features_to_scale[col].isnull().any():
            df_lesion_features_to_scale[col].fillna(df_lesion_features_to_scale[col].median(), inplace=True)

    scaler_lesions = RobustScaler()
    scaled_lesion_features_values = scaler_lesions.fit_transform(df_lesion_features_to_scale)
    df_radiology_scaled_lesions = pd.DataFrame(
        scaled_lesion_features_values,
        columns=all_radiomics_feature_cols,
        index=df_lesion_features_to_scale.index 
    )
    print(f"   Scaled {df_radiology_scaled_lesions.shape[1]} features for {df_radiology_scaled_lesions.shape[0]} lesions.")

    all_lesion_x_list = []
    patient_lesion_edge_src = []
    patient_lesion_edge_dst = []
    current_lesion_global_idx = 0
    radiology_present_mask_np = np.zeros(num_total_patients, dtype=bool)
    temp_aggregated_radiology_features = {} 

    for patient_orig_id in tqdm(patient_order, desc="Processing lesions per patient"):
        if patient_orig_id in df_radiology_scaled_lesions.index:
            patient_global_idx = patient_to_global_idx[patient_orig_id]
            lesions_for_patient_df = df_radiology_scaled_lesions.loc[[patient_orig_id]]

            if not lesions_for_patient_df.empty:
                radiology_present_mask_np[patient_global_idx] = True
                all_lesion_x_list.append(torch.tensor(lesions_for_patient_df.values, dtype=torch.float32))
                num_lesions_this_patient = lesions_for_patient_df.shape[0]
                
                patient_lesion_edge_src.extend([patient_global_idx] * num_lesions_this_patient)
                patient_lesion_edge_dst.extend(list(range(current_lesion_global_idx, current_lesion_global_idx + num_lesions_this_patient)))
                current_lesion_global_idx += num_lesions_this_patient

                patient_lesion_mean = lesions_for_patient_df.mean(axis=0)
                if num_lesions_this_patient > 1:
                    patient_lesion_std = lesions_for_patient_df.std(axis=0)
                else:

                    patient_lesion_std = pd.Series(0.0, index=patient_lesion_mean.index, dtype=float)
                patient_lesion_min = lesions_for_patient_df.min(axis=0)
                patient_lesion_max = lesions_for_patient_df.max(axis=0)

                # Concat
                aggregated_stats_for_patient = pd.concat([
                    patient_lesion_mean,
                    patient_lesion_std,
                    patient_lesion_min,
                    patient_lesion_max
                ]).values # .values converts
                
                temp_aggregated_radiology_features[patient_orig_id] = aggregated_stats_for_patient

    if all_lesion_x_list:
        data['lesion'].x = torch.cat(all_lesion_x_list, dim=0)
        data['lesion'].num_nodes = data['lesion'].x.shape[0]
        edge_index_patient_lesion = torch.tensor([patient_lesion_edge_src, patient_lesion_edge_dst], dtype=torch.long)
        data['patient', 'has_lesion', 'lesion'].edge_index = edge_index_patient_lesion
        print(f"   Assigned 'lesion.x' features: {data['lesion'].x.shape}")
        print(f"   Created ('patient', 'has_lesion', 'lesion') edges: {edge_index_patient_lesion.shape}")
    else:
        print("   No lesions found for any patient in the order.")
        data['lesion'].x = torch.empty((0, len(all_radiomics_feature_cols) if all_radiomics_feature_cols else 0), dtype=torch.float32)
        data['lesion'].num_nodes = 0
        data['patient', 'has_lesion', 'lesion'].edge_index = torch.empty((2,0), dtype=torch.long)

    data['patient'].radiology_mask = torch.from_numpy(radiology_present_mask_np)
    print(f"   Number of patients with actual radiology (lesion) data: {data['patient'].radiology_mask.sum().item()}")

    if temp_aggregated_radiology_features:
        
        new_radiology_agg_cols = []
        if all_radiomics_feature_cols:
            suffixes = ['mean', 'std', 'min', 'max']
            for stat_suffix in suffixes:
                new_radiology_agg_cols.extend([f"{col}_{stat_suffix}" for col in all_radiomics_feature_cols])
        
        df_temp_aggregated_radiology = pd.DataFrame.from_dict(
            temp_aggregated_radiology_features,
            orient='index',
            columns=new_radiology_agg_cols 
        )
        
        if not df_temp_aggregated_radiology.empty and df_temp_aggregated_radiology.shape[1] > 0:
            print(f"   Aggregated radiology features for similarity (mean, std, min, max) shape: {df_temp_aggregated_radiology.shape}")
            create_similarity_edges(df_temp_aggregated_radiology, patient_to_global_idx, SIMILARITY_THRESHOLD, 'similar_to_radiology', data)
        else:
            print("   Aggregated radiology DataFrame (mean, std, min, max) is empty or has no columns. Skipping similarity graph.")
    else:
        print("   No temporary aggregated radiology features to build similarity graph.")

else:
    print("   Radiology lesion data is empty or no radiomics feature columns identified.")
    data['lesion'].x = torch.empty((0, 0), dtype=torch.float32)
    data['lesion'].num_nodes = 0
    data['patient', 'has_lesion', 'lesion'].edge_index = torch.empty((2,0), dtype=torch.long)
    data['patient'].radiology_mask = torch.zeros(num_total_patients, dtype=torch.bool)


# --- 5. Add Patient Labels (Common to all patients) ---
print("\n--- 5. Adding Patient Labels ---")
label_cols_to_add = {'pfs': 'y', 'pfs_censor': 'event', 'label': 'binary_label'}
df_labels_ordered = df_clinical.reindex(patient_order)
for col_name, data_attr_name in label_cols_to_add.items():
    if col_name in df_labels_ordered.columns:
        try:
            if df_labels_ordered[col_name].isnull().all():
                print(f"   Warning: Label column '{col_name}' is all NaN. Skipping.")
                continue
            if data_attr_name == 'y': 
                tensor_data = torch.tensor(df_labels_ordered[col_name].fillna(np.nan).values, dtype=torch.float32)

            else: 
                filled_series = df_labels_ordered[col_name].fillna(0) 
                tensor_data = torch.tensor(filled_series.values, dtype=torch.long)

            data['patient'][data_attr_name] = tensor_data
            print(f"   Added label '{data_attr_name}' from column '{col_name}'. Shape: {tensor_data.shape}")
        except Exception as e:
            print(f"   Warning: Could not process or add label '{data_attr_name}': {e}")
    else:
        print(f"   Warning: Label column '{col_name}' not found.")


# --- Final Checks and Save ---
print(f"\n--- Multi-View HeteroData Summary ---")
print(data)
try:
    data.validate(raise_on_error=True)
    print("Multi-View Data validation successful!")
except Exception as e:
    print(f"Multi-View Data validation FAILED: {e}")

print("\nNode Types and Features:")
for node_type in data.node_types:
    print(f"  Node type: {node_type}")
    for key, value in data[node_type].items():
        if isinstance(value, torch.Tensor):
            print(f"    {key}: shape {value.shape}, num_nodes (implicit from dim 0): {value.shape[0] if value.dim() > 0 else 'N/A'}")
        else:
            print(f"    {key}: {value}")


print("\nEdge Types and Indices:")
if hasattr(data, 'edge_types'):
    for edge_type in sorted(data.edge_types, key=lambda x: str(x)):
        edge_index = data[edge_type].get('edge_index', None)
        shape_str = str(edge_index.shape) if edge_index is not None else 'N/A'
        num_edges_str = str(edge_index.shape[1]) if edge_index is not None and edge_index.dim() == 2 else '0'
        print(f"  {str(edge_type)}: edge_index shape: {shape_str} ({num_edges_str} edges)", end="")
        edge_attr = data[edge_type].get('edge_attr', None)
        if edge_attr is not None:
             print(f", edge_attr shape: {edge_attr.shape}")
        else:
             print("")
else:
    print("  No edge types defined in data object.")


try:
    torch.save(data, output_multi_view_data_path)
    print(f"\nSaved Multi-View HeteroData object to {output_multi_view_data_path}")
except Exception as e:
    print(f"\nError saving Multi-View HeteroData object: {e}")