In [1]:
import torch as t


In [2]:
a = t.cuda.device_count()
a

1

In [3]:
# !pip install datasets rouge_score torchmetrics 

In [4]:
from datasets import *
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForLanguageModeling, Seq2SeqTrainer, Seq2SeqTrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset

class MedicalDialogueDataset(Dataset):
    def __init__(self, split, percent=100, seed=42):  # 添加seed参数，默认值为42
        # 加载数据集
        ds = load_dataset("omi-health/medical-dialogue-to-soap-summary", split=split)
        
        # 移除不需要的列
        columns_to_remove = ['messages', 'prompt']
        ds = ds.remove_columns(columns_to_remove)
        
        # 替换换行符并重命名列
        ds = ds.rename_column('soap', 'summary')
        
        # 添加ID和格式化摘要
        ds = ds.map(self.add_id, with_indices=True)
        ds = ds.map(self.format_summary)
        
        # 如果需要子集，先随机打乱，再选择对应的百分比
        if percent < 100:
            ds = ds.shuffle(seed=seed).select(range(int(percent / 100.0 * len(ds))))
        
        self.data = ds
    
    def add_id(self, example, idx):
        example['id'] = str(idx)
        return example
    
    def format_summary(self, example):
        example['summary'] = example['summary'].replace('S: ', 'Subjective: ')
        example['summary'] = example['summary'].replace('O: ', 'Objective: ')
        example['summary'] = example['summary'].replace('A: ', 'Assessment: ')
        example['summary'] = example['summary'].replace('P: ', 'Plan: ')
        return example

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]  # 获取索引对应的数据项
        ordered_item = {'id': item['id']}  # 创建一个新字典，并首先加入'id'
        ordered_item.update({k: item[k] for k in item if k != 'id'})  # 添加其他字段，排除'id'
        # return item['dialogue'], item['summary'] 
        return ordered_item

# 创建不同百分比的训练数据集实例
# train_data = MedicalDialogueDataset('train', percent=20, seed=42)
train_data = MedicalDialogueDataset('train', percent=50, seed=42)
# train_data = MedicalDialogueDataset('train', percent=100, seed=42)  

valid_data = MedicalDialogueDataset('validation')
test_data = MedicalDialogueDataset('test')


In [6]:
# # 示例初始化
# train_data = MedicalDialogueDataset('train')
# valid_data = MedicalDialogueDataset('validation')
# test_data = MedicalDialogueDataset('test')

In [7]:
print(f'train set size: {len(train_data)}')
print(f'valid set size: {len(valid_data)}')
print(f'test set size: {len(test_data)}')
print(next(iter(train_data)))

train set size: 4625
valid set size: 500
test set size: 250
{'id': '8647', 'dialogue': "Doctor: Good morning, how can I help you today?\nPatient: Hi doctor, I recently underwent an abdominal ultrasonography (USG) for my bilateral renal nephrolithiasis.\nDoctor: I see. Tell me about your general health. How is your blood biochemistry, and do you have any cardiovascular or hormonal disorders?\nPatient: My blood biochemistry is normal, and I don't have any cardiovascular or hormonal disorders. I had an operation 17 years ago to repair my extrophic bladder, and they created an Indiana pouch for me.\nDoctor: Alright. Can you tell me about your weight and body mass index (BMI)?\nPatient: My weight is 85 kg, and my BMI is 28.7 kg/m2.\nDoctor: Thank you for the information. Now, let's talk about your USG results. It showed a hyperechogenic lesion at the fat intensity, filling out your right renal sinus completely. A computerized tomography (CT) scan confirmed the presence of a fatty mass that 

In [8]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("gauravkoradiya/T5-Finetuned-Summarization-DialogueDataset")



In [9]:
dialogue = "Doctor: What brings you back into the clinic today, miss? Patient: I came in for a refill of my blood pressure medicine. Doctor: It looks like Doctor Kumar followed up with you last time regarding your hypertension, osteoarthritis, osteoporosis, hypothyroidism, allergic rhinitis and kidney stones.  Have you noticed any changes or do you have any concerns regarding these issues? Patient: No. Doctor: Have you had any fever or chills, cough, congestion, nausea, vomiting, chest pain, chest pressure?Patient: No. Doctor: Great. Also, for our records, how old are you and what race do you identify yourself as?Patient: I am seventy six years old and identify as a white female."
inputs = tokenizer(dialogue, return_tensors="pt", max_length=1024, truncation=True, padding="max_length")
summary = "The patient is a 76-year-old white female who presents to the clinic today originally for hypertension and a med check.  She has a history of hypertension, osteoarthritis, osteoporosis, hypothyroidism, allergic rhinitis and kidney stones.  Since her last visit she has been followed by Dr. Kumar.  Those issues are stable.  She has had no fever or chills, cough, congestion, nausea, vomiting, chest pain, chest pressure."
# 对目标摘要进行编码
targets = tokenizer(summary, return_tensors="pt", max_length=1024, truncation=True, padding="max_length")
# 打印输入的令牌ID和对应的文本表示
print('Token IDs:', inputs['input_ids'])
print('Tokens:', tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist()))

