# dataset preparation

# 训练longformer

In [None]:
import torch
from torch import nn
from transformers import LongformerForSequenceClassification, LongformerConfig, LongformerTokenizer

# Step 1: Load Pretrained Model and Tokenizer
model_name = "yikuan8/Clinical-Longformer" # "allenai/longformer-base-4096"
tokenizer = LongformerTokenizer.from_pretrained(model_name)

# Define config with correct label count
config = LongformerConfig.from_pretrained(model_name)
config.num_labels = 4

# Load base model with classification head
model = LongformerForSequenceClassification.from_pretrained(model_name, config=config)

In [None]:
print("All model parameters:")
for name, param in model.named_parameters():
    print(f"{name:80} requires_grad={param.requires_grad}, shape={param.shape}")

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
print("All model parameters:")
for name, param in model.named_parameters():
    print(f"{name:80} requires_grad={param.requires_grad}, shape={param.shape}")

In [None]:
for name, param in model.named_parameters():
    if any([
        name.startswith("longformer.encoder.layer.10"),
        name.startswith("longformer.encoder.layer.11"),
        name.startswith("classifier"),
        "LayerNorm" in name,
    ]):
        param.requires_grad = True

In [None]:
print("All model parameters:")
for name, param in model.named_parameters():
    print(f"{name:80} requires_grad={param.requires_grad}, shape={param.shape}")

In [None]:
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(trainable_params, lr=2e-5, weight_decay=0.01)

In [None]:
import pandas as pd
from tqdm import tqdm
from utils import AcuteAbdominalDiagnosisDataset, get_dataloader
from sklearn.model_selection import train_test_split

df = pd.read_csv("/media/luzhenyang/project/agent_graph_diag/lm_classification/ab_cls_dataset_v2_complete_info.csv")
label_list = df['diagnosis'].unique().tolist()
print("label_list: ", label_list)

# Tokenizer + Padding + Truncation 处理
encoded = [
    tokenizer(
        text, 
        padding='max_length',
        truncation=True,
        max_length=4096,
        # return_tensors='pt' 会导致 input_ids shape: torch.Size([4, 1, 4096])
    ) for text in tqdm(df['context'].tolist())
]

df['input_ids'] = [ e['input_ids'] for e in encoded ]
df['attention_mask'] = [ e['attention_mask'] for e in encoded ]
df['input_length'] = df['context'].apply(lambda x: len(tokenizer.tokenize(x)))

print(df.shape)

# 划分训练集，测试集
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['diagnosis'], random_state=7)

print(train_df['diagnosis'].value_counts())
print(val_df['diagnosis'].value_counts())

# 列：['context', 'diagnosis']
train_dataset = AcuteAbdominalDiagnosisDataset(train_df)
val_dataset = AcuteAbdominalDiagnosisDataset(val_df)
train_loader = get_dataloader(train_dataset, use_weighted=True)
val_loader = get_dataloader(val_dataset, use_weighted=False)

In [None]:
# ---- 在优化器之后，添加 Scheduler 设置 ----
from transformers import get_linear_schedule_with_warmup

MAX_LENGTH = 4096
BATCH_SIZE = 4
NUM_EPOCHS = 3

# 计算总的训练步数
num_training_steps = len(train_loader) * NUM_EPOCHS
num_warmup_steps = int(0.1 * num_training_steps)  # 通常设为 10% 的 warmup
print(num_warmup_steps)

In [None]:
# 使用 Huggingface 提供的线性调度器
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=num_warmup_steps, 
    num_training_steps=num_training_steps
)

In [None]:
from utils import *

