In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import pandas as pd
from enum import Enum

# --- small‐config for quick debugging ---
class SmallConfig:
    def __init__(self):
        self.d_model = 64
        self.n_layers = 2
        self.d_state = 16
        self.d_conv = 3
        self.expand = 2
        self.dt_rank = 4
        self.d_inner = self.expand * self.d_model
        self.bias = False
        self.conv_bias = True
        self.dropout = 0.1
        self.max_sequence_length = 256
        self.safety_focus = True
        self.crisis_detection = True
        self.bidirectional = False
        self.num_hospitals = 6
        self.num_crisis_levels = 5
        self.num_acuity_levels = 5

# --- full HamburgMambaConfig inherits SmallConfig ---
class HamburgMambaConfig(SmallConfig):
    def __init__(self):
        super().__init__()
        self.d_model = 512
        self.n_layers = 12
        # … rest of your original full config …
        if self.dt_rank == "auto":
            self.dt_rank = math.ceil(self.d_model / 16)
        self.d_inner = int(self.expand * self.d_model)

# --- your ClinicalEncoder and HybridSafetyModel classes go here ---
# (Paste the exact definitions you already have; no emojis in code!)
# e.g. class HamburgClinicalEncoder(nn.Module): … 
#      class HybridSafetyModel(nn.Module): …
# and your HamburgGroundTruthGenerator + make_real_hamburg_df()


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import pandas as pd
from enum import Enum


# ----------------------------
# Simplified clinical encoder + hybrid model
# ----------------------------
class HamburgClinicalEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.vital_encoder = nn.Linear(7, 32)
        self.demographic_encoder = nn.Linear(2, 16)
        self.hospital_embedding = nn.Embedding(config.num_hospitals + 1, 16)
        self.acuity_embedding   = nn.Embedding(config.num_acuity_levels + 1, 16)
        self.crisis_embedding   = nn.Embedding(config.num_crisis_levels + 1, 16)
        self.crew_state_encoder     = nn.Linear(4, 32)
        self.system_stress_encoder  = nn.Linear(3, 16)
        self.safety_encoder         = nn.Linear(3, 16)
        self.time_encoder           = nn.Linear(3, 16)
        self.environment_encoder    = nn.Linear(2, 16)
        self.complaint_embedding    = nn.Embedding(20, 16)

        # Compute total input dim
        total = 32 + 16 + 16 + 16 + 16 + 32 + 16 + 16 + 16 + 16 + 16
        self.final_projection = nn.Linear(total, config.d_model)

    def forward(self, patient_data):
        # vitals
        vitals = torch.stack([
            patient_data['vital_heart_rate'],
            patient_data['vital_bp_systolic'],
            patient_data['vital_bp_diastolic'],
            patient_data['vital_respiratory_rate'],
            patient_data['vital_oxygen_saturation'],
            patient_data['vital_temperature'],
            patient_data['vital_gcs'],
        ], dim=-1)
        vfeat = self.vital_encoder(vitals)

        # demographics
        demo = torch.stack([patient_data['age']/100.0, patient_data['gender']], dim=-1)
        dfeat = self.demographic_encoder(demo)

        # embeddings
        hfeat = self.hospital_embedding(patient_data['hospital_destination'])
        afeat = self.acuity_embedding   (patient_data['acuity_level'])
        cfeat = self.crisis_embedding   (patient_data['system_crisis_level'])

        # crew
        crew = torch.stack([
            patient_data['crew_calls_today']/25.0,
            patient_data['crew_hours_on_shift']/12.0,
            patient_data['crew_fatigue_level']/2.0,
            patient_data['burnout_risk_score'],
        ], dim=-1)
        crew_feat = self.crew_state_encoder(crew)

        # system stress
        stress = torch.stack([
            patient_data['available_crews']/80.0,
            patient_data['calls_waiting']/10.0,
            patient_data['response_delay_minutes']/30.0,
        ], dim=-1)
        sfeat = self.system_stress_encoder(stress)

        # safety
        safe = torch.stack([
            patient_data['response_delay_minutes']/30.0,
            patient_data['handoff_quality_score'],
            patient_data['documentation_completeness'],
        ], dim=-1)
        safe_feat = self.safety_encoder(safe)

        # time
        time = torch.stack([
            patient_data['hour']/24.0,
            patient_data['day_of_week']/7.0,
            patient_data['shift_change_stress'],
        ], dim=-1)
        tfeat = self.time_encoder(time)

        # environment
        env = torch.stack([
            patient_data['weather_impact'],
            patient_data['tourism_factor'],
        ], dim=-1)
        efeat = self.environment_encoder(env)

        # complaint
        comp_feat = self.complaint_embedding(patient_data['chief_complaint_encoded'])

        # concatenate
        all_feats = torch.cat([
            vfeat, dfeat, hfeat, afeat, cfeat,
            crew_feat, sfeat, safe_feat, tfeat, efeat, comp_feat
        ], dim=-1)
        return self.final_projection(all_feats)

class HybridSafetyModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder = HamburgClinicalEncoder(config)
        self.backbone = nn.GRU(input_size=config.d_model,
                               hidden_size=config.d_model,
                               batch_first=True)
        self.crisis_head = nn.Sequential(
            nn.Linear(config.d_model, 64), nn.ReLU(), nn.Dropout(0.1),
            nn.Linear(64, config.num_crisis_levels)
        )
        self.staff_head  = nn.Sequential(
            nn.Linear(config.d_model, 64), nn.ReLU(),
            nn.Linear(64, 3)
        )
        self.safety_head = nn.Sequential(
            nn.Linear(config.d_model, 64), nn.ReLU(),
            nn.Linear(64, 3)
        )

    def forward(self, patient_data):
        x = self.encoder(patient_data)     # (B, L, d_model)
        x, _ = self.backbone(x)            # (B, L, d_model)
        last = x[:, -1, :]                 # (B, d_model)
        crisis_logits = self.crisis_head(last)
        staff_out     = self.staff_head(last)
        safety_out    = self.safety_head(last)

        return {
            "crisis_logits": crisis_logits,
            "burnout":       torch.sigmoid(staff_out[:, 0]),
            "fatigue":       torch.sigmoid(staff_out[:, 1]) * 2.0,
            "retention":     torch.sigmoid(staff_out[:, 2]),
            "response_delay":F.relu(safety_out[:, 0]),
            "quality":       torch.sigmoid(safety_out[:, 1]),
            "adverse_risk":  torch.sigmoid(safety_out[:, 2]),
        }

# ----------------------------
# Simplified ground-truth generator
# ----------------------------
class HamburgCrisisLevel(Enum):
    NORMAL = 0; STRESSED = 1; CRISIS = 2; BREAKDOWN = 3; STAFF_EXODUS = 4

class HamburgGroundTruthGenerator:
    def __init__(self):
        self.total_ambulances = 80
    def generate_complete_ground_truth(self,
                                       crew_calls_today:int,
                                       crew_hours_worked:float,
                                       system_calls_today:int,
                                       available_crews:int,
                                       calls_waiting:int,
                                       hour:int):
        if crew_calls_today>=20: lvl=HamburgCrisisLevel.BREAKDOWN
        elif crew_calls_today>=15: lvl=HamburgCrisisLevel.CRISIS
        elif crew_calls_today>=10: lvl=HamburgCrisisLevel.STRESSED
        else: lvl=HamburgCrisisLevel.NORMAL
        burnout = min(1.0, crew_calls_today/25 + crew_hours_worked/24)
        retention = max(0.1, 1.0 - 0.5*burnout)
        delay = 8.0*(1 + 0.3*lvl.value)
        handoff = max(0.5, 1.0 - 0.3*burnout)
        doc    = max(0.5, 1.0 - 0.2*burnout)
        adv    = min(0.8, 0.5*burnout)
        return {
          'crisis_level': lvl.value,
          'burnout_risk': burnout,
          'staff_retention_probability': retention,
          'response_delay_minutes': delay,
          'handoff_quality_score': handoff,
          'documentation_completeness': doc,
          'clinical_care_quality': 1.0 - adv,
          'adverse_event_risk': adv,
          'crew_fatigue_level': min(2.0, crew_calls_today/12 + crew_hours_worked/12)
        }