Token IDs: tensor([[7582,   10,  363,  ...,    0,    0,    0]])
Tokens: ['▁Doctor', ':', '▁What', '▁brings', '▁you', '▁back', '▁into', '▁the', '▁clinic', '▁today', ',', '▁miss', '?', '▁Patient', ':', '▁I', '▁came', '▁in', '▁for', '▁', 'a', '▁refill', '▁of', '▁my', '▁blood', '▁pressure', '▁medicine', '.', '▁Doctor', ':', '▁It', '▁looks', '▁like', '▁Doctor', '▁Kumar', '▁followed', '▁up', '▁with', '▁you', '▁last', '▁time', '▁regarding', '▁your', '▁hyper', 'tension', ',', '▁osteo', 'arth', 'riti', 's', ',', '▁osteo', 'p', 'o', 'ros', 'is', ',', '▁hypo', 't', 'hyroid', 'is', 'm', ',', '▁allergic', '▁', 'r', 'hin', 'it', 'is', '▁and', '▁kidney', '▁stones', '.', '▁Have', '▁you', '▁noticed', '▁any', '▁changes', '▁or', '▁do', '▁you', '▁have', '▁any', '▁concerns', '▁regarding', '▁these', '▁issues', '?', '▁Patient', ':', '▁No', '.', '▁Doctor', ':', '▁Have', '▁you', '▁had', '▁any', '▁fever', '▁or', '▁chill', 's', ',', '▁cough', ',', '▁congestion', ',', '▁nausea', ',', '▁vomiting', ',', '▁chest', '

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


In [11]:
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM, AdamW

max_input_length = 512
max_target_length = 256
model = AutoModelForSeq2SeqLM.from_pretrained("gauravkoradiya/T5-Finetuned-Summarization-DialogueDataset").half()
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)




In [12]:
def collote_fn(batch_samples):
    batch_inputs, batch_targets = [], []
    for sample in batch_samples:
        batch_inputs.append(sample['dialogue'])
        batch_targets.append(sample['summary'])
    batch_data = tokenizer(
        batch_inputs, 
        padding=True, 
        max_length=max_input_length,
        truncation=True, 
        return_tensors="pt"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch_targets, 
            padding=True, 
            max_length=max_target_length,
            truncation=True, 
            return_tensors="pt"
        )["input_ids"]
        batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
        end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
        for idx, end_idx in enumerate(end_token_index):
            labels[idx][end_idx+1:] = -100
        batch_data['labels'] = labels
    return batch_data

In [13]:
train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=4, shuffle=False, collate_fn=collote_fn)
test_dataloader = DataLoader(test_data,batch_size=4,shuffle=True)

In [14]:
import json
from tqdm import tqdm
from rouge_score import rouge_scorer

# 初始化 ROUGE 评分器
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

def predict_summary(dialogue, model, tokenizer):
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(dialogue, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        outputs = model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=64)
        summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary

results = []
rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
for batch in tqdm(test_dataloader, desc="Generating Summaries"):
    dialogues = batch['dialogue']
    reference_summaries = batch['summary']
    
    for dialogue, reference_summary in zip(dialogues, reference_summaries):
        predicted_summary = predict_summary(dialogue, model, tokenizer)
        scores = scorer.score(reference_summary, predicted_summary)
        results.append({
            "Dialogue": dialogue,
            "Reference Summary": reference_summary,
            "Predicted Summary": predicted_summary,
            "ROUGE-1": scores['rouge1'].fmeasure,
            "ROUGE-2": scores['rouge2'].fmeasure,
            "ROUGE-L": scores['rougeL'].fmeasure
        })

        # Accumulate scores
        rouge_scores['rouge1'].append(scores['rouge1'].fmeasure)
        rouge_scores['rouge2'].append(scores['rouge2'].fmeasure)
        rouge_scores['rougeL'].append(scores['rougeL'].fmeasure)

