<a href="https://colab.research.google.com/github/NXdevansh/healX-/blob/main/AIheal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install necessary Python libraries
!pip install torch torchvision torchaudio
!pip install torch_geometric
!pip install transformers
!pip install networkx
!pip install numpy pandas scikit-learn
!pip install matplotlib seaborn plotly
!pip install tqdm

# Check if GPU is available
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Set random seeds for reproducibility
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    print(f"Random seed set to {seed}")

set_seed()

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import matplotlib.pyplot as plt

class SmartWatchDataGenerator:
    def __init__(self, n_days=7, sample_rate_minutes=1):
        """
        Generate synthetic smartwatch data streams

        Args:
            n_days: Number of days of data to simulate
            sample_rate_minutes: Data frequency in minutes
        """
        self.n_days = n_days
        self.sample_rate_minutes = sample_rate_minutes
        self.n_samples = int((n_days * 24 * 60) / sample_rate_minutes)
        self.start_time = datetime.now() - timedelta(days=n_days)

    def generate_timestamps(self):
        """Generate evenly spaced timestamps"""
        timestamps = [
            self.start_time + timedelta(minutes=i * self.sample_rate_minutes)
            for i in range(self.n_samples)
        ]
        return timestamps

    def _generate_heart_data(self, activity_levels):
        """Generate PPG/ECG/HR data"""
        # Base heart rate varies with activity
        base_hr = 60 + 40 * activity_levels

        # Add circadian rhythm (lower at night)
        timestamps = self.generate_timestamps()
        hours = np.array([t.hour for t in timestamps])
        circadian = -10 * np.cos((hours - 14) * 2 * np.pi / 24)

        # Add noise and occasional irregularities
        hr = base_hr + circadian + np.random.normal(0, 5, self.n_samples)

        # Create occasional arrhythmia
        arrhythmia_mask = np.random.random(self.n_samples) < 0.005
        hr[arrhythmia_mask] += np.random.choice([-20, 20], size=sum(arrhythmia_mask))

        # Generate synthetic PPG wave
        ppg = []
        ecg = []

        for h in hr:
            # Simplistic PPG wave (just for demonstration)
            cycle_length = 60 / h  # seconds per beat
            points_per_cycle = int(60 / (h * self.sample_rate_minutes)) + 1
            t = np.linspace(0, 2*np.pi, points_per_cycle)
            ppg_cycle = np.sin(t) * 0.4 + 0.6 + np.random.normal(0, 0.05, len(t))
            ppg.append(ppg_cycle[0])

            # Simplistic ECG wave (just for demonstration)
            ecg_cycle = np.zeros_like(t)
            p_wave = 0.2 * np.exp(-(t-0.5)**2/0.1)
            qrs = 0.8 * np.exp(-(t-np.pi)**2/0.02) - 0.3 * np.exp(-(t-(np.pi+0.2))**2/0.02)
            t_wave = 0.3 * np.exp(-(t-(np.pi+1.0))**2/0.1)
            ecg_cycle = p_wave + qrs + t_wave + np.random.normal(0, 0.05, len(t))
            ecg.append(ecg_cycle[0])

        return hr, np.array(ppg), np.array(ecg)

    def _generate_spo2_data(self, activity_levels, hr):
        """Generate SpO2 data"""
        # Baseline SpO2 is typically 95-99%
        baseline = 97 + np.random.normal(0, 1, self.n_samples)

        # SpO2 drops slightly with intense activity
        activity_effect = -1 * (activity_levels > 0.7) * activity_levels

        # Create occasional hypoxic events (more likely during sleep)
        timestamps = self.generate_timestamps()
        hours = np.array([t.hour for t in timestamps])
        is_night = (hours >= 22) | (hours <= 6)

        hypoxic_mask = (np.random.random(self.n_samples) < 0.01) & is_night
        hypoxic_events = np.zeros(self.n_samples)
        hypoxic_events[hypoxic_mask] = -np.random.randint(5, 15, size=sum(hypoxic_mask))

        spo2 = baseline + activity_effect + hypoxic_events

        # Cap values to realistic range
        spo2 = np.clip(spo2, 70, 100)

        return spo2

    def _generate_activity_data(self):
        """Generate activity levels and accelerometer data"""
        # Create activity pattern with daily routine
        timestamps = self.generate_timestamps()
        hours = np.array([t.hour for t in timestamps])

        # Base activity level follows a pattern
        # Higher during daytime, peaks in morning and evening (exercise)
        activity_base = np.zeros(self.n_samples)

        # Morning activity (6-8 AM)
        morning_mask = (hours >= 6) & (hours < 8)
        activity_base[morning_mask] = 0.7

        # Daytime activity (8 AM - 6 PM)
        day_mask = (hours >= 8) & (hours < 18)
        activity_base[day_mask] = 0.4

        # Evening exercise (6-8 PM)
        evening_mask = (hours >= 18) & (hours < 20)
        activity_base[evening_mask] = 0.8

        # Evening wind down (8-11 PM)
        wind_down_mask = (hours >= 20) & (hours < 23)
        activity_base[wind_down_mask] = 0.3

        # Night sleep (11 PM - 6 AM)
        night_mask = (hours >= 23) | (hours < 6)
        activity_base[night_mask] = 0.05

        # Add randomness
        activity_levels = activity_base + np.random.normal(0, 0.1, self.n_samples)
        activity_levels = np.clip(activity_levels, 0, 1)

        # Generate steps based on activity
        steps_per_minute = activity_levels * 120  # Max ~120 steps per minute
        steps = np.random.poisson(steps_per_minute)

        # Generate accelerometer data (3-axis)
        accel_x = activity_levels * np.random.normal(0, 1, self.n_samples)
        accel_y = activity_levels * np.random.normal(0, 1, self.n_samples)
        accel_z = activity_levels * np.random.normal(-1, 1, self.n_samples) - 9.8  # gravity

        return activity_levels, steps, np.column_stack((accel_x, accel_y, accel_z))

    def _generate_sleep_data(self, activity_levels):
        """Generate sleep staging data"""
        # Low activity typically means sleep
        is_likely_sleeping = activity_levels < 0.1

        # Sleep staging (0=wake, 1=light, 2=deep, 3=REM)
        sleep_stage = np.zeros(self.n_samples)

        # Set sleep stages for likely sleep periods
        sleep_stage[is_likely_sleeping] = np.random.choice(
            [1, 2, 3],
            size=sum(is_likely_sleeping),
            p=[0.5, 0.3, 0.2]
        )

        return sleep_stage

    def _generate_stress_hrv(self, activity_levels, hr):
        """Generate stress scores and HRV data"""
        # HRV is typically higher during rest and lower during stress
        base_hrv = 50 - 30 * activity_levels

        # Add circadian component and noise
        timestamps = self.generate_timestamps()
        hours = np.array([t.hour for t in timestamps])
        circadian = 10 * np.cos((hours - 2) * 2 * np.pi / 24)

        hrv = base_hrv + circadian + np.random.normal(0, 5, self.n_samples)
        hrv = np.clip(hrv, 10, 100)

        # Stress score is inverse of normalized HRV
        stress = 100 - (hrv - 10) / 0.9
        stress = np.clip(stress, 0, 100)

        return stress, hrv

    def _generate_user_meta(self):
        """Generate mock user metadata"""
        age = np.random.randint(25, 65)
        gender = np.random.choice(['M', 'F'])
        height = np.random.normal(170, 10) if gender == 'M' else np.random.normal(165, 8)
        weight = np.random.normal(75, 10) if gender == 'M' else np.random.normal(65, 8)
        bmi = weight / ((height/100) ** 2)

        region_choices = ['US-East', 'US-West', 'Europe', 'Asia']
        region = np.random.choice(region_choices)

        # Generate pre-existing conditions
        conditions = []
        if np.random.random() < 0.1:
            conditions.append('hypertension')
        if np.random.random() < 0.08:
            conditions.append('diabetes')
        if np.random.random() < 0.05:
            conditions.append('asthma')

        return {
            'age': age,
            'gender': gender,
            'height': height,
            'weight': weight,
            'bmi': bmi,
            'region': region,
            'conditions': conditions
        }

    def generate_data(self):
        """Generate complete dataset"""
        # First generate activity since other metrics depend on it
        activity_levels, steps, accel = self._generate_activity_data()

        # Generate physiological measures
        hr, ppg, ecg = self._generate_heart_data(activity_levels)
        spo2 = self._generate_spo2_data(activity_levels, hr)
        sleep_stage = self._generate_sleep_data(activity_levels)
        stress, hrv = self._generate_stress_hrv(activity_levels, hr)

        # Create DataFrame
        timestamps = self.generate_timestamps()

        df = pd.DataFrame({
            'timestamp': timestamps,
            'hr': hr,
            'ppg': list(ppg),  # Store as list since each might be different length
            'ecg': list(ecg),  # Store as list since each might be different length
            'spo2': spo2,
            'activity_level': activity_levels,
            'steps': steps,
            'accel_x': accel[:, 0],
            'accel_y': accel[:, 1],
            'accel_z': accel[:, 2],
            'sleep_stage': sleep_stage,
            'stress': stress,
            'hrv': hrv
        })

        # Generate user metadata
        user_meta = self._generate_user_meta()

        return df, user_meta

    def plot_sample_data(self, df, figsize=(15, 20)):
        """Plot sample of the generated data"""
        fig, axes = plt.subplots(5, 1, figsize=figsize)

        # Sample 1440 points (1 day)
        sample = df.iloc[:1440]

        # HR and SpO2
        ax1 = axes[0]
        ax1.plot(sample['timestamp'], sample['hr'], 'r-', label='Heart Rate')
        ax1.set_ylabel('Heart Rate (bpm)')
        ax1.set_title('Heart Rate over Time')
        ax1.legend(loc='upper left')

        ax1b = ax1.twinx()
        ax1b.plot(sample['timestamp'], sample['spo2'], 'b-', alpha=0.7, label='SpO2')
        ax1b.set_ylabel('SpO2 (%)')
        ax1b.legend(loc='upper right')

        # Activity and Steps
        ax2 = axes[1]
        ax2.plot(sample['timestamp'], sample['activity_level'], 'g-', label='Activity Level')
        ax2.set_ylabel('Activity Level (0-1)')
        ax2.set_title('Activity Level and Steps')
        ax2.legend(loc='upper left')

        ax2b = ax2.twinx()
        ax2b.plot(sample['timestamp'], sample['steps'], 'k-', alpha=0.3, label='Steps')
        ax2b.set_ylabel('Steps per minute')
        ax2b.legend(loc='upper right')

        # Sleep Stage
        ax3 = axes[2]
        ax3.plot(sample['timestamp'], sample['sleep_stage'], 'b-', drawstyle='steps-post')
        ax3.set_ylabel('Sleep Stage')
        ax3.set_yticks([0, 1, 2, 3])
        ax3.set_yticklabels(['Wake', 'Light', 'Deep', 'REM'])
        ax3.set_title('Sleep Stages')

        # Stress and HRV
        ax4 = axes[3]
        ax4.plot(sample['timestamp'], sample['stress'], 'r-', label='Stress')
        ax4.set_ylabel('Stress (0-100)')
        ax4.set_title('Stress and HRV')
        ax4.legend(loc='upper left')

        ax4b = ax4.twinx()
        ax4b.plot(sample['timestamp'], sample['hrv'], 'g-', alpha=0.7, label='HRV')
        ax4b.set_ylabel('HRV (ms)')
        ax4b.legend(loc='upper right')

        # Accelerometer
        ax5 = axes[4]
        ax5.plot(sample['timestamp'], sample['accel_x'], 'r-', alpha=0.5, label='X')
        ax5.plot(sample['timestamp'], sample['accel_y'], 'g-', alpha=0.5, label='Y')
        ax5.plot(sample['timestamp'], sample['accel_z'], 'b-', alpha=0.5, label='Z')
        ax5.set_ylabel('Acceleration (m/s²)')
        ax5.set_title('Accelerometer Data')
        ax5.legend()

        plt.tight_layout()
        return fig

# Generate and save data
generator = SmartWatchDataGenerator(n_days=14)  # 2 weeks of data
data, user_meta = generator.generate_data()

print(f"Generated {len(data)} data points")
print(f"User metadata: {user_meta}")

# Plot sample data
fig = generator.plot_sample_data(data)
plt.show()

# Save the data for later use
data.to_pickle('smartwatch_data.pkl')
import json
with open('user_meta.json', 'w') as f:
    json.dump(user_meta, f)

print("Data saved to 'smartwatch_data.pkl' and 'user_meta.json'")

# Display first few rows
data.head()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from tqdm.notebook import tqdm

# Load data
data = pd.read_pickle('smartwatch_data.pkl')
print(f"Loaded {len(data)} data points")

class SmartWatchDataset(Dataset):
    def __init__(self, df, window_size=60, stride=10):
        """
        Create windowed smartwatch dataset for self-supervised learning

        Args:
            df: DataFrame with smartwatch data
            window_size: Window size in minutes
            stride: Stride for window creation in minutes
        """
        self.df = df
        self.window_size = window_size
        self.stride = stride

        # Create windows
        self.window_indices = []
        for i in range(0, len(df) - window_size, stride):
            self.window_indices.append((i, i + window_size))

        print(f"Created {len(self.window_indices)} windows from {len(df)} data points")

        # Features to use (excluding 'ppg' and 'ecg' which need special handling)
        self.features = ['hr', 'spo2', 'activity_level', 'steps',
                         'accel_x', 'accel_y', 'accel_z', 'sleep_stage',
                         'stress', 'hrv']

        # Standardize numerical features
        self.scalers = {}
        for feature in self.features:
            self.scalers[feature] = StandardScaler()
            self.df[feature] = self.scalers[feature].fit_transform(self.df[[feature]])

    def __len__(self):
        return len(self.window_indices)

    def __getitem__(self, idx):
        start_idx, end_idx = self.window_indices[idx]
        window_data = self.df.iloc[start_idx:end_idx].copy()

        # Extract features as tensors
        feature_data = torch.tensor(
            window_data[self.features].values,
            dtype=torch.float32
        )

        # Create a mask for self-supervised learning
        # Randomly mask between 10-20% of all values
        mask_rate = np.random.uniform(0.1, 0.2)
        mask = torch.rand_like(feature_data) < mask_rate

        # Create masked data (replace masked values with zeros)
        masked_data = feature_data.clone()
        masked_data[mask] = 0.0

        # Calculate timestamp differences for temporal information
        timestamps = window_data['timestamp'].values
        time_diffs = np.array([(t - timestamps[0]).total_seconds() / 3600.0  # Hours
                               for t in timestamps])
        time_tensor = torch.tensor(time_diffs, dtype=torch.float32).unsqueeze(1)

        # Add time information to features
        features_with_time = torch.cat([masked_data, time_tensor], dim=1)

        return {
            'masked_data': features_with_time,  # Input: masked data with time
            'original_data': feature_data,  # Target: original data to reconstruct
            'mask': mask,  # The mask that was applied
            'start_idx': start_idx,  # Starting index in original DataFrame
            'end_idx': end_idx  # Ending index in original DataFrame
        }

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class SmartWatchTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, nhead=8,
                 num_layers=4, dropout=0.1, output_dim=200):
        super(SmartWatchTransformer, self).__init__()

        self.input_projection = nn.Linear(input_dim, hidden_dim)
        self.pos_encoder = PositionalEncoding(hidden_dim)

        encoder_layers = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim*4,
            dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

        # Embedding projection for getting the 200-dim latent representation
        self.embedding_projection = nn.Linear(hidden_dim, output_dim)

        # Reconstruction head for predicting original values
        self.reconstruction_head = nn.Linear(output_dim, input_dim - 1)  # -1 because we added time

    def forward(self, src):
        # src shape: [seq_len, batch_size, feature_dim]
        src = self.input_projection(src)
        src = self.pos_encoder(src)
        memory = self.transformer_encoder(src)

        # Get embeddings
        embeddings = self.embedding_projection(memory)

        # Reconstruct the original data (without time feature)
        reconstructed = self.reconstruction_head(embeddings)

        return embeddings, reconstructed

class TemporalContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
        self.similarity = nn.CosineSimilarity(dim=-1)

    def forward(self, embeddings):
        # embeddings shape: [seq_len, batch_size, embed_dim]
        seq_len, batch_size, _ = embeddings.shape

        # Calculate similarity between each pair of positions
        # Reshape embeddings to [seq_len*batch_size, embed_dim]
        flat_embeddings = embeddings.reshape(-1, embeddings.size(2))

        # Create positive pairs (adjacent timestamps)
        pos_indices = []
        for i in range(seq_len-1):
            for b in range(batch_size):
                pos_indices.append((i*batch_size + b, (i+1)*batch_size + b))

        # Sample negative pairs (timestamps far apart)
        neg_indices = []
        for i in range(seq_len):
            for b in range(batch_size):
                current_idx = i*batch_size + b

                # Consider any position at least 10 steps away as negative
                neg_candidates = []
                for j in range(seq_len):
                    if abs(j - i) > 10:  # Far enough to be negative
                        for c in range(batch_size):
                            neg_candidates.append(j*batch_size + c)

                # Sample 10 negative pairs per positive
                if neg_candidates:
                    sampled_negs = np.random.choice(neg_candidates,
                                                   size=min(10, len(neg_candidates)),
                                                   replace=False)
                    for neg_idx in sampled_negs:
                        neg_indices.append((current_idx, neg_idx))

        # Calculate similarity for positive pairs
        pos_sim = torch.stack([self.similarity(flat_embeddings[i], flat_embeddings[j])
                              for i, j in pos_indices])
        pos_sim = pos_sim / self.temperature

        # Calculate similarity for negative pairs
        neg_sim = torch.stack([self.similarity(flat_embeddings[i], flat_embeddings[j])
                              for i, j in neg_indices])
        neg_sim = neg_sim / self.temperature

        # InfoNCE loss
        pos_loss = -pos_sim.mean()
        neg_loss = torch.log(torch.exp(neg_sim).sum())

        return pos_loss + neg_loss

def create_model_and_dataloader(data, batch_size=32):
    # Create dataset
    dataset = SmartWatchDataset(data, window_size=60, stride=10)

    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    # Feature dimension is features + time
    input_dim = len(dataset.features) + 1  # +1 for time information

    # Create model
    model = SmartWatchTransformer(
        input_dim=input_dim,
        hidden_dim=128,
        nhead=8,
        num_layers=4,
        dropout=0.1,
        output_dim=200  # 200-dim embedding as specified
    )

    # Create contrastive loss
    contrastive_loss = TemporalContrastiveLoss(temperature=0.5)

    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    contrastive_loss = contrastive_loss.to(device)

    return model, dataloader, device, contrastive_loss

model, dataloader, device, contrastive_loss = create_model_and_dataloader(data)

print(f"Model created with device: {device}")
print(f"Dataloader created with {len(dataloader)} batches")

In [None]:
import networkx as nx
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from datetime import timedelta
import pickle

# Load embeddings and original data
embedding_df = pd.read_pickle('smartwatch_embeddings.pkl')
data = pd.read_pickle('smartwatch_data.pkl')
print(f"Loaded {len(embedding_df)} embeddings")

class TemporalPhenotypeGraph:
    def __init__(self, embedding_df, data, k_neighbors=5, max_time_diff=24):
        """
        Construct the temporal phenotype graph

        Args:
            embedding_df: DataFrame with embeddings
            data: Original smartwatch data
            k_neighbors: Number of similarity neighbors
            max_time_diff: Max time difference (hours) for similarity edges
        """
        self.embedding_df = embedding_df
        self.data = data
        self.k_neighbors = k_neighbors
        self.max_time_diff = max_time_diff

        # Create graph
        self.G = nx.DiGraph()

        # Build graph
        self._add_nodes()
        self._add_sequential_edges()
        self._add_similarity_edges()

        print(f"Graph constructed with {len(self.G.nodes)} nodes and {len(self.G.edges)} edges")

    def _add_nodes(self):
        """Add nodes to the graph with attributes"""
        print("Adding nodes...")
        for _, row in tqdm(self.embedding_df.iterrows(), total=len(self.embedding_df)):
            idx = row['index']

            # Get original data at this index
            orig_data = self.data.iloc[idx]

            # Create node attributes
            attrs = {
                'timestamp': row['timestamp'],
                'embedding': row['embedding'],
                'hr': orig_data['hr'],
                'spo2': orig_data['spo2'],
                'activity_level': orig_data['activity_level'],
                'sleep_stage': orig_data['sleep_stage'],
                'stress': orig_data['stress'],
                'hrv': orig_data['hrv'],
                # Add more attributes as needed
            }

            # Add node
            self.G.add_node(idx, **attrs)

    def _add_sequential_edges(self):
        """Add edges between consecutive time points"""
        print("Adding sequential edges...")
        # Sort by timestamp
        sorted_nodes = sorted(self.G.nodes(), key=lambda n: self.G.nodes[n]['timestamp'])

        # Add edges between consecutive nodes
        for i in range(len(sorted_nodes) - 1):
            self.G.add_edge(
                sorted_nodes[i],
                sorted_nodes[i+1],
                edge_type='sequential',
                weight=1.0
            )

    def _add_similarity_edges(self):
        """Add edges between similar nodes"""
        print("Adding similarity edges...")
        # Get all embeddings and indices
        indices = list(self.G.nodes())
        embeddings = np.array([self.G.nodes[idx]['embedding'] for idx in indices])
        timestamps = np.array([self.G.nodes[idx]['timestamp'] for idx in indices])

        # Compute all pairwise similarities
        similarities = cosine_similarity(embeddings)

        # Add similarity edges for each node
        for i, idx in enumerate(tqdm(indices)):
            # Get k nearest neighbors by similarity
            neighbor_similarities = similarities[i]

            # Sort neighbors by similarity (excluding self)
            neighbor_indices = np.argsort(neighbor_similarities)[::-1][1:self.k_neighbors+1]

            # Add edges
            for j in neighbor_indices:
                neighbor_idx = indices[j]
                similarity = neighbor_similarities[j]

                # Check time difference constraint
                time_diff = abs((timestamps[i] - timestamps[j]).total_seconds() / 3600)
                if time_diff <= self.max_time_diff:
                    self.G.add_edge(
                        idx,
                        neighbor_idx,
                        edge_type='similarity',
                        weight=similarity,
                        time_diff=time_diff
                    )

    def plot_graph_sample(self, n_nodes=100, figsize=(12, 12)):
        """Plot a sample of the graph"""
        # Select a sample of nodes
        sample_nodes = list(self.G.nodes())[:n_nodes]
        G_sample = self.G.subgraph(sample_nodes)

        # Create positions
        pos = {}
        for node in G_sample.nodes():
            # Use timestamp for x-axis
            timestamp = G_sample.nodes[node]['timestamp']
            timestamp_float = timestamp.timestamp()

            # Use activity level for y-axis
            activity = G_sample.nodes[node]['activity_level']

            pos[node] = (timestamp_float, activity)

        # Create edge colors based on type
        edge_colors = []
        widths = []
        for u, v, data in G_sample.edges(data=True):
            if data['edge_type'] == 'sequential':
                edge_colors.append('blue')
                widths.append(1)
            else:  # similarity
                edge_colors.append('red')
                widths.append(0.5)

        # Node colors based on sleep stage
        node_colors = []
        for node in G_sample.nodes():
            sleep = G_sample.nodes[node]['sleep_stage']
            if sleep == 0:
                node_colors.append('lightyellow')  # Awake
            elif sleep == 1:
                node_colors.append('lightblue')  # Light sleep
            elif sleep == 2:
                node_colors.append('blue')  # Deep sleep
            else:
                node_colors.append('purple')  # REM

        # Create plot
        plt.figure(figsize=figsize)
        nx.draw_networkx(
            G_sample, pos=pos,
            with_labels=False,
            node_size=30,
            node_color=node_colors,
            edge_color=edge_colors,
            width=widths,
            alpha=0.7
        )

        plt.title('Temporal Phenotype Graph (Sample)')
        plt.xlabel('Time')
        plt.ylabel('Activity Level')

        # Legend
        plt.plot([0], [0], color='blue', label='Sequential Edge')
        plt.plot([0], [0], color='red', label='Similarity Edge')
        plt.plot([0], [0], marker='o', color='lightyellow', label='Awake', linestyle='')
        plt.plot([0], [0], marker='o', color='lightblue', label='Light Sleep', linestyle='')
        plt.plot([0], [0], marker='o', color='blue', label='Deep Sleep', linestyle='')
        plt.plot([0], [0], marker='o', color='purple', label='REM Sleep', linestyle='')

        plt.legend()
        plt.tight_layout()
        plt.show()

    def save_graph(self, filename='temporal_phenotype_graph.pkl'):
        """Save the graph to a file"""
        with open(filename, 'wb') as f:
            pickle.dump(self.G, f)
        print(f"Graph saved to '{filename}'")

    @classmethod
    def load_graph(cls, filename='temporal_phenotype_graph.pkl'):
        """Load a graph from file"""
        with open(filename, 'rb') as f:
            G = pickle.load(f)

        # Create a new instance
        instance = cls.__new__(cls)
        instance.G = G

        print(f"Loaded graph with {len(G.nodes)} nodes and {len(G.edges)} edges")
        return instance

# Create the graph
graph_builder = TemporalPhenotypeGraph(embedding_df, data, k_neighbors=5, max_time_diff=24)

# Plot a sample of the graph
graph_builder.plot_graph_sample(n_nodes=100)

# Save the graph
graph_builder.save_graph()

In [None]:
import networkx as nx
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import pickle
from datetime import timedelta
import random
from scipy.interpolate import interp1d

# Load graph and data
with open('temporal_phenotype_graph.pkl', 'rb') as f:
    G = pickle.load(f)
data = pd.read_pickle('smartwatch_data.pkl')
user_meta = pd.read_json('user_meta.json')

print(f"Loaded graph with {len(G.nodes)} nodes and {len(G.edges)} edges")