torch.cuda.set_device(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_torch(device=device)
model.to(device)

# ---------- 更新训练函数：加入 scheduler.step() ----------
def train_epoch(model, loader, optimizer, scheduler):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        global_attention_mask = torch.zeros_like(input_ids)
        global_attention_mask[:, 0] = 1

        outputs = model(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            global_attention_mask=global_attention_mask, 
            labels=labels
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()  # 更新学习率
        optimizer.zero_grad()
        total_loss += loss.item()
    return total_loss / len(loader)


In [None]:
import os
from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import classification_report

log_dir = "/media/luzhenyang/project/agent_graph_diag/lm_classification/training_logs_longformer"
os.makedirs(log_dir, exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_txt_path = os.path.join(log_dir, f"log_{timestamp}.txt")
log_csv_path = os.path.join(log_dir, f"log_{timestamp}.csv")

csv_headers = ["epoch", "train_loss", "micro_f1", "macro_f1", "accuracy"]
with open(log_csv_path, "w") as f:
    f.write(",".join(csv_headers) + "\n")


from sklearn.metrics import f1_score, accuracy_score, confusion_matrix

def per_class_accuracy(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(class_names)))
    total = cm.sum()
    per_class_acc = {}
    for i, cls in enumerate(class_names):
        TP = cm[i, i]
        FP = cm[:, i].sum() - TP
        FN = cm[i, :].sum() - TP
        TN = total - TP - FP - FN
        accuracy_i = (TP + TN) / total
        per_class_acc[cls] = accuracy_i
    return per_class_acc

def log_metrics(epoch, train_loss, y_true, y_pred, log_txt_path, log_csv_path):
    micro_f1 = f1_score(y_true, y_pred, average='micro')
    macro_f1 = f1_score(y_true, y_pred, average='macro')
    acc = accuracy_score(y_true, y_pred)

    # 计算每个类别准确率
    per_class_acc = per_class_accuracy(y_true, y_pred, label_list)

    # 文本日志
    with open(log_txt_path, "a") as f:
        f.write(f"\nEpoch {epoch+1}\n")
        f.write(f"Train Loss: {train_loss:.4f}\n")
        f.write(f"Accuracy: {acc:.4f}\n")
        f.write(f"Micro F1: {micro_f1:.4f} | Macro F1: {macro_f1:.4f}\n")

        # 每个类别准确率写入日志
        for cls, acc_cls in per_class_acc.items():
            f.write(f"Accuracy for {cls}: {acc_cls:.4f}\n")

        f.write(f"{classification_report(y_true, y_pred, target_names=label_list, digits=4)}\n")

    # CSV日志
    with open(log_csv_path, "a") as f:
        f.write(f"{epoch+1},{train_loss:.4f},{micro_f1:.4f},{macro_f1:.4f},{acc:.4f}\n")

In [None]:
# ---------- 验证函数 ----------
def evaluate(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            global_attention_mask = torch.zeros_like(input_ids)
            global_attention_mask[:, 0] = 1

            outputs = model(
                input_ids=input_ids, 
                attention_mask=attention_mask,
                global_attention_mask=global_attention_mask
            )
            preds = torch.argmax(outputs.logits, dim=-1)

            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    report = classification_report(all_labels, all_preds, target_names=label_list, digits=4)
    return report

In [None]:
# ---------- 主训练循环 ----------
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    train_loss = train_epoch(model, train_loader, optimizer, scheduler)
    print(f"Train loss: {train_loss:.4f}")
    
    # 验证并获取预测与标签
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=-1)

            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    # 打印+写入日志
    log_metrics(epoch, train_loss, all_labels, all_preds, log_txt_path, log_csv_path)

    # 保存模型
    save_path = f"checkpoint_epoch{epoch+1}.pt"
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")


### 对齐AGAP的实验设置，验证集80例

In [None]:
from utils import *

# ---------- 更新训练函数：加入 scheduler.step() ----------
def train_epoch(model, loader, optimizer, scheduler):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        global_attention_mask = torch.zeros_like(input_ids)
        global_attention_mask[:, 0] = 1

        outputs = model(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            global_attention_mask=global_attention_mask, 
            labels=labels
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()  # 更新学习率
        optimizer.zero_grad()
        total_loss += loss.item()
    return total_loss / len(loader)


In [None]:
import os
from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix

log_dir = "/media/luzhenyang/project/agent_graph_diag/lm_classification/training_logs_longformer_random_seeds"
os.makedirs(log_dir, exist_ok=True)




def per_class_accuracy(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(class_names)))
    total = cm.sum()
    per_class_acc = {}
    for i, cls in enumerate(class_names):
        TP = cm[i, i]
        FP = cm[:, i].sum() - TP
        FN = cm[i, :].sum() - TP
        TN = total - TP - FP - FN
        accuracy_i = (TP + TN) / total
        per_class_acc[cls] = accuracy_i
    return per_class_acc

def log_metrics(epoch, train_loss, y_true, y_pred, log_txt_path, log_csv_path):
    micro_f1 = f1_score(y_true, y_pred, average='micro')
    macro_f1 = f1_score(y_true, y_pred, average='macro')
    acc = accuracy_score(y_true, y_pred)

    # 计算每个类别准确率
    per_class_acc = per_class_accuracy(y_true, y_pred, label_list)

    # 文本日志
    with open(log_txt_path, "a") as f:
        f.write(f"\nEpoch {epoch+1}\n")
        f.write(f"Train Loss: {train_loss:.4f}\n")
        f.write(f"Accuracy: {acc:.4f}\n")
        f.write(f"Micro F1: {micro_f1:.4f} | Macro F1: {macro_f1:.4f}\n")

        # 每个类别准确率写入日志
        for cls, acc_cls in per_class_acc.items():
            f.write(f"Accuracy for {cls}: {acc_cls:.4f}\n")

        f.write(f"{classification_report(y_true, y_pred, target_names=label_list, digits=4)}\n")

    # CSV日志
    with open(log_csv_path, "a") as f:
        f.write(f"{epoch+1},{train_loss:.4f},{micro_f1:.4f},{macro_f1:.4f},{acc:.4f}\n")

In [None]:
# ---------- 验证函数 ----------
def evaluate(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            global_attention_mask = torch.zeros_like(input_ids)
            global_attention_mask[:, 0] = 1

            outputs = model(
                input_ids=input_ids, 
                attention_mask=attention_mask,
                global_attention_mask=global_attention_mask
            )
            preds = torch.argmax(outputs.logits, dim=-1)

            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())

    report = classification_report(all_labels, all_preds, target_names=label_list, digits=4)
    return report

### 对齐ADW的实验设置，验证集80例

In [None]:
import torch
from torch import nn
from transformers import (
    LongformerForSequenceClassification,
    LongformerConfig,
    LongformerTokenizer,
    get_linear_schedule_with_warmup
)


def setup_model_and_optimizer(train_loader,
                              model_name="yikuan8/Clinical-Longformer",
                              num_labels=4,
                              lr=2e-5,
                              weight_decay=0.01,
                              num_epochs=3,
                              freeze_except_last=True,
                              device=None):
    """
    初始化 Longformer 模型、优化器与学习率调度器。
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. 加载 tokenizer 和 config
    tokenizer = LongformerTokenizer.from_pretrained(model_name)
    config = LongformerConfig.from_pretrained(model_name)
    config.num_labels = num_labels

    # 2. 加载模型
    model = LongformerForSequenceClassification.from_pretrained(model_name, config=config)

    # 3. 冻结部分层（可选）
    if freeze_except_last:
        for param in model.parameters():
            param.requires_grad = False
        for name, param in model.named_parameters():
            if any([
                name.startswith("longformer.encoder.layer.10"),
                name.startswith("longformer.encoder.layer.11"),
                name.startswith("classifier"),
                "LayerNorm" in name,
            ]):
                param.requires_grad = True

    model.to(device)

    # 4. 设置优化器
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=weight_decay)

    # 5. 设置调度器
    num_training_steps = len(train_loader) * num_epochs
    num_warmup_steps = int(0.1 * num_training_steps)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    # 打印信息
    print(f"[✓] Loaded model: {model_name}")
    print(f"[✓] Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    print(f"[✓] Total steps: {num_training_steps}, Warmup: {num_warmup_steps}")

    return model, tokenizer, optimizer, scheduler


In [None]:
import pandas as pd
from tqdm import tqdm
from utils import AcuteAbdominalDiagnosisDataset, get_dataloader
from sklearn.model_selection import train_test_split

MAX_LENGTH = 4096
BATCH_SIZE = 4
NUM_EPOCHS = 3

model_name = "yikuan8/Clinical-Longformer" # "allenai/longformer-base-4096"
tokenizer = LongformerTokenizer.from_pretrained(model_name)

df = pd.read_csv("/media/luzhenyang/project/agent_graph_diag/lm_classification/ab_cls_dataset_v2_complete_info.csv")
label_list = df['diagnosis'].unique().tolist()
print("label_list: ", label_list)

# Tokenizer + Padding + Truncation 处理
encoded = [
    tokenizer(
        text, 
        padding='max_length',
        truncation=True,
        max_length=4096,
        # return_tensors='pt' 会导致 input_ids shape: torch.Size([4, 1, 4096])
    ) for text in tqdm(df['context'].tolist())
]

df['input_ids'] = [ e['input_ids'] for e in encoded ]
df['attention_mask'] = [ e['attention_mask'] for e in encoded ]
df['input_length'] = df['context'].apply(lambda x: len(tokenizer.tokenize(x)))

print(df.shape)


# 根据随机种子id划分测试集、训练集
ran_seeds = [1,4,7,9,10,20,23,42,71,96]

for seed in ran_seeds:
    print(f"\n====== Seed: {seed} ======")
    val_ids = pd.read_csv(f"/media/luzhenyang/project/agent_graph_diag/subset_ids_{seed}.csv")
    val_df = df[ df['hadm_id'].isin(val_ids['hadm_id'].values) ]
    train_df = df[ ~df['hadm_id'].isin(val_ids['hadm_id'].values) ]

    print(train_df['diagnosis'].value_counts())
    print(val_df['diagnosis'].value_counts())

    # 列：['context', 'diagnosis']
    train_dataset = AcuteAbdominalDiagnosisDataset(train_df)
    val_dataset = AcuteAbdominalDiagnosisDataset(val_df)
    train_loader = get_dataloader(train_dataset, use_weighted=True)
    val_loader = get_dataloader(val_dataset, use_weighted=False)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_txt_path = os.path.join(log_dir, f"log_{timestamp}_{seed}.txt")
    log_csv_path = os.path.join(log_dir, f"log_{timestamp}_{seed}.csv")

    csv_headers = ["epoch", "train_loss", "micro_f1", "macro_f1", "accuracy"]
    with open(log_csv_path, "w") as f:
        f.write(",".join(csv_headers) + "\n")

    # ---------- 主训练循环 ----------
    # os.environ['CUDA_VISIBLE_DEVICES'] = '3'
    model, tokenizer, optimizer, scheduler = setup_model_and_optimizer(train_loader)

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}_{seed}")
        train_loss = train_epoch(model, train_loader, optimizer, scheduler)
        print(f"Train loss: {train_loss:.4f}")
        
        # 验证并获取预测与标签
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Evaluating"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                preds = torch.argmax(outputs.logits, dim=-1)

                all_preds.extend(preds.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())

        # 打印+写入日志
        log_metrics(epoch, train_loss, all_labels, all_preds, log_txt_path, log_csv_path)

        # 保存模型
        save_path = f"{log_dir}/checkpoint_epoch{epoch+1}_{seed}.pt"
        torch.save(model.state_dict(), save_path)
        print(f"Model saved to {save_path}")


In [None]:
data = pd.read_csv('/media/luzhenyang/project/agent_graph_diag/lm_classification/ab_cls_dataset.csv')
data.columns

In [None]:
data.shape

In [None]:
inter_data = pd.read_csv('/media/luzhenyang/project/agent_graph_diag/AGAP_full_dataset_results/AGAP_full_dataset_own.csv')
inter_data.columns

In [None]:
import os
import sys
sys.path.append('/media/luzhenyang/project/agent_graph_diag/AGAP')

from extract_patient_info import Template_customized

patient_info_path = '/media/luzhenyang/project/datasets/mimic_iv_ext_clinical_decision_abdominal/clinical_decision_making_for_abdominal_pathologies_1.1'
patient_info_file_names = [
    'history_of_present_illness.csv', 
    'microbiology.csv',
    'laboratory_tests.csv',
    'radiology_reports.csv',
]

patients_info = Template_customized(
    base_path=patient_info_path,
    file_names=patient_info_file_names
)

In [None]:
df_dataset = inter_data[['hadm_id', 'diagnosis']].copy()
df_dataset.shape

In [None]:
# 推荐写法
def extract_all(row):
    return pd.Series({
        'hpi': patients_info.extract_hpi(row['hadm_id']),
        'pe': patients_info.extract_pe(row['hadm_id']),
        'lab': patients_info.laboratory_test_mapping_v2_llm(row['hadm_id']),
        'ima': patients_info.extract_rr(row['hadm_id'])
    })

df_dataset[['hpi', 'pe', 'lab', 'ima']] = df_dataset.apply(extract_all, axis=1)

# for id in tqdm(df_dataset['hadm_id'].values):
#     df_dataset['hpi'] = df_dataset.apply(
#         lambda id: patients_info.extract_hpi(hadm_id=id) 
#     )
#     df_dataset['pe'] = df_dataset.apply(
#         lambda id: patients_info.extract_pe(hadm_id=id)
#     )

In [None]:
df_dataset.shape

In [None]:
print(df_dataset['lab'].iloc[0])

In [None]:
cols_to_concat = ['hpi', 'pe', 'lab', 'ima']

df_dataset['context'] = df_dataset[cols_to_concat].astype(str).agg('\n'.join, axis=1)

In [None]:
print(df_dataset['context'].iloc[0])

## 统计平均tokens数

In [None]:
import torch
from torch import nn
from transformers import LongformerForSequenceClassification, LongformerConfig, LongformerTokenizer

# Step 1: Load Pretrained Model and Tokenizer
model_name = "allenai/longformer-base-4096"
tokenizer = LongformerTokenizer.from_pretrained(model_name)

df_dataset['token_count'] = df_dataset['context'].apply(lambda x: len(tokenizer.tokenize(x)))

avg_len = df_dataset['token_count'].mean()
max_len = df_dataset['token_count'].max()
percentiles = df_dataset['token_count'].quantile([0.5, 0.9, 0.95, 0.99])

print(f"平均 token 数量: {avg_len:.2f}")
print(f"最大 token 数量: {max_len}")
print("分布分位数:")
print(percentiles)


In [None]:
df_dataset.to_csv('/media/luzhenyang/project/agent_graph_diag/lm_classification/ab_cls_dataset_v2_complete_info.csv', index=False)