# Example of VRO-GREEN

In [1]:
import re
import pickle
import torch
import torch.nn as nn
from green_score import GREEN
from transformers import AutoModelForCausalLM, AutoTokenizer
from green_score.utils import process_responses, make_prompt, tokenize_batch_as_chat, truncate_to_max_len
pair_to_reward_dict = dict()
class GREENModel(nn.Module):
    """
    GREENModel is a neural network model for evaluating radiology reports.

    Args:
        cuda (bool): Whether to use CUDA for GPU acceleration.
        model_id_or_path (str): Path or identifier of the pretrained model.
        do_sample (bool): Whether to sample during generation.
        batch_size (int): Batch size for processing.
        return_0_if_no_green_score (bool): Whether to return 0 if no green score is found.

    Attributes:
        model: Pretrained model for causal language modeling.
        tokenizer: Tokenizer associated with the model.
        categories (list): List of evaluation categories.
        sub_categories (list): List of subcategories for error evaluation.
    """

    def __init__(
            self,
            cuda,
            model_id_or_path,
            do_sample=False,
            batch_size=4,
            return_0_if_no_green_score=True,
    ):
        super().__init__()
        self.cuda = cuda
        self.do_sample = do_sample
        self.batch_size = batch_size
        self.return_0_if_no_green_score = return_0_if_no_green_score
        self.model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_id_or_path,
            trust_remote_code=True,
            device_map={"": "cuda:{}".format(torch.cuda.current_device())} if cuda else "cpu",
            torch_dtype=torch.float16,
        )
        self.model.eval()

        self.tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=model_id_or_path,
            add_eos_token=True,
            use_fast=True,
            trust_remote_code=True,
            padding_side="left",
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.chat_template = "{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n'  + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

        self.categories = [
            "Clinically Significant Errors",
            "Clinically Insignificant Errors",
            "Matched Findings",
        ]

        self.sub_categories = [
            "(a) False report of a finding in the candidate",
            "(b) Missing a finding present in the reference",
            "(c) Misidentification of a finding's anatomic location/position",
            "(d) Misassessment of the severity of a finding",
            "(e) Mentioning a comparison that isn't in the reference",
            "(f) Omitting a comparison detailing a change from a prior study",
        ]

    def get_response(self, input_ids, attention_mask):
        """
        Generates responses using the model and processes them.

        Args:
            input_ids (Tensor): Input IDs for the model.
            attention_mask (Tensor): Attention mask for the input IDs.

        Returns:
            tuple: Processed response list and output IDs.
        """
        outputs = self.model.generate(
            input_ids,
            attention_mask=attention_mask,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            do_sample=self.do_sample,
            max_length=2048,
            temperature=None,
            top_p=None,
        )

        responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        response_list = process_responses(responses)

        return response_list, outputs

    def parse_error_counts(self, text, category, for_reward=False):
        """
        Parses error counts from the generated text for a specific category.

        Args:
            text (str): Text to parse for error counts.
            category (str): Category of errors to parse.

        Returns:
            tuple: Sum of counts and list of subcategory counts.
        """
        if category not in self.categories:
            raise ValueError(
                f"Category {category} is not a valid category. Please choose from {self.categories}."
            )

        pattern = rf"\[{category}\]:\s*(.*?)(?:\n\s*\n|\Z)"
        category_text = re.search(pattern, text, re.DOTALL)

        sum_counts = 0
        sub_counts = [0 for i in range(6)]

        if not category_text:
            if self.return_0_if_no_green_score:
                return sum_counts, sub_counts
            else:
                return None, [None for i in range(6)]

        if category_text.group(1).startswith("No"):
            return sum_counts, sub_counts

        if category == "Matched Findings":
            counts = re.findall(r"^\b\d+\b(?=\.)", category_text.group(1))
            if len(counts) > 0:
                sum_counts = int(counts[0])
            return sum_counts, sub_counts

        else:
            sub_categories = [s.split(" ", 1)[0] + " " for s in self.sub_categories]
            matches = sorted(re.findall(r"\([a-f]\) .*", category_text.group(1)))

            if len(matches) == 0:
                matches = sorted(re.findall(r"\([1-6]\) .*", category_text.group(1)))
                sub_categories = [
                    f"({i})" + " " for i in range(1, len(self.sub_categories) + 1)
                ]

            for position, sub_category in enumerate(sub_categories):
                for match in range(len(matches)):
                    if matches[match].startswith(sub_category):
                        count = re.findall(r"(?<=: )\b\d+\b(?=\.)", matches[match])
                        if len(count) > 0:
                            sub_counts[position] = int(count[0])
            return sum(sub_counts), sub_counts

    def compute_green(self, response):
        """
        Computes the green score based on significant clinical errors and matched findings.

        Args:
            response (str): Generated response to evaluate.

        Returns:
            float: Computed green score.
        """
        sig_present, sig_errors = self.parse_error_counts(response, self.categories[0])
        matched_findings, _ = self.parse_error_counts(response, self.categories[2])

        if matched_findings == 0:
            return 0

        if sig_present is None or matched_findings is None:
            return None

        return matched_findings / (matched_findings + sum(sig_errors))

    def forward(self, input_ids, attention_mask):
        """
        Forward pass for the model, computing green scores for input batch.

        Args:
            input_ids (Tensor): Input IDs for the model.
            attention_mask (Tensor): Attention mask for the input IDs.

        Returns:
            tuple: Tensor of green scores and output IDs.
        """
        if self.cuda:
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()

        reward_model_responses, output_ids = self.get_response(input_ids, attention_mask)

        greens = [self.compute_green(response) for response in reward_model_responses]
        greens = [green for green in greens if green is not None]
        return torch.tensor(greens, dtype=torch.float), output_ids