class PhysiologicalSimulator:
    def __init__(self, graph, data, user_meta):
        """
        Simple physiological simulator for digital twin augmentation

        Args:
            graph: NetworkX graph
            data: Original data
            user_meta: User metadata
        """
        self.G = graph
        self.data = data
        self.user_meta = user_meta

        # Get baseline physiological values
        self.baseline_hr = np.median(data['hr'])
        self.baseline_spo2 = np.median(data['spo2'])
        self.baseline_hrv = np.median(data['hrv'])

        # Adjust baselines based on user metadata
        self._adjust_baselines()

        print(f"Initialized physiological simulator with baselines:")
        print(f"  HR: {self.baseline_hr:.1f} bpm")
        print(f"  SpO2: {self.baseline_spo2:.1f} %")
        print(f"  HRV: {self.baseline_hrv:.1f} ms")

    def _adjust_baselines(self):
        """Adjust baseline values based on user metadata"""
        # Age effect on HR and HRV
        age = self.user_meta['age']
        if age > 50:
            # Older people tend to have lower HRV
            self.baseline_hrv *= 0.9

        # BMI effect on HR and SpO2
        bmi = self.user_meta['bmi']
        if bmi > 30:  # Obese
            # Higher HR and slightly lower SpO2 baseline
            self.baseline_hr *= 1.05
            self.baseline_spo2 *= 0.98

        # Pre-existing conditions
        conditions = self.user_meta.get('conditions', [])
        if 'hypertension' in conditions:
            self.baseline_hr *= 1.1
        if 'asthma' in conditions:
            self.baseline_spo2 *= 0.97

    def simulate_exercise(self, duration_minutes=30, intensity=0.8):
        """
        Simulate exercise response

        Args:
            duration_minutes: Duration of exercise in minutes
            intensity: Exercise intensity (0-1)

        Returns:
            List of simulated nodes
        """
        print(f"Simulating exercise: {duration_minutes} minutes, intensity {intensity}")

        # Start with a random timestamp from awake periods
        awake_nodes = [n for n, data in self.G.nodes(data=True)
                      if data.get('sleep_stage', 0) == 0]
        start_node = random.choice(awake_nodes)
        start_time = self.G.nodes[start_node]['timestamp']

        # Generate timestamps for the exercise
        timestamps = [start_time + timedelta(minutes=i) for i in range(duration_minutes)]

        # Generate physiological values
        hr_values = []
        spo2_values = []
        hrv_values = []

        # Exercise response curves
        # HR increases quickly, plateaus, then decreases
        # SpO2 decreases slightly during exercise
        # HRV decreases during exercise

        max_hr = self.baseline_hr * (1 + 0.7 * intensity)
        min_spo2 = self.baseline_spo2 * (1 - 0.03 * intensity)
        min_hrv = self.baseline_hrv * (1 - 0.5 * intensity)

        # Time points for interpolation
        t_points = np.linspace(0, 1, 5)  # 5 key points in normalized time

        # HR curve: baseline -> quick rise -> plateau -> decline -> baseline
        hr_curve = [
            self.baseline_hr,
            self.baseline_hr + 0.7 * (max_hr - self.baseline_hr),
            max_hr,
            self.baseline_hr + 0.3 * (max_hr - self.baseline_hr),
            self.baseline_hr
        ]
        hr_interp = interp1d(t_points, hr_curve, kind='quadratic')

        # SpO2 curve: baseline -> decline -> plateau -> recovery -> baseline
        spo2_curve = [
            self.baseline_spo2,
            self.baseline_spo2 - 0.5 * (self.baseline_spo2 - min_spo2),
            min_spo2,
            self.baseline_spo2 - 0.2 * (self.baseline_spo2 - min_spo2),
            self.baseline_spo2
        ]
        spo2_interp = interp1d(t_points, spo2_curve, kind='quadratic')

        # HRV curve: baseline -> decline -> plateau -> recovery -> baseline
        hrv_curve = [
            self.baseline_hrv,
            self.baseline_hrv - 0.7 * (self.baseline_hrv - min_hrv),
            min_hrv,
            self.baseline_hrv - 0.4 * (self.baseline_hrv - min_hrv),
            self.baseline_hrv
        ]
        hrv_interp = interp1d(t_points, hrv_curve, kind='quadratic')

        # Generate values
        t_values = np.linspace(0, 1, duration_minutes)
        hr_values = hr_interp(t_values) + np.random.normal(0, 2, duration_minutes)
        spo2_values = spo2_interp(t_values) + np.random.normal(0, 0.5, duration_minutes)
        hrv_values = hrv_interp(t_values) + np.random.normal(0, 2, duration_minutes)

        # Ensure physiological ranges
        spo2_values = np.clip(spo2_values, 85, 100)
        hr_values = np.clip(hr_values, 40, 200)
        hrv_values = np.clip(hrv_values, 5, 100)

        # Create simulated nodes
        simulated_nodes = []

        # Get a sample embedding to use as template
        template_embedding = np.mean([data['embedding'] for _, data in
                                      random.sample(list(self.G.nodes(data=True)), 10)], axis=0)

        for i in range(duration_minutes):
            # Perturb the embedding based on exercise state
            perturbed_embedding = template_embedding + 0.1 * np.random.randn(len(template_embedding))

            # Higher activity during exercise
            activity = 0.7 + 0.3 * intensity + 0.1 * np.random.random()

            # Create node ID - use negative numbers for virtual nodes to avoid collision
            node_id = -1 - i

            # Node attributes
            attrs = {
                'timestamp': timestamps[i],
                'embedding': perturbed_embedding,
                'hr': hr_values[i],
                'spo2': spo2_values[i],
                'activity_level': activity,
                'sleep_stage': 0,  # Awake
                'stress': 50 + 20 * intensity + np.random.normal(0, 5),  # Higher stress during exercise
                'hrv': hrv_values[i],
                'is_virtual': True,  # Mark as virtual node
                'virtual_type': 'exercise'
            }

            # Add to list
            simulated_nodes.append((node_id, attrs))

        return simulated_nodes

    def simulate_hypoxia(self, duration_minutes=60, severity=0.5):
        """
        Simulate hypoxic events (e.g., sleep apnea episodes)

        Args:
            duration_minutes: Duration of event in minutes
            severity: Severity of hypoxia (0-1)

        Returns:
            List of simulated nodes
        """
        print(f"Simulating hypoxia: {duration_minutes} minutes, severity {severity}")

        # Start with a random timestamp from sleep periods
        sleep_nodes = [n for n, data in self.G.nodes(data=True)
                      if data.get('sleep_stage', 0) > 0]
        if not sleep_nodes:
            sleep_nodes = list(self.G.nodes())

        start_node = random.choice(sleep_nodes)
        start_time = self.G.nodes[start_node]['timestamp']

        # Generate timestamps for the event
        timestamps = [start_time + timedelta(minutes=i) for i in range(duration_minutes)]

        # Hypoxia parameters
        min_spo2 = self.baseline_spo2 - 15 * severity  # Can drop to 85% for mild, 70% for severe
        max_hr = self.baseline_hr + 20 * severity  # HR increases during hypoxic events
        min_hrv = self.baseline_hrv - 15 * severity  # HRV decreases

        # Time points for interpolation
        t_points = np.linspace(0, 1, 5)  # 5 key points in normalized time

        # SpO2 curve: baseline -> rapid drop -> plateau -> recovery -> baseline
        spo2_curve = [
            self.baseline_spo2,
            self.baseline_spo2 - 0.7 * (self.baseline_spo2 - min_spo2),
            min_spo2,
            self.baseline_spo2 - 0.3 * (self.baseline_spo2 - min_spo2),
            self.baseline_spo2
        ]
        spo2_interp = interp1d(t_points, spo2_curve, kind='quadratic')

        # HR curve: baseline -> increase -> plateau -> decline -> baseline
        hr_curve = [
            self.baseline_hr,
            self.baseline_hr + 0.5 * (max_hr - self.baseline_hr),
            max_hr,
            self.baseline_hr + 0.3 * (max_hr - self.baseline_hr),
            self.baseline_hr
        ]
        hr_interp = interp1d(t_points, hr_curve, kind='quadratic')

        # HRV curve: baseline -> decrease -> plateau -> recovery -> baseline
        hrv_curve = [
            self.baseline_hrv,
            self.baseline_hrv - 0.6 * (self.baseline_hrv - min_hrv),
            min_hrv,
            self.baseline_hrv - 0.3 * (self.baseline_hrv - min_hrv),
            self.baseline_hrv
        ]
        hrv_interp = interp1d(t_points, hrv_curve, kind='quadratic')

        # Generate values with timing offsets (HR changes lag SpO2)
        t_values = np.linspace(0, 1, duration_minutes)
        spo2_values = spo2_interp(t_values) + np.random.normal(0, 0.7, duration_minutes)

        # HR response lags behind SpO2 drop
        t_hr = np.clip(t_values - 0.05, 0, 1)  # Shift by 5% of the time
        hr_values = hr_interp(t_hr) + np.random.normal(0, 3, duration_minutes)

        # HRV response also lags
        t_hrv = np.clip(t_values - 0.03, 0, 1)  # Shift by 3% of the time
        hrv_values = hrv_interp(t_hrv) + np.random.normal(0, 2, duration_minutes)

        # Ensure physiological ranges
        spo2_values = np.clip(spo2_values, 70, 100)
        hr_values = np.clip(hr_values, 40, 200)
        hrv_values = np.clip(hrv_values, 5, 100)

        # Create simulated nodes
        simulated_nodes = []

        # Get a sample embedding to use as template (from sleep states)
        sleep_embeddings = [data['embedding'] for _, data in
                           filter(lambda x: x[1].get('sleep_stage', 0) > 0,
                                 self.G.nodes(data=True))]

        if sleep_embeddings:
            template_embedding = np.mean(random.sample(sleep_embeddings,
                                                      min(10, len(sleep_embeddings))), axis=0)
        else:
            # Fallback to any nodes
            template_embedding = np.mean([data['embedding'] for _, data in
                                          random.sample(list(self.G.nodes(data=True)), 10)], axis=0)

        for i in range(duration_minutes):
            # Perturb the embedding based on hypoxia state
            perturbed_embedding = template_embedding + 0.1 * np.random.randn(len(template_embedding))

            # Sleep stage - mostly deep or REM
            sleep_stage = np.random.choice([2, 3], p=[0.3, 0.7])

            # Create node ID - use negative numbers for virtual nodes
            node_id = -1000 - i

            # Node attributes
            attrs = {
                'timestamp': timestamps[i],
                'embedding': perturbed_embedding,
                'hr': hr_values[i],
                'spo2': spo2_values[i],
                'activity_level': 0.05 + 0.05 * np.random.random(),  # Very low during sleep
                'sleep_stage': sleep_stage,
                'stress': 30 + 30 * severity + np.random.normal(0, 5),  # Stress increases during hypoxia
                'hrv': hrv_values[i],
                'is_virtual': True,  # Mark as virtual node
                'virtual_type': 'hypoxia'
            }

            # Add to list
            simulated_nodes.append((node_id, attrs))

        return simulated_nodes

    def augment_graph(self, num_exercise=3, num_hypoxia=2):
        """
        Augment the graph with simulated nodes and edges

        Args:
            num_exercise: Number of exercise events to simulate
            num_hypoxia: Number of hypoxic events to simulate

        Returns:
            Augmented graph
        """
        G_aug = self.G.copy()
        added_nodes = 0

        # Add exercise events
        for i in range(num_exercise):
            # Random duration between 20-45 minutes
            duration = np.random.randint(20, 46)
            # Random intensity between 0.6-0.9
            intensity = 0.6 + 0.3 * np.random.random()

            # Generate virtual nodes
            exercise_nodes = self.simulate_exercise(duration, intensity)

            # Add to graph
            for node_id, attrs in exercise_nodes:
                G_aug.add_node(node_id, **attrs)
                added_nodes += 1

            # Connect sequentially
            for i in range(len(exercise_nodes) - 1):
                G_aug.add_edge(
                    exercise_nodes[i][0],
                    exercise_nodes[i+1][0],
                    edge_type='sequential',
                    weight=1.0,
                    is_virtual=True
                )

        # Add hypoxic events
        for i in range(num_hypoxia):
            # Random duration between 15-90 minutes
            duration = np.random.randint(15, 91)
            # Random severity between 0.3-0.8
            severity = 0.3 + 0.5 * np.random.random()

            # Generate virtual nodes
            hypoxia_nodes = self.simulate_hypoxia(duration, severity)

            # Add to graph
            for node_id, attrs in hypoxia_nodes:
                G_aug.add_node(node_id, **attrs)
                added_nodes += 1

            # Connect sequentially
            for i in range(len(hypoxia_nodes) - 1):
                G_aug.add_edge(
                    hypoxia_nodes[i][0],
                    hypoxia_nodes[i+1][0],
                    edge_type='sequential',
                    weight=1.0,
                    is_virtual=True
                )

        print(f"Added {added_nodes} virtual nodes to the graph")

        # Add similarity edges between virtual and real nodes
        self._connect_virtual_nodes(G_aug)

        return G_aug

    def _connect_virtual_nodes(self, G_aug, k=3):
        """
        Connect virtual nodes to similar real nodes

        Args:
            G_aug: Augmented graph
            k: Number of connections per virtual node
        """
        print("Connecting virtual nodes to similar real nodes...")

        # Get all real and virtual nodes
        real_nodes = [(n, data['embedding']) for n, data in G_aug.nodes(data=True)
                     if not data.get('is_virtual', False)]
        virtual_nodes = [(n, data['embedding']) for n, data in G_aug.nodes(data=True)
                        if data.get('is_virtual', False)]

        if not real_nodes or not virtual_nodes:
            return

        # Convert to arrays for faster computation
        real_indices = [n[0] for n in real_nodes]
        real_embeddings = np.array([n[1] for n in real_nodes])

        # For each virtual node, find k similar real nodes
        for v_node, v_embedding in tqdm(virtual_nodes):
            # Compute similarities
            similarities = np.dot(real_embeddings, v_embedding) / (
                np.linalg.norm(real_embeddings, axis=1) * np.linalg.norm(v_embedding)
            )

            # Get top k similar nodes
            top_k_indices = np.argsort(similarities)[-k:]

            # Add edges
            for idx in top_k_indices:
                real_node = real_indices[idx]
                similarity = similarities[idx]

                # Add bidirectional similarity edges
                G_aug.add_edge(
                    v_node,
                    real_node,
                    edge_type='similarity',
                    weight=similarity,
                    is_virtual=True
                )
                G_aug.add_edge(
                    real_node,
                    v_node,
                    edge_type='similarity',
                    weight=similarity,
                    is_virtual=True
                )

    def plot_augmented_graph_sample(self, G_aug, n_nodes=150, figsize=(14, 10)):
        """
        Plot a sample of the augmented graph

        Args:
            G_aug: Augmented graph
            n_nodes: Number of nodes to show
            figsize: Figure size
        """
        # Select a sample of nodes with at least some virtual nodes
        virtual_nodes = [n for n, data in G_aug.nodes(data=True)
                        if data.get('is_virtual', False)]
        real_nodes = [n for n, data in G_aug.nodes(data=True)
                     if not data.get('is_virtual', False)]

        # Take all virtual nodes and some real nodes
        selected_virtual = virtual_nodes[:min(len(virtual_nodes), n_nodes//3)]
        selected_real = np.random.choice(real_nodes,
                                        size=min(len(real_nodes), n_nodes - len(selected_virtual)),
                                        replace=False)

        sample_nodes = list(selected_virtual) + list(selected_real)
        G_sample = G_aug.subgraph(sample_nodes)

        # Create positions
        pos = {}
        for node in G_sample.nodes():
            # Use timestamp for x-axis
            timestamp = G_sample.nodes[node]['timestamp']
            timestamp_float = timestamp.timestamp()

            # Use activity level for y-axis
            activity = G_sample.nodes[node]['activity_level']

            pos[node] = (timestamp_float, activity)

        # Create edge colors based on type
        edge_colors = []
        widths = []
        for u, v, data in G_sample.edges(data=True):
            if data.get('is_virtual', False):
                if data['edge_type'] == 'sequential':
                    edge_colors.append('purple')
                    widths.append(1)
                else:  # similarity
                    edge_colors.append('orange')
                    widths.append(0.5)
            else:
                if data['edge_type'] == 'sequential':
                    edge_colors.append('blue')
                    widths.append(1)
                else:  # similarity
                    edge_colors.append('red')
                    widths.append(0.5)

        # Node colors based on real/virtual status
        node_colors = []
        node_sizes = []
        for node in G_sample.nodes():
            data = G_sample.nodes[node]

            if data.get('is_virtual', False):
                if data.get('virtual_type') == 'exercise':
                    node_colors.append('lime')  # Virtual exercise
                else:
                    node_colors.append('magenta')  # Virtual hypoxia
                node_sizes.append(50)
            else:
                sleep = data['sleep_stage']
                if sleep == 0:
                    node_colors.append('lightyellow')  # Awake
                elif sleep == 1:
                    node_colors.append('lightblue')  # Light sleep
                elif sleep == 2:
                    node_colors.append('blue')  # Deep sleep
                else:
                    node_colors.append('purple')  # REM
                node_sizes.append(30)

        # Create plot
        plt.figure(figsize=figsize)
        nx.draw_networkx(
            G_sample, pos=pos,
            with_labels=False,
            node_size=node_sizes,
            node_color=node_colors,
            edge_color=edge_colors,
            width=widths,
            alpha=0.7
        )

        plt.title('Augmented Temporal Phenotype Graph (Sample)')
        plt.xlabel('Time')
        plt.ylabel('Activity Level')

        # Legend
        plt.plot([0], [0], color='blue', label='Real Sequential Edge')
        plt.plot([0], [0], color='red', label='Real Similarity Edge')
        plt.plot([0], [0], color='purple', label='Virtual Sequential Edge')
        plt.plot([0], [0], color='orange', label='Virtual Similarity Edge')
        plt.plot([0], [0], marker='o', color='lightyellow', label='Awake', linestyle='')
        plt.plot([0], [0], marker='o', color='lightblue', label='Light Sleep', linestyle='')
        plt.plot([0], [0], marker='o', color='blue', label='Deep Sleep', linestyle='')
        plt.plot([0], [0], marker='o', color='purple', label='REM Sleep', linestyle='')
        plt.plot([0], [0], marker='o', color='lime', label='Virtual Exercise', linestyle='', markersize=10)
        plt.plot([0], [0], marker='o', color='magenta', label='Virtual Hypoxia', linestyle='', markersize=10)

        plt.legend()
        plt.tight_layout()
        plt.show()

    def save_augmented_graph(self, G_aug, filename='augmented_phenotype_graph.pkl'):
        """Save the augmented graph to a file"""
        with open(filename, 'wb') as f:
            pickle.dump(G_aug, f)
        print(f"Augmented graph saved to '{filename}'")

# Create the simulator
simulator = PhysiologicalSimulator(G, data, user_meta)

# Augment the graph
G_aug = simulator.augment_graph(num_exercise=3, num_hypoxia=2)
print(f"Augmented graph has {len(G_aug.nodes)} nodes and {len(G_aug.edges)} edges")

# Plot the augmented graph
simulator.plot_augmented_graph_sample(G_aug, n_nodes=150)

# Save the augmented graph
simulator.save_augmented_graph(G_aug)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool
from torch_geometric.data import Data, Batch
import networkx as nx
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import pickle
import matplotlib.pyplot as plt

# Load the augmented graph
with open('augmented_phenotype_graph.pkl', 'rb') as f:
    G_aug = pickle.load(f)

print(f"Loaded augmented graph with {len(G_aug.nodes)} nodes and {len(G_aug.edges)} edges")

class GraphDatasetGenerator:
    def __init__(self, G, window_size=60, stride=10):
        """
        Create a dataset of graph windows from the temporal graph

        Args:
            G: NetworkX graph (temporal phenotype graph)
            window_size: Window size in nodes
            stride: Stride for window creation
        """
        self.G = G
        self.window_size = window_size
        self.stride = stride

        # Get sorted nodes by timestamp
        self.sorted_nodes = sorted(
            list(G.nodes()),
            key=lambda n: G.nodes[n]['timestamp']
        )

        # Create windows
        self.windows = []
        for i in range(0, len(self.sorted_nodes) - window_size, stride):
            self.windows.append(self.sorted_nodes[i:i+window_size])

        print(f"Created {len(self.windows)} graph windows")

        # Node feature dimension
        sample_node = self.sorted_nodes[0]
        sample_embedding = G.nodes[sample_node]['embedding']
        self.embedding_dim = len(sample_embedding)

        # Define node features to extract
        self.node_features = [
            'hr', 'spo2', 'activity_level', 'sleep_stage',
            'stress', 'hrv'
        ]

    def get_window_data(self, window_idx):
        """
        Extract a PyTorch Geometric Data object for a window

        Args:
            window_idx: Index of the window

        Returns:
            PyTorch Geometric Data object
        """
        # Get nodes in this window
        window_nodes = self.windows[window_idx]

        # Extract subgraph
        subgraph = self.G.subgraph(window_nodes)

        # Create node mapping (for edge indices)
        node_mapping = {node: i for i, node in enumerate(window_nodes)}

        # Get node features
        node_embeddings = []
        node_features = []

        for node in window_nodes:
            # Get embedding
            embedding = self.G.nodes[node]['embedding']
            node_embeddings.append(embedding)

            # Get other features
            features = [self.G.nodes[node].get(feat, 0) for feat in self.node_features]
            node_features.append(features)

        # Combine embeddings and features
        node_embeddings = torch.tensor(np.array(node_embeddings), dtype=torch.float)
        node_features = torch.tensor(np.array(node_features), dtype=torch.float)

        # Normalize node features
        node_features = (node_features - node_features.mean(dim=0)) / (node_features.std(dim=0) + 1e-6)

        # Combine all node features
        x = torch.cat([node_embeddings, node_features], dim=1)

        # Get edges
        edge_index = []
        edge_attr = []

        for u, v, data in subgraph.edges(data=True):
            # Convert node IDs to local indices
            src_idx = node_mapping[u]
            dst_idx = node_mapping[v]

            edge_index.append([src_idx, dst_idx])

            # Edge attributes
            edge_type = 1.0 if data.get('edge_type') == 'sequential' else 0.0
            is_virtual = 1.0 if data.get('is_virtual', False) else 0.0
            weight = data.get('weight', 1.0)

            edge_attr.append([edge_type, is_virtual, weight])

        # Convert to PyTorch tensors
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)

        # Create Data object
        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            num_nodes=len(window_nodes)
        )

        return data

    def generate_afib_labels(self):
        """
        Generate synthetic AFib labels for windows

        Returns:
            Tensor of binary labels for each window
        """
        # For synthetic data, we'll define AFib as windows containing:
        # 1. HR variability (rapid changes)
        # 2. Low HRV
        # 3. Some virtual hypoxia nodes

        labels = []

        for window_idx in tqdm(range(len(self.windows)), desc="Generating AFib labels"):
            window_nodes = self.windows[window_idx]

            # Check for HR variability
            hrs = [self.G.nodes[n]['hr'] for n in window_nodes]
            hr_std = np.std(hrs)

            # Check for low HRV
            hrvs = [self.G.nodes[n]['hrv'] for n in window_nodes]
            hrv_mean = np.mean(hrvs)

            # Check for virtual hypoxia nodes
            has_hypoxia = any(
                self.G.nodes[n].get('virtual_type') == 'hypoxia'
                for n in window_nodes
                if self.G.nodes[n].get('is_virtual', False)
            )

            # Combine criteria
            is_afib = (hr_std > 15) and (hrv_mean < 30) and has_hypoxia

            labels.append(1.0 if is_afib else 0.0)

        return torch.tensor(labels, dtype=torch.float)

    def generate_diabetes_labels(self):
        """
        Generate synthetic diabetes risk labels for windows

        Returns:
            Tensor of binary labels for each window
        """
        # For synthetic data, we'll define diabetes risk as windows containing:
        # 1. High resting heart rate
        # 2. Low activity levels
        # 3. High stress levels

        labels = []

        for window_idx in tqdm(range(len(self.windows)), desc="Generating diabetes labels"):
            window_nodes = self.windows[window_idx]

            # Get resting HR (during low activity)
            hr_rest = [
                self.G.nodes[n]['hr']
                for n in window_nodes
                if self.G.nodes[n]['activity_level'] < 0.2
            ]

            if not hr_rest:  # No resting periods in this window
                labels.append(0.0)
                continue

            mean_rest_hr = np.mean(hr_rest)

            # Get mean activity level
            mean_activity = np.mean([self.G.nodes[n]['activity_level'] for n in window_nodes])

            # Get mean stress level
            mean_stress = np.mean([self.G.nodes[n]['stress'] for n in window_nodes])

            # Combine criteria
            risk_score = (mean_rest_hr > 75) + (mean_activity < 0.25) + (mean_stress > 60)
            is_diabetes_risk = risk_score >= 2  # At least 2 of 3 criteria

            labels.append(1.0 if is_diabetes_risk else 0.0)

        return torch.tensor(labels, dtype=torch.float)

    def generate_copd_labels(self):
        """
        Generate synthetic COPD risk labels for windows

        Returns:
            Tensor of binary labels for each window
        """
        # For synthetic data, we'll define COPD risk as windows containing:
        # 1. Low SpO2 levels
        # 2. High heart rate with low activity
        # 3. Presence of virtual hypoxia events

        labels = []

        for window_idx in tqdm(range(len(self.windows)), desc="Generating COPD labels"):
            window_nodes = self.windows[window_idx]

            # Check for low SpO2
            spo2_values = [self.G.nodes[n]['spo2'] for n in window_nodes]
            min_spo2 = np.min(spo2_values)
            mean_spo2 = np.mean(spo2_values)

            # Check for high HR with low activity
            hr_values = []
            activity_values = []

            for n in window_nodes:
                hr = self.G.nodes[n]['hr']
                activity = self.G.nodes[n]['activity_level']

                if activity < 0.3:  # Only consider low activity periods
                    hr_values.append(hr)
                    activity_values.append(activity)

            if not hr_values:  # No low activity periods
                labels.append(0.0)
                continue

            mean_rest_hr = np.mean(hr_values)

            # Check for virtual hypoxia nodes
            has_hypoxia = any(
                self.G.nodes[n].get('virtual_type') == 'hypoxia'
                for n in window_nodes
                if self.G.nodes[n].get('is_virtual', False)
            )

            # Combine criteria
            is_copd_risk = (min_spo2 < 90) and (mean_spo2 < 95) and (mean_rest_hr > 80 or has_hypoxia)

            labels.append(1.0 if is_copd_risk else 0.0)

        return torch.tensor(labels, dtype=torch.float)

    def create_dataset(self, disease_type='afib'):
        """
        Create a complete dataset for a disease type

        Args:
            disease_type: 'afib', 'diabetes', or 'copd'

        Returns:
            List of Data objects, labels
        """
        # Generate labels based on disease type
        if disease_type.lower() == 'afib':
            labels = self.generate_afib_labels()
        elif disease_type.lower() == 'diabetes':
            labels = self.generate_diabetes_labels()
        elif disease_type.lower() == 'copd':
            labels = self.generate_copd_labels()
        else:
            raise ValueError(f"Unknown disease type: {disease_type}")

        # Create dataset
        dataset = []
        for i in tqdm(range(len(self.windows)), desc=f"Creating {disease_type} dataset"):
            data = self.get_window_data(i)
            data.y = labels[i]
            dataset.append(data)

        return dataset, labels

class DiseaseMotifGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_layers=3):
        """
        GNN for detecting disease motifs in the graph

        Args:
            input_dim: Dimension of input node features
            hidden_dim: Hidden dimension
            num_layers: Number of GNN layers
        """
        super(DiseaseMotifGNN, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Input projection
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        # Graph convolution layers
        self.convs = nn.ModuleList()
        for i in range(num_layers):
            # Alternating GCN and GAT layers
            if i % 2 == 0:
                self.convs.append(GCNConv(hidden_dim, hidden_dim))
            else:
                self.convs.append(GATConv(hidden_dim, hidden_dim))

        # Pooling projection
        self.pool_proj = nn.Linear(hidden_dim * 2, hidden_dim)

        # Output layers
        self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc2 = nn.Linear(hidden_dim // 2, 1)

        # Dropout
        self.dropout = nn.Dropout(0.2)

    def forward(self, data):
        """Forward pass"""
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # Initial projection
        x = self.input_proj(x)
        x = F.relu(x)
        x = self.dropout(x)

        # Graph convolutions
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)

        # Global pooling (combine mean and max)
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)

        # Pooling projection
        x = self.pool_proj(x)
        x = F.relu(x)
        x = self.dropout(x)

        # Output layers
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return torch.sigmoid(x).squeeze()

def train_model(model, train_loader, val_loader, num_epochs=50, lr=0.001):
    """
    Train a disease motif GNN model

    Args:
        model: Model to train
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs: Number of training epochs
        lr: Learning rate

    Returns:
        Trained model, training history
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCELoss()

    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    # Early stopping
    best_val_loss = float('inf')
    patience = 10
    counter = 0

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            batch = batch.to(device)

            # Forward pass
            outputs = model(batch)
            loss = criterion(outputs, batch.y)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update statistics
            train_loss += loss.item() * batch.num_graphs

            # Calculate accuracy
            preds = (outputs > 0.5).float()
            train_correct += (preds == batch.y).sum().item()
            train_total += batch.y.size(0)

        train_loss /= train_total
        train_acc = train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                batch = batch.to(device)

                # Forward pass
                outputs = model(batch)
                loss = criterion(outputs, batch.y)

                # Update statistics
                val_loss += loss.item() * batch.num_graphs

                # Calculate accuracy
                preds = (outputs > 0.5).float()
                val_correct += (preds == batch.y).sum().item()
                val_total += batch.y.size(0)

        val_loss /= val_total
        val_acc = val_correct / val_total

        # Update history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0

            # Save best model
            torch.save(model.state_dict(), f'best_model.pt')
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Load best model
    model.load_state_dict(torch.load('best_model.pt'))

    return model, history

def plot_training_history(history):
    """Plot training and validation metrics"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Loss
    ax1.plot(history['train_loss'], label='Train Loss')
    ax1.plot(history['val_loss'], label='Val Loss')
    ax1.set_title('Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()

    # Accuracy
    ax2.plot(history['train_acc'], label='Train Acc')
    ax2.plot(history['val_acc'], label='Val Acc')
    ax2.set_title('Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()

    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()

# Create dataset generator
dataset_gen = GraphDatasetGenerator(G_aug, window_size=60, stride=10)

# Create datasets for different diseases
afib_dataset, afib_labels = dataset_gen.create_dataset(disease_type='afib')
diabetes_dataset, diabetes_labels = dataset_gen.create_dataset(disease_type='diabetes')
copd_dataset, copd_labels = dataset_gen.create_dataset(disease_type='copd')

print(f"AFib dataset: {len(afib_dataset)} samples, {sum(afib_labels).item()} positive")
print(f"Diabetes dataset: {len(diabetes_dataset)} samples, {sum(diabetes_labels).item()} positive")
print(f"COPD dataset: {len(copd_dataset)} samples, {sum(copd_labels).item()} positive")

# Let's train a model for AFib detection as an example
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split

# Split dataset into train/val
train_indices, val_indices = train_test_split(
    range(len(afib_dataset)),
    test_size=0.2,
    stratify=afib_labels,
    random_state=42
)

train_dataset = [afib_dataset[i] for i in train_indices]
val_dataset = [afib_dataset[i] for i in val_indices]

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Create and train model
input_dim = afib_dataset[0].x.shape[1]  # Node feature dimension
model = DiseaseMotifGNN(input_dim=input_dim, hidden_dim=64, num_layers=3)

trained_model, history = train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=30,
    lr=0.001
)

# Plot training history
plot_training_history(history)

# Save trained model
torch.save(trained_model.state_dict(), 'afib_gnn_model.pt')
print("AFib GNN model saved to 'afib_gnn_model.pt'")

# Optionally train models for other diseases
# For this example, we'll skip training the other models
# But you could uncomment these sections to train them

'''
# Diabetes model
diabetes_train_indices, diabetes_val_indices = train_test_split(
    range(len(diabetes_dataset)),
    test_size=0.2,
    stratify=diabetes_labels,
    random_state=42
)

diabetes_train_dataset = [diabetes_dataset[i] for i in diabetes_train_indices]
diabetes_val_dataset = [diabetes_dataset[i] for i in diabetes_val_indices]

diabetes_train_loader = DataLoader(diabetes_train_dataset, batch_size=32, shuffle=True)
diabetes_val_loader = DataLoader(diabetes_val_dataset, batch_size=32, shuffle=False)

diabetes_model = DiseaseMotifGNN(input_dim=input_dim, hidden_dim=64, num_layers=3)

diabetes_trained_model, diabetes_history = train_model(
    diabetes_model,
    diabetes_train_loader,
    diabetes_val_loader,
    num_epochs=30,
    lr=0.001
)

torch.save(diabetes_trained_model.state_dict(), 'diabetes_gnn_model.pt')
print("Diabetes GNN model saved to 'diabetes_gnn_model.pt'")

# COPD model
copd_train_indices, copd_val_indices = train_test_split(
    range(len(copd_dataset)),
    test_size=0.2,
    stratify=copd_labels,
    random_state=42
)

copd_train_dataset = [copd_dataset[i] for i in copd_train_indices]
copd_val_dataset = [copd_dataset[i] for i in copd_val_indices]

copd_train_loader = DataLoader(copd_train_dataset, batch_size=32, shuffle=True)
copd_val_loader = DataLoader(copd_val_dataset, batch_size=32, shuffle=False)

copd_model = DiseaseMotifGNN(input_dim=input_dim, hidden_dim=64, num_layers=3)

copd_trained_model, copd_history = train_model(
    copd_model,
    copd_train_loader,
    copd_val_loader,
    num_epochs=30,
    lr=0.001
)

torch.save(copd_trained_model.state_dict(), 'copd_gnn_model.pt')
print("COPD GNN model saved to 'copd_gnn_model.pt'")
'''

# Save all dataset information
with open('disease_datasets.pkl', 'wb') as f:
    pickle.dump({
        'afib': (afib_dataset, afib_labels),
        'diabetes': (diabetes_dataset, diabetes_labels),
        'copd': (copd_dataset, copd_labels)
    }, f)

print("All disease datasets saved to 'disease_datasets.pkl'")

In [None]:
import torch
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
from tqdm.notebook import tqdm
import pickle
from torch_geometric.data import Data, Batch
import os

# Load the graph, data, and models
with open('augmented_phenotype_graph.pkl', 'rb') as f:
    G_aug = pickle.load(f)

# Load trained model
model = DiseaseMotifGNN(input_dim=206, hidden_dim=64, num_layers=3)  # 200 embedding dims + 6 features
model.load_state_dict(torch.load('afib_gnn_model.pt'))
model.eval()

print(f"Loaded augmented graph with {len(G_aug.nodes)} nodes and {len(G_aug.edges)} edges")
print("Loaded AFib detection model")

class SlidingWindowInference:
    def __init__(self, graph, model, window_size=60, stride=1):
        """
        Sliding window inference for real-time disease detection

        Args:
            graph: NetworkX graph
            model: Trained GNN model
            window_size: Window size in nodes
            stride: Stride for sliding window
        """
        self.G = graph
        self.model = model
        self.window_size = window_size
        self.stride = stride

        # Sort nodes by timestamp
        self.sorted_nodes = sorted(
            list(self.G.nodes()),
            key=lambda n: self.G.nodes[n]['timestamp']
        )

        # Define node features to extract
        self.node_features = [
            'hr', 'spo2', 'activity_level', 'sleep_stage',
            'stress', 'hrv'
        ]

        # Risk score history
        self.timestamps = []
        self.risk_scores = []

        # Device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)

    def get_window_data(self, start_idx):
        """
        Extract a PyTorch Geometric Data object for a window

        Args:
            start_idx: Starting index in sorted_nodes

        Returns:
            PyTorch Geometric Data object
        """
        # Ensure we don't go out of bounds
        if start_idx + self.window_size > len(self.sorted_nodes):
            return None

        # Get nodes in this window
        window_nodes = self.sorted_nodes[start_idx:start_idx+self.window_size]

        # Extract subgraph
        subgraph = self.G.subgraph(window_nodes)

        # Create node mapping (for edge indices)
        node_mapping = {node: i for i, node in enumerate(window_nodes)}

        # Get node features
        node_embeddings = []
        node_features = []

        for node in window_nodes:
            # Get embedding
            embedding = self.G.nodes[node]['embedding']
            node_embeddings.append(embedding)

            # Get other features
            features = [self.G.nodes[node].get(feat, 0) for feat in self.node_features]
            node_features.append(features)

        # Combine embeddings and features
        node_embeddings = torch.tensor(np.array(node_embeddings), dtype=torch.float)
        node_features = torch.tensor(np.array(node_features), dtype=torch.float)

        # Normalize node features
        node_features = (node_features - node_features.mean(dim=0)) / (node_features.std(dim=0) + 1e-6)

        # Combine all node features
        x = torch.cat([node_embeddings, node_features], dim=1)

        # Get edges
        edge_index = []
        edge_attr = []

        for u, v, data in subgraph.edges(data=True):
            # Convert node IDs to local indices
            src_idx = node_mapping[u]
            dst_idx = node_mapping[v]

            edge_index.append([src_idx, dst_idx])

            # Edge attributes
            edge_type = 1.0 if data.get('edge_type') == 'sequential' else 0.0
            is_virtual = 1.0 if data.get('is_virtual', False) else 0.0
            weight = data.get('weight', 1.0)

            edge_attr.append([edge_type, is_virtual, weight])

        # Convert to PyTorch tensors
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)

        # Create Data object
        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            num_nodes=len(window_nodes)
        )

        return data

    def run_inference(self, num_windows=None, threshold=0.5):
        """
        Run inference on sliding windows

        Args:
            num_windows: Number of windows to process (None = all)
            threshold: Risk score threshold for alert

        Returns:
            DataFrame with timestamps and risk scores
        """
        # Reset history
        self.timestamps = []
        self.risk_scores = []
        self.alerts = []

        # Calculate max windows
        max_windows = len(self.sorted_nodes) - self.window_size
        if num_windows is None:
            num_windows = max_windows
        else:
            num_windows = min(num_windows, max_windows)

        # Run inference
        for i in tqdm(range(0, num_windows, self.stride), desc="Running sliding window inference"):
            # Get window data
            data = self.get_window_data(i)
            if data is None:
                continue

            # Get middle node's timestamp
            mid_idx = i + self.window_size // 2
            timestamp = self.G.nodes[self.sorted_nodes[mid_idx]]['timestamp']

            # Run inference
            with torch.no_grad():
                data = data.to(self.device)
                batch = Batch.from_data_list([data])
                risk_score = self.model(batch).item()

            # Check for alert
            is_alert = risk_score > threshold

            # Store results
            self.timestamps.append(timestamp)
            self.risk_scores.append(risk_score)
            self.alerts.append(is_alert)

        # Create DataFrame
        results = pd.DataFrame({
            'timestamp': self.timestamps,
            'risk_score': self.risk_scores,
            'alert': self.alerts
        })

        return results

    def plot_risk_scores(self, results=None, figsize=(12, 6), threshold=0.5):
        """
        Plot risk scores over time

        Args:
            results: DataFrame with results (if None, use stored results)
            figsize: Figure size
            threshold: Risk score threshold to highlight
        """
        if results is None:
            results = pd.DataFrame({
                'timestamp': self.timestamps,
                'risk_score': self.risk_scores,
                'alert': self.alerts
            })

        plt.figure(figsize=figsize)

        # Plot risk scores
        plt.plot(results['timestamp'], results['risk_score'], 'b-', label='Risk Score')

        # Plot threshold
        plt.axhline(y=threshold, color='r', linestyle='--', label='Alert Threshold')

        # Highlight alerts
        alerts = results[results['alert']]
        if len(alerts) > 0:
            plt.scatter(alerts['timestamp'], alerts['risk_score'],
                       color='red', s=50, label='Alert')

        plt.title('Disease Risk Score Over Time')
        plt.xlabel('Time')
        plt.ylabel('Risk Score')
        plt.ylim(0, 1)
        plt.legend()
        plt.grid(alpha=0.3)
        plt.savefig('risk_scores.png')
        plt.show()

    def save_results(self, results, filename='risk_scores.csv'):
        """Save inference results to CSV"""
        results.to_csv(filename, index=False)
        print(f"Results saved to '{filename}'")

# Create sliding window inference
inference = SlidingWindowInference(G_aug, model, window_size=60, stride=5)

# Run inference
results = inference.run_inference(threshold=0.5)
print(f"Generated {len(results)} risk assessments")