# ----------------------------
# Build a “real” hamburg_df
# ----------------------------
def make_real_hamburg_df(num_patients=100, seq_len=16, seed=42):
    rng = np.random.default_rng(seed)
    gen = HamburgGroundTruthGenerator()
    records = []
    for pid in range(num_patients):
        c0 = int(rng.integers(5,25))
        h0 = float(rng.uniform(6,12))
        s0 = int(rng.integers(800,1600))
        a0 = int(rng.integers(20,80))
        w0 = int(rng.integers(0,12))
        hr0 = int(rng.integers(0,23))
        for t in range(seq_len):
            c = max(0, c0 + int(rng.integers(-1,2)))
            h = min(16, max(6, h0 + rng.normal(0,0.5)))
            s = max(500, s0 + int(rng.integers(-50,50)))
            a = max(5, min(80, a0 + int(rng.integers(-3,3))))
            w = max(0, w0 + int(rng.integers(-1,2)))
            hr = (hr0 + t) % 24
            gt = gen.generate_complete_ground_truth(
                crew_calls_today=c,
                crew_hours_worked=h,
                system_calls_today=s,
                available_crews=a,
                calls_waiting=w,
                hour=hr
            )
            records.append({
                'patient_id': pid,
                'time_step': t,
                'vital_heart_rate': float(rng.normal(80,10)),
                'vital_bp_systolic': float(rng.normal(120,15)),
                'vital_bp_diastolic': float(rng.normal(80,10)),
                'vital_respiratory_rate': float(rng.normal(16,2)),
                'vital_oxygen_saturation': float(np.clip(rng.normal(98,1),90,100)),
                'vital_temperature': float(rng.normal(37,0.5)),
                'vital_gcs': float(np.clip(rng.normal(15,1),3,15)),
                'age': int(rng.integers(20,90)),
                'gender': int(rng.integers(0,2)),
                'hospital_destination': int(rng.integers(0,6)),
                'acuity_level': int(rng.integers(1,6)),
                'system_crisis_level': gt['crisis_level'],
                'crew_calls_today': c,
                'crew_hours_on_shift': h,
                'crew_fatigue_level': gt['crew_fatigue_level'],
                'burnout_risk_score': gt['burnout_risk'],
                'available_crews': a,
                'calls_waiting': w,
                'response_delay_minutes': gt['response_delay_minutes'],
                'handoff_quality_score': gt['handoff_quality_score'],
                'documentation_completeness': gt['documentation_completeness'],
                'hour': hr,
                'day_of_week': int(rng.integers(0,6)),
                'shift_change_stress': float(rng.uniform(0,1)),
                'weather_impact': float(rng.uniform(0.8,1.2)),
                'tourism_factor': float(rng.uniform(0.9,1.1)),
                'chief_complaint_encoded': int(rng.integers(0,20)),
                'crisis_level': gt['crisis_level'],
                'burnout_risk': gt['burnout_risk'],
                'staff_retention_probability': gt['staff_retention_probability']
            })
    return pd.DataFrame.from_records(records)

# Instantiate and inspect
hamburg_df = make_real_hamburg_df(num_patients=100, seq_len=16)
print("hamburg_df shape:", hamburg_df.shape)
print(hamburg_df.head())


hamburg_df shape: (1600, 32)
   patient_id  time_step  vital_heart_rate  vital_bp_systolic  \
0           0          0         76.837574         119.747983   
1           0          1         78.454705         113.575083   
2           0          2         86.505928         131.148813   
3           0          3         73.611222         115.872866   
4           0          4         68.667128         106.208216   

   vital_bp_diastolic  vital_respiratory_rate  vital_oxygen_saturation  \
0           71.469561               17.758796                98.777792   
1           76.478664               17.064618                98.365444   
2           85.431543               14.668981                98.232161   
3           94.949413               14.268338                98.968278   
4           84.971607               16.284851                98.690485   

   vital_temperature  vital_gcs  age  ...  documentation_completeness  hour  \
0          37.033015  15.000000   73  ...               

In [30]:
#build a tiny synthetic batch to sanity‐check forward pass:
B, L = 2, 16
def rand(B,L): return torch.rand(B, L)
def randint(B,L,H): return torch.randint(0, H, (B, L))

