In [2]:
import sys
sys.path.append("/Users/claire/Desktop/Stuff-/codes/dissertation/cxr/rgrg/src/full_model")
sys.path.append("/Users/claire/Desktop/Stuff-/codes/dissertation/cxr/rgrg/src")
# import sys
sys.path.append("/Users/claire/Desktop/Stuff-/codes/dissertation/cxr/rgrg/")

from full_model.train_full_model import *
from full_model.generate_reports_for_images import *
from full_model.test_set_evaluation import *

In [6]:
raw_test_dataset, raw_test_2_dataset = get_dataset()

def get_tokenized_dataset(tokenizer, raw_test_2_dataset):
    def tokenize_function(example):
        phrases = example["bbox_phrases"]  # List[str]
        bos_token = "<|endoftext|>"  # note: in the GPT2 tokenizer, bos_token = eos_token = "<|endoftext|>"
        eos_token = "<|endoftext|>"

        phrases_with_special_tokens = [bos_token + phrase + eos_token for phrase in phrases]

        # the tokenizer will return input_ids of type List[List[int]] and attention_mask of type List[List[int]]
        return tokenizer(phrases_with_special_tokens, truncation=True, max_length=1024)

    tokenized_test_2_dataset = raw_test_2_dataset.map(tokenize_function)

    # tokenized datasets will consist of the columns
    #   - mimic_image_file_path (str)
    #   - bbox_coordinates (List[List[int]])
    #   - bbox_labels (List[int])
    #   - bbox_phrases (List[str])
    #   - input_ids (List[List[int]])
    #   - attention_mask (List[List[int]])
    #   - bbox_phrase_exists (List[bool])
    #   - bbox_is_abnormal (List[bool])
    #   - reference_report (str)

    return tokenized_test_2_dataset

# note that we don't actually need to tokenize anything (i.e. we don't need the input ids and attention mask),
# because we evaluate the language model on it's generation capabilities (for which we only need the input images)
# but since the custom dataset and collator are build in a way that they expect input ids and attention mask
# (as they were originally made for training the model),
# it's better to just leave it as it is instead of adding unnecessary complexity
tokenizer = get_tokenizer()
tokenized_test_2_dataset = get_tokenized_dataset(tokenizer,  raw_test_2_dataset)

test_transforms = get_transforms()
# model = get_model()





# obj_detector_scores, region_selection_scores, region_abnormal_scores = evaluate_obj_detector_and_binary_classifiers_on_test_set(model, test_loader, test_2_loader)
# evaluate_language_model_on_test_set(model, test_loader, test_2_loader, tokenizer)

  0%|          | 0/1173 [00:00<?, ?ex/s]

In [16]:
import cv2
import torch
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, dataset_name: str, tokenized_dataset, transforms, log):
        super().__init__()
        self.dataset_name = dataset_name
        self.tokenized_dataset = tokenized_dataset
        self.transforms = transforms
        self.log = log

    def __len__(self):
        return len(self.tokenized_dataset)

    def __getitem__(self, index):
        # get the image_path for potential logging in except block
        image_path = self.tokenized_dataset[index]["mimic_image_file_path"]

        # if something in __get__item fails, then return None
        # collate_fn in dataloader filters out None values
        try:
            bbox_coordinates = self.tokenized_dataset[index]["bbox_coordinates"]  # List[List[int]]
            bbox_labels = self.tokenized_dataset[index]["bbox_labels"]  # List[int]
            input_ids = self.tokenized_dataset[index]["input_ids"]  # List[List[int]]
            attention_mask = self.tokenized_dataset[index]["attention_mask"]  # List[List[int]]
            bbox_phrase_exists = self.tokenized_dataset[index]["bbox_phrase_exists"]  # List[bool]
            bbox_is_abnormal = self.tokenized_dataset[index]["bbox_is_abnormal"]  # List[bool]

            if self.dataset_name != "train":
                # we only need the reference phrases during evaluation when computing scores for metrics
                bbox_phrases = self.tokenized_dataset[index]["bbox_phrases"]  # List[str]

                # same for the reference_report
                reference_report = self.tokenized_dataset[index]["reference_report"]  # str

            # cv2.imread by default loads an image with 3 channels
            # since we have grayscale images, we only have 1 channel and thus use cv2.IMREAD_UNCHANGED to read in the 1 channel
            image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)[:,:,0]
            image = cv2.resize(image, (512, 512))


            # apply transformations to image, bbox_coordinates and bbox_labels
            transformed = self.transforms(image=image, bboxes=bbox_coordinates, class_labels=bbox_labels)

            transformed_image = transformed["image"]

            transformed_bbox_coordinates = transformed["bboxes"]
            transformed_bbox_labels = transformed["class_labels"]
            print(transformed_bbox_coordinates)
            transformed_bbox_coordinates = [[x * 2 for x in bbox] for bbox in transformed_bbox_coordinates]
            print(transformed_bbox_coordinates)
            sample = {
                "image": transformed_image,
                "bbox_coordinates": torch.tensor(transformed_bbox_coordinates, dtype=torch.float),
                "bbox_labels": torch.tensor(transformed_bbox_labels, dtype=torch.int64),
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "bbox_phrase_exists": torch.tensor(bbox_phrase_exists, dtype=torch.bool),
                "bbox_is_abnormal": torch.tensor(bbox_is_abnormal, dtype=torch.bool),
            }

            if self.dataset_name != "train":
                sample["bbox_phrases"] = bbox_phrases
                sample["reference_report"] = reference_report

        except Exception as e:
            self.log.error(f"__getitem__ failed for: {image_path}")
            self.log.error(f"Reason: {e}")
            return None

        return sample, transformed_bbox_coordinates


In [19]:
dataset = CustomDataset("test", tokenized_test_2_dataset.select(range(20)), test_transforms, log)
a, b = dataset.__getitem__(0)
np.max(b)/2

[(42.0, 3.0, 137.0, 130.0), (59.0, 3.0, 133.0, 65.0), (30.0, 114.0, 53.0, 145.0), (46.0, 102.0, 129.0, 137.0), (142.0, 4.0, 219.0, 113.0), (144.0, 9.0, 214.0, 66.0), (142.0, 66.0, 218.0, 83.0), (144.0, 83.0, 219.0, 113.0), (142.0, 62.0, 180.0, 86.0), (148.0, 4.0, 204.0, 38.0), (208.0, 99.0, 230.0, 122.0), (142.0, 90.0, 219.0, 121.0), (123.0, 19.0, 147.0, 75.0), (126.0, 0.0, 154.0, 254.0), (77.0, 5.0, 126.0, 34.0), (149.0, 5.0, 198.0, 34.0), (147.0, 41.0, 170.0, 66.0), (96.0, 14.0, 192.0, 126.0), (114.0, 24.0, 171.0, 69.0), (114.0, 41.0, 147.0, 69.0), (96.0, 70.0, 192.0, 126.0), (96.0, 70.0, 139.0, 89.0), (96.0, 89.0, 139.0, 126.0), (130.0, 67.0, 138.0, 75.0), (32.0, 129.0, 225.0, 254.0)]
[[84.0, 6.0, 274.0, 260.0], [118.0, 6.0, 266.0, 130.0], [60.0, 228.0, 106.0, 290.0], [92.0, 204.0, 258.0, 274.0], [284.0, 8.0, 438.0, 226.0], [288.0, 18.0, 428.0, 132.0], [284.0, 132.0, 436.0, 166.0], [288.0, 166.0, 438.0, 226.0], [284.0, 124.0, 360.0, 172.0], [296.0, 8.0, 408.0, 76.0], [416.0, 198.0, 

254.0