In [8]:
!pip install openai
!pip install transformers



In [1]:
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import os
import json
import openai
from transformers import RobertaTokenizer, RobertaModel, AutoTokenizer, AutoModel
from tqdm import tqdm
import time
import math
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import log_loss
import re

In [2]:
tokenizer = AutoTokenizer.from_pretrained("monologg/biobert_v1.1_pubmed")
model = AutoModel.from_pretrained("monologg/biobert_v1.1_pubmed")

Some weights of the model checkpoint at monologg/biobert_v1.1_pubmed were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
class Biview_MultiSent(Dataset):

    def __init__(self, phase, dataset_dir, folds, report_fname, label_fname):
        self.phase = phase
        self.case_list = []
        for fold in list(folds):
            with open(os.path.join(dataset_dir, 'fold{}.txt'.format(fold))) as f:
                self.case_list += f.read().splitlines()
        with open(report_fname) as f:
            self.reports = json.load(f)
        with open(label_fname) as f:
            self.label_dict = json.load(f)
        self.transform = transforms.Compose([
            transforms.RandomCrop((512, 512), pad_if_needed=True),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.case_list)

    def __getitem__(self, idx):
        caseid, img1_path, img2_path = self.case_list[idx].split()
        
        label = self.label_dict[caseid]
        label = torch.tensor(label, dtype=torch.float)

        report = self.reports[caseid]
        text = ''
        if report['impression'] is not None:
            text += report['impression']
        text += '. '
        if report['findings'] is not None:
            text += report['findings']
        sents = text.lower().split('.')
        sents = [sent for sent in sents if len(sent.strip()) > 1]
        
        return label, text, caseid

In [4]:
dataset_dir = "data"
report_path = 'data/reports.json'
label_path = 'data/label_dict.json'
train_set = Biview_MultiSent('train', dataset_dir, '0234', report_path,
                                 label_path)
test_set = Biview_MultiSent('validation', dataset_dir, '1', report_path,
                                 label_path)

In [5]:
openai.api_key = 'sk-'

### In-context Samples Candidate Pool

In [6]:
candidate_examples = {
    "1": {
        "question": "Below is the medical report \n\n[No acute cardiopulmonary abnormality... Cardiomediastinal silhouette is within normal limits for size, with redemonstration of tortuous and atherosclerotic calcified thoracic aorta. No focal consolidation, effusion, or pneumothorax identified. Eventration of the right hemidiaphragm is stable compared to prior examination. Multilevel degenerative disc disease and thoracolumbar spine again noted without acute osseous abnormality.]",
        "answer": "The report mentions that 'Multilevel degenerative disc disease and thoracolumbar spine', suggesting the scoliosis / degenerative. Therefore, the output is [the disease indices are: (3)]"            
    },
    
    "2": {
        "question": "Below is the medical report \n\n[No acute disease.. The heart is normal in size. The mediastinum is unremarkable. The lungs are clear.]",
        "answer": "For all the sentences in the report, there is no evidence of any potential disease. Besides, according to the rule 1, 'A report must not be classified into 'normal (1)' and disease labels 2-20 simultaneously'. Therefore, the output is [the disease indices are: (1)]"
    },
    
    "3": {
        "question": "Below is the medical report \n\n[1. No acute cardiopulmonary abnormalities. 2. Emphysema and chronic bony abnormalities are unchanged from prior exams. .. The trachea is midline. The cardiomediastinal silhouette is normal. The superior thoracic spine is again noted, unchanged from prior. Lucent pulmonary parenchyma is consistent appearance with emphysema and appears unchanged from prior examinations. No evidence of pneumothorax. No focal airspace disease or pleural effusion. Vague density in the medial right lung apex most XXXX representing overlying shadows of bony structures, which is stable.]",
        "answer": "The report mentions that 'Worsening bibasilar subpleural interstitial opacities', suggesting opacity(15); The report mentions 'Lung volumes are low', suggesting hypoinflation / hyperdistention(18); The report mentions 'There calcifications of the thoracic aorta.', suggesting calcinosis(9). Therefore, the output is [the disease indices are: (9, 15, 18)]"
    },
    
    "4": {
        "question": "Below is the medical report \n\n[Left midlung opacity may be secondary to acute infectious process or developing mass lesion. Followup to resolution is recommended.. The heart is normal in size. The mediastinal contours are stable. Aortic calcifications are noted. There are small calcified lymph XXXX. Emphysema and chronic changes are identified. There is XXXX opacity in the left perihilar upper lobe. There is questionable XXXX extension to the pleural surface. This may represent acute infiltrate or developing density. There is no pleural effusion or pneumothorax.]",
        "answer": "The report mentions that 'Aortic calcifications are noted. There are small calcified lymph XXXX.', suggesting that calcinosis exist; the report mentions that 'Emphysema and chronic changes are identified.', suggesting emphysema / pulmonary emphysema(10); the report mentions that 'Left midlung opacity may be secondary to acute infectious process or developing mass lesion', suggesting opacity(15). Therefore, the answer is [the disease indices are: (9, 10, 15)]"
    },
    
    "5": {
        "question": "Below is the medical report \n\n[No evidence of active disease.. Lungs are clear bilaterally. Cardiac and mediastinal silhouettes are normal. Pulmonary vasculature is normal. No pneumothorax or pleural effusion. No acute bony abnormality. The distal tip of a right IJ dual-lumen central venous catheter is at the XXXX which junction.]",
        "answer": "The report mentions that 'venous catheter', suggesting catheters indwelling / surgical instruments / tube inserted / medical device(19)"
    },
    
    "6": {
        "question": "Below is the medical report \n\n[1. No acute findings. 2. Emphysema. 3. Scattered XXXX of scarring, most notably in the left upper lobe.. The lungs are hyperexpanded. There are stable scattered XXXX bilateral opacities, most notable in the left upper lobe, XXXX scarring. No focal airspace consolidation to suggest pneumonia. No large pleural effusion. No pneumothorax. Heart size is normal. Thoracic aorta is mildly tortuous and demonstrates atherosclerotic vascular calcification. There are degenerative changes of the spine.]",
        "answer": "The report mentions that '2. Emphysema', suggesting emphysema / pulmonary emphysema(10); the report mentions that '3. Scattered XXXX of scarring, most notably in the left upper lobe.', suggesting cicatrix(14); the report mentions that 'The lungs are hyperexpanded.', suggesting hypoinflation / hyperdistention(18); the report mentions that 'There are stable scattered XXXX bilateral opacities', suggesting opacity(15); the report mentions that 'Thoracic aorta is mildly tortuous', suggesting scoliosis / degenerative(3). Therefore, the answer is [the disease indices are: (3, 10, 14, 15, 18)]"
    },
    
    "7": {
        "question": "Below is the medical report \n\n[1. No acute cardiopulmonary findings.. The heart size and mediastinal contours appear within normal limits. Atherosclerotic calcification of the aorta. No focal airspace consolidation, pleural effusions or pneumothorax. Questionable thin-walled cavitary lesion in the right lower lobe, only seen on the AP view and may represent artifact. No acute bony abnormalities.]",
        "answer": "The report mentions 'Atherosclerotic calcification of the aorta', but this disease does not belong to any of the potential labels from 2 to 19. Besides, according to the rule 2, 'A report must not be classified into 'other findings (20)' and disease labels 1-19 simultaneously'. Therefore, the output is [the disease indices are: (20)]"
    },
    
    "8": {
        "question": "Below is the medical report \n\n[Normal exam. Normal heart size. Normal mediastinal silhouette. No pneumothorax, pleural effusion or suspicious focal air space opacity.]",
        "answer": "For all the sentences in the report, there is no evidence of any potential disease. Besides, according to the rule 1, 'A report must not be classified into 'normal (1)' and disease labels 2-20 simultaneously'. Therefore, the output is [the disease indices are: (1)]"
    },
    
    "9": {
        "question": "Below is the medical report \n\n[No acute cardiopulmonary findings. Heart size within normal limits. No focal alveolar consolidation, no definite pleural effusion seen. No typical findings of pulmonary edema. Mediastinal calcification and dense right upper lung nodule suggest a previous granulomatous process.]",
        "answer": "The report mentions that 'Mediastinal calcification and dense right upper lung nodule suggest a previous granulomatous process.', suggesting calcinosis(9) and nodule / mass(16). Therefore, the output is [the disease indices are: (9, 16)]"
    },
    
    "10": {
        "question": "Below is the medical report \n\n[No acute disease. Retrocardiac density XXXX corresponding to known hiatal hernia.. The heart is normal in size. The mediastinum is within normal limits. There is retrocardiac density which XXXX corresponds to patient's known hiatal hernia. The lungs are hypoinflated.]",
        "answer": "The report mentions that 'Retrocardiac density XXXX corresponding to known hiatal hernia..', suggesting hernia hiatal(8); the report suggest that 'The lungs are hypoinflated.', suggesting hypoinflation / hyperdistention(18). Therefore, the output is [the disease indices are: (8, 18)]"
    },
    
#     "11": {
#         "question": "Below is the medical report \n\n[Chest radiograph. 1. No acute radiographic cardiopulmonary process.. XXXX sternotomy XXXX are in XXXX and intact. Normal cardiomediastinal silhouette. The bilateral costophrenic XXXX are excluded from the image on the PA view. Lungs are clear without focal areas of consolidation, pleural effusion, or pneumothorax. XXXX XXXX are intact without acute osseous abnormality. Mild degenerative changes throughout the thoracic spine.]",
#         "answer": "The report mentions that 'Mild degenerative changes throughout the thoracic spine', suggesting scoliosis / degenerative(3). Therefore, the output is [the disease indices are: (3)]"
#     },
    
#     "12": {
#         "question": "Below is the medical report \n\n[Mild cardiomegaly.. Mild cardiomegaly. Normal pulmonary vascularity. Tortuosity of the descending aorta. No focal infiltrate, pneumothorax or pleural effusion.]",
#         "answer": "The report mentions that 'Mild cardiomegaly', suggesting cardiomegaly(2). Therefore, the output is [the disease indices are: (2)]"
#     },
    
#     "13": {
#         "question": "Below is the medical report \n\n[No acute findings.. Cardiac and mediastinal contours are within normal limits. The lungs are clear. Bony structures are intact.]",
#         "answer": "For all the sentences in the report, there is no evidence of any potential disease. Besides, according to the rule 1, 'A report must not be classified into 'normal (1)' and disease labels 2-20 simultaneously'. Therefore, the output is [the disease indices are: (1)]"
#     },
    
#     "14": {
#         "question": "Below is the medical report \n\n[No evidence of active disease. The lungs are clear. There is no focal airspace consolidation. No pleural effusion or pneumothorax. Heart size is within normal limits. Right paratracheal density is stable from prior radiographs and may reflect tortuous vasculature. There is aortic atherosclerotic vascular calcification. There are mild degenerative changes of the spine. Surgical clips are noted in the region of the left breast. There is mild diaphragm eventration.]",
#         "answer": "The report mentions that 'There are mild degenerative changes of the spine', suggesting scoliosis / degenerative(3); The report mentions that 'Surgical clips are noted in the region of the left breast.', suggesting catheters indwelling / surgical instruments / tube inserted / medical device(19). Therefore, the output is [the disease indices are: (3, 19)]"
#     },
    
#     "15": {
#         "question": "Below is the medical report \n\n[No acute cardiopulmonary findings.. Heart size is normal. Lungs are clear. Low lung volumes. There is no pneumothorax or large pleural effusion.]",
#         "answer": "The report mentions that 'Low lung volumes', suggesting hypoinflation / hyperdistention(18). Therefore, the output is [the disease indices are: (18)]"
#     },
    
#     "16": {
#         "question": "Below is the medical report \n\n[No acute disease.. The heart is normal in size. The mediastinum is unremarkable. Mild pectus excavatum deformity is noted. The lungs are clear.]",
#         "answer": "The report mentions that 'Mild pectus excavatum deformity is noted.', suggesting a disease corresponding to pectus excavatum, but this disease does not belong to any disease label 2-20. Therefore, the output is [the disease indices are: (20)]"
#     },
    
#     "17": {
#         "question": "Below is the medical report \n\n[No acute cardiopulmonary abnormality.. Normal heart size. Density surrounding superior mediastinum reflex combination of vascular, osseous common pleural structures. No focal airspace consolidation. Moderate degenerative disc disease with osteophyte formation bridging.]",
#         "answer": "The report mentions that 'Moderate degenerative changes of the thoracic spine.', suggesting scoliosis / degenerative(3). Therefore, the output is [the disease indices are: (3)]"
#     },
    
#     "18": {
#         "question": "Below is the medical report \n\n[No acute pulmonary findings.. Cardiac and mediastinal contours are within normal limits. Large calcified granulomas in the right hilum. The lungs are otherwise clear. Prior anterior cervical fusion.]",
#         "answer": "The report mentions that 'Large calcified granulomas in the right hilum', which is related to a right hilum. Calcinosis(9) is also about calcified disease, but according to the rules, calcinosis is related to Mediastinum. Therefore, this disease does not belong to any of the labels 2-20. Therefore, the output is [the disease indices are (20)]"
#     },
    
#     "19": {
#         "question": "Below is the medical report \n\n[No acute cardiopulmonary findings. Heart size within normal limits. No focal alveolar consolidation, no definite pleural effusion seen. No typical findings of pulmonary edema. Mediastinal calcification and dense right upper lung nodule suggest a previous granulomatous process.]",
#         "answer": "The report mentions that 'Mediastinal calcification', suggesting calcinosis(9); The report mentions that 'dense right upper lung nodule', suggesting nodule / mass(16). Therefore, the output is [the disease indices are (9, 16)]"
#     },
    
#     "20": {
#         "question": "Below is the medical report \n\n[Bilateral large pleural effusion, possibly from pleuritis or sympathetic from the known pancreatitis.. One XXXX are low. Both costophrenic XXXX are blunted. Pulmonary XXXX are normal. No visible infiltrates in the aerated lungs.]",
#         "answer": "The report mentions that 'Bilateral large pleural effusion', suggesting pleural effusion(5)."
#     },
    
#     "21": {
#         "question": "Below is the medical report \n\n[1. Mild stable cardiomegaly and central vascular congestion. 2. Low lung volumes with elevated left hemidiaphragm and basilar subsegmental atelectasis. 3. Extensive bilateral shoulder degenerative changes with subluxation/dislocation left shoulder, possibly chronic. Suggest clinical correlation.. The heart is again mildly enlarged. Mediastinal contours are stable. Patient is somewhat rotated. The lungs are hypoinflated with elevated left hemidiaphragm. XXXX XXXX opacities compatible with atelectasis. No large effusion is seen. There is no focal consolidation. Pulmonary vascularity is mildly accentuated. There are bilateral degenerative changes of the XXXX with probable chronic dislocation of the left humerus. Correlate clinically.]",
#         "answer": "The report mentions that 'Mild stable cardiomegaly and central vascular congestion', suggesting cardiomegaly(2); The report mentions that 'Low lung volumes with elevated left hemidiaphragm and basilar subsegmental atelectasis.', suggesting pulmonary atelectasis(13); The report mentions that 'Extensive bilateral shoulder degenerative changes with subluxation/dislocation left shoulder', suggesting scoliosis / degenerative(3); The report mentions  that 'The heart is again mildly enlarged', suggesting hypoinflation / hyperdistention(18); The report mentions that 'opacities compatible with atelectasis.', suggesting opacity(15). Therefore, the output is [the disease indices are (2,3,13,15,18)]"
#     },
    
#     "22": {
#         "question": "Below is the medical report \n\n[No acute cardiopulmonary abnormality.. There are no focal areas of consolidation. No suspicious pulmonary opacities. Heart size within normal limits. No pleural effusions. There is no evidence of pneumothorax. Degenerative changes of the thoracic spine.]",
#         "answer": "The report mentions that 'Degenerative changes of the thoracic spine.', suggesting scoliosis / degenerative(3). Therefore, the output is [the disease indices are (3)]"
#     },
    
#     "23": {
#         "question": "Below is the medical report \n\n[No acute disease. Retrocardiac density XXXX corresponding to known hiatal hernia.. The heart is normal in size. The mediastinum is within normal limits. There is retrocardiac density which XXXX corresponds to patient's known hiatal hernia. The lungs are hypoinflated. No focal consolidation is seen.]",
#         "answer": "The report mentions that 'Retrocardiac density XXXX corresponding to known hiatal hernia..', suggesting hernia hiatal(8); The report mentions that 'The lungs are hypoinflated', suggesting hypoinflation / hyperdistention(18). Therefore, the output is [the disease indices are (8, 18)]."
#     },
    
#     "24": {
#         "question": "Below is the medical report \n\n[1. Emphysematous changes. 2. Resolution of prior right midlung infiltrate.. Previous sulcal is normal in size and contour. Lungs are clear. No focal consolidation, pneumothorax, or pleural effusion. Interval resolution of previously described right midlung opacity suggesting resolved inflammatory/infectious process. Lungs are hyperexpanded with flattened diaphragms. XXXX and soft tissue are unremarkable.]",
#         "answer": "The report mentions that 'Lungs are hyperexpanded with flattened diaphragms', suggesting hypoinflation / hyperdistention(18); The report mentions that 'Emphysematous changes', suggesting emphysema / pulmonary emphysema(10)"
#     },
    
#     "25": {
#         "question": "Below is the medical report \n\n[Right lower lobe pneumonia.. Heart size is within normal limits. Tortuous thoracic aorta. There is patchy right base airspace disease. No pneumothorax or pleural effusion. There mild degenerative changes throughout the thoracic spine.]",
#         "answer": "The report mentions that 'Right lower lobe pneumonia', suggesting pneumonia / infiltrate / consolidation(11); The report mentions that 'There is patchy right base airspace disease', suggesting airspace disease(17); The report mentions that 'There mild degenerative changes throughout the thoracic spine.', suggesting scoliosis / degenerative(3). Therefore, the output is [the disease indices are (3,11,17)]"
#     }

}

In [7]:
import json

out_file = open("candidate_examples.json", "w")
  
json.dump(candidate_examples, out_file)

out_file.close()

In [8]:
cand_examples = []
cand_answers = []
for item in candidate_examples.items():
    cand_examples.append(item[1]["question"])
    cand_answers.append(item[1]["answer"])

In [9]:
# hyperparameters for learning policy models
gpu = "0"
lr = 0.001
batch_size = 20
epochs = 30
shot_number = 5
training_sample_number = 80

ckpt_path = "./results"

In [10]:
import copy
def get_gpt3_output(text):
    max_tries = 3
    for try_number in range(max_tries):
        
        try:
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
#                 model="gpt-4-0314",
                messages= text
            )
            return response['choices'][0]['message']['content']

        except openai.error.APIError as e:
            if try_number == max_tries - 1:
                print('APIError')
                return '\n'
            else:
                time.sleep(0.1)
        except openai.error.Timeout as e:
            if try_number == max_tries - 1:
                print('Timeout')
                return '\n'
            else:
                time.sleep(0.1)
        except openai.error.APIConnectionError as e:
            if try_number == max_tries - 1:
                print('APIConnectionError')
                return '\n'
            else:
                time.sleep(0.1)
        except openai.error.RateLimitError as e:
            if try_number == max_tries - 1:
                print('RateLimitError')
                return 'error'
            else:
                time.sleep(0.1)