# Plot risk scores
inference.plot_risk_scores(results)

# Save results
inference.save_results(results)

# Count alerts
num_alerts = results['alert'].sum()
print(f"Detected {num_alerts} potential disease events")

# Process the highest risk windows for analysis
high_risk_windows = results.nlargest(5, 'risk_score')
print("Top 5 highest risk windows:")
print(high_risk_windows)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
from scipy.stats import beta
import pickle

# Load the risk scores
risk_scores = pd.read_csv('risk_scores.csv')
print(f"Loaded {len(risk_scores)} risk assessments")

class BayesianPersonalization:
    def __init__(self, initial_prevalence=0.05):
        """
        Bayesian personalization for disease risk calibration

        Args:
            initial_prevalence: Initial disease prevalence (prior)
        """
        self.prevalence = initial_prevalence

        # Beta distribution parameters for prior
        # Shape the beta to have mean = prevalence
        self.alpha = 1 + initial_prevalence * 10  # Adding 1 for stability
        self.beta = 1 + (1 - initial_prevalence) * 10  # Adding 1 for stability

        # History of parameter updates
        self.alpha_history = [self.alpha]
        self.beta_history = [self.beta]
        self.prevalence_history = [self.prevalence]

        # User feedback history
        self.feedback_history = []

        print(f"Initialized Bayesian personalization with prevalence {initial_prevalence:.4f}")
        print(f"Initial beta parameters: alpha={self.alpha:.4f}, beta={self.beta:.4f}")

    def update_prior(self, risk_score, user_confirmed):
        """
        Update prior based on user feedback

        Args:
            risk_score: Model's risk score (0-1)
            user_confirmed: Whether user confirmed the alert (True/False)
        """
        # Store feedback
        self.feedback_history.append((risk_score, user_confirmed))

        # Update parameters
        if user_confirmed:
            # True positive: Increase alpha (more confidence in positive)
            # Weight by risk score (more confident update for high scores)
            self.alpha += risk_score
        else:
            # False positive: Increase beta (more confidence in negative)
            # Weight by risk score
            self.beta += risk_score

        # Update prevalence
        self.prevalence = self.alpha / (self.alpha + self.beta)

        # Update history
        self.alpha_history.append(self.alpha)
        self.beta_history.append(self.beta)
        self.prevalence_history.append(self.prevalence)

        print(f"Updated prevalence to {self.prevalence:.4f}")
        print(f"New beta parameters: alpha={self.alpha:.4f}, beta={self.beta:.4f}")

    def calibrate_risk(self, risk_score):
        """
        Calibrate risk score based on current prior

        Args:
            risk_score: Model's raw risk score

        Returns:
            Calibrated risk score
        """
        # Simple Bayesian calibration using prevalence
        # This is a simplified approach; more sophisticated methods exist

        # Adjust using Bayes' rule: P(A|B) = P(B|A)P(A)/P(B)
        # where P(A) is the prevalence, P(B|A) is the risk score
        # We'll treat the risk score as a likelihood ratio

        # Avoid division by zero
        risk_score = np.clip(risk_score, 0.01, 0.99)

        # Calculate likelihood ratio
        lr = risk_score / (1 - risk_score)

        # Apply Bayes' rule
        calibrated = (self.prevalence * lr) / (self.prevalence * lr + (1 - self.prevalence))

        return calibrated

    def plot_beta_distribution(self, figsize=(10, 6)):
        """Plot the current beta distribution"""
        x = np.linspace(0, 1, 1000)
        y = beta.pdf(x, self.alpha, self.beta)

        plt.figure(figsize=figsize)
        plt.plot(x, y, 'b-', lw=2, label='Current Beta Distribution')

        # Plot initial distribution
        y_initial = beta.pdf(x, self.alpha_history[0], self.beta_history[0])
        plt.plot(x, y_initial, 'r--', lw=2, label='Initial Beta Distribution')

        # Highlight the mean (prevalence)
        plt.axvline(x=self.prevalence, color='g', linestyle='-',
                   label=f'Current Mean: {self.prevalence:.4f}')
        plt.axvline(x=self.prevalence_history[0], color='orange', linestyle='--',
                   label=f'Initial Mean: {self.prevalence_history[0]:.4f}')

        plt.title('Personalized Disease Prevalence Prior')
        plt.xlabel('Prevalence')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(alpha=0.3)
        plt.savefig('beta_distribution.png')
        plt.show()

    def plot_prevalence_history(self, figsize=(10, 6)):
        """Plot the history of prevalence updates"""
        plt.figure(figsize=figsize)
        plt.plot(range(len(self.prevalence_history)), self.prevalence_history, 'b-', marker='o')

        plt.title('Personalized Disease Prevalence History')
        plt.xlabel('Update')
        plt.ylabel('Prevalence')
        plt.grid(alpha=0.3)
        plt.savefig('prevalence_history.png')
        plt.show()

    def save_state(self, filename='bayesian_personalization.pkl'):
        """Save the current state"""
        state = {
            'alpha': self.alpha,
            'beta': self.beta,
            'prevalence': self.prevalence,
            'alpha_history': self.alpha_history,
            'beta_history': self.beta_history,
            'prevalence_history': self.prevalence_history,
            'feedback_history': self.feedback_history
        }

        with open(filename, 'wb') as f:
            pickle.dump(state, f)

        print(f"Bayesian personalization state saved to '{filename}'")

    @classmethod
    def load_state(cls, filename='bayesian_personalization.pkl'):
        """Load state from file"""
        with open(filename, 'rb') as f:
            state = pickle.load(f)

        # Create instance
        instance = cls.__new__(cls)

        # Set attributes
        instance.alpha = state['alpha']
        instance.beta = state['beta']
        instance.prevalence = state['prevalence']
        instance.alpha_history = state['alpha_history']
        instance.beta_history = state['beta_history']
        instance.prevalence_history = state['prevalence_history']
        instance.feedback_history = state['feedback_history']

        print(f"Loaded Bayesian personalization state with prevalence {instance.prevalence:.4f}")

        return instance

# Simulate user feedback for demonstration
def simulate_user_feedback(risk_scores, true_positive_rate=0.7, false_positive_rate=0.2):
    """
    Simulate user feedback on alerts

    Args:
        risk_scores: DataFrame with risk scores
        true_positive_rate: Rate at which high-risk scores are confirmed
        false_positive_rate: Rate at which low-risk scores are confirmed

    Returns:
        DataFrame with added user_confirmed column
    """
    df = risk_scores.copy()

    # Simulate user confirmation based on risk score
    user_confirmed = []

    for score in df['risk_score']:
        if score > 0.5:  # High risk
            confirmed = np.random.random() < true_positive_rate
        else:  # Low risk
            confirmed = np.random.random() < false_positive_rate

        user_confirmed.append(confirmed)

    df['user_confirmed'] = user_confirmed

    return df

# Initial disease prevalence (we'll use approximate real-world values)
# AFib prevalence in general population ~2%
afib_prevalence = 0.02

# Create Bayesian personalizer
personalizer = BayesianPersonalization(initial_prevalence=afib_prevalence)

# Plot initial distribution
personalizer.plot_beta_distribution()

# Simulate user feedback
feedback_data = simulate_user_feedback(risk_scores)

# Process some user feedback
for i, (_, row) in enumerate(feedback_data.sample(10).iterrows()):
    print(f"\nProcessing feedback {i+1}:")
    print(f"  Risk score: {row['risk_score']:.4f}")
    print(f"  User confirmed: {row['user_confirmed']}")

    # Update prior based on feedback
    personalizer.update_prior(row['risk_score'], row['user_confirmed'])

# Plot updated distribution
personalizer.plot_beta_distribution()

# Plot prevalence history
personalizer.plot_prevalence_history()

# Calibrate all risk scores
calibrated_scores = []
for score in risk_scores['risk_score']:
    calibrated = personalizer.calibrate_risk(score)
    calibrated_scores.append(calibrated)

risk_scores['calibrated_risk'] = calibrated_scores

# Compare original vs. calibrated scores
plt.figure(figsize=(12, 6))
plt.plot(risk_scores['timestamp'], risk_scores['risk_score'], 'b-', label='Original Risk Score')
plt.plot(risk_scores['timestamp'], risk_scores['calibrated_risk'], 'g-', label='Calibrated Risk Score')

plt.title('Original vs. Calibrated Risk Scores')
plt.xlabel('Time')
plt.ylabel('Risk Score')
plt.legend()
plt.grid(alpha=0.3)
plt.savefig('original_vs_calibrated.png')
plt.show()

# Save personalizer state
personalizer.save_state()

# Save updated risk scores
risk_scores.to_csv('calibrated_risk_scores.csv', index=False)
print("Calibrated risk scores saved to 'calibrated_risk_scores.csv'")

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
import networkx as nx
import pickle
from datetime import datetime, timedelta
from matplotlib.colors import LinearSegmentedColormap
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

# Load all necessary data
risk_scores = pd.read_csv('calibrated_risk_scores.csv')
with open('augmented_phenotype_graph.pkl', 'rb') as f:
    G_aug = pickle.load(f)
with open('bayesian_personalization.pkl', 'rb') as f:
    bayes_state = pickle.load(f)

print(f"Loaded {len(risk_scores)} risk assessments")
print(f"Loaded graph with {len(G_aug.nodes)} nodes and {len(G_aug.edges)} edges")

