In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import pandas as pd
import numpy as np
import random
import math
import ast

import torch
from torch.utils.data import DataLoader, TensorDataset
import utils.my_ecg_process as ecg
from utils.my_tokenizer import Tokenizer
from utils.my_templates import Sentences, Choices, Reports, Predict

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, set_seed
from safetensors.torch import load_file
from peft import LoraConfig, TaskType
from trl import SFTTrainer
from datasets import Dataset

set_seed(42)

In [None]:
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
print(f"PyTorch sees {torch.cuda.device_count()} GPU(s)")
print(f"Current device index: {torch.cuda.current_device()}")
print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")

In [None]:
sampling_rate = 500
batch_size = 32
seq_length = 500
patch_size = 25
latent_ratio = 0.5
channels = 12
codebook_size = 256
residual_levels = 2
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dir = f"level_{residual_levels}_code_{codebook_size}_len_{seq_length}_ratio_{latent_ratio}"
ecg_tokenizer_path = f"tokenizer/ecg_tokenizer_{dir}.pth"
model_path = 'phi-3'
num = 10
trainer_path = f"training/{num}"
merged_model_path = f"phi-3-ecg/{num}"

In [None]:
vq_kwargs = {'residual_levels': residual_levels}
ecg_tokenizer = Tokenizer(
    seq_length=seq_length,
    patch_size=patch_size,
    latent_ratio=latent_ratio,
    channels=channels,
    codebook_size=codebook_size,
    vq_kwargs=vq_kwargs
).to(device)
# ecg_weights = torch.load(ecg_tokenizer_path, weights_only=False)
# ecg_tokenizer.load_state_dict(ecg_weights['model_state_dict'])

In [None]:
tok = {
    'model_state_dict': ecg_tokenizer.state_dict(),
}
torch.save(tok, f"tokenizer/tok_2.pth")

In [None]:
config = AutoConfig.from_pretrained(
    model_path,
    trust_remote_code=False
)

text_tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=False
)

text_tokenizer.pad_token = text_tokenizer.unk_token
text_tokenizer.pad_token_id = text_tokenizer.convert_tokens_to_ids(text_tokenizer.pad_token)
text_tokenizer.padding_side = 'left'

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    config=config,
    trust_remote_code=False,
    torch_dtype=torch.float16,
    attn_implementation='sdpa',
    # device_map="auto"
).to(device)

weight1 = load_file(f"{model_path}/model-00001-of-00002.safetensors")
weight2 = load_file(f"{model_path}/model-00002-of-00002.safetensors")

state_dict = {**weight1, **weight2}
model.load_state_dict(state_dict, strict=False)

In [None]:
args = TrainingArguments(
        output_dir=trainer_path,
        logging_dir=f"runs/{trainer_path}",
        save_strategy="steps",
        save_steps=500,
        save_total_limit=2,
        eval_strategy="steps",
        eval_steps=500,
        do_eval=True,
        optim="adamw_torch",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        per_device_eval_batch_size=4,
        log_level="warning",
        logging_steps=100,
        learning_rate=1e-4,
        fp16=True,
        bf16=False,
        num_train_epochs=1,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        report_to="tensorboard",
        seed=42,
)