In [11]:
def build_prompt(test_example, cot_examples, additional=""):
    prompt_query = "Below is the medical report \n\n[{}] \n\nBelow is a set of strict rules to follow, which are important and must be considered in the classification and generating the answers: \n\n 1. A report must not be classified into 'normal (1)' and disease labels 2-20 simultaneously! \n 2. A report must not be classified into 'other findings(20)' and disease labels 1-19 simultaneously. \n3. Cardiomegaly (2) is related to heart disease. \n 4. scoliosis / degenerative (3) is related to spine disease.\n 5. fractures bone (4) is related to the bone disease. \n 6. pleural effusion(5) thickening(6), and pneumothorax(7) are all related to the pleural disease\n 7. hernia hiatal (8) and calcinosis(9) are both related to the Mediastinum disease. \n 8. emphysema / pulmonary emphysema(10) pneumonia / infiltrate / consolidation(11) pulmonary edema(12) pulmonary atelectasis (13) cicatrix(14) opacity(15), and nodule / mass(16) are all related to lung disease. \n 9. airspace disease(17), and hypoinflation / hyperdistention(18) are both related to airspace disease. \n\n Please strictly output with the following format: [the disease indices are: ()], where the bracket should be filled with a number from 1 to 20. Strictly follow the rules and the format when generating the answers.".format(test_example)
    messages = [{"role": "system",
      "content": "You are a helpful assistant that read X-ray report and conduct multi-label classification task. Below is the list of all the potential disease labels. \n\n 1. normal / no indexing \n\n 2. cardiomegaly  \n\n 3. scoliosis / degenerative \n\n 4. fractures bone \n\n 5. pleural effusion \n\n 6. thickening \n\n 7. pneumothorax\n\n 8. hernia hiatal \n\n 9. calcinosis \n\n 10. emphysema / pulmonary emphysema\n\n 11. pneumonia / infiltrate / consolidation\n\n 12. pulmonary edema\n\n 13. pulmonary atelectasis \n\n 14. cicatrix \n\n 15. opacity \n\n 16. nodule / mass \n\n 17. airspace disease: \n \n 18. hypoinflation / hyperdistention\n\n 19. catheters indwelling / surgical instruments / tube inserted / medical device \n\n 20. other findings"}] + [i for i in cot_examples] + [{
        "role": "user",
        "content": prompt_query
    }]
    return messages