class HealthMonitor:
    def __init__(self, risk_scores, graph, bayes_state=None):
        """
        Visualization and monitoring dashboard for GET-Phen

        Args:
            risk_scores: DataFrame with risk scores
            graph: NetworkX graph
            bayes_state: Bayesian personalization state
        """
        self.risk_scores = risk_scores
        self.G = graph
        self.bayes_state = bayes_state

        # Extract physiological data from graph
        self.extract_physiological_data()

    def extract_physiological_data(self):
        """Extract physiological time series from graph nodes"""
        # Sort nodes by timestamp
        sorted_nodes = sorted(
            list(self.G.nodes()),
            key=lambda n: self.G.nodes[n]['timestamp']
        )

        # Extract data
        timestamps = []
        hr_values = []
        spo2_values = []
        activity_values = []
        stress_values = []
        hrv_values = []
        sleep_values = []
        is_virtual = []

        for node in sorted_nodes:
            data = self.G.nodes[node]

            timestamps.append(data['timestamp'])
            hr_values.append(data['hr'])
            spo2_values.append(data['spo2'])
            activity_values.append(data['activity_level'])
            stress_values.append(data['stress'])
            hrv_values.append(data['hrv'])
            sleep_values.append(data['sleep_stage'])
            is_virtual.append(data.get('is_virtual', False))

        # Create DataFrame
        self.physio_data = pd.DataFrame({
            'timestamp': timestamps,
            'hr': hr_values,
            'spo2': spo2_values,
            'activity': activity_values,
            'stress': stress_values,
            'hrv': hrv_values,
            'sleep': sleep_values,
            'is_virtual': is_virtual
        })

        print(f"Extracted physiological data with {len(self.physio_data)} time points")

    def plot_risk_timeline(self, figsize=(15, 8)):
        """Plot risk scores and physiological data over time"""
        # Create figure with subplots
        fig, axes = plt.subplots(4, 1, figsize=figsize, sharex=True, gridspec_kw={'height_ratios': [2, 1, 1, 1]})

        # Risk score timeline
        ax1 = axes[0]
        ax1.plot(self.risk_scores['timestamp'], self.risk_scores['risk_score'],
                'b-', label='Original Risk')
        ax1.plot(self.risk_scores['timestamp'], self.risk_scores['calibrated_risk'],
                'g-', label='Calibrated Risk')

        # Add alert threshold
        ax1.axhline(y=0.5, color='r', linestyle='--', label='Alert Threshold')

        # Highlight alerts
        alerts = self.risk_scores[self.risk_scores['alert']]
        if len(alerts) > 0:
            ax1.scatter(alerts['timestamp'], alerts['calibrated_risk'],
                      color='red', s=50, label='Alert')

        ax1.set_title('Disease Risk Timeline')
        ax1.set_ylabel('Risk Score')
        ax1.set_ylim(0, 1)
        ax1.legend()
        ax1.grid(alpha=0.3)

        # HR and SpO2
        ax2 = axes[1]
        ax2.plot(self.physio_data['timestamp'], self.physio_data['hr'],
                'r-', label='Heart Rate')
        ax2.set_ylabel('Heart Rate (bpm)')
        ax2.legend(loc='upper left')

        ax2b = ax2.twinx()
        ax2b.plot(self.physio_data['timestamp'], self.physio_data['spo2'],
                 'b-', alpha=0.7, label='SpO2')
        ax2b.set_ylabel('SpO2 (%)')
        ax2b.set_ylim(90, 100)
        ax2b.legend(loc='upper right')
        ax2.grid(alpha=0.3)

        # Activity and Stress
        ax3 = axes[2]
        ax3.plot(self.physio_data['timestamp'], self.physio_data['activity'],
                'g-', label='Activity')
        ax3.set_ylabel('Activity (0-1)')
        ax3.set_ylim(0, 1)
        ax3.legend(loc='upper left')

        ax3b = ax3.twinx()
        ax3b.plot(self.physio_data['timestamp'], self.physio_data['stress'],
                 'orange', alpha=0.7, label='Stress')
        ax3b.set_ylabel('Stress (0-100)')
        ax3b.set_ylim(0, 100)
        ax3b.legend(loc='upper right')
        ax3.grid(alpha=0.3)

        # Sleep and HRV
        ax4 = axes[3]
        ax4.plot(self.physio_data['timestamp'], self.physio_data['hrv'],
                'purple', label='HRV')
        ax4.set_ylabel('HRV (ms)')
        ax4.legend(loc='upper left')
        ax4.set_xlabel('Time')

        ax4b = ax4.twinx()
        # Plot sleep as a step function
        ax4b.step(self.physio_data['timestamp'], self.physio_data['sleep'],
                 'c-', where='post', alpha=0.7, label='Sleep Stage')
        ax4b.set_ylabel('Sleep Stage')
        ax4b.set_yticks([0, 1, 2, 3])
        ax4b.set_yticklabels(['Wake', 'Light', 'Deep', 'REM'])
        ax4b.legend(loc='upper right')
        ax4.grid(alpha=0.3)

        # Format x-axis
        plt.gcf().autofmt_xdate()

        plt.tight_layout()
        plt.savefig('risk_timeline.png', dpi=300)
        plt.show()

    def plot_risk_distribution(self, figsize=(12, 6)):
        """Plot distribution of risk scores"""
        fig, axes = plt.subplots(1, 2, figsize=figsize)

        # Original risk distribution
        sns.histplot(self.risk_scores['risk_score'], bins=20, kde=True, ax=axes[0])
        axes[0].set_title('Original Risk Score Distribution')
        axes[0].set_xlabel('Risk Score')
        axes[0].set_ylabel('Frequency')
        axes[0].axvline(x=0.5, color='r', linestyle='--', label='Alert Threshold')
        axes[0].legend()

        # Calibrated risk distribution
        sns.histplot(self.risk_scores['calibrated_risk'], bins=20, kde=True, ax=axes[1])
        axes[1].set_title('Calibrated Risk Score Distribution')
        axes[1].set_xlabel('Risk Score')
        axes[1].set_ylabel('Frequency')
        axes[1].axvline(x=0.5, color='r', linestyle='--', label='Alert Threshold')
        axes[1].legend()

        plt.tight_layout()
        plt.savefig('risk_distribution.png', dpi=300)
        plt.show()

    def plot_snapshots(self, timestamps, figsize=(15, 12), window_size=60):
        """
        Plot graph snapshots at specified timestamps

        Args:
            timestamps: List of timestamps to plot
            figsize: Figure size
            window_size: Number of nodes to include in each snapshot
        """
        n_plots = len(timestamps)
        fig, axes = plt.subplots(1, n_plots, figsize=figsize)

        if n_plots == 1:
            axes = [axes]  # Make sure axes is iterable

        # Custom colormap for sleep stages
        sleep_colors = ['lightyellow', 'lightblue', 'blue', 'purple']
        sleep_cmap = LinearSegmentedColormap.from_list('sleep', sleep_colors, N=4)

        # Sort nodes by timestamp
        sorted_nodes = sorted(
            list(self.G.nodes()),
            key=lambda n: self.G.nodes[n]['timestamp']
        )

        for i, timestamp in enumerate(timestamps):
            ax = axes[i]

            # Find closest node to timestamp
            closest_node = min(
                sorted_nodes,
                key=lambda n: abs((self.G.nodes[n]['timestamp'] - timestamp).total_seconds())
            )
            closest_idx = sorted_nodes.index(closest_node)

            # Get window around closest node
            start_idx = max(0, closest_idx - window_size // 2)
            end_idx = min(len(sorted_nodes), start_idx + window_size)
            window_nodes = sorted_nodes[start_idx:end_idx]

            # Extract subgraph
            subgraph = self.G.subgraph(window_nodes)

            # Create positions
            pos = {}
            for node in subgraph.nodes():
                # Use timestamp for x-axis
                node_time = subgraph.nodes[node]['timestamp']
                time_diff = (node_time - timestamp).total_seconds() / 3600  # Hours

                # Use activity level for y-axis
                activity = subgraph.nodes[node]['activity_level']

                pos[node] = (time_diff, activity)

            # Create edge colors based on type
            edge_colors = []
            widths = []
            for u, v, data in subgraph.edges(data=True):
                if data.get('is_virtual', False):
                    if data['edge_type'] == 'sequential':
                        edge_colors.append('purple')
                        widths.append(1)
                    else:  # similarity
                        edge_colors.append('orange')
                        widths.append(0.5)
                else:
                    if data['edge_type'] == 'sequential':
                        edge_colors.append('blue')
                        widths.append(1)
                    else:  # similarity
                        edge_colors.append('red')
                        widths.append(0.5)

            # Node colors based on real/virtual status and sleep stage
            node_colors = []
            node_sizes = []
            for node in subgraph.nodes():
                data = subgraph.nodes[node]

                if data.get('is_virtual', False):
                    if data.get('virtual_type') == 'exercise':
                        node_colors.append('lime')  # Virtual exercise
                    else:
                        node_colors.append('magenta')  # Virtual hypoxia
                    node_sizes.append(50)
                else:
                    sleep = data['sleep_stage']
                    node_colors.append(sleep_colors[int(sleep)])
                    node_sizes.append(30)

            # Draw graph
            nx.draw_networkx(
                subgraph, pos=pos,
                with_labels=False,
                node_size=node_sizes,
                node_color=node_colors,
                edge_color=edge_colors,
                width=widths,
                alpha=0.7,
                ax=ax
            )

            # Set title and labels
            ax.set_title(f'Graph Snapshot at {timestamp.strftime("%Y-%m-%d %H:%M")}')
            ax.set_xlabel('Time Difference (hours)')
            ax.set_ylabel('Activity Level')
            ax.set_xlim(-12, 12)  # ±12 hours

            # Add legend to first plot only
            if i == 0:
                handles = [
                    plt.Line2D([0], [0], color='blue', label='Sequential Edge'),
                    plt.Line2D([0], [0], color='red', label='Similarity Edge'),
                    plt.Line2D([0], [0], marker='o', color='lightyellow', label='Awake', linestyle=''),
                    plt.Line2D([0], [0], marker='o', color='lightblue', label='Light Sleep', linestyle=''),
                    plt.Line2D([0], [0], marker='o', color='blue', label='Deep Sleep', linestyle=''),
                    plt.Line2D([0], [0], marker='o', color='purple', label='REM Sleep', linestyle=''),
                    plt.Line2D([0], [0], marker='o', color='lime', label='Exercise', linestyle=''),
                    plt.Line2D([0], [0], marker='o', color='magenta', label='Hypoxia', linestyle='')
                ]
                ax.legend(handles=handles, loc='upper right')

        plt.tight_layout()
        plt.savefig('graph_snapshots.png', dpi=300)
        plt.show()

    def create_interactive_dashboard(self):
        """Create an interactive dashboard using Plotly"""
        # Merge risk scores with physiological data by closest timestamp
        risk_scores_resampled = self.risk_scores.copy()
        physio_data_resampled = self.physio_data.copy()

        # Convert timestamps to pandas datetime
        risk_scores_resampled['timestamp'] = pd.to_datetime(risk_scores_resampled['timestamp'])
        physio_data_resampled['timestamp'] = pd.to_datetime(physio_data_resampled['timestamp'])

        # Set timestamp as index
        risk_scores_resampled.set_index('timestamp', inplace=True)
        physio_data_resampled.set_index('timestamp', inplace=True)

        # Resample to common frequency (5 minute intervals)
        risk_resampled = risk_scores_resampled.resample('5T').mean().interpolate()
        physio_resampled = physio_data_resampled.resample('5T').mean().interpolate()

        # Merge data
        merged_data = pd.merge(
            risk_resampled, physio_resampled,
            left_index=True, right_index=True,
            how='outer'
        )
        merged_data = merged_data.interpolate()

        # Reset index to get timestamp as column
        merged_data.reset_index(inplace=True)

        # Create plotly figure
        fig = make_subplots(
            rows=4, cols=1,
            shared_xaxes=True,
            vertical_spacing=0.05,
            subplot_titles=('Disease Risk', 'Heart Rate & SpO2', 'Activity & Stress', 'HRV & Sleep'),
            row_heights=[0.35, 0.25, 0.2, 0.2]
        )

        # Risk scores
        fig.add_trace(go.Scatter(
            x=merged_data['timestamp'],
            y=merged_data['risk_score'],
            mode='lines',
            name='Original Risk',
            line=dict(color='blue')
        ), row=1, col=1)

        fig.add_trace(go.Scatter(
            x=merged_data['timestamp'],
            y=merged_data['calibrated_risk'],
            mode='lines',
            name='Calibrated Risk',
            line=dict(color='green')
        ), row=1, col=1)

        # Add threshold line
        fig.add_trace(go.Scatter(
            x=merged_data['timestamp'],
            y=[0.5] * len(merged_data),
            mode='lines',
            name='Alert Threshold',
            line=dict(color='red', dash='dash')
        ), row=1, col=1)

        # Heart Rate and SpO2
        fig.add_trace(go.Scatter(
            x=merged_data['timestamp'],
            y=merged_data['hr'],
            mode='lines',
            name='Heart Rate',
            line=dict(color='red')
        ), row=2, col=1)

        fig.add_trace(go.Scatter(
            x=merged_data['timestamp'],
            y=merged_data['spo2'],
            mode='lines',
            name='SpO2',
            line=dict(color='cyan'),
            yaxis='y2'
        ), row=2, col=1)

        # Activity and Stress
        fig.add_trace(go.Scatter(
            x=merged_data['timestamp'],
            y=merged_data['activity'],
            mode='lines',
            name='Activity',
            line=dict(color='green')
        ), row=3, col=1)

        fig.add_trace(go.Scatter(
            x=merged_data['timestamp'],
            y=merged_data['stress'],
            mode='lines',
            name='Stress',
            line=dict(color='orange'),
            yaxis='y3'
        ), row=3, col=1)

        # HRV and Sleep
        fig.add_trace(go.Scatter(
            x=merged_data['timestamp'],
            y=merged_data['hrv'],
            mode='lines',
            name='HRV',
            line=dict(color='purple')
        ), row=4, col=1)

        fig.add_trace(go.Scatter(
            x=merged_data['timestamp'],
            y=merged_data['sleep'],
            mode='lines',
            name='Sleep Stage',
            line=dict(color='darkblue', shape='hv'),
            yaxis='y4'
        ), row=4, col=1)

        # Update layout
        fig.update_layout(
            height=900,
            width=1000,
            title='GET-Phen Health Monitoring Dashboard',
            xaxis=dict(title='Time'),
            yaxis=dict(title='Risk Score', range=[0, 1]),
            yaxis2=dict(title='Heart Rate (bpm) / SpO2 (%)', range=[60, 100]),
            yaxis3=dict(title='Activity (0-1) / Stress (0-100)', range=[0, 100]),
            yaxis4=dict(title='HRV (ms) / Sleep Stage', range=[0, 100]),
            legend=dict(x=1.05, y=1, orientation='v'),
            hovermode='x unified'
        )

        # Save as HTML
        fig.write_html('interactive_dashboard.html')

        return fig

    def plot_disease_motifs(self, num_motifs=3, figsize=(15, 15)):
        """
        Plot the most significant disease motifs detected

        Args:
            num_motifs: Number of motifs to plot
            figsize: Figure size
        """
        # Find time windows with highest risk scores
        high_risk_windows = self.risk_scores.nlargest(num_motifs, 'calibrated_risk')

        # Convert timestamps to datetime if needed
        high_risk_windows['timestamp'] = pd.to_datetime(high_risk_windows['timestamp'])

        # Extract timestamps
        timestamps = high_risk_windows['timestamp'].tolist()

        # Plot graph snapshots
        self.plot_snapshots(timestamps, figsize=figsize)

    def create_correlation_matrix(self, figsize=(10, 8)):
        """Create a correlation matrix of physiological metrics and risk scores"""
        # Merge risk scores with physiological data
        risk_scores_df = self.risk_scores.copy()
        physio_df = self.physio_data.copy()

        # Convert timestamps to datetime
        risk_scores_df['timestamp'] = pd.to_datetime(risk_scores_df['timestamp'])
        physio_df['timestamp'] = pd.to_datetime(physio_df['timestamp'])

        # Resample to common frequency
        risk_scores_df.set_index('timestamp', inplace=True)
        physio_df.set_index('timestamp', inplace=True)

        risk_resampled = risk_scores_df.resample('5T').mean().interpolate()
        physio_resampled = physio_df.resample('5T').mean().interpolate()

        # Merge data
        merged_data = pd.merge(
            risk_resampled, physio_resampled,
            left_index=True, right_index=True,
            how='outer'
        )
        merged_data = merged_data.interpolate()

        # Drop non-numeric columns
        numeric_data = merged_data.select_dtypes(include=['float64', 'int64'])

        # Calculate correlation matrix
        corr_matrix = numeric_data.corr()

        # Create heatmap
        plt.figure(figsize=figsize)
        sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0, fmt='.2f')
        plt.title('Correlation Matrix of Health Metrics and Risk Scores')
        plt.tight_layout()
        plt.savefig('correlation_matrix.png', dpi=300)
        plt.show()

        return corr_matrix

    def save_monitoring_summary(self, filename='monitoring_summary.html'):
        """Save a comprehensive monitoring summary as HTML"""
        import base64
        from io import BytesIO

        # Create HTML output
        html_output = """
        <!DOCTYPE html>
        <html>
        <head>
            <title>GET-Phen Health Monitoring Summary</title>
            <style>
                body {
                    font-family: Arial, sans-serif;
                    margin: 20px;
                    line-height: 1.6;
                }
                .header {
                    background-color: #4e73df;
                    color: white;
                    padding: 20px;
                    text-align: center;
                }
                .section {
                    margin: 20px 0;
                    padding: 15px;
                    border: 1px solid #ddd;
                    border-radius: 5px;
                }
                .figure {
                    margin: 20px 0;
                    text-align: center;
                }
                .stats {
                    display: flex;
                    justify-content: space-around;
                    flex-wrap: wrap;
                }
                .stat-box {
                    background-color: #f8f9fc;
                    border-left: 4px solid #4e73df;
                    padding: 15px;
                    margin: 10px;
                    width: 200px;
                }
                table {
                    border-collapse: collapse;
                    width: 100%;
                    margin: 20px 0;
                }
                th, td {
                    border: 1px solid #ddd;
                    padding: 8px;
                    text-align: left;
                }
                th {
                    background-color: #f2f2f2;
                }
            </style>
        </head>
        <body>
            <div class="header">
                <h1>GET-Phen Health Monitoring Summary</h1>
                <p>Generated on {date}</p>
            </div>
        """.format(date=datetime.now().strftime("%Y-%m-%d %H:%M"))

        # Add risk statistics section
        html_output += """
            <div class="section">
                <h2>Risk Assessment Summary</h2>
                <div class="stats">
                    <div class="stat-box">
                        <h3>Mean Risk Score</h3>
                        <p>{mean_risk:.3f}</p>
                    </div>
                    <div class="stat-box">
                        <h3>Max Risk Score</h3>
                        <p>{max_risk:.3f}</p>
                    </div>
                    <div class="stat-box">
                        <h3>Total Alerts</h3>
                        <p>{total_alerts}</p>
                    </div>
                    <div class="stat-box">
                        <h3>Alert Rate</h3>
                        <p>{alert_rate:.1f}%</p>
                    </div>
                </div>
            </div>
        """.format(
            mean_risk=self.risk_scores['calibrated_risk'].mean(),
            max_risk=self.risk_scores['calibrated_risk'].max(),
            total_alerts=self.risk_scores['alert'].sum(),
            alert_rate=self.risk_scores['alert'].mean() * 100
        )

        # Add risk timeline figure
        plt.figure(figsize=(12, 6))
        plt.plot(self.risk_scores['timestamp'], self.risk_scores['calibrated_risk'],
                'g-', label='Calibrated Risk')
        plt.axhline(y=0.5, color='r', linestyle='--', label='Alert Threshold')
        plt.title('Disease Risk Timeline')
        plt.xlabel('Time')
        plt.ylabel('Risk Score')
        plt.legend()
        plt.grid(alpha=0.3)

        # Save figure to BytesIO
        buf = BytesIO()
        plt.savefig(buf, format='png', dpi=100)
        buf.seek(0)
        img_str = base64.b64encode(buf.read()).decode('utf-8')
        plt.close()

        # Add figure to HTML
        html_output += """
            <div class="section">
                <h2>Risk Timeline</h2>
                <div class="figure">
                    <img src="data:image/png;base64,{img_data}" alt="Risk Timeline">
                </div>
            </div>
        """.format(img_data=img_str)

        # Add alerts table
        alerts = self.risk_scores[self.risk_scores['alert']].copy()
        alerts['timestamp'] = pd.to_datetime(alerts['timestamp']).dt.strftime("%Y-%m-%d %H:%M")
        alerts = alerts.sort_values('calibrated_risk', ascending=False).head(10)

        if len(alerts) > 0:
            alert_table = """
                <table>
                    <tr>
                        <th>Timestamp</th>
                        <th>Risk Score</th>
                        <th>Calibrated Risk</th>
                    </tr>
            """

            for _, row in alerts.iterrows():
                alert_table += """
                    <tr>
                        <td>{timestamp}</td>
                        <td>{risk:.3f}</td>
                        <td>{cal_risk:.3f}</td>
                    </tr>
                """.format(
                    timestamp=row['timestamp'],
                    risk=row['risk_score'],
                    cal_risk=row['calibrated_risk']
                )

            alert_table += "</table>"

            html_output += """
                <div class="section">
                    <h2>Top 10 Highest Risk Alerts</h2>
                    {table}
                </div>
            """.format(table=alert_table)

        # Add correlation matrix
        plt.figure(figsize=(10, 8))

        # Create a small correlation matrix excluding virtual node data
        merged_data = pd.merge(
            self.risk_scores[['timestamp', 'calibrated_risk']],
            self.physio_data[self.physio_data['is_virtual'] == False][
                ['timestamp', 'hr', 'spo2', 'activity', 'stress', 'hrv']
            ],
            on='timestamp',
            how='inner'
        )

        corr_matrix = merged_data.select_dtypes(include=['float64', 'int64']).corr()
        sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0, fmt='.2f')
        plt.title('Correlation of Health Metrics with Risk Scores')

        # Save figure to BytesIO
        buf = BytesIO()
        plt.savefig(buf, format='png', dpi=100)
        buf.seek(0)
        img_str = base64.b64encode(buf.read()).decode('utf-8')
        plt.close()

        # Add correlation matrix to HTML
        html_output += """
            <div class="section">
                <h2>Correlation Analysis</h2>
                <div class="figure">
                    <img src="data:image/png;base64,{img_data}" alt="Correlation Matrix">
                </div>
                <p>This matrix shows how different health metrics correlate with disease risk scores.</p>
            </div>
        """.format(img_data=img_str)

        # Add physiological stats
        html_output += """
            <div class="section">
                <h2>Physiological Statistics</h2>
                <div class="stats">
                    <div class="stat-box">
                        <h3>Mean Heart Rate</h3>
                        <p>{mean_hr:.1f} bpm</p>
                    </div>
                    <div class="stat-box">
                        <h3>Mean SpO2</h3>
                        <p>{mean_spo2:.1f}%</p>
                    </div>
                    <div class="stat-box">
                        <h3>Mean HRV</h3>
                        <p>{mean_hrv:.1f} ms</p>
                    </div>
                    <div class="stat-box">
                        <h3>Mean Activity</h3>
                        <p>{mean_activity:.2f}</p>
                    </div>
                </div>
            </div>
        """.format(
            mean_hr=self.physio_data[self.physio_data['is_virtual'] == False]['hr'].mean(),
            mean_spo2=self.physio_data[self.physio_data['is_virtual'] == False]['spo2'].mean(),
            mean_hrv=self.physio_data[self.physio_data['is_virtual'] == False]['hrv'].mean(),
            mean_activity=self.physio_data[self.physio_data['is_virtual'] == False]['activity'].mean()
        )

        # Close HTML
        html_output += """
            <div class="section">
                <h2>Analysis Summary</h2>
                <p>This report shows the results of GET-Phen disease detection from smartwatch data using graph neural networks.</p>
                <p>The system has detected {alert_count} potential disease events that exceed the risk threshold.</p>
                <p>The mean calibrated risk score is {mean_risk:.3f}.</p>
                <p>The monitoring period spans from {start_date} to {end_date}.</p>
            </div>

            <div class="section">
                <h2>Next Steps</h2>
                <ul>
                    <li>Continue data collection for longer-term trends</li>
                    <li>Refine the risk threshold based on validation data</li>
                    <li>Add more disease-specific models</li>
                    <li>Integrate with healthcare systems for alerts</li>
                </ul>
            </div>
        </body>
        </html>
        """.format(
            alert_count=self.risk_scores['alert'].sum(),
            mean_risk=self.risk_scores['calibrated_risk'].mean(),
            start_date=pd.to_datetime(self.risk_scores['timestamp']).min().strftime("%Y-%m-%d %H:%M"),
            end_date=pd.to_datetime(self.risk_scores['timestamp']).max().strftime("%Y-%m-%d %H:%M")
        )

        # Save HTML to file
        with open(filename, 'w') as f:
            f.write(html_output)

        print(f"Monitoring summary saved to '{filename}'")