# 计算平均 ROUGE 分数
average_scores = {key: sum(values) / len(values) for key, values in rouge_scores.items()}
print("Average ROUGE Scores:", average_scores)

# 保存到 JSON 文件
with open('pre_T5_finetuned_summarization_DialogueDataset.json', 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=4)

print("Pre-training predictions and ROUGE scores saved successfully.")

Generating Summaries: 100%|██████████| 63/63 [00:59<00:00,  1.07it/s]

Average ROUGE Scores: {'rouge1': 0.14346195300443879, 'rouge2': 0.06140384692675334, 'rougeL': 0.1049545161266254}
Pre-training predictions and ROUGE scores saved successfully.





In [15]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset

class MedicalDialogueDataset(Dataset):
    def __init__(self, split, percent=100, seed=42):  # 添加seed参数，默认值为42
        # 加载数据集
        ds = load_dataset("omi-health/medical-dialogue-to-soap-summary", split=split)
        
        # 移除不需要的列
        columns_to_remove = ['messages', 'prompt']
        ds = ds.remove_columns(columns_to_remove)
        
        # 替换换行符并重命名列
        ds = ds.rename_column('soap', 'summary')
        
        # 添加ID和格式化摘要
        ds = ds.map(self.add_id, with_indices=True)
        ds = ds.map(self.format_summary)
        
        # 如果需要子集，先随机打乱，再选择对应的百分比
        if percent < 100:
            ds = ds.shuffle(seed=seed).select(range(int(percent / 100.0 * len(ds))))
        
        self.data = ds
    
    def add_id(self, example, idx):
        example['id'] = str(idx)
        return example
    
    def format_summary(self, example):
        example['summary'] = example['summary'].replace('S: ', 'Subjective: ')
        example['summary'] = example['summary'].replace('O: ', 'Objective: ')
        example['summary'] = example['summary'].replace('A: ', 'Assessment: ')
        example['summary'] = example['summary'].replace('P: ', 'Plan: ')
        return example

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]  # 获取索引对应的数据项
        ordered_item = {'id': item['id']}  # 创建一个新字典，并首先加入'id'
        ordered_item.update({k: item[k] for k in item if k != 'id'})  # 添加其他字段，排除'id'
        return item['dialogue'], item['summary'] 
        # return ordered_item

# 创建不同百分比的训练数据集实例
# train_data = MedicalDialogueDataset('train', percent=20, seed=42)
train_data = MedicalDialogueDataset('train', percent=50, seed=42)
# train_data = MedicalDialogueDataset('train', percent=100, seed=42)  

valid_data = MedicalDialogueDataset('validation')
test_data = MedicalDialogueDataset('test')


In [16]:

import pandas as pd
model_name = 'pre_T5_finetuned_summarization_DialogueDataset'

results = []
for idx in range(len(test_data)):
    try:
        dialogue, reference_summary = test_data[idx]
        predicted_summary = predict_summary(dialogue, model, tokenizer)
        results.append({
            "Dialogue": dialogue,
            "Reference Summary": reference_summary,
            "Predicted Summary": predicted_summary
        })
    except ValueError as e:
        print(f"Error at index {idx}: {e}")


# 保存到CSV
df = pd.DataFrame(results)
df.to_csv(f"{model_name}.csv", index=False)

In [17]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset

class MedicalDialogueDataset(Dataset):
    def __init__(self, split, percent=100, seed=42):  # 添加seed参数，默认值为42
        # 加载数据集
        ds = load_dataset("omi-health/medical-dialogue-to-soap-summary", split=split)
        
        # 移除不需要的列
        columns_to_remove = ['messages', 'prompt']
        ds = ds.remove_columns(columns_to_remove)
        
        # 替换换行符并重命名列
        ds = ds.rename_column('soap', 'summary')
        
        # 添加ID和格式化摘要
        ds = ds.map(self.add_id, with_indices=True)
        ds = ds.map(self.format_summary)
        
        # 如果需要子集，先随机打乱，再选择对应的百分比
        if percent < 100:
            ds = ds.shuffle(seed=seed).select(range(int(percent / 100.0 * len(ds))))
        
        self.data = ds
    
    def add_id(self, example, idx):
        example['id'] = str(idx)
        return example
    
    def format_summary(self, example):
        example['summary'] = example['summary'].replace('S: ', 'Subjective: ')
        example['summary'] = example['summary'].replace('O: ', 'Objective: ')
        example['summary'] = example['summary'].replace('A: ', 'Assessment: ')
        example['summary'] = example['summary'].replace('P: ', 'Plan: ')
        return example

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]  # 获取索引对应的数据项
        ordered_item = {'id': item['id']}  # 创建一个新字典，并首先加入'id'
        ordered_item.update({k: item[k] for k in item if k != 'id'})  # 添加其他字段，排除'id'
        # return item['dialogue'], item['summary'] 
        return ordered_item

