In [88]:
# load train.txt dataset
with open('data/train.txt', 'r', encoding='utf-8') as f:
    content = f.readlines()
    f.close()

In [92]:
# strip str
content = list(map(lambda x: x.strip(), content))

In [93]:
data = content[0]

In [94]:
data

'{"context": "问题：患者反复出现反酸、烧心等症状，考虑为Barrett食管，需要注意哪些并发症？\\n回答: ", "target": "根据知识，Barrett食管的并发症包括消化性溃疡、反流食管炎、胃肠道出血、贫血、肿瘤等，需要引起注意。"}'

In [22]:
import json

In [95]:
# cast str to dict
content = list(map(lambda x: json.loads(x), content))

In [96]:
type(content[0])

dict

In [97]:
content[0]

{'context': '问题：患者反复出现反酸、烧心等症状，考虑为Barrett食管，需要注意哪些并发症？\n回答: ',
 'target': '根据知识，Barrett食管的并发症包括消化性溃疡、反流食管炎、胃肠道出血、贫血、肿瘤等，需要引起注意。'}

In [98]:
#add a instruction
instruction = '下面是一个医学领域的问题，请根据你对医学领域的了解，严谨细致的回答，如果不清楚就回答不清楚，不许胡编乱造。'

In [99]:
#add a instruction to context
def add_instruction(data):
    inputs = "指令：{instruction}\n{query}".format(instruction=instruction, query=data['context'])
    data['context']=inputs
    return data

In [101]:
content = list(map(add_instruction, content))

In [109]:
content[0]

{'context': '指令：下面是一个医学领域的问题，请根据你对医学领域的了解，严谨细致的回答，如果不清楚就回答不清楚，不许胡编乱造。\n问题：患者反复出现反酸、烧心等症状，考虑为Barrett食管，需要注意哪些并发症？\n回答: ',
 'target': '根据知识，Barrett食管的并发症包括消化性溃疡、反流食管炎、胃肠道出血、贫血、肿瘤等，需要引起注意。'}

In [108]:
#store dataset file as jsonl
with open('data/train.jsonl', 'w', encoding='utf-8') as f:
    for data in content:
        f.write(json.dumps(data, ensure_ascii=False))
        f.write('\n')

In [37]:
from datasets import load_dataset, load_from_disk

In [10]:
Med_dataset = load_dataset('json', data_files='./data/train.jsonl')

Downloading and preparing dataset json/default to /Users/shuaiqiduan/.cache/huggingface/datasets/json/default-a1c08c7776f433c3/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96...


Downloading data files: 100%|███████████████████| 1/1 [00:00<00:00, 4032.98it/s]
Extracting data files: 100%|█████████████████████| 1/1 [00:00<00:00, 460.41it/s]
                                                        

Dataset json downloaded and prepared to /Users/shuaiqiduan/.cache/huggingface/datasets/json/default-a1c08c7776f433c3/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96. Subsequent calls will reuse this data.


100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 596.29it/s]


In [11]:
Med_dataset

DatasetDict({
    train: Dataset({
        features: ['context', 'target'],
        num_rows: 7622
    })
})

In [15]:
Med_dataset['train']['context'][0]

'指令：下面是一个医学领域的问题，请根据你对医学领域的了解，严谨细致的回答，如果不清楚就回答不清楚，不许胡编乱造。\n问题：患者反复出现反酸、烧心等症状，考虑为Barrett食管，需要注意哪些并发症？\n回答: '

In [16]:
seed = 42

In [19]:
Med_dataset = Med_dataset.shuffle(seed=seed)

Loading cached shuffled indices for dataset at /Users/shuaiqiduan/.cache/huggingface/datasets/json/default-a1c08c7776f433c3/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96/cache-b4ea38f397688be9.arrow


In [20]:
Med_dataset['train']['context'][0]

'指令：下面是一个医学领域的问题，请根据你对医学领域的了解，严谨细致的回答，如果不清楚就回答不清楚，不许胡编乱造。\n问题：一名60岁女性患者出现外阴肿物和瘙痒，经过病理检查发现是基底细胞癌，请问有哪些并发症？\n回答: '

In [25]:
#split dataset as train and valid
split_dataset = Med_dataset['train'].train_test_split(train_size=6622, test_size=1000, seed=seed)

In [26]:
split_dataset

DatasetDict({
    train: Dataset({
        features: ['context', 'target'],
        num_rows: 6622
    })
    test: Dataset({
        features: ['context', 'target'],
        num_rows: 1000
    })
})

In [27]:
split_dataset['valid'] =split_dataset['test']

In [28]:
split_dataset