In [12]:
def extract_prediction(output):
    labels = re.findall(r"the disease indices are\s*[:\-\s]*\(\s*(\d+(\s*,\s*\d+)*)\s*\)", str(output))
    pattern = r"\d+"
    matches = re.findall(pattern, str(labels))
    labels = [eval(i) for i in matches]
    
    if labels == []:
        labels = [1]
    vector = np.zeros((20, ))
    for i in labels:
        if i > 20:
            continue
        vector[i-1] = 1
        
    return vector

In [13]:
def get_batch_reward_loss(scores, cand_examples, cand_answers, train_batch, label_batch):

    batch_loss = 0
    batch_reward = 0

    ## loop over the training examples
    for i in range(len(scores)):

        # interact with the environment to get rewards, which in our case is to feed the prompt into GPT-3 and evaluate the prediction
        cand_prob = scores[i, :].clone().detach()
        cand_prob = cand_prob.cpu().numpy()
        cand_prob = np.nan_to_num(cand_prob, nan=0.000001)  # replace np.nan with 0
        cand_prob /= cand_prob.sum()  # make probabilities sum to 1
        # print(f"cand_prob: {cand_prob}")

        # sample shot_pids from the cand_prob distribution
        cids = np.random.choice(range(len(cand_examples)), shot_number, p=cand_prob, replace=False)

        # reverse shot_pids so more relevant prompt will be put closer to the question
        cids = cids[::-1]
        
        
        # construct prompts
        cot_examples = []
        for idx in cids:
                cot_examples_temp = [{"role": "user",
                     "content": str(cand_examples[idx])},
                    {"role": "assistant",
                    "content": str(cand_answers[idx])}]
                cot_examples += cot_examples_temp

        # generate the prompt input
        prompt = build_prompt(train_batch[i], cot_examples)
        
        

        # get the output from GPT-3
        output = get_gpt3_output(prompt)

        # extract the prediction from the output
        prediction = extract_prediction(output)

        loss = log_loss(label_batch[i], prediction)
        
        count_correct = np.sum(np.array(label_batch[i]) == np.array(prediction))
        
        count_incorrect = -10 * np.sum(np.array(label_batch[i]) != np.array(prediction))
        
        _reward = (count_correct + count_incorrect) / len(label_batch[i])

        log_prob = 0
        for cid in cids:
            log_prob += torch.log(scores[i, cid])
        print(f"reward: {_reward}")

        batch_reward += _reward
        batch_loss -= _reward * log_prob

    return cids, batch_reward, batch_loss