batch = {
    'vital_heart_rate':       rand(B,L),
    'vital_bp_systolic':      rand(B,L),
    'vital_bp_diastolic':     rand(B,L),
    'vital_respiratory_rate': rand(B,L),
    'vital_oxygen_saturation':rand(B,L),
    'vital_temperature':      rand(B,L),
    'vital_gcs':              rand(B,L),
    'age':                    rand(B,L),
    'gender':                 rand(B,L),
    'hospital_destination':   randint(B,L,config.num_hospitals+1).long(),
    'acuity_level':           randint(B,L,config.num_acuity_levels+1).long(),
    'system_crisis_level':    randint(B,L,config.num_crisis_levels+1).long(),
    'crew_calls_today':       rand(B,L),
    'crew_hours_on_shift':    rand(B,L),
    'crew_fatigue_level':     rand(B,L),
    'burnout_risk_score':     rand(B,L),
    'available_crews':        rand(B,L),
    'calls_waiting':          rand(B,L),
    'response_delay_minutes': rand(B,L),
    'handoff_quality_score':  rand(B,L),
    'documentation_completeness': rand(B,L),
    'hour':                   rand(B,L),
    'day_of_week':            rand(B,L),
    'shift_change_stress':    rand(B,L),
    'weather_impact':         rand(B,L),
    'tourism_factor':         rand(B,L),
    'chief_complaint_encoded': randint(B,L,20).long()
}

# 4) run forward
with torch.no_grad():
    out = model(batch)

# 5) print shapes
print("✅ Forward pass shapes:")
for k,v in out.items():
    print(f"  {k:16s} → {tuple(v.shape)}")


✅ Forward pass shapes:
  crisis_logits    → (2, 5)
  burnout          → (2,)
  fatigue          → (2,)
  retention        → (2,)
  response_delay   → (2,)
  quality          → (2,)
  adverse_risk     → (2,)


In [26]:
# Cell 3: Prepare PyTorch Dataset + DataLoader for your real Hamburg data

import torch
from torch.utils.data import Dataset, DataLoader, random_split

class HamburgSafetyDataset(Dataset):
    def __init__(self, df, seq_len=16):
        self.df = df.reset_index(drop=True)
        self.seq_len = seq_len
        # these are all the keys your model.encoder.forward expects:
        self.features = [
            'vital_heart_rate','vital_bp_systolic','vital_bp_diastolic',
            'vital_respiratory_rate','vital_oxygen_saturation','vital_temperature',
            'vital_gcs','age','gender','hospital_destination','acuity_level',
            'system_crisis_level','crew_calls_today','crew_hours_on_shift',
            'crew_fatigue_level','burnout_risk_score','available_crews',
            'calls_waiting','response_delay_minutes','handoff_quality_score',
            'documentation_completeness','hour','day_of_week',
            'shift_change_stress','weather_impact','tourism_factor',
            'chief_complaint_encoded'
        ]
        # your supervision targets at last time-step:
        self.targets = ['crisis_level','burnout_risk','staff_retention_probability','response_delay_minutes']

    def __len__(self):
        return len(self.df) // self.seq_len

    def __getitem__(self, idx):
        chunk = self.df.iloc[idx*self.seq_len:(idx+1)*self.seq_len]
        sample = {}
        # pack inputs: Long for embeddings, Float for everything else
        for feat in self.features:
            arr = chunk[feat].values
            tensor = torch.tensor(arr, 
                                  dtype=torch.long if feat in 
                                      ('hospital_destination','acuity_level',
                                       'system_crisis_level','chief_complaint_encoded')
                                  else torch.float32)
            # shape: (seq_len,) → (1, seq_len) so model sees batch dim=1
            sample[feat] = tensor.unsqueeze(0)
        # extract last-step targets
        last = chunk.iloc[-1]
        target = {
            'crisis_level': torch.tensor(int(last.crisis_level), dtype=torch.long),
            'burnout_risk': torch.tensor(float(last.burnout_risk), dtype=torch.float32),
            'retention_probability': torch.tensor(float(last.staff_retention_probability), dtype=torch.float32),
            'response_delay_minutes': torch.tensor(float(last.response_delay_minutes), dtype=torch.float32),
        }
        return sample, target

# split into train/val
dataset = HamburgSafetyDataset(hamburg_df, seq_len=16)
val_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

# create loaders (batch of sequences)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False)

print(f"✔️  Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")


✔️  Train batches: 20, Val batches: 5
