# This Notebook looks at false predictions our models make and compares them to find commonalities among them. 

We also specifically look at cases where the text_only model makes mistakes, which are corrected by our multimodal MMBT model. Even though the MMBT model has slightly lower accuracy than the text-only model, we do find some cases where it in fact makes correct predictions, while the text-only model does not. Since the text remains contant in these cases we hypothesize that there must be some cross-modal interactions that explain this behaviour.

## Skip Unless running in Google Colab

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%pwd

'/content'

In [2]:
%cd /content/drive/MyDrive/LAP_MMBT

/content/drive/.shortcut-targets-by-id/1gwgx4ZApTKz5fN6SG9YkiVjVCZ0WNGeH/LAP_MMBT


In [3]:
%ls

baseline_experiments_results.ipynb  [0m[01;34mmmbt_output_major_findings_10epochs[0m/
[01;34mdata[0m/                               [01;34m__pycache__[0m/
false_preds.ipynb                   [01;34mresults[0m/
image_submodel.ipynb                run_bert_text_only.ipynb
[01;34mintegrated_gradients[0m/               run_mmbt.ipynb
[01;34mMMBT[0m/                               run_mmbt_masked_text_eval.ipynb
[01;34mmmbt_experiment_logs_results[0m/       [01;34mruns[0m/
[01;34mmmbt_output_findings_10epochs[0m/      textBert_utils.py
[01;34mmmbt_output_findings_10epochs_n[0m/    [01;34mtext_only_experiment_logs_results[0m/
[01;34mmmbt_output_findings_4epochs[0m/       [01;34mwandb[0m/
[01;34mmmbt_output_major_8epochs[0m/


## Configure device

In [4]:
import torch

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: Tesla T4


## Install Transformers library

In [5]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 18.0MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 47.8MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 51.1MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp37-none-any.whl size=893262 sha256=a562

## Resolve current directory

In [6]:
from pathlib import Path
import os

try:
    FILE_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
    print('__file__ does not exist for notebook, use current directory instead')
    FILE_DIR = Path().resolve()
    
print(f'current directory is: {FILE_DIR}')

__file__ does not exist for notebook, use current directory instead
current directory is: /content/drive/.shortcut-targets-by-id/1gwgx4ZApTKz5fN6SG9YkiVjVCZ0WNGeH/LAP_MMBT


## Load Data

In [7]:
import pandas as pd
pd.set_option('display.max_colwidth', -1)
pd.options.display.max_rows = 580

  


In [8]:
DATA_DIR = os.path.join(FILE_DIR, 'data')

CSV_PATH = os.path.join(DATA_DIR, 'csv')

major = '/image_labels_major_findings_frontal_test.csv'
baseline = '/image_labels_findings_frontal_test.csv'

major = pd.read_csv(CSV_PATH+major)
baseline = pd.read_csv(CSV_PATH+baseline)

# all the test sets are of equal length
print(f"There are {len(baseline)} items in the dataset.")

There are 570 items in the dataset.


# sentences and labels

In [11]:
major_sentences = major.text.values
major_labels = major.label.values

base_sentences = baseline.text.values
base_labels = baseline.label.values

# Max sentence length (training)

In [12]:
from transformers import BertTokenizer

# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

Loading BERT tokenizer...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




In [13]:
max_len = 0

# For every sentence...
for sent in major_sentences:

    # Tokenize the text and add `[CLS]` and `[SEP]` tokens.
    input_ids = tokenizer.encode(sent, add_special_tokens=True)

    # Update the maximum sentence length.
    max_len = max(max_len, len(input_ids))

print('Max sentence length: ', max_len)

Max sentence length:  165


In [15]:
#default sentence length
MAX_SENT_LEN = max_len # default: 256

## Helper functions for labeling and multiclass criterion

# Tokenize and Encode data

In [16]:
def tokenize_and_encode_data(sentences_iterable, tokenizer_encoder, max_sent_len, labels_iterable):
    # Tokenize all of the sentences and map the tokens to thier word IDs.
    input_ids = []
    attention_masks = []

    # For every sentence...
    for sent, label in zip(sentences_iterable, labels_iterable):
        # `encode_plus` will:
        #   (1) Tokenize the sentence.
        #   (2) Prepend the `[CLS]` token to the start.
        #   (3) Append the `[SEP]` token to the end.
        #   (4) Map tokens to their IDs.
        #   (5) Pad or truncate the sentence to `max_length`
        #   (6) Create attention masks for [PAD] tokens.
        encoded_dict = tokenizer_encoder.encode_plus(
                            sent,                      # Sentence to encode.
                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                            max_length = max_sent_len,           # Pad & truncate all sentences.
                            padding = 'max_length',
                            return_attention_mask = True,   # Construct attn. masks.
                            return_tensors = 'pt',     # Return pytorch tensors.
                    )
        
        # Add the encoded sentence to the list.    
        input_ids.append(encoded_dict['input_ids'])
        
        # And its attention mask (simply differentiates padding from non-padding).
        attention_masks.append(encoded_dict['attention_mask'])
        

    # Convert the lists into tensors.
    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    input_labels = torch.tensor(labels_iterable)

    # Print sentence 0, now as a list of IDs.
    print('Original: ', sentences_iterable[0])
    print('Token IDs:', input_ids[0])

    return input_ids, attention_masks, input_labels

In [17]:
major_input_ids, major_attention_masks, major_labels_tensors = tokenize_and_encode_data(major_sentences, tokenizer, MAX_SENT_LEN, major_labels)
base_input_ids, base_attention_masks, base_labels_tensors = tokenize_and_encode_data(base_sentences, tokenizer, MAX_SENT_LEN, base_labels) 

Original:   Lungs are clear without focal consolidation, effusion or pneumothorax. Normal heart size. Bony thorax and soft tissues unremarkable
Token IDs: tensor([  101,  8948,  2024,  3154,  2302, 15918, 17439,  1010,  1041,  4246,
        14499,  2030,  1052,  2638,  2819, 29288,  2527,  2595,  1012,  3671,
         2540,  2946,  1012, 22678, 15321,  8528,  1998,  3730, 14095,  4895,
        28578, 17007,  3085,   102,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,  

# Torch dataset and dataloader

In [18]:
from torch.utils.data import TensorDataset

major_dataset = TensorDataset(major_input_ids, major_attention_masks, major_labels_tensors)
base_dataset = TensorDataset(base_input_ids, base_attention_masks, base_labels_tensors)

Create an iterator for the dataset using the torch DataLoader class. 

In [19]:
from torch.utils.data import DataLoader, SequentialSampler

# The DataLoader needs to know our batch size for training, so we specify it 
# here. For fine-tuning BERT on a specific task, the authors recommend a batch 
# size of 16 or 32.
batch_size = 32



# Dataloader for the major label based dataset
major_dataloader = DataLoader(
            major_dataset,  
            sampler = SequentialSampler(major_dataset), # Select batches randomly
            batch_size = batch_size # Trains with this batch size.
            )

# Dataloader for the impression label based dataset
base_dataloader = DataLoader(
            base_dataset,  
            sampler = SequentialSampler(base_dataset), # Select batches randomly
            batch_size = batch_size # Trains with this batch size.
            )



# Fine Tune BERT for Classification

## Imports

In [20]:
import os
import random
import sys
import numpy as np
import transformers
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
    get_linear_schedule_with_warmup,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process

from tqdm import tqdm, trange

## Model Parameters

In [21]:
model_name = "bert-base-uncased"
tokenizer_name = "bert-base-uncased"
max_seq_length = max_len

config_name = None

n_gpu = torch.cuda.device_count() if torch.cuda.is_available() else 0

my_seed = 42

eval_batch_size = 32

## Evaluation function

This function is modified from the original eval function used in the text_only notebook. Instead of returning the eval results, it returns a dictionary containing the indices in the dataset as keys, where the prediction did not match the target. As values for the keys, it returns the prediction and target labels associated to these examples.

In [22]:
def evaluate(model, tokenizer, dataloader):
   
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = []
    out_label_ids = []
    for batch in tqdm(dataloader, desc="Batch Evaluating"):
        model.eval()
        #batch = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            batch = tuple(t.to(device) for t in batch)
            b_input_ids = batch[0]
            b_input_mask = batch[1]
            b_labels = batch[2]
            
            result = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask,
                            labels=b_labels,
                            return_dict=True)

            logits = result.logits  # model outputs are always tuple in transformers (see doc)
            tmp_eval_loss = result.loss

            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        # Move logits and labels to CPU
        pred = torch.nn.functional.softmax(logits, dim=1).argmax(dim=1).cpu().detach().numpy()
        
        out_label_id = b_labels.detach().cpu().numpy()
        preds.append(pred)
        out_label_ids.append(out_label_id)
        
    eval_loss = eval_loss / nb_eval_steps

    
    preds = [l for sl in preds for l in sl]
    out_label_ids = [l for sl in out_label_ids for l in sl]
    
    idxs = {}
    for i, (pred, label) in enumerate(zip(preds, out_label_ids)):    
      if pred != label:
        idxs.update({i: (pred, label)})

    return idxs

## Set random seed

In [23]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

set_seed(my_seed)

# Evaluation on Test set

## Basic model configurations

+ path to pretrained model weights and dataloaders



In [24]:
MODELS_DIR = os.path.join(DATA_DIR, 'models')
TEXT_ONLY_DIR = os.path.join(MODELS_DIR, 'text_only')

models = ["text_only_findings.bin", "text_only_major_findings.bin"]
dataloaders = [base_dataloader, major_dataloader]

transformer_config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name if tokenizer_name else model_name,
        do_lower_case=True,
        cache_dir=None,
    )
transformer_model = AutoModelForSequenceClassification.from_pretrained(model_name, config=transformer_config)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

## Run base model

Loads the pretrained weights and evaluates model

In [25]:
base_checkpoint = os.path.join(TEXT_ONLY_DIR, models[0])
transformer_model.load_state_dict(torch.load(base_checkpoint))
transformer_model.to(device)
base_result = evaluate(transformer_model, tokenizer, dataloaders[0])
base_idxs = list(base_result.keys())

Batch Evaluating: 100%|██████████| 18/18 [00:05<00:00,  3.36it/s]


## False predictions

- Key: index in dataset
- Value: tuple (prediction, target)

In [26]:
base_result

{5: (0, 1),
 10: (0, 1),
 15: (0, 1),
 19: (0, 1),
 22: (0, 1),
 49: (1, 0),
 54: (1, 0),
 56: (1, 0),
 80: (1, 0),
 86: (0, 1),
 88: (0, 1),
 91: (0, 1),
 98: (0, 1),
 99: (0, 1),
 100: (1, 0),
 106: (0, 1),
 107: (0, 1),
 108: (0, 1),
 113: (0, 1),
 115: (0, 1),
 128: (1, 0),
 136: (1, 0),
 140: (0, 1),
 143: (0, 1),
 156: (0, 1),
 185: (0, 1),
 204: (1, 0),
 213: (1, 0),
 219: (0, 1),
 224: (0, 1),
 229: (0, 1),
 241: (0, 1),
 242: (1, 0),
 260: (1, 0),
 271: (0, 1),
 273: (1, 0),
 291: (0, 1),
 296: (0, 1),
 297: (0, 1),
 305: (1, 0),
 309: (1, 0),
 313: (0, 1),
 337: (1, 0),
 357: (0, 1),
 382: (1, 0),
 410: (0, 1),
 421: (0, 1),
 422: (1, 0),
 434: (0, 1),
 437: (0, 1),
 440: (0, 1),
 442: (1, 0),
 443: (0, 1),
 458: (1, 0),
 485: (1, 0),
 487: (1, 0),
 488: (0, 1),
 493: (1, 0),
 498: (0, 1),
 505: (0, 1),
 520: (1, 0),
 526: (0, 1),
 530: (1, 0),
 531: (0, 1),
 533: (0, 1),
 542: (0, 1),
 548: (0, 1),
 553: (1, 0),
 564: (0, 1),
 567: (0, 1)}

In [27]:
len(base_idxs)

70

## Visualize false predictions from the dataset

label = target label

In [28]:
base_false_preds = [baseline.iloc[idx] for idx in base_idxs]
base_false_preds = pd.DataFrame(base_false_preds)
base_false_preds

Unnamed: 0.1,Unnamed: 0,img,label,text
5,2672,CXR3205_IM-1513-1001.png,1,No pneumothorax. Heart size is normal. No large pleural effusions. No focal airspace opacities. No definite visualized rib fractures.
10,2029,CXR299_IM-1377-1001.png,1,Cardiac and mediastinal contours are within normal limits. The lungs are clear. Bony structures are intact. Small hiatal hernia.
15,2280,CXR1439_IM-0282-1001.png,1,Lungs are XXXX. XXXX opacities are present in the left lung base. Heart size normal. Mediastinum normal.
19,2380,CXR1911_IM-0593-1001.png,1,"Cardiac and mediastinal contours are unremarkable. Pulmonary vascularity is within normal limits. No focal air space opacities, pleural effusion, or pneumothorax. XXXX are grossly unremarkable."
22,2775,CXR3695_IM-1845-1001.png,1,"The cardiac silhouette size is at the upper limits of normal. Central vascular markings are mildly prominent. The lungs are normally inflated with no focal airspace disease, pleural effusion, or pneumothorax. No acute bony abnormality."
49,1628,CXR3355_IM-1609-1001.png,0,"Cardiomediastinal silhouettes are within normal limits. Lungs are hyperexpanded. Lungs are clear without focal consolidation, pneumothorax, or pleural effusion. Bony thorax is unremarkable."
54,1707,CXR3494_IM-1699-1001.png,0,There are low lung volumes with bronchovascular crowding. There is no focal consolidation. No visualized pneumothorax. Heart size is within normal limits. The cardiomediastinal contours is grossly normal in size and contour.
56,50,CXR102_IM-0016-1001.png,0,"Normal heart size. Clear, hyperaerated lungs. No pneumothorax. No pleural effusion. XXXX substernal density may be related to a pectus deformity."
80,330,CXR655_IM-2231-1001.png,0,Normal heart size and mediastinal contours. The lungs are hyperinflated but clear. No pneumothorax or pleural effusion. No acute bony abnormalities.
86,2493,CXR2390_IM-0944-1001.png,1,The heart is borderline in size. The mediastinum is stable with changes of XXXX sternotomy and bypass graft. Aorta is atherosclerotic. There are postsurgical changes of the left hemithorax with mild left-sided volume loss as evidenced by diaphragm elevation. Left post thoracotomy rib changes are noted. The right lung is clear. There is no pleural effusion.


## Run major model

Loads the pretrained weights and evaluates model

In [29]:
major_checkpoint = os.path.join(TEXT_ONLY_DIR, models[1])
transformer_model.load_state_dict(torch.load(major_checkpoint))
transformer_model.to(device)
major_result = evaluate(transformer_model, tokenizer, dataloaders[1])
major_idxs = list(major_result.keys())

Batch Evaluating: 100%|██████████| 18/18 [00:05<00:00,  3.48it/s]


## False predictions

- Key: index in dataset
- Value: tuple (prediction, target)

In [30]:
major_result

{48: (0, 1),
 63: (1, 0),
 135: (1, 0),
 166: (1, 0),
 185: (0, 1),
 277: (0, 1),
 282: (0, 1),
 291: (1, 0),
 306: (1, 0),
 340: (1, 0),
 403: (0, 1),
 424: (1, 0),
 431: (1, 0),
 448: (1, 0),
 473: (0, 1),
 479: (0, 1),
 547: (1, 0)}

In [31]:
len(major_idxs)

17

## Visualize false predictions from the dataset

label = target label

In [32]:
major_false_preds = [major.iloc[idx] for idx in major_idxs]
major_false_preds = pd.DataFrame(major_false_preds)
major_false_preds

Unnamed: 0.1,Unnamed: 0,img,label,text
48,1909,CXR3898_IM-1978-1001.png,1,"The cardiomediastinal silhouette is normal in size and contour. No focal consolidation, pneumothorax or large pleural effusion. Normal XXXX."
63,1417,CXR2942_IM-1343-1001.png,0,"Postsurgical changes of XXXX sternotomy with screw fixation of anterior XXXX plates. Heart size and cardiomediastinal silhouette are normal. No focal consolidation, suspicious bony opacity, pneumothorax, or pleural effusion. No acute osseous abnormality."
135,107,CXR226_IM-0851-1001.png,0,The cardiac contours are normal. The lungs are clear. Thoracic spondylosis.
166,1135,CXR2339_IM-0905-1001.png,0,The heart size and pulmonary vascularity appear within normal limits. The lungs are free of focal airspace disease. No pleural effusion or pneumothorax is seen. Multiple XXXX-filled loops of bowel are present. Gastrostomy is noted.
185,2508,CXR2448_IM-0983-1001.png,1,Mediastinal contours are normal. No significant change in pneumothorax or right pleural fluid..
277,1843,CXR3751_IM-1875-1001.png,1,The lungs are clear without evidence of focal airspace disease. There is no evidence of pneumothorax or large pleural effusion. The cardiac and mediastinal contours are within normal limits. The XXXX are unremarkable.
282,683,CXR1392_IM-0251-1001.png,1,"The aortic XXXX, cardiac apex, and stomach are left-sided. Cardiomediastinal silhouette is within normal limits in overall size and appearance. Pulmonary vascular markings are symmetric and within normal limits. The lungs are normally inflated with no focal airspace disease, pleural effusion, or pneumothorax. No acute bony abnormality."
291,2520,CXR2487_IM-1014-1001.png,0,"Chest: The cardiomediastinal silhouette is within normal limits for size and contour. The lungs are normally inflated without evidence of focal airspace disease, pleural effusion, or pneumothorax. Thoracic spine: Mild dextro curvature the upper thoracic spine. Evaluation of the upper thoracic bodies is limited secondary to osseous overlap. Vertebral body XXXX and disc spaces are maintained. Mild degenerative endplate changes. Lumbar spine: There are 5 nonrib-bearing lumbar type vertebral bodies. Alignment is within normal limits. Vertebral body XXXX and disc spaces are maintained. Mild degenerative change without acute displaced fracture or dislocation. Moderate amount of stool.."
306,490,CXR965_IM-2455-1001.png,0,The lungs appear clear. Heart and pulmonary XXXX appear normal. Mediastinal contours are normal. Pleural spaces are clear. There appears to the contrast XXXX within small colonic diverticula in the splenic flexure region.
340,546,CXR1091_IM-0062-1001.png,0,The heart and lungs have XXXX XXXX in the interval. Both lungs are clear and expanded. Heart and mediastinum normal.


## MMBT

## Configure Imports

In [33]:
import json
import os
from collections import Counter
import numpy as np
from matplotlib.pyplot import imshow

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset

In [34]:
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

## Files and Paths

In [35]:
img_data_dir = os.path.join(DATA_DIR, "NLMCXR_png_frontal")
jsonl_data_dir = os.path.join(DATA_DIR, "json")
saved_chexnet = "saved_chexnet.pt"
major_file = "image_labels_major_findings_frontal_test.jsonl"
baseline_file = "image_labels_findings_frontal_test.jsonl"

## Define Image Encoder

In [36]:
# mapping number of image embeddings to AdaptiveAvgPool2d output size
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}

class ImageEncoderDenseNet(nn.Module):
    def __init__(self, num_image_embeds, saved_model=True, path=os.path.join(MODELS_DIR, saved_chexnet)):
        super().__init__()
        if saved_model:
            print(f"Loading a saved model from {path}")
            model = torch.load(path)
        else:
            print("No saved model found. Loading a pretrained Densenet121 model from Pytorch repositories.")
            model = torchvision.models.densenet121(pretrained=True)

        # DenseNet architecture last layer is the classifier; we only want everything before that
        modules = list(model.children())[:-1] 
        self.model = nn.Sequential(*modules)
        # self.model same as original DenseNet self.features part of the forward function
        self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[num_image_embeds])

    def forward(self, input_modal):
        # Bx3x224x224 -> Bx1024x7x7 -> Bx1024xN -> BxNx1024
        features = self.model(input_modal)
        out = F.relu(features, inplace=True)
        out = self.pool(out)
        out = torch.flatten(out, start_dim=2)
        out = out.transpose(1, 2).contiguous()

        return out  # BxNx1024

## Define custom JsonDataset

In [37]:
class JsonlDataset(Dataset):
    def __init__(self, jsonl_data_path, img_dir, tokenizer, transforms, labels, max_seq_length):
        self.data = [json.loads(l) for l in open(jsonl_data_path)]
        # self.data_dir = os.path.dirname(data_path)
        self.img_data_dir = img_dir
        self.tokenizer = tokenizer
        self.labels = labels
        self.n_classes = len(labels)
        self.max_seq_length = max_seq_length

        # for image normalization for DenseNet
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
        start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
        sentence = sentence[:self.max_seq_length]
        label = torch.LongTensor([self.labels.index(self.data[index]["label"])])

        image = Image.open(os.path.join(self.img_data_dir, self.data[index]["img"])).convert("RGB")
        image = self.transforms(image)

        return {
            "image_start_token": start_token,
            "image_end_token": end_token,
            "sentence": sentence,
            "image": image,
            "label": label,
        }


def collate_fn(batch):
    """
    Specify batching for the torch Dataloader function

    :param batch: each batch of the JsonlDataset
    :return: text tensor, attention mask tensor, img tensor, modal start token, modal end token, label
    """
    lens = [len(row["sentence"]) for row in batch]
    bsz, max_seq_len = len(batch), max(lens)

    mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
    text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)

    for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
        text_tensor[i_batch, :length] = input_row["sentence"]
        mask_tensor[i_batch, :length] = 1

    img_tensor = torch.stack([row["image"] for row in batch])
    tgt_tensor = torch.stack([row["label"] for row in batch])
    img_start_token = torch.stack([row["image_start_token"] for row in batch])
    img_end_token = torch.stack([row["image_end_token"] for row in batch])

    return text_tensor, mask_tensor, img_tensor, img_start_token, img_end_token, tgt_tensor

## Define other helper functions

In [41]:
def get_labels():
    """
    0: normal
    1: abnormal

    :return: label classes
    """

    return [0, 1]

# mean and std specific for DenseNet architecture
def get_image_transforms():
    return transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]
            )
        ]
    )

def load_examples(tokenizer, max_seq_len, num_image_embeds, test_file, evaluate=False, test=False, data_dir=jsonl_data_dir, img_dir=img_data_dir):
    if evaluate and not test:
        path = os.path.join(data_dir, val_file)
    elif evaluate and test:
        path = os.path.join(data_dir, test_file)
    elif not evaluate and not test:
        path = os.path.join(data_dir, train_file)
    else:
        # shouldn't get here not evaluate and test?
        raise ValueError("invalid data file option!!")

    img_transforms = get_image_transforms()
    labels = get_labels()
    dataset = JsonlDataset(path, img_dir, tokenizer, img_transforms, labels, max_seq_len - num_image_embeds - 2)
    return dataset

## Create Datasets

In [42]:
img_transforms = get_image_transforms()
labels = get_labels()
num_labels = len(labels)

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True, cache_dir=None)

major_dataset = load_examples(tokenizer, max_seq_len=256, num_image_embeds=3, test_file=major_file, evaluate=True, test=True)
base_dataset = load_examples(tokenizer, max_seq_len=256, num_image_embeds=3, test_file=baseline_file, evaluate=True, test=True)

## Modifying ModalEmbeddings, MMBTModel and MMBTForClassification classes


In [43]:
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
from transformers.modeling_utils import ModuleUtilsMixin
from transformers.models.bert.modeling_bert import BertPreTrainedModel
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers import MMBTConfig

## ModalEmbeddings Modification 

The ModalEmbeddings class needs **hidden_modal_size**, which is part of the MMBTConfig and also MMBTConfig2. However, originally this attribute gets passed as part of the 'config' argument, which now has only the BertConfig and not MMBTConfig attributes.

So the modiciations in ModalEmbeddings2 has both BertConfig (self.config) and MMBTConfig attributes.

In [44]:
_CONFIG_FOR_DOC = "MMBTConfig"


class ModalEmbeddings2(nn.Module):
    """Generic Modal Embeddings which takes in an encoder, and a transformer embedding."""

    def __init__(self, config, modal_hidden_size, encoder, embeddings):
        super().__init__()
        self.config = config
        self.encoder = encoder
        self.proj_embeddings = nn.Linear(modal_hidden_size, config.hidden_size)
        self.position_embeddings = embeddings.position_embeddings
        self.token_type_embeddings = embeddings.token_type_embeddings
        self.word_embeddings = embeddings.word_embeddings
        self.LayerNorm = embeddings.LayerNorm
        self.dropout = nn.Dropout(p=config.hidden_dropout_prob)

    def forward(self, input_modal, start_token=None, end_token=None, position_ids=None, token_type_ids=None):
        token_embeddings = self.proj_embeddings(self.encoder(input_modal))
        seq_length = token_embeddings.size(1)

        if start_token is not None:
            start_token_embeds = self.word_embeddings(start_token)
            seq_length += 1
            token_embeddings = torch.cat([start_token_embeds.unsqueeze(1), token_embeddings], dim=1)

        if end_token is not None:
            end_token_embeds = self.word_embeddings(end_token)
            seq_length += 1
            token_embeddings = torch.cat([token_embeddings, end_token_embeds.unsqueeze(1)], dim=1)

        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_modal.device)
            position_ids = position_ids.unsqueeze(0).expand(input_modal.size(0), seq_length)

        if token_type_ids is None:
            token_type_ids = torch.zeros(
                (input_modal.size(0), seq_length), dtype=torch.long, device=input_modal.device
            )

        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = token_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


MMBT_START_DOCSTRING = r"""
    MMBT model was proposed in `Supervised Multimodal Bitransformers for Classifying Images and Text
    <https://github.com/facebookresearch/mmbt>`__ by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
    It's a supervised multimodal bitransformer model that fuses information from text and other image encoders, and
    obtain state-of-the-art performance on various multimodal classification benchmark tasks.
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
    Parameters:
        config (:class:`~transformers.MMBTConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration.
        transformer (:class: `~nn.Module`): A text transformer that is used by MMBT.
            It should have embeddings, encoder, and pooler attributes.
        encoder (:class: `~nn.Module`): Encoder for the second modality.
            It should take in a batch of modal inputs and return k, n dimension embeddings.
"""

MMBT_INPUTS_DOCSTRING = r"""
    Args:
        input_modal (``torch.FloatTensor`` of shape ``(batch_size, ***)``):
            The other modality data. It will be the shape that the encoder for that type expects. e.g. With an Image
            Encoder, the shape would be (batch_size, channels, height, width)
        input_ids (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``):
            Indices of input sequence tokens in the vocabulary. It does not expect [CLS] token to be added as it's
            appended to the end of other modality embeddings. Indices can be obtained using
            :class:`~transformers.BertTokenizer`. See :meth:`transformers.PreTrainedTokenizer.encode` and
            :meth:`transformers.PreTrainedTokenizer.__call__` for details.
            `What are input IDs? <../glossary.html#input-ids>`__
        modal_start_tokens (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
            Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for classification
            tasks.
        modal_end_tokens (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
            Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used.
        attention_mask (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            `What are attention masks? <../glossary.html#attention-mask>`__
        token_type_ids (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
            1]``:
            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        modal_token_type_ids (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``:
            Segment token indices to indicate different portions of the non-text modality. The embeddings from these
            tokens will be summed with the respective token embeddings for the non-text modality.
        position_ids (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
            config.max_position_embeddings - 1]``.
            `What are position IDs? <../glossary.html#position-ids>`__
        modal_position_ids (``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``, `optional`):
            Indices of positions of each input sequence tokens in the position embeddings for the non-text modality.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
            `What are position IDs? <../glossary.html#position-ids>`__
        head_mask (``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``, `optional`):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        inputs_embeds (``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``, `optional`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
        encoder_hidden_states (``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        output_attentions (:obj:`bool`, `optional`):
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
        output_hidden_states (:obj:`bool`, `optional`):
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
        return_dict (:obj:`bool`, `optional`):
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""

In [45]:
@add_start_docstrings(
    "The bare MMBT Model outputting raw hidden-states without any specific head on top.",
    MMBT_START_DOCSTRING,
)
class MMBTModel2(nn.Module, ModuleUtilsMixin):
    def __init__(self, config, mmbt_config):
        super().__init__()
        self.config = config
        self.transformer = mmbt_config.transformer
        self.modal_encoder = ModalEmbeddings2(config, 
                                             mmbt_config.modal_hidden_size, 
                                             mmbt_config.encoder, 
                                             mmbt_config.transformer.embeddings)

    @add_start_docstrings_to_model_forward(MMBT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_modal,
        input_ids=None,
        modal_start_tokens=None,
        modal_end_tokens=None,
        attention_mask=None,
        token_type_ids=None,
        modal_token_type_ids=None,
        position_ids=None,
        modal_position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:
        Examples::
            # For example purposes. Not runnable.
            transformer = BertModel.from_pretrained('bert-base-uncased')
            encoder = ImageEncoder(args)
            mmbt = MMBTModel(config, transformer, encoder)
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_txt_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_txt_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        modal_embeddings = self.modal_encoder(
            input_modal,
            start_token=modal_start_tokens,
            end_token=modal_end_tokens,
            position_ids=modal_position_ids,
            token_type_ids=modal_token_type_ids,
        )

        input_modal_shape = modal_embeddings.size()[:-1]

        if token_type_ids is None:
            token_type_ids = torch.ones(input_txt_shape, dtype=torch.long, device=device)

        txt_embeddings = self.transformer.embeddings(
            input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )

        embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1)

        input_shape = embedding_output.size()[:-1]

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        else:
            attention_mask = torch.cat(
                [torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1
            )
        if encoder_attention_mask is None:
            encoder_attention_mask = torch.ones(input_shape, device=device)
        else:
            encoder_attention_mask = torch.cat(
                [torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
            )

        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
        encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        encoder_outputs = self.transformer.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = encoder_outputs[0]
        pooled_output = self.transformer.pooler(sequence_output)

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

In [46]:
class MMBTConfig2(object):
    """
    This is the configuration class to store the configuration of a :class:`~transformers.MMBTModel`. It is used to
    instantiate a MMBT model according to the specified arguments, defining the model architecture.
    Args:
        config (:class:`~transformers.PreTrainedConfig`):
            Config of the underlying Transformer models. Its values are copied over to use a single config.
        num_labels (:obj:`int`, `optional`):
            Size of final Linear layer for classification.
        modal_hidden_size (:obj:`int`, `optional`, defaults to 2048):
            Embedding dimension of the non-text modality encoder.
    """

    def __init__(self, transformer, encoder, num_labels=None, modal_hidden_size=1024):
        self.transformer = transformer
        self.encoder = encoder
        self.modal_hidden_size = modal_hidden_size
        if num_labels:
            self.num_labels = num_labels

In [48]:


@add_start_docstrings(
    """
    MMBT Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
    """,
    MMBT_START_DOCSTRING,
    MMBT_INPUTS_DOCSTRING,
)
class MMBTForClassification2(BertPreTrainedModel):
    r"""
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the sequence classification/regression loss. Indices should be in ``[0, ...,
            config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
    Returns: `Tuple` comprising various elements depending on the configuration (config) and inputs: **loss**:
    (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Classification (or
    regression if config.num_labels==1) loss. **logits**: ``torch.FloatTensor`` of shape ``(batch_size,
    config.num_labels)`` Classification (or regression if config.num_labels==1) scores (before SoftMax).
    **hidden_states**: (`optional`, returned when ``output_hidden_states=True``) list of ``torch.FloatTensor`` (one for
    the output of each layer + the output of the embeddings) of shape ``(batch_size, sequence_length, hidden_size)``:
    Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**:
    (`optional`, returned when ``output_attentions=True``) list of ``torch.FloatTensor`` (one for each layer) of shape
    ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used
    to compute the weighted average in the self-attention heads.
    Examples::
        # For example purposes. Not runnable.
        transformer = BertModel.from_pretrained('bert-base-uncased')
        encoder = ImageEncoder(args)
        model = MMBTForClassification(config, transformer, encoder)
        outputs = model(input_modal, input_ids, labels=labels)
        loss, logits = outputs[:2]
    """
    def __init__(self, config, mmbt_config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.mmbt = MMBTModel2(config, mmbt_config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_modal,
        input_ids=None,
        modal_start_tokens=None,
        modal_end_tokens=None,
        attention_mask=None,
        token_type_ids=None,
        modal_token_type_ids=None,
        position_ids=None,
        modal_position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.mmbt(
            input_modal=input_modal,
            input_ids=input_ids,
            modal_start_tokens=modal_start_tokens,
            modal_end_tokens=modal_end_tokens,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            modal_token_type_ids=modal_token_type_ids,
            position_ids=position_ids,
            modal_position_ids=modal_position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            return_dict=return_dict,
            output_attentions=True
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

## Model Hyperparameters

In [49]:
model_name = "bert-base-uncased"
tokenizer_name = "bert-base-uncased"
max_seq_length = 256
num_img_embeddings = 3
batch_size = 32

config_name = None

## Evaluation function

This function is modified from the original eval function used in the mmbt notebook. Instead of returning the eval results, it returns a dictionary containing the indices in the dataset as keys, where the prediction did not match the target. As values for the keys, it returns the prediction and target labels associated to these examples.

In [50]:
def mmbt_evaluate(model, tokenizer, dataloader):
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = []
    out_label_ids = []
    for batch in tqdm(dataloader, desc="Evaluating"):
        model.eval()
        #batch = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            batch = tuple(t.to(device) for t in batch)
            labels = batch[5]
            input_ids = batch[0]
            input_modal = batch[2]
            attention_mask = batch[1]
            modal_start_tokens = batch[3]
            modal_end_tokens = batch[4]
                   
            outputs = model(
                input_modal,
                input_ids=input_ids,
                modal_start_tokens=modal_start_tokens,
                modal_end_tokens=modal_end_tokens,
                attention_mask=attention_mask,
                token_type_ids=None,
                modal_token_type_ids=None,
                position_ids=None,
                modal_position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=labels,
                return_dict=True
            )
           
            logits = outputs.logits
            tmp_eval_loss = outputs.loss              
            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        # Move logits and labels to CPU
        pred = torch.nn.functional.softmax(logits, dim=1).argmax(dim=1).cpu().detach().numpy()      
        out_label_id = labels.detach().cpu().numpy()
        preds.append(pred)
        out_label_ids.append(out_label_id)

    eval_loss = eval_loss / nb_eval_steps

    preds = [l for sl in preds for l in sl]
    out_label_ids = [l for sl in out_label_ids for l in sl]

    idxs = {}
    for i, (pred, label) in enumerate(zip(preds, out_label_ids)):
      if pred != label:
        idxs.update({i: (pred, label)})

    return idxs

## Configure model and create dataloaders

In [53]:
from torch.utils.data import DataLoader
from tqdm import tqdm

#
major_dataloader = DataLoader(
            major_dataset,  
            sampler = SequentialSampler(major_dataset),
            batch_size = batch_size,
            collate_fn=collate_fn
            )

#
base_dataloader = DataLoader(
            base_dataset,  
            sampler = SequentialSampler(base_dataset), 
            batch_size = batch_size,
            collate_fn=collate_fn
            )

transformer_config = AutoConfig.from_pretrained(config_name if config_name else model_name)
tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name if tokenizer_name else model_name,
        do_lower_case=True,
        cache_dir=None,
    )
transformer = AutoModel.from_pretrained(model_name, config=transformer_config, cache_dir=None)
img_encoder = ImageEncoderDenseNet(num_image_embeds=num_img_embeddings)
multimodal_config = MMBTConfig2(transformer, img_encoder, num_labels=num_labels, modal_hidden_size=1024)
mmbt_model = MMBTForClassification2(transformer_config, multimodal_config)

mmbt_model.to(device)

MODELS_DIR = os.path.join(DATA_DIR, 'models')
MODELS_MMBT_DIR = os.path.join(MODELS_DIR, 'mmbt')

models = ["mmbt_findings.bin", "mmbt_major.bin"]
dataloaders = [base_dataloader, major_dataloader]

Loading a saved model from /content/drive/.shortcut-targets-by-id/1gwgx4ZApTKz5fN6SG9YkiVjVCZ0WNGeH/LAP_MMBT/data/models/saved_chexnet.pt


## Run base MMBT model

Loads the pretrained weights and evaluates model

In [54]:
mmbt_base_checkpoint = os.path.join(MODELS_MMBT_DIR, models[0])
mmbt_model.load_state_dict(torch.load(mmbt_base_checkpoint, map_location=device))

mmbt_base_result = mmbt_evaluate(mmbt_model, tokenizer, dataloaders[0])
mmbt_base_idxs = list(mmbt_base_result.keys())

Evaluating: 100%|██████████| 18/18 [03:10<00:00, 10.59s/it]


## False predictions

- Key: index in dataset
- Value: tuple (prediction, target)

In [55]:
mmbt_base_result

{2: (0, array([1])),
 5: (0, array([1])),
 10: (0, array([1])),
 15: (0, array([1])),
 19: (0, array([1])),
 22: (0, array([1])),
 26: (1, array([0])),
 35: (0, array([1])),
 52: (1, array([0])),
 90: (0, array([1])),
 91: (0, array([1])),
 98: (0, array([1])),
 99: (0, array([1])),
 100: (1, array([0])),
 104: (0, array([1])),
 106: (0, array([1])),
 107: (0, array([1])),
 108: (0, array([1])),
 113: (0, array([1])),
 124: (0, array([1])),
 140: (0, array([1])),
 150: (1, array([0])),
 156: (0, array([1])),
 160: (0, array([1])),
 161: (1, array([0])),
 178: (0, array([1])),
 200: (0, array([1])),
 204: (1, array([0])),
 208: (0, array([1])),
 211: (0, array([1])),
 218: (0, array([1])),
 219: (0, array([1])),
 221: (1, array([0])),
 224: (0, array([1])),
 229: (0, array([1])),
 242: (1, array([0])),
 247: (0, array([1])),
 248: (1, array([0])),
 260: (1, array([0])),
 263: (0, array([1])),
 271: (0, array([1])),
 273: (1, array([0])),
 291: (0, array([1])),
 296: (0, array([1])),
 29

In [56]:
len(mmbt_base_idxs)

84

In [57]:
len(base_idxs)

70

## Investigate if there are cases where MMBT corrects our unimodal model.



In [58]:
mmbt_base_corrected = []
for idx in base_idxs:
  if idx not in mmbt_base_idxs:
    mmbt_base_corrected.append(idx)

## Visualize those cases where MMBT corrected.

label = target label

In [59]:
df_mmbt_base_corrected = [baseline.iloc[idx] for idx in mmbt_base_corrected]
df_mmbt_base_corrected = pd.DataFrame(df_mmbt_base_corrected)
df_mmbt_base_corrected

Unnamed: 0.1,Unnamed: 0,img,label,text
49,1628,CXR3355_IM-1609-1001.png,0,"Cardiomediastinal silhouettes are within normal limits. Lungs are hyperexpanded. Lungs are clear without focal consolidation, pneumothorax, or pleural effusion. Bony thorax is unremarkable."
54,1707,CXR3494_IM-1699-1001.png,0,There are low lung volumes with bronchovascular crowding. There is no focal consolidation. No visualized pneumothorax. Heart size is within normal limits. The cardiomediastinal contours is grossly normal in size and contour.
56,50,CXR102_IM-0016-1001.png,0,"Normal heart size. Clear, hyperaerated lungs. No pneumothorax. No pleural effusion. XXXX substernal density may be related to a pectus deformity."
80,330,CXR655_IM-2231-1001.png,0,Normal heart size and mediastinal contours. The lungs are hyperinflated but clear. No pneumothorax or pleural effusion. No acute bony abnormalities.
86,2493,CXR2390_IM-0944-1001.png,1,The heart is borderline in size. The mediastinum is stable with changes of XXXX sternotomy and bypass graft. Aorta is atherosclerotic. There are postsurgical changes of the left hemithorax with mild left-sided volume loss as evidenced by diaphragm elevation. Left post thoracotomy rib changes are noted. The right lung is clear. There is no pleural effusion.
88,2473,CXR2314_IM-0889-1001.png,1,The lungs are XXXX. XXXX opacities are present in the right costophrenic XXXX. No focal infiltrates. Heart size normal.
115,2836,CXR3961_IM-2026-1001.png,1,The heart size is normal. The mediastinal contour is within normal limits. There is a streaky opacity within the right upper lobe. There are no nodules or masses. No visible pneumothorax. No visible pleural fluid. The XXXX are grossly normal. There is no visible free intraperitoneal air under the diaphragm.
128,1817,CXR3708_IM-1852-1001.png,0,"Lung volumes are decreased from XXXX, and there is resultant bronchovascular crowding. No evidence of focal airspace disease. No definite pleural effusion or pneumothorax. Cardiomediastinal silhouette is within normal limits given the low lung volumes. No free subdiaphragmatic air. Grossly stable mild degenerative changes of the right lower thoracic spine."
136,6,CXR14_IM-0256-1001.png,0,"Heart size within normal limits, stable mediastinal and hilar contours. Mild hyperinflation appears similar to prior. No focal alveolar consolidation, no definite pleural effusion seen. Scattered chronic appearing irregular interstitial markings, no typical findings of pulmonary edema."
143,2622,CXR3005_IM-1388-1001.png,1,"The heart is normal in size. The mediastinum is stable. Aorta is tortuous. Calcified lymph XXXX are again identified. There is mild prominence of the right paratracheal soft tissues, stable in appearance from prior studies. There is no acute infiltrate or pleural effusion. Osteopenia and degenerative changes are identified. XXXX deformity of T9 appears worse than prior study."


In [60]:
len(df_mmbt_base_corrected)

27

## Run major MMBT model

Loads the pretrained weights and evaluates model

In [61]:
mmbt_major_checkpoint = os.path.join(MODELS_MMBT_DIR, models[1])
mmbt_model.load_state_dict(torch.load(mmbt_major_checkpoint, map_location=device))

mmbt_major_result = mmbt_evaluate(mmbt_model, tokenizer, dataloaders[1])
mmbt_major_idxs = list(mmbt_major_result.keys())

Evaluating: 100%|██████████| 18/18 [00:10<00:00,  1.79it/s]


## False predictions

- Key: index in dataset
- Value: tuple (prediction, target)

In [62]:
mmbt_major_result

{5: (1, array([0])),
 27: (0, array([1])),
 48: (0, array([1])),
 63: (1, array([0])),
 107: (1, array([0])),
 135: (1, array([0])),
 166: (1, array([0])),
 193: (0, array([1])),
 234: (0, array([1])),
 260: (1, array([0])),
 277: (0, array([1])),
 283: (0, array([1])),
 291: (1, array([0])),
 303: (1, array([0])),
 306: (1, array([0])),
 340: (1, array([0])),
 369: (1, array([0])),
 378: (0, array([1])),
 403: (0, array([1])),
 407: (0, array([1])),
 424: (1, array([0])),
 431: (1, array([0])),
 448: (1, array([0])),
 463: (1, array([0])),
 479: (0, array([1])),
 547: (1, array([0]))}

In [63]:
len(mmbt_major_idxs)

26

In [64]:
len(major_idxs)

17

## Investigate if there are cases where MMBT corrects our unimodal model.

In [65]:
mmbt_major_corrected = []
for idx in major_idxs:
  if idx not in mmbt_major_idxs:
    mmbt_major_corrected.append(idx)

## Visualize those cases where MMBT corrected.

label = target label


In [66]:
df_mmbt_major_corrected = [major.iloc[idx] for idx in mmbt_major_corrected]
df_mmbt_major_corrected = pd.DataFrame(df_mmbt_major_corrected)
df_mmbt_major_corrected

Unnamed: 0.1,Unnamed: 0,img,label,text
185,2508,CXR2448_IM-0983-1001.png,1,Mediastinal contours are normal. No significant change in pneumothorax or right pleural fluid..
282,683,CXR1392_IM-0251-1001.png,1,"The aortic XXXX, cardiac apex, and stomach are left-sided. Cardiomediastinal silhouette is within normal limits in overall size and appearance. Pulmonary vascular markings are symmetric and within normal limits. The lungs are normally inflated with no focal airspace disease, pleural effusion, or pneumothorax. No acute bony abnormality."
473,1171,CXR2406_IM-0954-1001.png,1,Both lungs remain clear and expanded. Heart and pulmonary XXXX are normal. No change in the large hiatus hernia.


In [None]:
len(df_mmbt_major_corrected)

3

## Check if there are cases of false predictions that overlap between different models.

In [67]:
set(mmbt_base_idxs) & set(mmbt_major_idxs)

{5, 107, 260, 291, 378}

In [68]:
set(base_idxs) & set(major_idxs)

{185, 291}

In [69]:
set(base_idxs) & set(mmbt_base_idxs)

{5,
 10,
 15,
 19,
 22,
 91,
 98,
 99,
 100,
 106,
 107,
 108,
 113,
 140,
 156,
 204,
 219,
 224,
 229,
 242,
 260,
 271,
 273,
 291,
 296,
 297,
 313,
 410,
 421,
 434,
 437,
 443,
 458,
 485,
 487,
 488,
 498,
 520,
 530,
 531,
 533,
 542,
 564}

In [70]:
set(major_idxs) & set(mmbt_major_idxs)

{48, 63, 135, 166, 277, 291, 306, 340, 403, 424, 431, 448, 479, 547}

In [71]:
drive.flush_and_unmount()
print('All changes made in this colab session should now be visible in Drive.')

All changes made in this colab session should now be visible in Drive.