# 创建不同百分比的训练数据集实例
# train_data = MedicalDialogueDataset('train', percent=20, seed=42)
train_data = MedicalDialogueDataset('train', percent=50, seed=42)
# train_data = MedicalDialogueDataset('train', percent=100, seed=42)  

valid_data = MedicalDialogueDataset('validation')
test_data = MedicalDialogueDataset('test')


In [18]:
batch = next(iter(train_dataloader))
print(batch.keys())
print('batch shape:', {k: v.shape for k, v in batch.items()})
print(batch)

dict_keys(['input_ids', 'attention_mask', 'decoder_input_ids', 'labels'])
batch shape: {'input_ids': torch.Size([4, 512]), 'attention_mask': torch.Size([4, 512]), 'decoder_input_ids': torch.Size([4, 256]), 'labels': torch.Size([4, 256])}
{'input_ids': tensor([[7582,   10, 8774,  ...,   84, 3217,    1],
        [7582,   10, 8774,  ...,  120,    6,    1],
        [7582,   10, 8774,  ...,   87,   26,    1],
        [7582,   10, 8774,  ...,  134, 9021,    1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'decoder_input_ids': tensor([[    0, 19237,   757,  ..., 12103, 21740,    13],
        [    0, 19237,   757,  ...,   228,   560,   119],
        [    0, 19237,   757,  ...,    41, 10925,   747],
        [    0, 19237,   757,  ...,     5, 21028, 16464]]), 'labels': tensor([[19237,   757,    10,  ..., 21740,    13,     1],
        [19237,   757,    10,  ...,   560,   119,     1],
 



In [19]:
from tqdm.auto import tqdm

def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)
    
    model.train()
    for batch, batch_data in enumerate(dataloader, start=1):
        batch_data = batch_data.to(device)
        outputs = model(**batch_data)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

In [20]:
import numpy as np
from rouge import Rouge

rouge = Rouge()

def test_loop(dataloader, model):

    preds, labels = [], []
    
    model.eval()
    for batch_data in tqdm(dataloader):
        # batch_data = batch_data.to(device)
        batch_data = {k: v.to(device) for k, v in batch_data.items()}
        with torch.no_grad():
            generated_tokens = model.generate(
                batch_data["input_ids"],
                attention_mask=batch_data["attention_mask"],
                max_length=max_target_length,
                num_beams=4,
                no_repeat_ngram_size=2,
            ).cpu().numpy()
        if isinstance(generated_tokens, tuple):
            generated_tokens = generated_tokens[0]
        label_tokens = batch_data["labels"].cpu().numpy()

        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        preds += [' '.join(pred.strip()) for pred in decoded_preds]
        labels += [' '.join(label.strip()) for label in decoded_labels]
    scores = rouge.get_scores(hyps=preds, refs=labels, avg=True)
    result = {key: value['f'] * 100 for key, value in scores.items()}
    result['avg'] = np.mean(list(result.values()))
    print(f"Rouge1: {result['rouge-1']:>0.2f} Rouge2: {result['rouge-2']:>0.2f} RougeL: {result['rouge-l']:>0.2f}\n")
    return result

In [21]:
from transformers import AdamW, get_scheduler

learning_rate = 1e-7
epoch_num = 10

optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
best_avg_rouge = 0.
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)
    valid_rouge = test_loop(valid_dataloader, model)
    print(valid_rouge)
    rouge_avg = valid_rouge['avg']
    if rouge_avg > best_avg_rouge:
        best_avg_rouge = rouge_avg
        print('saving new weights...\n')
        torch.save(model.state_dict(), f'epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_weights.bin')
print("Done!")

Epoch 1/10
-------------------------------


loss: 3.093403: 100%|██████████| 1157/1157 [01:31<00:00, 12.59it/s]
100%|██████████| 125/125 [01:41<00:00,  1.23it/s]


Rouge1: 78.65 Rouge2: 60.15 RougeL: 73.72

{'rouge-1': 78.65075048609447, 'rouge-2': 60.14838342714769, 'rouge-l': 73.72054509137938, 'avg': 70.83989300154052}
saving new weights...

Epoch 2/10
-------------------------------


loss: 2.739942: 100%|██████████| 1157/1157 [01:26<00:00, 13.34it/s]
100%|██████████| 125/125 [01:51<00:00,  1.12it/s]


Rouge1: 79.85 Rouge2: 61.90 RougeL: 74.90

{'rouge-1': 79.85282473922697, 'rouge-2': 61.90149552146561, 'rouge-l': 74.9026972480516, 'avg': 72.21900583624806}
saving new weights...

Epoch 3/10
-------------------------------


loss: 2.583174: 100%|██████████| 1157/1157 [01:26<00:00, 13.34it/s]
100%|██████████| 125/125 [02:21<00:00,  1.13s/it]


Rouge1: 80.42 Rouge2: 62.62 RougeL: 75.41

{'rouge-1': 80.41611582344136, 'rouge-2': 62.622901971771036, 'rouge-l': 75.41272680216073, 'avg': 72.81724819912438}
saving new weights...

Epoch 4/10
-------------------------------


loss: 2.490022: 100%|██████████| 1157/1157 [01:26<00:00, 13.36it/s]
100%|██████████| 125/125 [02:48<00:00,  1.34s/it]


Rouge1: 80.76 Rouge2: 62.89 RougeL: 75.78

{'rouge-1': 80.75956740574847, 'rouge-2': 62.8874541757246, 'rouge-l': 75.78141944350571, 'avg': 73.14281367499292}
saving new weights...

Epoch 5/10
-------------------------------


loss: 2.427917: 100%|██████████| 1157/1157 [01:26<00:00, 13.32it/s]
100%|██████████| 125/125 [02:56<00:00,  1.41s/it]


Rouge1: 81.22 Rouge2: 63.24 RougeL: 76.32

{'rouge-1': 81.22147355047535, 'rouge-2': 63.24098226132965, 'rouge-l': 76.32105451113375, 'avg': 73.59450344097958}
saving new weights...

Epoch 6/10
-------------------------------


loss: 2.383520: 100%|██████████| 1157/1157 [01:27<00:00, 13.30it/s]
100%|██████████| 125/125 [03:04<00:00,  1.47s/it]


Rouge1: 81.29 Rouge2: 63.41 RougeL: 76.32

{'rouge-1': 81.28881483340462, 'rouge-2': 63.411318687135456, 'rouge-l': 76.32466074522112, 'avg': 73.6749314219204}
saving new weights...

Epoch 7/10
-------------------------------


loss: 2.350214: 100%|██████████| 1157/1157 [01:26<00:00, 13.34it/s]
100%|██████████| 125/125 [03:06<00:00,  1.50s/it]


Rouge1: 81.44 Rouge2: 63.55 RougeL: 76.65

{'rouge-1': 81.43733322320819, 'rouge-2': 63.5488901229759, 'rouge-l': 76.6489930362526, 'avg': 73.87840546081223}
saving new weights...

Epoch 8/10
-------------------------------


loss: 2.324793: 100%|██████████| 1157/1157 [01:27<00:00, 13.29it/s]
100%|██████████| 125/125 [03:08<00:00,  1.51s/it]


Rouge1: 81.59 Rouge2: 63.68 RougeL: 76.86

{'rouge-1': 81.59297746010449, 'rouge-2': 63.676628927783206, 'rouge-l': 76.85611279105026, 'avg': 74.04190639297933}
saving new weights...

Epoch 9/10
-------------------------------


loss: 2.304610: 100%|██████████| 1157/1157 [01:26<00:00, 13.34it/s]
100%|██████████| 125/125 [03:10<00:00,  1.52s/it]


Rouge1: 81.62 Rouge2: 63.68 RougeL: 76.99

{'rouge-1': 81.61780801468355, 'rouge-2': 63.67992513873939, 'rouge-l': 76.98649411427598, 'avg': 74.09474242256631}
saving new weights...

Epoch 10/10
-------------------------------


loss: 2.288461: 100%|██████████| 1157/1157 [01:27<00:00, 13.29it/s]
100%|██████████| 125/125 [03:09<00:00,  1.52s/it]


Rouge1: 81.56 Rouge2: 63.67 RougeL: 76.90

{'rouge-1': 81.56434676475007, 'rouge-2': 63.66639350423407, 'rouge-l': 76.89882627012727, 'avg': 74.04318884637047}
Done!


In [22]:
# !pip install datasets


## 推理

In [23]:
test_dataloader = DataLoader(test_data, batch_size = 4, shuffle = False, collate_fn = collote_fn)

In [25]:
model.load_state_dict(torch.load('epoch_9_valid_rouge_74.0947_model_weights.bin'))

<All keys matched successfully>

In [26]:
import json
model.eval()

with torch.no_grad():
    print('evaluating on test set...')
    sources, preds, labels = [], [], []
    for batch_data in tqdm(test_dataloader):
        batch_data = batch_data.to(device)
        generated_tokens = model.generate(
            batch_data["input_ids"],
            attention_mask=batch_data["attention_mask"],
            max_length=max_target_length,
            num_beams=4,
            no_repeat_ngram_size=2,
        )

        generated_tokens = generated_tokens.cpu().numpy()
        label_tokens = batch_data["labels"].cpu().numpy()

        decoded_sources = tokenizer.batch_decode(
            batch_data["input_ids"].cpu().numpy(), 
            skip_special_tokens=True
        )
        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        sources += [source.strip() for source in decoded_sources]
        preds += [pred.strip() for pred in decoded_preds]
        labels += [label.strip() for label in decoded_labels]
    scores = rouge.get_scores(
        hyps=[' '.join(pred) for pred in preds], 
        refs=[' '.join(label) for label in labels], 
        avg=True
    )
    rouges = {key: value['f'] * 100 for key, value in scores.items()}
    rouges['avg'] = np.mean(list(rouges.values()))
    print(f"Test Rouge1: {rouges['rouge-1']:>0.2f} Rouge2: {rouges['rouge-2']:>0.2f} RougeL: {rouges['rouge-l']:>0.2f}\n")
    results = []
    print('saving predicted results...')
    for source, pred, label in zip(sources, preds, labels):
        results.append({
            "document": source, 
            "prediction": pred, 
            "summarization": label
        })
    with open('T5_test_data_pred.json', 'wt', encoding='utf-8') as f:
        for exapmle_result in results:
            f.write(json.dumps(exapmle_result, ensure_ascii=False) + '\n')

evaluating on test set...


100%|██████████| 63/63 [01:42<00:00,  1.63s/it]


Test Rouge1: 80.99 Rouge2: 63.86 RougeL: 76.65

saving predicted results...


In [27]:
def predict_summary(input_text, model, tokenizer, device='cuda'):
    model.to(device)
    model.eval()

    inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True, padding="max_length")
    inputs = inputs.to(device)

    outputs = model.generate(
        inputs["input_ids"],
        max_length=512,
        num_beams=10,
        no_repeat_ngram_size=2,
        early_stopping=False
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)


    return summary

# 使用函数
input_text = "Doctor: Hello, I remember you had an emergency caesarean delivery at 39 weeks due to fetal distress. How have you been since then? Any postpartum complications? Patient: Hi, Doctor. I've been doing well since the delivery. No complications, thankfully. Doctor: That's good to hear. As part of our ongoing study on 'Vaginal delivery after caesarean section', you underwent a saline contrast sonohysterography 6 months after the caesarean section. The results showed a small indentation in your caesarean scar, and the remaining myometrium over the defect was 7.5 mm (Fig. ). Patient: Oh, I see. What does that mean for my current pregnancy? Doctor: At around 11 weeks, you had a dating scan with no remarks. Then, you came for a transvaginal ultrasound examination at around 13 weeks asc part of our study. The scan revealed a duplex pregnancy with one viable intrauterine fetus with normal anatomy and placenta located high on the anterior wall. A small gestational sac (8 mm) with a yolk sac without an embryo was located in the caesarean scar (Fig. ). There was no extensive vascularity surrounding the sac, and you were asymptomatic. Patient: Yes, that's right. I didn't feel any discomfort or symptoms. Doctor: We informed you that there wasn't enough evidence to advise a specific management for this condition. After discussion with you and your husband, expectant management was chosen with a new ultrasound examination scheduled after 5 weeks. Patient: Yes, we decided to wait and see how things would progress. Doctor: You came to our ultrasound department at 18 weeks, 22 weeks, and 30 weeks of gestation. Throughout this time, you remained asymptomatic. The ectopic gestational sac was not visualized with transvaginal or transabdominal scans at the 18 weeks examination (Fig. ). The niche in the scar and the thickness of the thinnest part of the remaining myometrium appeared unchanged at all visits. Patient: That's a relief. How's the intrauterine pregnancy developing? Doctor: The intrauterine pregnancy developed normally with no signs of abnormal placentation. At 30 weeks of gestation, the ultrasound appearance of the scar area did not indicate any contraindications for vaginal delivery. The thickness of the lower uterine segment (LUS) was 4.9 mm (Fig. ). Patient: So, I can have a vaginal delivery this time? Doctor: Yes, in agreement with you, we've planned for a vaginal delivery. The staff of the labor ward has been fully informed and prepared for your case. Patient: That's great news! Thank you, Doctor. Doctor: You're welcome. You'll be admitted to the labor ward when the time comes. Please continue to monitor your symptoms and reach out if you have any concerns. Good luck with the rest of your pregnancy. Patient: Thank you so much, Doctor. I appreciate your help and guidance throughout this process."
summary = predict_summary(input_text, model, tokenizer, device)
print(summary)

As part of our ongoing study on 'Vaginal delivery after caesarean section', the patient underwent saline contrast sonohysterography at around 13 weeks due to fetal distress. Results showed a small indentation in the scar area, and the remaining myometrium over the defect was 7.5 mm (Fig. 2). The patient has been doing well since the delivery and is expected to be admitted to the labor ward when the time comes.


In [28]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset

class MedicalDialogueDataset(Dataset):
    def __init__(self, split, percent=100, seed=42):  # 添加seed参数，默认值为42
        # 加载数据集
        ds = load_dataset("omi-health/medical-dialogue-to-soap-summary", split=split)
        
        # 移除不需要的列
        columns_to_remove = ['messages', 'prompt']
        ds = ds.remove_columns(columns_to_remove)
        
        # 替换换行符并重命名列
        ds = ds.rename_column('soap', 'summary')
        
        # 添加ID和格式化摘要
        ds = ds.map(self.add_id, with_indices=True)
        ds = ds.map(self.format_summary)
        
        # 如果需要子集，先随机打乱，再选择对应的百分比
        if percent < 100:
            ds = ds.shuffle(seed=seed).select(range(int(percent / 100.0 * len(ds))))
        
        self.data = ds
    
    def add_id(self, example, idx):
        example['id'] = str(idx)
        return example
    
    def format_summary(self, example):
        example['summary'] = example['summary'].replace('S: ', 'Subjective: ')
        example['summary'] = example['summary'].replace('O: ', 'Objective: ')
        example['summary'] = example['summary'].replace('A: ', 'Assessment: ')
        example['summary'] = example['summary'].replace('P: ', 'Plan: ')
        return example

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]  # 获取索引对应的数据项
        ordered_item = {'id': item['id']}  # 创建一个新字典，并首先加入'id'
        ordered_item.update({k: item[k] for k in item if k != 'id'})  # 添加其他字段，排除'id'
        return item['dialogue'], item['summary'] 
        # return ordered_item

# 创建不同百分比的训练数据集实例
# train_data = MedicalDialogueDataset('train', percent=20, seed=42)
train_data = MedicalDialogueDataset('train', percent=50, seed=42)
# train_data = MedicalDialogueDataset('train', percent=100, seed=42)  

valid_data = MedicalDialogueDataset('validation')
test_data = MedicalDialogueDataset('test')


In [29]:
import pandas as pd
model_name = 'T5-Finetuned-Summarization-DialogueDataset'
results = []
for idx in range(len(test_data)):  # 遍历整个测试集
    dialogue, reference_summary = test_data[idx]
    predicted_summary = predict_summary(dialogue, model, tokenizer)
    results.append({
        "Dialogue": dialogue,
        "Reference Summary": reference_summary,
        "Predicted Summacccry": predicted_summary
    })

# 保存到CSV
df = pd.DataFrame(results)
df.to_csv(f"post_{model_name}.csv", index=False)