peft_config = LoraConfig(
        r=64,
        lora_alpha=64,
        lora_dropout=0.05,
        task_type=TaskType.CAUSAL_LM,
        target_modules=['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
)

In [None]:
X = np.load("data/full_data.npy")
total_length = int(seq_length * (1 + latent_ratio))
start = int((X.shape[2] - total_length)/2)
signal = X[:, :, start:start+seq_length]
predict = X[:, :, start+seq_length:start+total_length]

In [None]:
df_dataset = TensorDataset(torch.tensor(signal, dtype=torch.float32))
df_loader = DataLoader(df_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
tokens = []
new = []
with torch.no_grad(): 
    for batch in df_loader:
        x = batch[0].to(device)
        ecg_tokens = ecg_tokenizer.tokenize(x).cpu().detach().tolist()
        tokens.extend(ecg_tokens)        

In [None]:
combined_df = pd.read_parquet("data/full_data.parquet")

In [None]:
X = np.load("data/full_data.npy")
total_length = int(seq_length * (1 + latent_ratio))
start = int((X.shape[2] - total_length)/2)
signal = X[:, :, start:start+seq_length]
predict = X[:, :, start+seq_length:start+total_length]

In [None]:
df_dataset = TensorDataset(torch.tensor(signal, dtype=torch.float32))
df_loader = DataLoader(df_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
tokens = []
new = []
with torch.no_grad(): 
    for batch in df_loader:
        x = batch[0].to(device)
        ecg_tokens = ecg_tokenizer.tokenize(x).cpu().detach().tolist()
        tokens.extend(ecg_tokens)
        
combined_df["tokens"] = tokens

In [None]:
new_tokens = ["<|ecg_start|>", "<|ecg_end|>", "<|report_start|>", "<|report_end|>", "<|pred_start|>", "<|pred_end|>"] + [f"<|ecg_{i+1}|>" for i in range(codebook_size)]
text_tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(text_tokenizer))

In [None]:
def message_form(template):
    user = {
        "content": f"{template[0]}\n{template[1]}",
        "role": "user"
    }

    assistant = {
        "content": template[2],
        "role": "assistant"
    }

    msg = {"text": text_tokenizer.apply_chat_template([user, assistant], add_generation_prompt=False, tokenize=False)}
    return msg

In [None]:
symptoms = ["diagnostic", "form", "rhythm"]

def create_message_column(row):
    messages = []
    ecg_text = [f"<|ecg_{row.tokens[i][0]+1}|>" for i in range(len(row.tokens))]
    ecg_input = "<|ecg_start|> " + " ".join(ecg_text) + " <|ecg_end|>"
    age = int(row.age)
    sex = row.sex

    symptom_dict_list = [ast.literal_eval(row[symptoms[idx]]) for idx in range(3)]
    S = Sentences.get_template(age, sex, ecg_input, symptom_dict_list)
    C = Choices.get_template(age, sex, ecg_input, symptom_dict_list)
    R = Reports.get_template(age, sex, ecg_input, symptom_dict_list)
    P = Predict.get_template(age, sex, ecg_input)
    lists = S + C + R + P
    messages = [message_form(l) for l in lists]
    return messages

In [None]:
dataset_chatml = combined_df.apply(create_message_column, axis=1)

In [None]:
valid_fold = 9
test_fold = 10
train_list = dataset_chatml[(combined_df.strat_fold != valid_fold) & (combined_df.strat_fold != test_fold)]
valid_list = dataset_chatml[combined_df.strat_fold == valid_fold]
test_list = dataset_chatml[combined_df.strat_fold == test_fold]

train_dataset = Dataset.from_list([item for sublist in train_list for item in sublist]).shuffle(seed=42)
valid_dataset = Dataset.from_list([item for sublist in valid_list for item in sublist])
test_dataset = Dataset.from_list([item for sublist in test_list for item in sublist])

In [None]:
trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        peft_config=peft_config,
        processing_class=text_tokenizer,
        args=args,
)

In [None]:
trainer.train()
trainer.save_model()

In [None]:
del model
del trainer

import gc
gc.collect()
gc.collect()
torch.cuda.empty_cache()
gc.collect()

if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    device = torch.device('cuda')
    torch.cuda.reset_peak_memory_stats(device)
    print("CUDA 设备已重置")
else:
    print("CUDA 不可用")
    
gc.collect()

In [None]:
from peft import AutoPeftModelForCausalLM

new_model = AutoPeftModelForCausalLM.from_pretrained(
    args.output_dir,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    trust_remote_code=True
)
merged_model = new_model.merge_and_unload()
merged_model.save_pretrained(merged_model_path, trust_remote_code=True, safe_serialization=True)
text_tokenizer.save_pretrained(merged_model_path)