DatasetDict({
    train: Dataset({
        features: ['context', 'target'],
        num_rows: 6622
    })
    test: Dataset({
        features: ['context', 'target'],
        num_rows: 1000
    })
    valid: Dataset({
        features: ['context', 'target'],
        num_rows: 1000
    })
})

In [29]:
#remove test dataset
split_dataset.pop('test')

Dataset({
    features: ['context', 'target'],
    num_rows: 1000
})

In [30]:
split_dataset

DatasetDict({
    train: Dataset({
        features: ['context', 'target'],
        num_rows: 6622
    })
    valid: Dataset({
        features: ['context', 'target'],
        num_rows: 1000
    })
})

In [35]:
# save datasets to disk
split_dataset.save_to_disk('data/Med_datasets')

                                                                                

In [39]:
datasets = load_from_disk(dataset_path='data/Med_datasets.jsonl')

In [40]:
datasets

DatasetDict({
    train: Dataset({
        features: ['context', 'target'],
        num_rows: 6622
    })
    valid: Dataset({
        features: ['context', 'target'],
        num_rows: 1000
    })
})

In [42]:
datasets['train'][0]

{'context': '指令：下面是一个医学领域的问题，请根据你对医学领域的了解，严谨细致的回答，如果不清楚就回答不清楚，不许胡编乱造。\n问题：一个患者开始感觉口腔黏膜出现不适感，经过检查发现口腔黏膜破溃，出现了一些渗出物，还有口腔结节，辅助检查中X线片检查发现病情较重，想请问这是什么疾病？\n回答: ',
 'target': '根据口腔黏膜破溃、渗出物、口腔结节等症状，结合X线片检查结果，可疑为口腔结核。'}

In [43]:
from transformers import AutoTokenizer

In [46]:
tokenizer = AutoTokenizer.from_pretrained('THUDM/chatglm2-6b', trust_remote_code=True)

Downloading tokenizer.model: 100%|█████████| 1.02M/1.02M [00:00<00:00, 3.19MB/s]


In [294]:
from typing import List, Union, Dict, Tuple
import torch
from transformers import DataCollatorForSeq2Seq

In [462]:
class DataCollatorForChatGlm2():
    def __init__(self, pad_token_id: int, label_pad_id: int=-100):
        self.pad_token_id = pad_token_id
        self.label_pad_id = label_pad_id

    def __call__(self, batch_data: List[Dict[str, List]]) -> Dict[str, torch.Tensor]:
        #dynamic padding
        len_list = [len(data['input_ids']) for data in batch_data]
        batch_max_len = max(len_list)
        input_ids = []
        attention_mask = []
        position_ids = []
        labels_list = []
        for len_ids, data in zip(len_list, batch_data):
            #left padding
            pad_len = batch_max_len - len_ids
            ids = torch.cat((torch.tensor([self.pad_token_id] * pad_len),data['input_ids']))
            label = torch.cat((torch.tensor([self.label_pad_id] * pad_len), data['labels']))
            attention = [0] * pad_len + [1] * len_ids 
            position = [0] * pad_len + list(range(len_ids))
            input_ids.append(ids)
            attention_mask.append(torch.FloatTensor(attention))
            position_ids.append(torch.FloatTensor(position))
            labels_list.append(label)
        input_ids = torch.stack(input_ids)
        attention_mask = torch.stack(attention_mask)
        position_ids = torch.stack(position_ids)
        labels = torch.stack(labels_list)
        return {'input_ids': input_ids,
                'attention_mask': attention_mask,
                'position_ids': position_ids,
                'labels': labels}

In [None]:
datacollatorforchatglm2 = DataCollatorForChatGlm2(tokenizer.pad_token_id)

In [385]:
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

In [387]:
tokenized_datasets_train.set_format('torch')

In [472]:
tokenized_datasets_train['labels'][0].dtype

torch.int64

In [465]:
train_dataloader = DataLoader(tokenized_datasets_train, batch_size=8, collate_fn=datacollatorforchatglm2, shuffle=True)

In [466]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x2a1014410>

In [471]:
for data in train_dataloader:
    print(data['input_ids'][1].shape)
    print(data['input_ids'][2].shape)
    print(tokenizer.decode(data['labels'][1]))
    break

torch.Size([160])
torch.Size([160])
该患者可能患有副流行性感冒，且病情较为严重。建议进行病毒分离鉴定、血清学检查等实验室检查，以确定诊断。治疗上可使用干扰素、克仑特罗、利巴韦林等药物。建议前往呼吸科或感染内科进行诊治。同时要注意预防并发症，如肺炎。
