
# پیاده‌سازی گام‌به‌گام MRTF (بر اساس متن روش پیشنهادی)

این نوت‌بوک یک پیاده‌سازی **قابل‌اجرا (Reference Implementation)** از چارچوب پیشنهادی **MRTF** برای تشخیص زودهنگام پارکینسون است.  
در این پیاده‌سازی، بخش‌های زیر مطابق دیاگرام و متن روش پیشنهادی ساخته می‌شوند:

1. **پیش‌پردازش سه مدالیته**: صدا، MRI، و سنسور پوشیدنی  
2. **استخراج ویژگی مدالیته‌محور**:
   - Voice Encoder: **CNN + BiLSTM**
   - MRI Encoder: **Vision Transformer (ViT)**
   - Sensor Encoder: **TCN + Temporal Transformer**
3. **همجوشی چندمدالیته** با **Cross-Attention Fusion Transformer (CAFT)** (سه‌جهته)
4. **لایه توضیح‌پذیری**: SHAP (تقریبی) + کانترفکتوال (gradient-based)
5. **RASL** (Reinforcement-Assisted Self-Learning): به‌روزرسانی سیاست/آستانه/وزن مدالیته‌ها با سیگنال پاداش
6. **آموزش مشترک**: زیان کل `L_total = L_CE + λ1 L_RL + λ2 L_reg`

> نکته مهم: چون در این محیط دیتاست واقعی در اختیار نیست، یک دیتاست شبیه‌سازی‌شده (Dummy Dataset) ساخته می‌شود تا کل pipeline از ابتدا تا انتها قابل اجرا باشد.  
> کافیست بعداً جای Dummy Dataset را با دیتاست واقعی خودتان (Voice/MRI/Sensor) جایگزین کنید.

---

## پیش‌نیازها
- PyTorch
- torchvision (برای ViT)
- numpy
- (اختیاری) shap — اگر نصب نبود، ما fallback ساده داریم.


## 1) ایمپورت کتابخانه‌ها

In [None]:

import os
import math
import numpy as np
from dataclasses import dataclass
from typing import Dict, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    import torchvision
    from torchvision.models import vit_b_16, ViT_B_16_Weights
    _has_torchvision = True
except Exception as e:
    _has_torchvision = False
    print("torchvision/ViT not available -> fallback to simple CNN encoder for MRI. Error:", e)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


## 2) تنظیمات کلی و هایپرپارامترها

In [None]:

@dataclass
class Config:
    # Embedding dim (طبق متن: d=512)
    d: int = 512

    # Voice spectrogram size (dummy)
    voice_freq_bins: int = 128
    voice_time_steps: int = 256

    # MRI image size (طبق متن: 224x224)
    mri_size: int = 224

    # Sensor window (dummy)
    sensor_channels: int = 6   # accel(3)+gyro(3)
    sensor_len: int = 200      # ~ 5s * 40Hz (مثال)

    # Training
    batch_size: int = 8
    lr: float = 1e-4
    weight_decay: float = 1e-4
    epochs: int = 3  # برای دمو کم گذاشته شده
    grad_clip: float = 1.0

    # Loss weights
    lambda_rl: float = 0.5
    lambda_reg: float = 1e-4

    # RL reward weights (Eq.20)
    alpha1_acc: float = 1.0
    alpha2_ecs: float = 0.5
    alpha3_err: float = 1.0

    gamma: float = 0.95  # discount

cfg = Config()
cfg


## 3) دیتاست نمونه (Dummy) — برای اجرای کامل Pipeline

In [None]:

class DummyParkinsonMultimodalDataset(Dataset):
    """ 
    دیتاست شبیه‌سازی‌شده:
    - Voice: spectrogram tensor [1, F, T]
    - MRI: image tensor [3, 224, 224]
    - Sensor: timeseries [C, L]
    - y: label {0,1}
    """
    def __init__(self, n=200, cfg: Config = cfg, seed=42):
        rng = np.random.default_rng(seed)
        self.n = n
        self.cfg = cfg
        self.y = rng.integers(0, 2, size=n).astype(np.int64)

        # برای اینکه داده کمی "قابل یادگیری" باشد، یک بایاس کوچک روی کلاس 1 ایجاد می‌کنیم
        self.voice = rng.normal(0, 1, size=(n, 1, cfg.voice_freq_bins, cfg.voice_time_steps)).astype(np.float32)
        self.mri = rng.normal(0, 1, size=(n, 3, cfg.mri_size, cfg.mri_size)).astype(np.float32)
        self.sensor = rng.normal(0, 1, size=(n, cfg.sensor_channels, cfg.sensor_len)).astype(np.float32)

        pos_idx = self.y == 1
        self.voice[pos_idx] += 0.15
        self.mri[pos_idx] += 0.05
        self.sensor[pos_idx] += 0.10

    def __len__(self): 
        return self.n

    def __getitem__(self, idx):
        return {
            "voice": torch.from_numpy(self.voice[idx]),
            "mri": torch.from_numpy(self.mri[idx]),
            "sensor": torch.from_numpy(self.sensor[idx]),
            "y": torch.tensor(self.y[idx], dtype=torch.float32)
        }

ds = DummyParkinsonMultimodalDataset(n=120, cfg=cfg)
dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=True)
batch = next(iter(dl))
{k: v.shape for k,v in batch.items()}


## 4) ماژول‌های پیش‌پردازش (اسکلت قابل جایگزینی با پیاده‌سازی واقعی)

In [None]:

# در این نوت‌بوک، پیش‌پردازش واقعی (MFCC/STFT, skull stripping, wavelet denoise, ...) را
# به صورت اسکلت/Placeholder می‌گذاریم چون به دیتای خام و کتابخانه‌های تخصصی نیاز دارد.
# شما کافیست در پروژه‌تان این توابع را با پیاده‌سازی واقعی جایگزین کنید.

def preprocess_voice(voice_wave_or_spec: torch.Tensor) -> torch.Tensor:
    """فرض می‌کنیم ورودی spectrogram است. در عمل: STFT/MFCC/نرمال‌سازی."""
    x = voice_wave_or_spec
    return (x - x.mean(dim=(-1,-2), keepdim=True)) / (x.std(dim=(-1,-2), keepdim=True) + 1e-6)

def preprocess_mri(mri_img: torch.Tensor) -> torch.Tensor:
    """در عمل: skull stripping, bias correction, resize, intensity norm."""
    x = mri_img
    return (x - x.mean(dim=(-1,-2), keepdim=True)) / (x.std(dim=(-1,-2), keepdim=True) + 1e-6)

def preprocess_sensor(sensor_ts: torch.Tensor) -> torch.Tensor:
    """در عمل: windowing, wavelet denoise, FFT features, normalization."""
    x = sensor_ts
    return (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + 1e-6)


## 5) Voice Encoder: CNN + BiLSTM (Section 3.2.1)

In [None]:

class VoiceCNNBiLSTM(nn.Module):
    def __init__(self, d=512, voice_freq_bins=128):
        super().__init__()
        self.voice_freq_bins = voice_freq_bins
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.proj = nn.Linear(128 * (voice_freq_bins//4), 256)
        self.lstm = nn.LSTM(
            input_size=256,
            hidden_size=256,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        self.out = nn.Linear(512, d)  # 256*2 -> d

    def forward(self, x):  # x: [B,1,F,T]
        x = self.cnn(x)     # [B,128,F/4,T/4]
        B, C, F, T = x.shape
        x = x.permute(0, 3, 1, 2).contiguous()  # [B,T,C,F]
        x = x.view(B, T, C*F)                   # [B,T,128*(F)]
        x = self.proj(x)                        # [B,T,256]
        out, _ = self.lstm(x)                   # [B,T,512]
        h_last = out[:, -1, :]                  # [B,512]
        emb = self.out(h_last)                  # [B,d]
        return emb


## 6) MRI Encoder: ViT (Section 3.2.2) + fallback

In [None]:

class MRIEncoderViT(nn.Module):
    def __init__(self, d=512):
        super().__init__()
        if _has_torchvision:
            weights = ViT_B_16_Weights.DEFAULT
            self.vit = vit_b_16(weights=weights)
            in_features = self.vit.heads.head.in_features
            self.vit.heads = nn.Identity()
            self.proj = nn.Linear(in_features, d)
        else:
            self.vit = nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
                nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1,1))
            )
            self.proj = nn.Linear(128, d)

    def forward(self, x):  # [B,3,224,224]
        if _has_torchvision:
            feats = self.vit(x)
            return self.proj(feats)
        else:
            feats = self.vit(x).flatten(1)
            return self.proj(feats)