In [14]:
def policy_gradient_train(policy_model, train_set, cand_examples, cand_answers):
    # REINFORCE

    optimizer = torch.optim.Adam(policy_model.parameters(), lr=lr)

    train_samples, train_labels = [], []
    for idx, item in enumerate(train_set):
        if idx >= training_sample_number:
            break
        label, caption, caseid = item
        train_samples.append(caption)
        train_labels.append(label)
        

    num_batch = math.ceil(len(train_samples) / batch_size)

    reward_history = []
    loss_history = []

    total_reward_history = []  # epoch based
    total_loss_history = []  # epoch based

    STOP_FLAG = False

    for epoch in range(epochs):
        print("Epoch: {epoch}")

        total_train_reward = 0
        total_train_loss = 0

        # We can simply set the batch_size to len(train_data) in few-shot setting.
        for batch_i in range(num_batch):
            print(f"Batch: {batch_i}")
            train_batch = train_samples[batch_i * batch_size:(batch_i + 1) * batch_size]
            label_batch = train_labels[batch_i * batch_size:(batch_i + 1) * batch_size]

            # We need to encode cands again every time we update the network
            embedding_cands = policy_model(cand_examples)  # len(cand_examples) x embedding_size
            embedding_ctxt = policy_model(train_batch)  # len(train_batch) x embedding_size
            
            

            scores = torch.mm(embedding_ctxt, embedding_cands.t())  # len(train_batch) x len(cand_examples)

            scores = F.softmax(scores, dim=1)  # len(train_batch) x len(cand_examples)

            cids, reward, loss = get_batch_reward_loss(scores, cand_examples, cand_answers, train_batch, label_batch)

            print(f"cids for sample[-1] in batch: {cids}")
            print(f"Cand prob for sample[-1] in batch: {[round(x,5) for x in scores[-1, :].tolist()]}")
            print(f"### reward for the batch: {reward}")
            print(f"### loss for the batch: {loss}\n")

            # linear layer has Weight and bias
            # prev_param = list(policy_model.linear.parameters())[0].clone()
            # print(f"prev_param: {prev_param.data}")

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # for each iteration/batch
            total_train_reward += reward
            total_train_loss += loss.item()

            reward_history.append(reward)
            loss_history.append(loss.item())

            if np.isnan(loss.item()):
                STOP_FLAG = True
                break

        # for each epoch
        total_reward_history.append(total_train_reward)
        total_loss_history.append(total_train_loss)

        best_reward = max(total_reward_history)
        best_loss = min(total_loss_history)

        best_reward_epoch = total_reward_history.index(best_reward)
        best_loss_epoch = total_loss_history.index(best_loss)

        print("============================================")
        print(f"### Epoch: {epoch} / {epochs}")
        print(f"### Total reward: {total_train_reward}, " + f"Total loss: {round(total_train_loss,5)}, " +
                     f"Best reward: {best_reward} at epoch {best_reward_epoch}, " +
                     f"Best loss: {round(best_loss, 5)} at epoch {best_loss_epoch}\n")

        # save every epoch
        ckpt_file = os.path.join(ckpt_path, f"ckpt_{epoch}.pt")
        torch.save(policy_model.linear.state_dict(), ckpt_file)
        print(f"saved the ckpt to {ckpt_file}")

        # save best epoch
        if epoch == best_reward_epoch:
            ckpt_file = os.path.join(ckpt_path, "{}_{}_ckpt_best_reward.pt".format(str(shot_number), str(training_sample_number)))
            torch.save(policy_model.linear.state_dict(), ckpt_file)
            print(f"saved the best reward ckpt to {ckpt_file}")

        if epoch == best_loss_epoch:
            ckpt_file = os.path.join(ckpt_path, "{}_{}_ckpt_best_loss.pt".format(str(shot_number), str(training_sample_number)))
            torch.save(policy_model.linear.state_dict(), ckpt_file)
            print(f"saved the best loss ckpt to {ckpt_file}")

        # save reward and loss history
        history = {
            "reward_history": reward_history,
            "loss_history": loss_history,
            "total_reward_history": total_reward_history,
            "total_loss_history": total_loss_history,
        }
        history_file = os.path.join(ckpt_path, "{}_{}_history.json".format(str(shot_number), str(training_sample_number)))
        with open(history_file, 'w') as f:
            json.dump(history, f, indent=2, separators=(',', ': '))

        if STOP_FLAG:
            break

    # save in the end
    ckpt_file = os.path.join(ckpt_path, "{}_{}_ckpt_final.pt".format(str(shot_number), str(training_sample_number)))
    torch.save(policy_model.linear.state_dict(), ckpt_file)