class GREEN(nn.Module):
    """
    GREEN is a wrapper model for GREENModel, handling batching and aggregation.

    Args:
        cuda (bool): Whether to use CUDA for GPU acceleration.

    Attributes:
        model: GREENModel instance for evaluation.
        tokenizer: Tokenizer associated with the model.
    """

    def __init__(self, cuda, max_len=200, **kwargs):
        super().__init__()
        self.cuda = cuda
        self.max_len = max_len
        self.model = GREENModel(cuda, **kwargs)
        self.tokenizer = self.model.tokenizer
        if self.cuda:
            print("Using {} GPUs!".format(torch.cuda.device_count()))
            # self.model = torch.nn.DataParallel(self.model)

    def forward(self, refs, hyps):
        """
        Forward pass for the model, computing green scores for pairs of reference and hypothesis reports.

        Args:
            refs (list): List of reference reports.
            hyps (list): List of hypothesis reports.

        Returns:
            tuple: Mean green score, tensor of green scores, and list of processed responses.
        """
        assert len(refs) == len(hyps)

        refs = truncate_to_max_len(refs, self.max_len)
        hyps = truncate_to_max_len(hyps, self.max_len)

        with torch.no_grad():
            pairs_to_process = []
            final_scores = torch.zeros(len(refs))
            output_ids_dict = {}

            # Iterate over ref-hyp pairs and populate final_scores and pairs_to_process
            for i, (ref, hyp) in enumerate(zip(refs, hyps)):
                if (ref, hyp) in pair_to_reward_dict:
                    final_scores[i], output_ids = pair_to_reward_dict[(ref, hyp)]
                    output_ids_dict[i] = output_ids
                else:
                    pairs_to_process.append((ref, hyp, i))

            if pairs_to_process:
                batch = [make_prompt(ref, hyp) for ref, hyp, _ in pairs_to_process]
                batch = [[{"from": "human", "value": prompt}, {"from": "gpt", "value": ""}] for prompt in batch]
                batch = tokenize_batch_as_chat(self.tokenizer, batch)

                greens_tensor, output_ids = self.model(batch['input_ids'], batch['attention_mask'])

                if len(greens_tensor) == len(pairs_to_process):
                    for (ref, hyp, idx), score, out_id in zip(pairs_to_process, greens_tensor, output_ids):
                        pair_to_reward_dict[(ref, hyp)] = (score, out_id)
                        final_scores[idx] = score
                        output_ids_dict[idx] = out_id
                else:
                    print("An inconsistency was detected in processing pairs.")

            responses = [output_ids_dict[i] for i in range(len(refs))]
            responses = self.tokenizer.batch_decode(responses, skip_special_tokens=True)

            mean_green = final_scores.mean()
            return mean_green, final_scores, process_responses(responses)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load 5 reference (original) reports
with open("prediction_5.pkl", "rb") as f:
    refs = pickle.load(f)
# Load 10 sampled (hypothesis) reports for each of the 5 original reports 
with open("samples_5.pkl", "rb") as f:
    hyps = pickle.load(f)

In [3]:
print("********Original report:********" )
print(refs[0])
print("********Sampled reports:********")
for i in hyps[0]:
    print(i) 

********Original report:********
The patient is status post coronary artery bypass graft surgery.  The cardiac, mediastinal and hilar contours appear unchanged.  There is no pleural effusion or pneumothorax.  The lungs appear clear.  A nodular opacity projecting over the right upper lobe appears unchanged and may be a nipple shadow.  There has been no significant change.
********Sampled reports:********
The patient is status post sternotomy.  Allowing for differences in technique, the cardiac, mediastinal, and hilar contours appear unchanged.  The lungs appear clear.  There are no pleural effusions or pneumothorax.  There has been no significant change.  A previously described nodular opacity in the lingula appears to be a nipple shadow on this examination.
The patient is status post coronary artery bypass graft surgery.  The cardiac, mediastinal and hilar contours appear unchanged.  There is no pleural effusion or pneumothorax.  A nodular density projecting over the left upper lobe is

In [4]:
model = GREEN(
    model_id_or_path="StanfordAIMI/GREEN-radllama2-7b",
    do_sample=False,  # should be always False
    batch_size=10,
    return_0_if_no_green_score=True,
    cuda=True,
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.79s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Using 2 GPUs!


In [6]:
for i in range(len(hyps)):
    mean_green, greens, text = model([refs[i]] * len(hyps[i]), hyps[i])
    print(f'Green uncertainty for {i}-th sample is {1 - mean_green}')

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Green uncertainty for 0-th sample is 0.5963419675827026


A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Green uncertainty for 1-th sample is 0.5768253803253174


A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Green uncertainty for 2-th sample is 0.657012939453125


A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Green uncertainty for 3-th sample is 0.3086904287338257
Green uncertainty for 4-th sample is 0.503333330154419
