In [1]:
import numpy as np
import pandas as pd
import wfdb
import os
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import seaborn as snss
from pprint import pprint
from tqdm import tqdm
from pathlib import Path
import sys


In [2]:
# 检查train和val的csv文件
train_csv_path = 'datasets/pretrain/train.csv'
val_csv_path = 'datasets/pretrain/val.csv'

print(f'Train CSV: {train_csv_path}')
print(f'Val CSV: {val_csv_path}')

Train CSV: datasets/pretrain/train.csv
Val CSV: datasets/pretrain/val.csv


In [6]:
data = pd.read_csv(train_csv_path, low_memory=False)
ecg = np.load(data['path'][0])
# 将 numpy 数组转换为 DataFrame 后再保存
pd.DataFrame(ecg).to_csv('ecg.csv', index=False)


In [None]:
# 先查看train csv的列
train_df = pd.read_csv(train_csv_path, low_memory=False)
print('Train CSV columns:')
print(train_df.columns.tolist())
print(f'\nTrain rows: {len(train_df)}')

Index(['subject_id', 'study_id', 'cart_id', 'ecg_time', 'report_0', 'report_1',
       'report_2', 'report_3', 'report_4', 'report_5', 'report_6', 'report_7',
       'report_8', 'report_9', 'report_10', 'report_11', 'report_12',
       'report_13', 'report_14', 'report_15', 'report_16', 'report_17',
       'bandwidth', 'filtering', 'rr_interval', 'p_onset', 'p_end',
       'qrs_onset', 'qrs_end', 't_end', 'p_axis', 'qrs_axis', 't_axis',
       'report_length', 'total_report', 'path'],
      dtype='object')

In [None]:
# 查看第一条report示例
print('Train report示例:')
print(train_df['total_report'][0][:200])

'atrial fibrillation with rapid ventricular response.. cannot rule out anteroseptal infarct - age undetermined. possible left ventricular hypertrophy. inferior/lateral st-t changes are probably due to ventricular hypertrophy. abnormal ecg.'

In [24]:
from transformers import AutoTokenizer
import torch
import torch.nn as nn

class TextEncoder(nn.Module):
    def __init__(self, 
                 text_model='ncbi/MedCPT-Query-Encoder',
                 free_layers=6,           # 冻结前6层
                 proj_hidden=256,         # 投影层隐藏维度
                 proj_out=256):           # 投影层输出维度
        super().__init__()
        
        # ========== 1. Text Encoder ==========
        self.tokenizer = AutoTokenizer.from_pretrained(
            text_model             
        )

    
    
    def _tokenize(self, text, device=None):
        tokenizer_output = self.tokenizer.batch_encode_plus(
            batch_text_or_text_pairs=text,
            add_special_tokens=True,    
            truncation=True,             
            max_length=256,              
            padding='max_length',        
            return_tensors='pt'          
        )
        return tokenizer_output
    
    def forward(self, report):
        return self._tokenize(report)

text_encoder = TextEncoder()

# 创建统一的tokenize结果目录
tokenize_dir = Path('datasets/pretrain/report_tokenize')
tokenize_dir.mkdir(parents=True, exist_ok=True)

# 处理train和val两个数据集
datasets = ['train', 'val']

for dataset_name in datasets:
    print(f'\n{"="*80}')
    print(f'开始处理 {dataset_name} 数据集')
    print(f'{"="*80}')
    
    # 读取csv
    csv_path = f'datasets/pretrain/{dataset_name}.csv'
    df = pd.read_csv(csv_path, low_memory=False)
    
    # 在csv中添加新列（如果不存在）
    if 'report_tokenize_path' not in df.columns:
        df['report_tokenize_path'] = ''
    
    # 逐行处理每个report
    for i in tqdm(range(len(df)), desc=f'Tokenizing {dataset_name}'):
        # 获取当前行的report
        report_text = df.loc[i, 'total_report']
        
        # tokenize (需要包装成列表)
        tokenized = text_encoder._tokenize([report_text])
        
        # 生成npz文件路径 (格式: train_000000.npz 或 val_000000.npz)
        npz_filename = f'{i:06d}.npz'
        npz_path = tokenize_dir / npz_filename
        # 保存tokenized结果到npz (取第一个元素，因为batch size=1)
        np.savez(
            npz_path,
            input_ids=tokenized['input_ids'][0].numpy(),
            attention_mask=tokenized['attention_mask'][0].numpy()
        )
        
        # 在csv中记录相对路径
        relative_path = f'datasets/pretrain/report_tokenize/{npz_filename}'
        df.loc[i, 'report_tokenize_path'] = relative_path
    
    # 保存更新后的csv
    df.to_csv(csv_path, index=False)
    print(f'{dataset_name} 完成！处理了 {len(df)} 条记录')