In [16]:
# Learn Policy Network
from model import policy_network


## policy network
policy_model = policy_network(model_config="monologg/biobert_v1.1_pubmed",
                                add_linear=True,
                                embedding_size=128,
                                freeze_encoder=True)

device = torch.device("cuda:" + gpu if torch.cuda.is_available() else "cpu")  # one GPU
policy_model = policy_model.to(device)

policy_gradient_train(policy_model, train_set, cand_examples, cand_answers)

model_config: monologg/biobert_v1.1_pubmed


Some weights of the model checkpoint at monologg/biobert_v1.1_pubmed were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at 

In [None]:
# Policy Network Inference Stage

ckpt_file = os.path.join(ckpt_path, "{}_{}_ckpt_best_reward.pt".format(str(shot_number), str(training_sample_number)))
# ckpt_file = os.path.join(ckpt_path, "5_ckpt_best_reward.pt")
if os.path.exists(ckpt_file):
    policy_model.linear.load_state_dict(torch.load(ckpt_file))
else:
    print("Wrong===========")
policy_model.eval()

answers = []
true_ids = []
with torch.no_grad():
    # Calculate the embeddings for candidate examples only one time!
    cand_embedding = policy_model(cand_examples)
    
    for t_id, item in enumerate(test_set):
        if t_id >= 199:
            break
        label, caption, caseid = item
        ctxt_embedding = policy_model([caption])
        # print("ctxt_embedding:", ctxt_embedding.shape)  # [1 x emb_size]

        scores = F.softmax(torch.mm(ctxt_embedding, cand_embedding.t()), dim=1)[0]
        scores = scores.cpu().detach().numpy().tolist()
        
        cids = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:shot_number]
        
        
        # construct prompts
        cot_examples = []
        for idx in cids:
                cot_examples_temp = [{"role": "user",
                     "content": str(cand_examples[idx])},
                    {"role": "assistant",
                    "content": str(cand_answers[idx])}]
                cot_examples += cot_examples_temp
        
        print(caption, label)
        print()
        print(cot_examples)
        
        input()
                
        prompt = build_prompt(caption, cot_examples)
        # get the output from GPT-3
        output = get_gpt3_output(prompt)
        
        print(output)
        print(label)
        
        print("==================Index {}==================".format(t_id))

        answers.append(output)
        true_ids.append(list(label.numpy()))