# Create health monitor
monitor = HealthMonitor(risk_scores, G_aug, bayes_state)

# Create visualizations
monitor.plot_risk_timeline()
monitor.plot_risk_distribution()

# Plot disease motifs
monitor.plot_disease_motifs(num_motifs=3)

# Create correlation matrix
corr_matrix = monitor.create_correlation_matrix()

# Create and display interactive dashboard
dashboard = monitor.create_interactive_dashboard()
# This will save the dashboard as HTML
# To view it in Colab, you can use:
from IPython.display import IFrame
# Display the saved HTML (adjust width/height as needed)
display(IFrame('interactive_dashboard.html', width=1000, height=800))

# Save monitoring summary
monitor.save_monitoring_summary()
print("All visualizations and monitoring outputs have been generated.")

In [None]:
import torch
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import pickle
from datetime import datetime, timedelta
from tqdm.notebook import tqdm
from IPython.display import display, HTML, clear_output
import os
import time
import warnings
warnings.filterwarnings('ignore')

# Check if required files exist
required_files = [
    'smartwatch_data.pkl',
    'user_meta.json',
    'smartwatch_transformer.pt',
    'smartwatch_embeddings.pkl',
    'augmented_phenotype_graph.pkl',
    'afib_gnn_model.pt',
    'calibrated_risk_scores.csv',
    'bayesian_personalization.pkl'
]

missing_files = [f for f in required_files if not os.path.exists(f)]

if missing_files:
    print("WARNING: The following required files are missing:")
    for file in missing_files:
        print(f"  - {file}")
    print("\nPlease run the previous modules to generate these files.")
else:
    print("All required files are present. Ready to run the GET-Phen framework.")

class GETPhen:
    """
    Graph-Enabled Temporal Phenotyping (GET-Phen) Framework
    Integration of all components for health monitoring from smartwatch data
    """
    def __init__(self, load_pretrained=True):
        """
        Initialize the GET-Phen framework

        Args:
            load_pretrained: Whether to load pretrained models and data
        """
        self.models = {}
        self.data = {}
        self.components = {}

        if load_pretrained:
            self.load_pretrained_components()

    def load_pretrained_components(self):
        """Load all pretrained models and data"""
        print("Loading pretrained GET-Phen components...")

        try:
            # Load smartwatch data
            self.data['smartwatch_data'] = pd.read_pickle('smartwatch_data.pkl')

            # Load user metadata
            import json
            with open('user_meta.json', 'r') as f:
                self.data['user_meta'] = json.load(f)

            # Load embeddings
            self.data['embeddings'] = pd.read_pickle('smartwatch_embeddings.pkl')

            # Load graph
            with open('augmented_phenotype_graph.pkl', 'rb') as f:
                self.data['graph'] = pickle.load(f)

            # Load disease models
            from disease_motif_gnn import DiseaseMotifGNN

            # AFib model
            afib_model = DiseaseMotifGNN(input_dim=206, hidden_dim=64, num_layers=3)
            afib_model.load_state_dict(torch.load('afib_gnn_model.pt'))
            afib_model.eval()
            self.models['afib'] = afib_model

            # Load risk scores
            self.data['risk_scores'] = pd.read_csv('calibrated_risk_scores.csv')

            # Load Bayesian personalization
            from bayesian_updating import BayesianPersonalization
            self.components['bayes'] = BayesianPersonalization.load_state('bayesian_personalization.pkl')

            print("All components loaded successfully!")

            # Print data summary
            self._print_data_summary()

        except Exception as e:
            print(f"Error loading components: {e}")
            print("Please make sure all required modules have been run.")

    def _print_data_summary(self):
        """Print summary of loaded data"""
        print("\n=== GET-Phen Framework Data Summary ===")
        print(f"Smartwatch data: {len(self.data['smartwatch_data'])} time points")
        print(f"User metadata: {self.data['user_meta']}")
        print(f"Embeddings: {len(self.data['embeddings'])} vectors")
        print(f"Graph: {len(self.data['graph'].nodes)} nodes, {len(self.data['graph'].edges)} edges")
        print(f"Risk scores: {len(self.data['risk_scores'])} predictions")
        print(f"Models: {list(self.models.keys())}")
        print(f"Components: {list(self.components.keys())}")

    def run_end_to_end_demo(self, demo_length_hours=24, real_time=False):
        """
        Run an end-to-end demonstration of the GET-Phen framework

        Args:
            demo_length_hours: Number of hours of data to process
            real_time: Whether to simulate real-time processing
        """
        print("\n=== Running GET-Phen End-to-End Demo ===")

        # Prepare demo data - use a slice of the existing data
        data = self.data['smartwatch_data'].copy()
        data = data.iloc[:min(len(data), demo_length_hours * 60)]  # 1 minute intervals

        # Initialize storage for demo results
        timestamps = []
        risk_scores = []
        calibrated_scores = []
        hrs = []
        spo2s = []
        activities = []
        alerts = []

        # Create fig for live plotting
        if real_time:
            plt.figure(figsize=(14, 8))
            plt.ion()  # Enable interactive mode

        # Import necessary modules for real-time processing
        from self_supervised_model import SmartWatchTransformer
        from online_inference import SlidingWindowInference

        # Load models
        model = SmartWatchTransformer(input_dim=11, hidden_dim=128, output_dim=200)
        model.load_state_dict(torch.load('smartwatch_transformer.pt'))
        model.eval()

        # Process data points
        print(f"Processing {len(data)} data points ({demo_length_hours} hours)...")

        window_size = 30  # Process in 30-minute windows
        bayes = self.components['bayes']

        for i in tqdm(range(0, len(data) - window_size, window_size)):
            # Get current window
            current_window = data.iloc[i:i+window_size]

            # Extract features
            timestamp = current_window.iloc[-1]['timestamp']
            hr = current_window.iloc[-1]['hr']
            spo2 = current_window.iloc[-1]['spo2']
            activity = current_window.iloc[-1]['activity_level']

            # Simulate generating embedding
            embedding_idx = min(i + window_size // 2, len(self.data['embeddings']) - 1)
            embedding = self.data['embeddings'].iloc[embedding_idx]['embedding']

            # Simulate GNN risk prediction
            risk_idx = min(i + window_size // 2, len(self.data['risk_scores']) - 1)
            risk_score = self.data['risk_scores'].iloc[risk_idx]['risk_score']

            # Calibrate score with Bayesian personalization
            calibrated_score = bayes.calibrate_risk(risk_score)

            # Check for alert
            is_alert = calibrated_score > 0.5

            # Store results
            timestamps.append(timestamp)
            risk_scores.append(risk_score)
            calibrated_scores.append(calibrated_score)
            hrs.append(hr)
            spo2s.append(spo2)
            activities.append(activity)
            alerts.append(is_alert)

            # Live plotting if real-time mode is enabled
            if real_time and (i % (window_size * 2) == 0 or i == len(data) - window_size - 1):
                clear_output(wait=True)

                # Create plot
                plt.clf()
                fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), sharex=True, gridspec_kw={'height_ratios': [2, 1]})

                # Plot risk scores
                ax1.plot(timestamps, risk_scores, 'b-', label='Original Risk')
                ax1.plot(timestamps, calibrated_scores, 'g-', label='Calibrated Risk')
                ax1.axhline(y=0.5, color='r', linestyle='--', label='Alert Threshold')

                # Highlight alerts
                alert_indices = [i for i, alert in enumerate(alerts) if alert]
                if alert_indices:
                    alert_times = [timestamps[i] for i in alert_indices]
                    alert_scores = [calibrated_scores[i] for i in alert_indices]
                    ax1.scatter(alert_times, alert_scores, color='red', s=80, marker='o', label='Alert')

                ax1.set_title('GET-Phen Real-Time Disease Risk Monitoring')
                ax1.set_ylabel('Risk Score')
                ax1.set_ylim(0, 1)
                ax1.legend()
                ax1.grid(True, alpha=0.3)

                # Plot physiological data
                ax2.plot(timestamps, hrs, 'r-', label='Heart Rate')
                ax2.set_ylabel('Heart Rate (bpm)')
                ax2.legend(loc='upper left')

                ax2b = ax2.twinx()
                ax2b.plot(timestamps, spo2s, 'b-', alpha=0.7, label='SpO2')
                ax2b.set_ylabel('SpO2 (%)')
                ax2b.set_ylim(90, 100)
                ax2b.legend(loc='upper right')

                plt.xlabel('Time')
                plt.tight_layout()
                plt.show()

                # Pause to simulate real-time processing
                time.sleep(1)

        # Create final results dataframe
        results = pd.DataFrame({
            'timestamp': timestamps,
            'risk_score': risk_scores,
            'calibrated_risk': calibrated_scores,
            'hr': hrs,
            'spo2': spo2s,
            'activity': activities,
            'alert': alerts
        })

        # Final visualization
        plt.ioff()  # Disable interactive mode

        plt.figure(figsize=(14, 8))
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), sharex=True, gridspec_kw={'height_ratios': [2, 1]})

        # Plot risk scores
        ax1.plot(results['timestamp'], results['risk_score'], 'b-', label='Original Risk')
        ax1.plot(results['timestamp'], results['calibrated_risk'], 'g-', label='Calibrated Risk')
        ax1.axhline(y=0.5, color='r', linestyle='--', label='Alert Threshold')

        # Highlight alerts
        alerts = results[results['alert']]
        if len(alerts) > 0:
            ax1.scatter(alerts['timestamp'], alerts['calibrated_risk'], color='red', s=80, marker='o', label='Alert')

        ax1.set_title('GET-Phen End-to-End Demo Results')
        ax1.set_ylabel('Risk Score')
        ax1.set_ylim(0, 1)
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Plot physiological data
        ax2.plot(results['timestamp'], results['hr'], 'r-', label='Heart Rate')
        ax2.set_ylabel('Heart Rate (bpm)')
        ax2.legend(loc='upper left')

        ax2b = ax2.twinx()
        ax2b.plot(results['timestamp'], results['spo2'], 'b-', alpha=0.7, label='SpO2')
        ax2b.set_ylabel('SpO2 (%)')
        ax2b.set_ylim(90, 100)
        ax2b.legend(loc='upper right')

        plt.xlabel('Time')
        plt.tight_layout()
        plt.savefig('getphen_demo_results.png', dpi=300)
        plt.show()

        # Print summary
        print("\n=== GET-Phen Demo Results ===")
        print(f"Processed {len(results)} time points over {demo_length_hours} hours")
        print(f"Detected {sum(results['alert'])} potential disease events")
        print(f"Mean risk score: {results['risk_score'].mean():.4f}")
        print(f"Mean calibrated score: {results['calibrated_risk'].mean():.4f}")

        # Return results
        return results

    def process_new_data(self, new_data, window_size=60):
        """
        Process new smartwatch data through the full GET-Phen pipeline

        Args:
            new_data: DataFrame with new smartwatch data
            window_size: Size of processing window

        Returns:
            DataFrame with risk scores and alerts
        """
        print(f"Processing {len(new_data)} new data points...")

        # Setup components
        from self_supervised_model import SmartWatchTransformer, SmartWatchDataset
        from disease_motif_gnn import DiseaseMotifGNN

        try:
            # Step 1: Generate embeddings
            model = SmartWatchTransformer(input_dim=11, hidden_dim=128, output_dim=200)
            model.load_state_dict(torch.load('smartwatch_transformer.pt'))
            model.eval()

            # Create dataset from new data
            dataset = SmartWatchDataset(new_data, window_size=window_size, stride=window_size//2)

            # Generate embeddings
            embeddings = []
            indices = []

            with torch.no_grad():
                for i in range(len(dataset)):
                    batch = dataset[i]
                    masked_data = batch['masked_data'].unsqueeze(1)  # Add batch dimension

                    # Forward pass to get embeddings
                    batch_embeddings, _ = model(masked_data)

                    # Take mean embedding for the window
                    mean_embedding = batch_embeddings.mean(dim=0).numpy()

                    # Store embeddings and indices
                    embeddings.append(mean_embedding)
                    indices.append((batch['start_idx'] + batch['end_idx']) // 2)

            # Step 2: Create temporary graph nodes and edges
            G_temp = nx.DiGraph()

            # Add nodes with embeddings
            for i, (idx, embedding) in enumerate(zip(indices, embeddings)):
                node_data = new_data.iloc[idx].copy()

                # Create node attributes
                attrs = {
                    'timestamp': node_data['timestamp'],
                    'embedding': embedding[0],  # First element (batch dimension)
                    'hr': node_data['hr'],
                    'spo2': node_data['spo2'],
                    'activity_level': node_data['activity_level'],
                    'sleep_stage': node_data['sleep_stage'],
                    'stress': node_data['stress'],
                    'hrv': node_data['hrv'],
                }

                # Add node
                G_temp.add_node(i, **attrs)

            # Add sequential edges
            for i in range(len(embeddings) - 1):
                G_temp.add_edge(
                    i, i+1,
                    edge_type='sequential',
                    weight=1.0
                )

            # Step 3: Run GNN inference
            afib_model = self.models['afib']
            afib_model.eval()

            # Prepare input for GNN
            from torch_geometric.data import Data, Batch

            risk_scores = []
            timestamps = []

            for node_idx in G_temp.nodes():
                # Create a subgraph centered around this node
                neighbors = list(G_temp.neighbors(node_idx))
                subgraph_nodes = [node_idx] + neighbors

                # Skip if not enough nodes
                if len(subgraph_nodes) < 3:
                    continue

                subgraph = G_temp.subgraph(subgraph_nodes)

                # Create PyG data object
                node_mapping = {node: i for i, node in enumerate(subgraph_nodes)}

                # Get node features
                node_feats = []
                for node in subgraph_nodes:
                    embedding = G_temp.nodes[node]['embedding']
                    other_feats = [
                        G_temp.nodes[node]['hr'],
                        G_temp.nodes[node]['spo2'],
                        G_temp.nodes[node]['activity_level'],
                        G_temp.nodes[node]['sleep_stage'],
                        G_temp.nodes[node]['stress'],
                        G_temp.nodes[node]['hrv']
                    ]
                    all_feats = np.concatenate([embedding, other_feats])
                    node_feats.append(all_feats)

                # Convert to tensor
                x = torch.tensor(np.array(node_feats), dtype=torch.float)

                # Get edges
                edge_index = []
                edge_attr = []

                for u, v, data in subgraph.edges(data=True):
                    edge_index.append([node_mapping[u], node_mapping[v]])

                    edge_type = 1.0 if data.get('edge_type') == 'sequential' else 0.0
                    is_virtual = 1.0 if data.get('is_virtual', False) else 0.0
                    weight = data.get('weight', 1.0)

                    edge_attr.append([edge_type, is_virtual, weight])

                # Convert to tensors
                edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
                edge_attr = torch.tensor(edge_attr, dtype=torch.float)

                # Create data object
                data = Data(
                    x=x,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    num_nodes=len(subgraph_nodes)
                )

                # Run inference
                with torch.no_grad():
                    batch = Batch.from_data_list([data])
                    risk = afib_model(batch).item()

                # Store results
                risk_scores.append(risk)
                timestamps.append(G_temp.nodes[node_idx]['timestamp'])

            # Step 4: Apply Bayesian personalization
            calibrated_scores = []
            for score in risk_scores:
                calibrated = self.components['bayes'].calibrate_risk(score)
                calibrated_scores.append(calibrated)

            # Step 5: Create results dataframe
            results = pd.DataFrame({
                'timestamp': timestamps,
                'risk_score': risk_scores,
                'calibrated_risk': calibrated_scores,
                'alert': [score > 0.5 for score in calibrated_scores]
            })

            print(f"Processed {len(results)} windows, detected {sum(results['alert'])} potential issues.")

            return results

        except Exception as e:
            print(f"Error processing new data: {e}")
            return None

    def simulate_real_time_monitoring(self, duration_minutes=60, interval_seconds=5):
        """
        Simulate real-time monitoring with the GET-Phen framework

        Args:
            duration_minutes: Duration of simulation in minutes
            interval_seconds: Interval between updates in seconds
        """
        print(f"\n=== Simulating Real-Time Monitoring for {duration_minutes} minutes ===")

        # Get a slice of data to use for simulation
        data = self.data['smartwatch_data'].copy()

        # Determine starting point - use a period with some risk
        risk_scores = self.data['risk_scores']
        interesting_period = risk_scores[risk_scores['risk_score'] > 0.3].iloc[0]
        interesting_timestamp = pd.to_datetime(interesting_period['timestamp'])

        # Find closest timestamp in data
        start_idx = data['timestamp'].searchsorted(interesting_timestamp)
        start_idx = max(0, start_idx - 30)  # Back up a bit

        # Take a slice of data
        data_slice = data.iloc[start_idx:start_idx + duration_minutes]

        if len(data_slice) < duration_minutes:
            print(f"Warning: Not enough data for {duration_minutes} minutes. Using {len(data_slice)} points.")

        # Setup live plot
        plt.ion()
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), gridspec_kw={'height_ratios': [2, 1]})

        # Storage for real-time data
        live_timestamps = []
        live_hrs = []
        live_spo2s = []
        live_risk_scores = []

        # Simulation loop
        try:
            window_size = 10
            buffer = []
            risk_buffer = []

            for i, (_, row) in enumerate(data_slice.iterrows()):
                # Add to buffer
                buffer.append(row)

                # Process in windows
                if len(buffer) >= window_size:
                    # Create mini-batch
                    mini_batch = pd.DataFrame(buffer)

                    # Process through framework (simplified for simulation)
                    # In a real implementation, this would call process_new_data

                    # Get corresponding risk score from pre-computed data
                    mini_idx = start_idx + i - window_size // 2
                    if 0 <= mini_idx < len(risk_scores):
                        risk = risk_scores.iloc[mini_idx]['calibrated_risk']
                    else:
                        risk = 0.1 + 0.1 * np.random.random()

                    risk_buffer.append(risk)

                    # Remove oldest from buffer
                    buffer = buffer[-window_size:]

                # Update visualization every few points
                if i % max(1, int(5 / interval_seconds)) == 0 or i == len(data_slice) - 1:
                    # Add latest data
                    live_timestamps.append(row['timestamp'])
                    live_hrs.append(row['hr'])
                    live_spo2s.append(row['spo2'])

                    if len(risk_buffer) > 0:
                        live_risk_scores.append(risk_buffer[-1])
                    elif len(live_risk_scores) > 0:
                        live_risk_scores.append(live_risk_scores[-1])
                    else:
                        live_risk_scores.append(0)

                    # Keep data series to a reasonable length
                    max_points = 100
                    if len(live_timestamps) > max_points:
                        live_timestamps = live_timestamps[-max_points:]
                        live_hrs = live_hrs[-max_points:]
                        live_spo2s = live_spo2s[-max_points:]
                        live_risk_scores = live_risk_scores[-max_points:]

                    # Update plot
                    ax1.clear()
                    ax2.clear()

                    # Risk score plot
                    ax1.plot(live_timestamps, live_risk_scores, 'g-')
                    ax1.set_title('GET-Phen Real-Time Health Monitoring')
                    ax1.set_ylabel('Disease Risk Score')
                    ax1.set_ylim(0, 1)
                    ax1.axhline(y=0.5, color='r', linestyle='--', label='Alert Threshold')

                    # Add alert markers
                    alerts = [(ts, score) for ts, score in zip(live_timestamps, live_risk_scores) if score > 0.5]
                    if alerts:
                        alert_times, alert_scores = zip(*alerts)
                        ax1.scatter(alert_times, alert_scores, color='red', s=80, marker='o', label='Alert')

                    ax1.legend()
                    ax1.grid(True, alpha=0.3)

                    # Physio plot
                    ax2.plot(live_timestamps, live_hrs, 'r-', label='Heart Rate')
                    ax2.set_ylabel('Heart Rate (bpm)')
                    ax2.set_xlabel('Time')
                    ax2.legend(loc='upper left')

                    ax2b = ax2.twinx()
                    ax2b.plot(live_timestamps, live_spo2s, 'b-', alpha=0.7, label='SpO2')
                    ax2b.set_ylabel('SpO2 (%)')
                    ax2b.set_ylim(90, 100)
                    ax2b.legend(loc='upper right')

                    plt.tight_layout()
                    plt.draw()
                    plt.pause(0.001)

                # Pause to simulate real-time
                time.sleep(interval_seconds)

        except KeyboardInterrupt:
            print("\nSimulation stopped by user.")

        finally:
            plt.ioff()
            plt.show()

            # Final summary
            print("\n=== Real-Time Simulation Summary ===")
            print(f"Processed {len(data_slice)} data points")
            if len(live_risk_scores) > 0:
                print(f"Mean risk score: {np.mean(live_risk_scores):.4f}")
                print(f"Max risk score: {np.max(live_risk_scores):.4f}")
                print(f"Detected {sum(1 for score in live_risk_scores if score > 0.5)} potential issues")

    def generate_summary_report(self):
        """Generate a comprehensive summary report of all data and findings"""
        print("\n=== GET-Phen Summary Report ===")

        # User information
        print("\n== User Information ==")
        user = self.data['user_meta']
        print(f"Age: {user['age']}")
        print(f"Gender: {user['gender']}")
        print(f"BMI: {user['bmi']:.1f}")
        print(f"Pre-existing conditions: {', '.join(user['conditions']) if user['conditions'] else 'None'}")

        # Data summary
        print("\n== Data Summary ==")
        data = self.data['smartwatch_data']
        print(f"Monitoring period: {data['timestamp'].min()} to {data['timestamp'].max()}")
        print(f"Total duration: {(data['timestamp'].max() - data['timestamp'].min()).total_seconds() / 3600:.1f} hours")
        print(f"Data points: {len(data)}")

        # Risk assessment summary
        print("\n== Risk Assessment ==")
        risk_scores = self.data['risk_scores']
        print(f"Mean risk score: {risk_scores['calibrated_risk'].mean():.4f}")
        print(f"Potential disease events detected: {sum(risk_scores['alert'])}")

        # Health metrics summary
        print("\n== Health Metrics Summary ==")
        print(f"Mean heart rate: {data['hr'].mean():.1f} bpm")
        print(f"Mean SpO2: {data['spo2'].mean():.1f}%")
        print(f"Mean HRV: {data['hrv'].mean():.1f} ms")
        print(f"Mean stress level: {data['stress'].mean():.1f}")

        # Sleep summary
        sleep_stats = data['sleep_stage'].value_counts()
        total_sleep = sum(sleep_stats[i] for i in [1, 2, 3] if i in sleep_stats.index)
        total_time = len(data)

        print("\n== Sleep Summary ==")
        print(f"Total sleep time: {total_sleep} minutes ({total_sleep/60:.1f} hours)")
        print(f"Percentage of time asleep: {100*total_sleep/total_time:.1f}%")

        for stage, label in zip([0, 1, 2, 3], ['Awake', 'Light sleep', 'Deep sleep', 'REM sleep']):
            if stage in sleep_stats.index:
                print(f"{label}: {sleep_stats[stage]} minutes ({100*sleep_stats[stage]/total_time:.1f}%)")

        # High risk periods
        print("\n== High Risk Periods ==")
        high_risk = risk_scores[risk_scores['alert']].copy()

        if len(high_risk) > 0:
            # Group consecutive alerts
            high_risk['alert_group'] = (high_risk['alert'] != high_risk['alert'].shift()).cumsum()
            alert_groups = high_risk.groupby('alert_group')

            for i, (_, group) in enumerate(alert_groups):
                start_time = pd.to_datetime(group['timestamp'].min())
                end_time = pd.to_datetime(group['timestamp'].max())
                duration = (end_time - start_time).total_seconds() / 60
                max_risk = group['calibrated_risk'].max()

                print(f"Alert {i+1}: {start_time.strftime('%Y-%m-%d %H:%M')} - {end_time.strftime('%H:%M')}")
                print(f"  Duration: {duration:.1f} minutes")
                print(f"  Maximum risk score: {max_risk:.4f}")
        else:
            print("No high risk periods detected.")

        # Save report to file
        with open('getphen_report.txt', 'w') as f:
            f.write("GET-Phen Framework Summary Report\n")
            f.write("================================\n\n")

            f.write("User Information\n")
            f.write("---------------\n")
            f.write(f"Age: {user['age']}\n")
            f.write(f"Gender: {user['gender']}\n")
            f.write(f"BMI: {user['bmi']:.1f}\n")
            f.write(f"Pre-existing conditions: {', '.join(user['conditions']) if user['conditions'] else 'None'}\n\n")

            # Add more sections similar to the printed output

        print("\nFull report saved to 'getphen_report.txt'")