## 7) Sensor Encoder: TCN + Temporal Transformer (Section 3.2.3)

In [None]:

class TCNBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=3, dilation=1, dropout=0.1):
        super().__init__()
        pad = (kernel - 1) * dilation
        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=kernel, dilation=dilation, padding=pad)
        self.dropout = nn.Dropout(dropout)
        self.res = nn.Conv1d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity()

    def forward(self, x):  # [B,C,L]
        y = self.conv(x)
        y = y[..., :x.shape[-1]]
        y = F.relu(y)
        y = self.dropout(y)
        return y + self.res(x)

class SensorTemporalEncoder(nn.Module):
    def __init__(self, d=512, sensor_len=200, sensor_channels=6, n_heads=8, n_layers=2):
        super().__init__()
        self.tcn = nn.Sequential(
            TCNBlock(sensor_channels, 64, dilation=1),
            TCNBlock(64, 128, dilation=2),
            TCNBlock(128, 256, dilation=4),
        )
        self.pos_emb = nn.Parameter(torch.randn(1, sensor_len, 256) * 0.01)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=256, nhead=n_heads, dim_feedforward=512, batch_first=True
        )
        self.tr = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.out = nn.Linear(256, d)

    def forward(self, x):  # [B,C,L]
        z = self.tcn(x)            # [B,256,L]
        z = z.permute(0,2,1)       # [B,L,256]
        z = z + self.pos_emb[:, :z.shape[1], :]
        z = self.tr(z)             # [B,L,256]
        z = z.permute(0,2,1)       # [B,256,L]
        pooled = self.pool(z).squeeze(-1)  # [B,256]
        return self.out(pooled)            # [B,d]


## 8) CAFT: Cross-Attention Fusion Transformer (Section 3.3)

In [None]:

class CrossAttention(nn.Module):
    def __init__(self, d=512, n_heads=8, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=d, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.ln = nn.LayerNorm(d)
        self.ff = nn.Sequential(nn.Linear(d, 2*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(2*d, d))
        self.ln2 = nn.LayerNorm(d)

    def forward(self, q, k, v): 
        attn_out, attn_w = self.mha(q, k, v, need_weights=True, average_attn_weights=True)
        x = self.ln(q + attn_out)
        x2 = self.ln2(x + self.ff(x))
        return x2, attn_w

class CAFT(nn.Module):
    def __init__(self, d=512, n_heads=8, n_layers=2, dropout=0.1):
        super().__init__()
        self.cross_vm = CrossAttention(d, n_heads, dropout)
        self.cross_vs = CrossAttention(d, n_heads, dropout)
        self.cross_mv = CrossAttention(d, n_heads, dropout)
        self.cross_ms = CrossAttention(d, n_heads, dropout)
        self.cross_sv = CrossAttention(d, n_heads, dropout)
        self.cross_sm = CrossAttention(d, n_heads, dropout)

        self.omega = nn.Parameter(torch.ones(6))

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d, nhead=n_heads, dim_feedforward=4*d, dropout=dropout, batch_first=True
        )
        self.tr = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.ln = nn.LayerNorm(d)

    def forward(self, Ev, Em, Es):
        Ev1 = Ev.unsqueeze(1); Em1 = Em.unsqueeze(1); Es1 = Es.unsqueeze(1)

        v_m, w_vm = self.cross_vm(Ev1, Em1, Em1)
        v_s, w_vs = self.cross_vs(Ev1, Es1, Es1)

        m_v, w_mv = self.cross_mv(Em1, Ev1, Ev1)
        m_s, w_ms = self.cross_ms(Em1, Es1, Es1)

        s_v, w_sv = self.cross_sv(Es1, Ev1, Ev1)
        s_m, w_sm = self.cross_sm(Es1, Em1, Em1)

        ome = F.softmax(self.omega, dim=0)

        Ev_hat = ome[0]*v_m + ome[1]*v_s
        Em_hat = ome[2]*m_v + ome[3]*m_s
        Es_hat = ome[4]*s_v + ome[5]*s_m

        tokens = torch.cat([Ev_hat, Em_hat, Es_hat], dim=1)  # [B,3,d]
        H = self.tr(tokens)
        H = self.ln(H)

        fused = H.mean(dim=1)  # [B,d]
        attn_info = {
            "w_vm": w_vm.detach(), "w_vs": w_vs.detach(),
            "w_mv": w_mv.detach(), "w_ms": w_ms.detach(),
            "w_sv": w_sv.detach(), "w_sm": w_sm.detach(),
            "omega": ome.detach()
        }
        return fused, attn_info


## 9) Classifier + Fusion Regularization (Eq.16-17)

In [None]:

class MRTFClassifierHead(nn.Module):
    def __init__(self, d=512, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, d//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d//2, 1)
        )

    def forward(self, H):
        return torch.sigmoid(self.net(H)).squeeze(-1)

def fusion_regularization(Ev, Em, Es):
    return (Ev-Em).pow(2).mean() + (Ev-Es).pow(2).mean() + (Em-Es).pow(2).mean()


## 10) Explainability Layer: SHAP (تقریبی) + Counterfactual + ECS (Section 3.5)

In [None]:

def compute_shap_like_attribution(model_fn, x: torch.Tensor, baseline: Optional[torch.Tensor]=None):
    """تقریب سبک از SHAP با Integrated Gradients روی فضای فیوژن."""
    x = x.detach()
    if baseline is None:
        baseline = torch.zeros_like(x)
    steps = 16
    alphas = torch.linspace(0, 1, steps, device=x.device).view(steps, 1, 1)
    x_interp = baseline.unsqueeze(0) + alphas * (x.unsqueeze(0) - baseline.unsqueeze(0))  # [steps,B,d]
    x_interp.requires_grad_(True)

    y = model_fn(x_interp.view(-1, x.shape[-1]))  # [steps*B]
    y_sum = y.sum()
    grads = torch.autograd.grad(y_sum, x_interp, retain_graph=False, create_graph=False)[0]  # [steps,B,d]
    avg_grads = grads.mean(dim=0)  # [B,d]
    phi = (x - baseline) * avg_grads
    return phi.detach()

def find_counterfactual(model_fn, x: torch.Tensor, target_flip: float=0.5, steps=50, lr=1e-1, l2_weight=0.01):
    x0 = x.detach()
    x_cf = x0.clone().detach().requires_grad_(True)
    opt = torch.optim.SGD([x_cf], lr=lr, momentum=0.0)
    with torch.enable_grad():
        for _ in range(steps):
            y = model_fn(x_cf)
            loss_flip = (y - target_flip).pow(2).mean()
            loss_l2 = l2_weight * (x_cf - x0).pow(2).mean()
            loss = loss_flip + loss_l2
            opt.zero_grad()
            loss.backward()
            opt.step()
    return x_cf.detach()

def compute_ecs(phi: torch.Tensor, tau: float=0.01, clinical_mask: Optional[torch.Tensor]=None):
    if clinical_mask is None:
        clinical_mask = torch.ones_like(phi)
    important = (phi.abs() > tau).float()
    ecs = (important * clinical_mask).mean(dim=-1)
    return ecs


## 11) RASL (نسخه عملیاتی ساده)

In [None]:

class RASLModule(nn.Module):
    def __init__(self, init_threshold=0.5, beta=0.9):
        super().__init__()
        self.threshold_logit = nn.Parameter(torch.tensor([math.log(init_threshold/(1-init_threshold))], dtype=torch.float32))
        self.register_buffer("reward_baseline", torch.tensor(0.0))
        self.beta = beta

    def threshold(self):
        return torch.sigmoid(self.threshold_logit).item()

    @torch.no_grad()
    def update_baseline(self, r: torch.Tensor):
        r_mean = r.mean().detach()
        self.reward_baseline = self.beta * self.reward_baseline + (1 - self.beta) * r_mean

    def rl_loss(self, rewards: torch.Tensor):
        adv = rewards - self.reward_baseline
        return -(adv.mean())


## 12) مدل کامل MRTF

In [None]:

class MRTFModel(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.voice_enc = VoiceCNNBiLSTM(d=cfg.d, voice_freq_bins=cfg.voice_freq_bins)
        self.mri_enc = MRIEncoderViT(d=cfg.d)
        self.sensor_enc = SensorTemporalEncoder(d=cfg.d, sensor_len=cfg.sensor_len, sensor_channels=cfg.sensor_channels)
        self.caft = CAFT(d=cfg.d)
        self.clf = MRTFClassifierHead(d=cfg.d, dropout=0.3)
        self.rasl = RASLModule(init_threshold=0.5)

        self.ln_v = nn.LayerNorm(cfg.d)
        self.ln_m = nn.LayerNorm(cfg.d)
        self.ln_s = nn.LayerNorm(cfg.d)

    def forward(self, voice, mri, sensor):
        voice = preprocess_voice(voice)
        mri = preprocess_mri(mri)
        sensor = preprocess_sensor(sensor)

        Ev = self.ln_v(self.voice_enc(voice))
        Em = self.ln_m(self.mri_enc(mri))
        Es = self.ln_s(self.sensor_enc(sensor))

        H, attn = self.caft(Ev, Em, Es)
        yhat = self.clf(H)
        return yhat, (Ev, Em, Es, H, attn)

model = MRTFModel(cfg).to(device)
sum(p.numel() for p in model.parameters())/1e6


## 13) آموزش مشترک (L_total)

In [None]:

def binary_cross_entropy(yhat, y):
    eps = 1e-7
    yhat = torch.clamp(yhat, eps, 1-eps)
    return -(y*torch.log(yhat) + (1-y)*torch.log(1-yhat)).mean()

def model_fn_on_fused(model: MRTFModel, H_flat: torch.Tensor):
    return torch.sigmoid(model.clf.net(H_flat).squeeze(-1))

optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

def train_one_epoch(model, dl):
    model.train()
    total = {"loss":0.0, "ce":0.0, "fuse":0.0, "rl":0.0, "reward":0.0, "acc":0.0, "ecs":0.0}
    n_batches = 0

    for batch in dl:
        voice = batch["voice"].to(device)
        mri = batch["mri"].to(device)
        sensor = batch["sensor"].to(device)
        y = batch["y"].to(device)

        yhat, (Ev, Em, Es, H, attn) = model(voice, mri, sensor)

        # XAI
        H_detached = H.detach().requires_grad_(True)
        phi = compute_shap_like_attribution(lambda z: model_fn_on_fused(model, z), H_detached)
        ecs = compute_ecs(phi, tau=0.01)

        # (اختیاری) کانترفکتوال سبک برای 1 نمونه
        _ = find_counterfactual(lambda z: model_fn_on_fused(model, z), H_detached[:1], steps=10, lr=0.2)

        # reward
        thr = model.rasl.threshold()
        pred = (yhat >= thr).float()
        acc = (pred == y).float().mean()
        err = 1.0 - acc
        R = cfg.alpha1_acc*acc + cfg.alpha2_ecs*ecs.mean() - cfg.alpha3_err*err
        model.rasl.update_baseline(R.detach())

        # losses
        L_CE = binary_cross_entropy(yhat, y)
        L_fuse = fusion_regularization(Ev, Em, Es)
        L_RL = model.rasl.rl_loss(R.view(1))
        loss = L_CE + 0.1*L_fuse + cfg.lambda_rl*L_RL

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        optimizer.step()

        total["loss"] += loss.item()
        total["ce"] += L_CE.item()
        total["fuse"] += L_fuse.item()
        total["rl"] += L_RL.item()
        total["reward"] += float(R.detach().cpu())
        total["acc"] += float(acc.detach().cpu())
        total["ecs"] += float(ecs.mean().detach().cpu())
        n_batches += 1

    for k in total:
        total[k] /= max(1, n_batches)
    return total

# ساخت دیتالودر دمو
class DummyParkinsonMultimodalDataset(Dataset):
    def __init__(self, n=200, cfg: Config = cfg, seed=42):
        rng = np.random.default_rng(seed)
        self.n = n
        self.cfg = cfg
        self.y = rng.integers(0, 2, size=n).astype(np.int64)
        self.voice = rng.normal(0, 1, size=(n, 1, cfg.voice_freq_bins, cfg.voice_time_steps)).astype(np.float32)
        self.mri = rng.normal(0, 1, size=(n, 3, cfg.mri_size, cfg.mri_size)).astype(np.float32)
        self.sensor = rng.normal(0, 1, size=(n, cfg.sensor_channels, cfg.sensor_len)).astype(np.float32)
        pos_idx = self.y == 1
        self.voice[pos_idx] += 0.15
        self.mri[pos_idx] += 0.05
        self.sensor[pos_idx] += 0.10

    def __len__(self): 
        return self.n

    def __getitem__(self, idx):
        return {
            "voice": torch.from_numpy(self.voice[idx]),
            "mri": torch.from_numpy(self.mri[idx]),
            "sensor": torch.from_numpy(self.sensor[idx]),
            "y": torch.tensor(self.y[idx], dtype=torch.float32)
        }

ds = DummyParkinsonMultimodalDataset(n=120, cfg=cfg)
dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=True)

for ep in range(cfg.epochs):
    stats = train_one_epoch(model, dl)
    print(f"epoch={ep+1} | loss={stats['loss']:.4f} | CE={stats['ce']:.4f} | fuse={stats['fuse']:.4f} | RL={stats['rl']:.4f} | R={stats['reward']:.4f} | acc={stats['acc']:.4f} | ECS={stats['ecs']:.4f} | thr={model.rasl.threshold():.3f}")


## 14) اینفرنس + گزارش توضیح‌پذیری برای یک نمونه

In [None]:

model.eval()
batch = next(iter(dl))
voice = batch["voice"].to(device)
mri = batch["mri"].to(device)
sensor = batch["sensor"].to(device)
y = batch["y"].to(device)

with torch.no_grad():
    yhat, (Ev, Em, Es, H, attn) = model(voice, mri, sensor)

thr = model.rasl.threshold()
pred = (yhat >= thr).float()

H2 = H[:2].detach().requires_grad_(True)
phi2 = compute_shap_like_attribution(lambda z: model_fn_on_fused(model, z), H2)
ecs2 = compute_ecs(phi2, tau=0.01)
cf2 = find_counterfactual(lambda z: model_fn_on_fused(model, z), H2, steps=20, lr=0.2)

print("yhat[:5] =", yhat[:5].detach().cpu().numpy())
print("pred[:5] =", pred[:5].detach().cpu().numpy(), " (thr=", thr, ")")
print("true[:5] =", y[:5].detach().cpu().numpy())
print("ECS(first2) =", ecs2.detach().cpu().numpy())

abs_phi = phi2[0].abs().detach().cpu().numpy()
top_idx = abs_phi.argsort()[-10:][::-1]
print("Top-10 fused feature indices:", top_idx.tolist())
print("Top-10 attribution values:", phi2[0][top_idx].detach().cpu().numpy())

dist = (cf2 - H2).pow(2).sum(dim=-1).sqrt().detach().cpu().numpy()
print("Counterfactual L2 distance (first2):", dist)



## 15) جایگزینی با دیتای واقعی

### Voice
- ورودی مدل در این نوت‌بوک `spectrogram` با شکل `[B,1,F,T]` است.
- در پیاده‌سازی واقعی مطابق متن روش:
  - پیش‌تأکید (Eq.1)
  - STFT (Eq.2) و/یا MFCC
  - نرمال‌سازی

### MRI
- ورودی `[B,3,224,224]` است.
- اگر MRI تک‌کاناله است، یا 3 بار تکرار کنید یا encoder را برای 1 کانال تغییر دهید.
- skull stripping / bias correction / MNI normalization را پیش از ورود به مدل انجام دهید (Eq.3).

### Sensor
- ورودی `[B,C,L]` است (acc+gyro).
- windowing 5 ثانیه، wavelet denoise، FFT و نرمال‌سازی (Eq.4).

### نکته درباره SHAP واقعی
در این نوت‌بوک برای سبک بودن، از **Integrated Gradients** به عنوان تقریب attribution استفاده شد.
در پروژه واقعی می‌توانید SHAP یا Captum را اضافه کنید و ECS را مطابق Eq.22 محاسبه کنید.