### Extract output from the ChatGPT (regular expression)

No acute cardiopulmonary findings.. The cardiomediastinal silhouette and pulmonary vasculature are within normal limits in size. The lungs are clear of focal airspace disease, pneumothorax, or pleural effusion. There are no acute bony findings. tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])

[{'role': 'user', 'content': 'Below is the medical report \n\n[1. No acute cardiopulmonary findings.. The heart size and mediastinal contours appear within normal limits. Atherosclerotic calcification of the aorta. No focal airspace consolidation, pleural effusions or pneumothorax. Questionable thin-walled cavitary lesion in the right lower lobe, only seen on the AP view and may represent artifact. No acute bony abnormalities.]'}, {'role': 'assistant', 'content': "The report mentions 'Atherosclerotic calcification of the aorta', but this disease does not belong to any of the potential labels from 2 to 19. Besides, according to the rule 2, 'A report 


The report mentions that 'XXXX right apical pneumothorax measuring approximately 5 mm in thickness', suggesting pneumothorax(7); the report mentions that 'Multiple right-sided rib fractures involving at XXXX the right anterior 5th through 9th ribs with mild displacement', suggesting fractures bone(4); the report mentions that 'Mild right basilar airspace disease, atelectasis versus contusion', suggesting airspace disease(17) and pulmonary atelectasis(13); the report mentions that 'There is extensive subcutaneous emphysema in the right chest wall and neck', suggesting emphysema / pulmonary emphysema(10); the report mentions that 'There is mild streaky airspace disease in the right lung base', suggesting airspace disease(17). According to the rule 6, 'pleural effusion(5), thickening(6), and pneumothorax(7) are all related to pleural disease'. Therefore, the output is [the disease indices are: (4, 7, 10, 13, 17)]
tensor([0., 0., 0., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1.,


The report mentions that 'Suspected mild patchy right upper lobe pneumonia', suggesting pneumonia / infiltrate / consolidation(11); the report mentions that 'Stable mild background chronic interstitial changes', suggesting that there might exist one or more related lung diseases from the list of [emphysema / pulmonary emphysema(10), pulmonary edema(12), pulmonary atelectasis (13), cicatrix(14), opacity(15), and nodule / mass(16)]. However, according to rule 8, a report must not be classified into two or more disease labels from this group simultaneously. Therefore, the only output is [the disease indices are: (11)].
tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0.])
No acute cardiopulmonary finding.. The heart size and cardiomediastinal silhouette are normal. There is no focal airspace opacity, pleural effusion or pneumothorax. The osseous structures are intact. tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [38]:
import numpy as np
import torch

import re

predictions_results = []
for item in answers:
    print(item)
    labels = re.findall(r"the disease indices are\s*[:\-\s]*\(\s*(\d+(\s*,\s*\d+)*)\s*\)", str(item))
    pattern = r"\d+"
    matches = re.findall(pattern, str(labels))
    labels = [eval(i) for i in matches]
    print(labels)
    if labels == []:
        labels = [1]
    vector = np.zeros((20, ))
    for i in labels:
        if i > 20:
            continue
        vector[i-1] = 1
    predictions_results.append(vector)

predictions_results = torch.Tensor(predictions_results)
true_ids = torch.Tensor(true_ids)

The report mentions that 'No acute cardiopulmonary findings', suggesting normal/ no indexing(1). Therefore, the output is [the disease indices are: (1)]
[1]
The report mentions 'Right apical pneumothorax', suggesting pneumothorax(7); The report mentions 'Multiple right-sided rib fractures involving at least the right anterior 5th through 9th ribs', suggesting fractures bone(4); The report mentions 'mild streaky airspace disease in the right lung base', suggesting airspace disease(17); The report mentions 'small hiatal hernia', suggesting hernia hiatal(8); The report mentions 'extensive subcutaneous emphysema in the right chest wall and neck', suggesting emphysema / pulmonary emphysema(10); The report mentions 'intrathecal catheter terminating in the lower thoracic spine', suggesting catheters indwelling / surgical instruments / tube inserted / medical device(19). Therefore, the output is [the disease indices are: (4, 7, 8, 10, 17, 19)].
[4, 7, 8, 10, 17, 19, 19]
The report mentions tha

### Statistics on the Results

In [39]:
statistics = []
for j in range(true_ids.shape[1]):
    TP_j = ((predictions_results[:, j] == 1) & (true_ids[:, j] == 1)).sum().item()
    TN_j = ((predictions_results[:, j] == 0) & (true_ids[:, j] == 0)).sum().item()
    FP_j = ((predictions_results[:, j] == 1) & (true_ids[:, j] == 0)).sum().item()
    FN_j = ((predictions_results[:, j] == 0) & (true_ids[:, j] == 1)).sum().item()
    accuracy_j = (TP_j + TN_j) / (TP_j + TN_j + FP_j + FN_j)
    print(f"Accuracy of label {j}: {accuracy_j:.4f}")
    precision_j = TP_j / (TP_j + FP_j) if TP_j + FP_j > 0 else 0
    recall_j = TP_j / (TP_j + FN_j) if TP_j + FN_j > 0 else 0
    print(f"Label {j} - Precision: {precision_j:.4f}, Recall: {recall_j:.4f}")
    f1_j = 2 * precision_j * recall_j / (precision_j + recall_j) if precision_j + recall_j > 0 else 0

    
    statistics.append([precision_j, recall_j, f1_j])
    print("")

Accuracy of label 0: 0.8593
Label 0 - Precision: 0.7664, Recall: 0.9647

Accuracy of label 1: 0.9749
Label 1 - Precision: 0.8333, Recall: 0.8824

Accuracy of label 2: 0.9196
Label 2 - Precision: 0.7097, Recall: 0.7586

Accuracy of label 3: 0.9598
Label 3 - Precision: 0.3636, Recall: 0.8000

Accuracy of label 4: 0.9849
Label 4 - Precision: 0.7143, Recall: 0.8333

Accuracy of label 5: 0.9899
Label 5 - Precision: 1.0000, Recall: 0.3333

Accuracy of label 6: 0.9950
Label 6 - Precision: 1.0000, Recall: 0.5000

Accuracy of label 7: 0.9849
Label 7 - Precision: 0.6000, Recall: 0.7500

Accuracy of label 8: 0.9347
Label 8 - Precision: 0.3077, Recall: 0.5000

Accuracy of label 9: 0.9849
Label 9 - Precision: 0.5714, Recall: 1.0000

Accuracy of label 10: 0.9799
Label 10 - Precision: 0.5714, Recall: 0.8000

Accuracy of label 11: 1.0000
Label 11 - Precision: 1.0000, Recall: 1.0000

Accuracy of label 12: 0.9849
Label 12 - Precision: 0.9091, Recall: 0.8333

Accuracy of label 13: 0.9648
Label 13 - Preci

In [40]:
import pandas as pd
pd.DataFrame(statistics, columns=["Precision", "Recall", "F1 Score"])

Unnamed: 0,Precision,Recall,F1 Score
0,0.766355,0.964706,0.854167
1,0.833333,0.882353,0.857143
2,0.709677,0.758621,0.733333
3,0.363636,0.8,0.5
4,0.714286,0.833333,0.769231
5,1.0,0.333333,0.5
6,1.0,0.5,0.666667
7,0.6,0.75,0.666667
8,0.307692,0.5,0.380952
9,0.571429,1.0,0.727273


In [41]:
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, hamming_loss
from sklearn.metrics import classification_report

# assume y_true and y_pred are binary label vectors of size (N, K)
# where N is the number of examples and K is the number of labels

# compute overall precision, recall, and accuracy
overall_precision, overall_recall, overall_f1_score, _  = precision_recall_fscore_support(true_ids, predictions_results, average="weighted")
overall_accuracy = accuracy_score(true_ids, predictions_results,  normalize=True, sample_weight=None)

# print the results
print("Overall precision: {}".format(overall_precision))
print("Overall recall: {}".format(overall_recall))
print("Overall F1-score: {}".format(overall_f1_score))
print("Hamming Loss {}".format(hamming_loss(true_ids, predictions_results)))
print("Exact Matching: {}".format(overall_accuracy))

print(classification_report(true_ids, predictions_results))


Overall precision: 0.6477885857293144
Overall recall: 0.6834532374100719
Overall F1-score: 0.6461481105617842
Hamming Loss 0.04472361809045226
Exact Matching: 0.5477386934673367
              precision    recall  f1-score   support

           0       0.77      0.96      0.85        85
           1       0.83      0.88      0.86        17
           2       0.71      0.76      0.73        29
           3       0.36      0.80      0.50         5
           4       0.71      0.83      0.77         6
           5       1.00      0.33      0.50         3
           6       1.00      0.50      0.67         2
           7       0.60      0.75      0.67         4
           8       0.31      0.50      0.38         8
           9       0.57      1.00      0.73         4
          10       0.57      0.80      0.67         5
          11       1.00      1.00      1.00         2
          12       0.91      0.83      0.87        12
          13       0.75      0.33      0.46         9
          1