# Create GET-Phen framework
getphen = GETPhen(load_pretrained=True)

# Run demo
demo_results = getphen.run_end_to_end_demo(demo_length_hours=24, real_time=False)

# Generate summary report
getphen.generate_summary_report()

# For interactive demonstration, uncomment:
# getphen.simulate_real_time_monitoring(duration_minutes=10, interval_seconds=0.5)

print("\n=== GET-Phen Framework Demo Completed ===")
print("All modules have been successfully integrated and demonstrated.")

In [None]:
import os
import time
import sys
from tqdm.notebook import tqdm
from IPython.display import clear_output, HTML, display
import warnings
warnings.filterwarnings('ignore')

def print_header():
    """Print a fancy header for GET-Phen"""
    header = """
    ██████╗ ███████╗████████╗   ██████╗ ██╗  ██╗███████╗███╗   ██╗
    ██╔════╝ ██╔════╝╚══██╔══╝   ██╔══██╗██║  ██║██╔════╝████╗  ██║
    ██║  ███╗█████╗     ██║█████╗██████╔╝███████║█████╗  ██╔██╗ ██║
    ██║   ██║██╔══╝     ██║╚════╝██╔═══╝ ██╔══██║██╔══╝  ██║╚██╗██║
    ╚██████╔╝███████╗   ██║      ██║     ██║  ██║███████╗██║ ╚████║
    ╚═════╝ ╚══════╝   ╚═╝      ╚═╝     ╚═╝  ╚═╝╚══════╝╚═╝  ╚═══╝

    Graph-Enabled Temporal Phenotyping Framework
    ------------------------------------------
    A framework for disease detection from smartwatch data
    """
    print(header)

def check_dependencies():
    """Check if all dependencies are installed"""
    print("Checking dependencies...")

    required_packages = [
        'torch', 'torch_geometric', 'transformers', 'networkx',
        'numpy', 'pandas', 'scikit-learn', 'matplotlib', 'seaborn',
        'plotly', 'tqdm'
    ]

    missing_packages = []

    for package in required_packages:
        try:
            __import__(package)
            print(f"✓ {package}")
        except ImportError:
            missing_packages.append(package)
            print(f"✗ {package}")

    if missing_packages:
        print("\nInstalling missing packages...")
        import subprocess

        for package in missing_packages:
            print(f"Installing {package}...")
            if package == 'torch_geometric':
                # Special installation for PyG
                subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'torch-scatter', 'torch-sparse', 'torch-cluster', 'torch-geometric'])
            else:
                subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])

        print("\nAll dependencies installed!")
    else:
        print("\nAll dependencies are already installed!")

def run_module(module_number, module_name, file_path):
    """Run a specific GET-Phen module"""
    print(f"\n=== Running Module {module_number}: {module_name} ===")

    # Check if file exists
    if not os.path.exists(file_path):
        print(f"Error: File {file_path} not found!")
        return False

    try:
        # Run the module
        start_time = time.time()

        # Execute module
        exec(open(file_path).read())

        end_time = time.time()
        duration = end_time - start_time

        print(f"\n✓ Module {module_number} completed successfully! ({duration:.2f} seconds)")
        return True
    except Exception as e:
        print(f"\n✗ Error running module {module_number}: {e}")
        return False

def run_pipeline(start_module=0, end_module=10):
    """Run the full GET-Phen pipeline"""
    modules = [
        (0, "Environment Setup", "setup_environment.py"),
        (1, "Data Ingestion", "data_ingestion.py"),
        (2, "Self-Supervised Pretraining Model", "self_supervised_model.py"),
        (3, "Generate Embeddings", "generate_embeddings.py"),
        (4, "Graph Construction", "graph_construction.py"),
        (5, "Digital Twin Augmentation", "digital_twin.py"),
        (6, "Disease-Motif GNN", "disease_motif_gnn.py"),
        (7, "Online Inference & Sliding Window", "online_inference.py"),
        (8, "Bayesian Updating", "bayesian_updating.py"),
        (9, "Visualization and Monitoring", "visualization_monitoring.py"),
        (10, "Final Integration", "final_integration.py")
    ]

    # Filter modules to run
    modules_to_run = [m for m in modules if start_module <= m[0] <= end_module]

    print(f"Running {len(modules_to_run)} modules...\n")

    # Run each module
    results = []

    for module_num, module_name, file_path in modules_to_run:
        success = run_module(module_num, module_name, file_path)
        results.append((module_num, module_name, success))

    # Print summary
    print("\n=== Pipeline Execution Summary ===")

    for module_num, module_name, success in results:
        status = "✓" if success else "✗"
        print(f"{status} Module {module_num}: {module_name}")

    # Check if all modules succeeded
    all_success = all(success for _, _, success in results)

    if all_success:
        print("\n✓ All modules completed successfully!")
    else:
        print("\n✗ Some modules failed. Check the errors above.")

def display_menu():
    """Display the interactive menu"""
    print_header()

    print("\nWelcome to the GET-Phen Framework!")
    print("This launcher helps you run the various components of the framework.\n")

    menu = """
    === Main Menu ===
    1. Run Full Pipeline
    2. Run Specific Module
    3. Run Demo (requires all modules to be run first)
    4. Check Dependencies
    5. Exit
    """

    print(menu)

    choice = input("Enter your choice (1-5): ")

    if choice == '1':
        clear_output()
        print_header()
        print("\n=== Running Full Pipeline ===")
        check_dependencies()
        run_pipeline()
    elif choice == '2':
        clear_output()
        print_header()

        modules = [
            (0, "Environment Setup"),
            (1, "Data Ingestion"),
            (2, "Self-Supervised Pretraining Model"),
            (3, "Generate Embeddings"),
            (4, "Graph Construction"),
            (5, "Digital Twin Augmentation"),
            (6, "Disease-Motif GNN"),
            (7, "Online Inference & Sliding Window"),
            (8, "Bayesian Updating"),
            (9, "Visualization and Monitoring"),
            (10, "Final Integration")
        ]

        print("\n=== Available Modules ===")
        for num, name in modules:
            print(f"{num}. {name}")

        module = input("\nEnter module number to run: ")
        try:
            module_num = int(module)
            if 0 <= module_num <= 10:
                module_path = f"{modules[module_num][1].lower().replace(' ', '_').replace('-', '_')}.py"
                clear_output()
                print_header()
                check_dependencies()
                run_module(module_num, modules[module_num][1], module_path)
            else:
                print("Invalid module number!")
        except ValueError:
            print("Please enter a valid number!")
    elif choice == '3':
        clear_output()
        print_header()
        print("\n=== Running GET-Phen Demo ===")

        # Check if final_integration.py exists
        if not os.path.exists('final_integration.py'):
            print("Error: final_integration.py not found! Please run the full pipeline first.")
        else:
            try:
                from final_integration import GETPhen

                print("Creating GET-Phen framework...")
                getphen = GETPhen(load_pretrained=True)

                print("\nChoose demo type:")
                print("1. Standard Demo")
                print("2. Real-time Simulation")

                demo_choice = input("Enter choice (1-2): ")

                if demo_choice == '1':
                    # Run standard demo
                    getphen.run_end_to_end_demo(demo_length_hours=24, real_time=False)
                elif demo_choice == '2':
                    # Run real-time simulation
                    getphen.simulate_real_time_monitoring(duration_minutes=10, interval_seconds=0.5)
                else:
                    print("Invalid choice!")
            except Exception as e:
                print(f"Error running demo: {e}")
                print("Please make sure all modules have been run successfully first.")
    elif choice == '4':
        clear_output()
        print_header()
        check_dependencies()
    elif choice == '5':
        print("\nExiting GET-Phen Framework. Goodbye!")
        return
    else:
        print("Invalid choice, please try again!")

    # Return to menu after action completes
    input("\nPress Enter to return to the main menu...")
    clear_output()
    display_menu()

# Run the menu
if __name__ == "__main__":
    display_menu()