print(f'\n{"="*80}')
print('全部完成！')



开始处理 train 数据集


Tokenizing train: 100%|██████████| 745405/745405 [06:19<00:00, 1961.76it/s]


train 完成！处理了 745405 条记录

开始处理 val 数据集


Tokenizing val: 100%|██████████| 15213/15213 [00:08<00:00, 1863.11it/s]


val 完成！处理了 15213 条记录

全部完成！


In [None]:
# 验证：读取train和val各一个npz文件查看内容
for dataset_name in ['train', 'val']:
    print(f'\n{"="*60}')
    print(f'{dataset_name.upper()} 数据集验证')
    print(f'{"="*60}')
    
    # 读取csv
    df = pd.read_csv(f'datasets/pretrain/{dataset_name}.csv', low_memory=False)
    
    test_idx = 0
    test_path = df.loc[test_idx, 'report_tokenize_path']
    print(f'NPZ路径: {test_path}')
    
    # 加载npz文件
    loaded = np.load(test_path)
    print(f'\nNPZ包含的数组:')
    for key in loaded.files:
        print(f'  {key}: shape={loaded[key].shape}, dtype={loaded[key].dtype}')
    
    print(f'\ninput_ids前10个token: {loaded["input_ids"][0][:10]}')


In [None]:
# 验证配对关系：检查train和val各前2行数据
for dataset_name in ['train', 'val']:
    print(f'\n{"="*80}')
    print(f'{dataset_name.upper()} 验证配对关系（前2行）:')
    print('='*80)
    
    df = pd.read_csv(f'datasets/pretrain/{dataset_name}.csv', low_memory=False)
    
    for i in range(min(2, len(df))):
        print(f'\n第{i}行:')
        print(f'  Report片段: {df.loc[i, "total_report"][:50]}...')
        print(f'  Tokenize路径: {df.loc[i, "report_tokenize_path"]}')


In [None]:
# 查看更新后的csv结构统计
print('处理结果统计:')
print('='*80)

for dataset_name in ['train', 'val']:
    df = pd.read_csv(f'datasets/pretrain/{dataset_name}.csv', low_memory=False)
    
    print(f'\n{dataset_name.upper()} 数据集:')
    print(f'  列名: {df.columns.tolist()}')
    print(f'  总行数: {len(df)}')
    print(f'  report_tokenize_path列非空数: {df["report_tokenize_path"].notna().sum()}')
    
    # 检查是否有空值
    if df["report_tokenize_path"].isna().sum() > 0:
        print(f'  ⚠️  警告：有 {df["report_tokenize_path"].isna().sum()} 行路径为空')

print('\n' + '='*80)
print('✓ 所有数据处理完成！')


In [None]:
# 查看tokenize目录内容
import os
tokenize_dir = 'datasets/pretrain/report_tokenize'
if os.path.exists(tokenize_dir):
    files = os.listdir(tokenize_dir)
    train_files = [f for f in files if f.startswith('train_')]
    val_files = [f for f in files if f.startswith('val_')]
    
    print(f'Tokenize目录: {tokenize_dir}')
    print(f'  Train文件数: {len(train_files)}')
    print(f'  Val文件数: {len(val_files)}')
    print(f'  总文件数: {len(files)}')
    
    if len(train_files) > 0:
        print(f'\n示例文件名:')
        print(f'  Train: {train_files[0]}')
        if len(val_files) > 0:
            print(f'  Val: {val_files[0]}')
else:
    print(f'目录不存在: {tokenize_dir}')
