<a href="https://colab.research.google.com/github/ChakesWu/parkinson-predict/blob/main/Parkinson_Rehabilitation_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""Parkinson Rehabilitation System.ipynb

Automatically generated by Colaboratory.

Original file is located at:
    https://colab.research.google.com/drive/your-drive-link
"""

# ==================== 环境配置 ====================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# 创建项目目录结构
!mkdir -p "/content/drive/MyDrive/Parkinson_Project/{data,models,results}"

# 安装依赖库（修复版本冲突）
!pip install torch==2.0.1 scipy==1.10.1 neurokit2==0.2.4 matplotlib==3.7.1 -q

# ==================== 数据生成模块 ====================
import numpy as np
import pandas as pd
from scipy import signal
import os
import time

def generate_base_dataset(samples=30000):
    """生成基准临床数据集（100Hz采样，5分钟数据）"""
    np.random.seed(42)
    fs = 100  # 采样频率
    t = np.linspace(0, 300, samples)
    tremor_freq = 4 + np.random.normal(0, 0.5)
    finger_angle = 90 + 10 * signal.sawtooth(2 * np.pi * tremor_freq * t)
    finger_angle += np.random.normal(0, 2, samples)
    acceleration = 0.8 * np.sin(2 * np.pi * 0.3 * t) * np.exp(-0.005*t)
    acceleration += 0.1 * np.random.randn(samples)
    emg_bursts = np.zeros(samples)
    for i in range(0, samples, 2000):
        burst = 0.5 * np.abs(signal.hilbert(np.random.randn(500)))
        emg_bursts[i:i+500] = burst
    emg = 0.4 * np.abs(signal.hilbert(np.random.randn(samples))) + emg_bursts
    labels = np.where(
        (np.std(finger_angle) > 8) &
        (np.mean(emg) > 0.45) &
        (np.max(acceleration) < 1.2),
        1, 0
    )
    df = pd.DataFrame({
        'timestamp': t,
        'finger_angle': finger_angle,
        'acceleration': acceleration,
        'emg': emg,
        'parkinson_label': labels
    })
    df.to_csv("/content/drive/MyDrive/Parkinson_Project/data/base_data.csv", index=False)
    print("基准数据集已生成，包含样本数:", len(df))
    return df

def generate_custom_data(samples=3000):
    """生成用户自定义数据（50Hz采样，模拟设备采集）"""
    np.random.seed(int(time.time()))
    t = np.linspace(0, 300, samples)
    finger_angle = 85 + 12 * signal.sawtooth(2 * np.pi * 5.5 * t)
    finger_angle += np.random.normal(0, 3, samples)
    acceleration = 0.6 * np.sin(2 * np.pi * 0.25 * t) * np.exp(-0.004*t)
    acceleration += 0.15 * np.random.randn(samples)
    emg = 0.5 * np.abs(signal.hilbert(np.random.randn(samples)))
    spike_indices = np.random.choice(samples, 50, replace=False)
    emg[spike_indices] += 0.8
    df = pd.DataFrame({
        'timestamp': t,
        'finger_angle': finger_angle,
        'acceleration': acceleration,
        'emg': emg,
        'parkinson_label': 1
    })
    df.to_csv("/content/drive/MyDrive/Parkinson_Project/data/custom_data.csv", index=False)
    print("自定义数据集已生成，包含样本数:", len(df))
    return df

# ==================== 特征工程模块 ====================
def kinematic_feature_engineering(df):
    """运动学特征增强（最终生成9个特征，并保留标签列）"""
    df['angle_velocity'] = np.gradient(df['finger_angle'], df['timestamp'])
    df['angle_acceleration'] = np.gradient(df['angle_velocity'], df['timestamp'])
    freqs, psd = signal.welch(df['emg'], fs=100, nperseg=512)
    df['emg_peak_freq'] = freqs[np.argmax(psd)]
    df['emg_psd_ratio'] = psd[(freqs > 10) & (freqs < 35)].sum() / psd.sum()
    features = [
        'finger_angle', 'acceleration', 'emg',
        'angle_velocity', 'angle_acceleration',
        'emg_peak_freq', 'emg_psd_ratio'
    ]
    for feat in features:
        df[feat] = df[feat].replace([np.inf, -np.inf], np.nan)
        df[feat] = df[feat].fillna(df[feat].mean())
    df[features] = (df[features] - df[features].mean()) / df[features].std()
    df[features] = df[features].replace([np.inf, -np.inf], np.nan).fillna(0)
    df['rolling_angle_var'] = df['finger_angle'].rolling(window=100, center=True).var().fillna(0)
    final_features = [
        'finger_angle', 'acceleration', 'emg',
        'angle_velocity', 'angle_acceleration',
        'emg_peak_freq', 'emg_psd_ratio',
        'rolling_angle_var', 'timestamp',
        'parkinson_label'
    ]
    return df[final_features]

# ==================== 模型架构 ====================
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class PretrainedBioEncoder(nn.Module):
    """8通道输入版本（对应8个特征）"""
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(8, 32, 5, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(32, 64, 3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.AdaptiveMaxPool1d(1)
        )
        self.lstm = nn.LSTM(
            input_size=8,
            hidden_size=64,
            bidirectional=True,
            num_layers=2,
            batch_first=True
        )

    def forward(self, x):
        # CNN 处理 [batch_size, channels, time]
        cnn_feat = self.cnn(x).squeeze(-1)
        # LSTM 处理 [batch_size, time, input_size]
        lstm_input = x.permute(0, 2, 1)  # 从 [batch_size, 8, 3000] 转为 [batch_size, 3000, 8]
        lstm_out, _ = self.lstm(lstm_input)
        lstm_feat = lstm_out[:, -1, :]
        return torch.cat([cnn_feat, lstm_feat], dim=1)

class TransferLearningModel(nn.Module):
    """迁移学习模型（带安全加载）"""
    def __init__(self, pretrained_path):
        super().__init__()
        self.encoder = PretrainedBioEncoder()
        try:
            self.encoder.load_state_dict(
                torch.load(pretrained_path, map_location='cpu', weights_only=True)
            )
            print("预训练权重加载成功")
        except Exception as e:
            print(f"权重加载失败: {str(e)}")
            self._initialize_weights()
        for param in list(self.encoder.parameters())[:4]:
            param.requires_grad = False
        self.adapter = nn.Sequential(
            nn.Linear(192, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 3)
        )

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.encoder(x)
        return self.adapter(features)

# ==================== 数据预处理 ====================
class ParkinsonDataset(Dataset):
    """数据加载器（带维度验证）"""
    def __init__(self, df, seq_length=3000):
        self.data = df.drop(columns=['timestamp', 'parkinson_label']).values
        self.labels = df['parkinson_label'].values
        self.seq_length = seq_length
        if self.data.shape[1] != 8:
            raise ValueError(f"输入特征数应为8，当前为{self.data.shape[1]}")

    def __len__(self):
        return len(self.data) // self.seq_length

    def __getitem__(self, idx):
        start = idx * self.seq_length
        end = start + self.seq_length
        seq = self.data[start:end].T  # 转置为 [8, 3000]
        label = int(self.labels[start:end].mean() > 0.5)
        return torch.FloatTensor(seq), torch.tensor(label, dtype=torch.long)

# ==================== 训练流程 ====================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")

    print("\n===== 正在生成数据 =====")
    base_df = generate_base_dataset()
    custom_df = generate_custom_data()

    print("\n===== 正在处理特征 =====")
    processed_base = kinematic_feature_engineering(base_df)
    processed_custom = kinematic_feature_engineering(custom_df)
    processed_df = pd.concat([processed_base, processed_custom])
    print("处理后的特征维度:", processed_df.shape)
    if 'parkinson_label' not in processed_df.columns:
        raise KeyError("parkinson_label 列在特征工程后丢失！")
    print("标签列 'parkinson_label' 已成功保留")

    print("\n===== 正在划分数据集 =====")
    train_dataset = ParkinsonDataset(processed_df.iloc[:24000])
    val_dataset = ParkinsonDataset(processed_df.iloc[24000:])
    print(f"训练集样本数: {len(train_dataset)}")
    print(f"验证集样本数: {len(val_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)

    print("\n===== 正在初始化模型 =====")
    model = TransferLearningModel(
        "/content/drive/MyDrive/Parkinson_Project/models/pretrained_bio_model.pth"
    ).to(device)

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=1e-4,
        weight_decay=1e-5
    )
    criterion = nn.CrossEntropyLoss()

    print("\n===== 开始训练 =====")
    best_acc = 0.0
    for epoch in range(20):
        if epoch == 8:
            print("解冻CNN深层参数")
            for param in model.encoder.cnn.parameters():
                param.requires_grad = True
        if epoch == 12:
            print("解冻LSTM参数")
            for param in model.encoder.lstm.parameters():
                param.requires_grad = True

        model.train()
        total_loss = 0.0
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            if torch.isnan(loss):
                print("警告：损失值为 nan，检查输入数据")
                print(f"inputs: {inputs}")
                print(f"outputs: {outputs}")
                break
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/20 | 平均损失: {avg_loss:.4f} | 验证准确率: {val_acc:.2f}%")
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc
            }, "/content/drive/MyDrive/Parkinson_Project/models/best_model.pth")
            print("新的最佳模型已保存")

    print("\n===== 生成康复方案测试 =====")
    test_data = {
        'timestamp': np.linspace(0, 300, 3000),
        'finger_angle': 85 + 10 * np.sin(2 * np.pi * 5 * np.linspace(0, 1, 3000)),
        'acceleration': 0.6 * np.exp(-0.005 * np.linspace(0, 300, 3000)),
        'emg': 0.7 * np.abs(np.random.randn(3000)),
        'parkinson_label': 1
    }
    plan = predict_rehabilitation_plan(test_data, device)
    print("\n生成的帕金森手部训练方案：")
    for key, value in plan.items():
        if isinstance(value, list):
            print(f"- {key}:")
            for item in value:
                print(f"  * {item}")
        else:
            print(f"- {key}: {value}")

# ==================== 推理模块 ====================
def predict_rehabilitation_plan(input_data, device):
    """生成详细的康复方案"""
    try:
        model = TransferLearningModel(
            "/content/drive/MyDrive/Parkinson_Project/models/pretrained_bio_model.pth"
        ).to(device)

        checkpoint_path = "/content/drive/MyDrive/Parkinson_Project/models/best_model.pth"
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            model.load_state_dict(checkpoint['model_state_dict'])
            print("最佳模型权重加载成功")
        else:
            print("未找到最佳模型权重，使用随机初始化的模型")

        model.eval()
        processed_data = kinematic_feature_engineering(pd.DataFrame(input_data))
        dataset = ParkinsonDataset(processed_data)
        loader = DataLoader(dataset, batch_size=1)

        with torch.no_grad():
            inputs, labels = next(iter(loader))
            print(f"inputs shape: {inputs.shape}")
            inputs = inputs.to(device)
            outputs = model(inputs)
            print(f"outputs shape: {outputs.shape}")
            probabilities = torch.softmax(outputs, dim=1).cpu().numpy()
            print(f"probabilities shape: {probabilities.shape}")
            print(f"probabilities: {probabilities}")

        return {
            '基础训练': [
                f"手指伸展训练: {int(probabilities[0][1]*100)}% 强度 (每日3组，每组10次)",
                f"握力强化训练: {int(probabilities[0][1]*100)}% 强度 (每日2组，每组8次)"
            ],
            '高级训练': [
                f"协调性训练: {int(probabilities[0][1]*100)}% 强度 (每日1组，每组5分钟)",
                "使用压力球进行精细动作练习"
            ],
            '注意事项': [
                "训练前后进行10分钟热敷/冷敷",
                "每个动作间隔休息2分钟",
                "如出现疼痛或疲劳立即停止"
            ]
        }

    except Exception as e:
        print(f"生成方案时出错: {str(e)}")
        return {"error": "无法生成训练方案"}

# ==================== 执行主程序 ====================
if __name__ == "__main__":
    main()

Mounted at /content/drive
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.9/58.9 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.1/34.1 MB[0m [31m43.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m50.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m66.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.1/317.1 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.8/11.8 MB[0m [31m88.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.0/21.0 MB[0m [31m67.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━

  return self.fget.__get__(instance, owner)()


权重加载失败: Error(s) in loading state_dict for PretrainedBioEncoder:
	size mismatch for cnn.0.weight: copying a param with shape torch.Size([32, 9, 5]) from checkpoint, the shape in current model is torch.Size([32, 8, 5]).
	size mismatch for lstm.weight_ih_l0: copying a param with shape torch.Size([256, 9]) from checkpoint, the shape in current model is torch.Size([256, 8]).
	size mismatch for lstm.weight_ih_l0_reverse: copying a param with shape torch.Size([256, 9]) from checkpoint, the shape in current model is torch.Size([256, 8]).

===== 开始训练 =====
Epoch 1/20 | 平均损失: 1.6606 | 验证准确率: 0.00%
Epoch 2/20 | 平均损失: 1.3542 | 验证准确率: 0.00%
Epoch 3/20 | 平均损失: 1.2315 | 验证准确率: 0.00%
Epoch 4/20 | 平均损失: 1.4171 | 验证准确率: 0.00%
Epoch 5/20 | 平均损失: 1.3272 | 验证准确率: 0.00%
Epoch 6/20 | 平均损失: 1.6265 | 验证准确率: 0.00%
Epoch 7/20 | 平均损失: 1.2544 | 验证准确率: 0.00%
Epoch 8/20 | 平均损失: 1.1448 | 验证准确率: 0.00%
解冻CNN深层参数
Epoch 9/20 | 平均损失: 1.2337 | 验证准确率: 0.00%
Epoch 10/20 | 平均损失: 0.9555 | 验证准确率: 0.00%
Epoch 11/20 | 平均损失: 1.02