In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import pickle
import glob
import os
from sklearn.preprocessing import StandardScaler

In [2]:
class MultiModalEncoder(nn.Module):
    def __init__(self):
        super(MultiModalEncoder, self).__init__()
        
        # 1D-CNN for signals
        self.cnn_branch = nn.Sequential(
            nn.Conv1d(18, 32, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(4),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        
        # MLP for metadata
        self.mlp_branch = nn.Sequential(
            nn.Linear(3, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU()
        )

    def forward(self, signal, meta):
        sig_feat = self.cnn_branch(signal).squeeze(-1) # Output: 64
        meta_feat = self.mlp_branch(meta)               # Output: 16
        
        # Combine into one "Feature Vector"
        latent_vector = torch.cat((sig_feat, meta_feat), dim=1)
        return latent_vector

In [3]:
SENSOR_DIR = "C:\\DumbStuff\\epf study\\Meta-Elasto\\els\\meta\\Elastography_rawdata\\oldcode\\"
MAX_LEN = 2000
meta_cols = ['Sex', 'Age', 'Waist_Circum_mean']

In [10]:
def predict_single_patient(patient_csv_path, patient_meta_dict, kmeans_model_path="C:\\DumbStuff\\epf study\\Meta-Elasto\\els\\80\\kmeans_model.pkl"):
    """
    Predict cluster for one patient.
    
    Args:
        patient_csv_path: path to "in_test_0007_signal.csv"
        patient_meta_dict: {'Sex': 0, 'Age': 45, 'Waist_Circum_mean': 85}
        kmeans_model_path: optional path to saved kmeans (for cluster prediction)
    
    Returns:
        cluster_id, risk_prediction (0=healthy, 1=sick)
    """
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # 1. Load encoder + scaler
    encoder = MultiModalEncoder()
    encoder.load_state_dict(torch.load('C:\\DumbStuff\\epf study\\Meta-Elasto\\els\\80\\multi_modal_encoder.pth', map_location=device))
    encoder.to(device)
    encoder.eval()

    with open('C:\\DumbStuff\\epf study\\Meta-Elasto\\els\\80\\meta_scaler.pkl', 'rb') as f:
        scaler = pickle.load(f)
    
    # 2. Load and preprocess signal
    df_sig = pd.read_csv(patient_csv_path)
    signal = df_sig.filter(regex='^F\d+').fillna(0).values.T  # (18, T)
    
    # Crop/pad to MAX_LEN
    if signal.shape[1] > MAX_LEN:
        signal = signal[:, :MAX_LEN]
    else:
        signal = np.pad(signal, ((0,0), (0, MAX_LEN - signal.shape[1])), 'constant')
    
    # Normalize signal
    signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-8)
    signal_tensor = torch.tensor(signal, dtype=torch.float32).unsqueeze(0)  # (1, 18, MAX_LEN)
    
    # 3. Preprocess metadata
    meta_df = pd.DataFrame([patient_meta_dict])
    meta_array = meta_df[meta_cols].values.astype(np.float32)
    meta_scaled = scaler.transform(meta_array)
    meta_tensor = torch.tensor(meta_scaled, dtype=torch.float32)  # (1, 3)
    
    # 4. Get latent vector
    with torch.no_grad():
        latent_vector = encoder(signal_tensor.to(device), meta_tensor.to(device))
        latent_vector = latent_vector.cpu().numpy().flatten()  # (80,)
    
    # 5. Predict cluster

    if kmeans_model_path:
        from sklearn.cluster import KMeans
        kmeans = pickle.load(open(kmeans_model_path, 'rb'))
        cluster_id = kmeans.predict([latent_vector])[0]
        print(f"ðŸŽ¯ Cluster ID: {cluster_id}")
    else:
        cluster_id = "KMeans not loaded"
    
    # 6. Risk prediction (you'd need to save the cluster-risk mapping)
    
    cluster_means = {0: 1.64, 1: 1.50, 2: 1.70, 3: 1.53, 4: 1.87}
    sick_cluster = max(cluster_means, key=cluster_means.get)
    risk_pred = 1 if cluster_id == sick_cluster else 0
    risk_label = "HIGH RISK ðŸš¨" if risk_pred == 1 else "LOW RISK âœ…"
    
    # return {
    #     'latent_vector': latent_vector,
    #     'cluster_id': cluster_id,
    #     'risk_prediction': risk_pred,
    #     'risk_label': risk_label
    # }
    return risk_label



  signal = df_sig.filter(regex='^F\d+').fillna(0).values.T  # (18, T)


In [11]:
# Patient metadata
patient_info = {
    'Sex': 1, # 0=Male, 1=Female
    'Age': 52,
    'Waist_Circum_mean': 91.5
}

# Path to patient's sensor CSV
patient_csv = r"C:\\DumbStuff\\epf study\\Meta-Elasto\\els\\meta\\Elastography_rawdata\\oldcode\\in_test_0001_A303_02_24_08_44_22.csv"

# Prediction
result = predict_single_patient(patient_csv, patient_info)
print("Prediction result:", result)



ðŸŽ¯ Cluster ID: 0
Prediction result: LOW RISK âœ…
