<a href="https://colab.research.google.com/github/AxelAllen/Pre-trained-Multimodal-Text-Image-Classifier-in-a-Sparse-Data-Application/blob/master/run_mmbt_masked_text_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Evaluating MMBT on Masked Text Test Partition

This notebook shows the end-to-end pipeline for fine-tuning pre-trained MMBT model for multimodal (text and image) classification on our dataset.

Parts of this pipeline are adapted from the
Huggingface `run_mmimdb.py` script to execute the MMBT model. This code can
be accessed [here.](https://github.com/huggingface/transformers/blob/8ea412a86faa8e9edeeb6b5c46b08def06aa03ea/examples/research_projects/mm-imdb/run_mmimdb.py#L305). 

The code is slightly modified from `run_mmbt.ipynb` notebook to evaluate trained MMBT models on test data with all the text inputs masked. i.e. to test how multimodally fine-tuned MMBT perform at image classification.

## Skip unless on Google Colab

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/LAP_MMBT
%pwd

/content/drive/MyDrive/LAP_MMBT


'/content/drive/MyDrive/LAP_MMBT'

Before running the cell below, make sure to select 'GPU' runtime type

In [3]:
import torch

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: Tesla T4


## Install Huggingface Library

These should have been installed during your environment set-up; you only need to run these cells in Google Colab.

In [4]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/81/91/61d69d58a1af1bd81d9ca9d62c90a6de3ab80d77f27c5df65d9a2c1f5626/transformers-4.5.0-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.2MB 7.9MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 47.9MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/08/cd/342e584ee544d044fb573ae697404ce22ede086c9e87ce5960772084cad0/sacremoses-0.0.44.tar.gz (862kB)
[K     |████████████████████████████████| 870kB 42.6MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.44-cp37-none-any.whl size=886084 sha256=a80b2f231ea

# Data directories and file paths

Paths to data files options are provide in the following cell. Uncomment the train/val/test partitions according to the desired labeling scheme:

- filenames with 'major' are labeled with the 'major' metadata column text
- filenames without  'major'are labeled with the 'impression' metadata column text
- filenames with 'multi' are labeled for multiclass classification
- filename without 'multi' are labeled for binary classification


In [5]:
#train_file = "image_labels_impression_frontal_train.jsonl"
#val_file = "image_labels_impression_frontal_val.jsonl"
#test_file = "image_labels_impression_frontal_test.jsonl"

train_file = "image_multi_labels_major_findings_frontal_train.jsonl"
val_file = "image_multi_labels_major_findings_frontal_val.jsonl"
test_file = "image_multi_labels_major_findings_frontal_test.jsonl"


#train_file = "image_labels_major_findings_frontal_train.jsonl"
#val_file = "image_labels_major_findings_frontal_val.jsonl"
#test_file = "image_labels_major_findings_frontal_test.jsonl"


#train_file = "image_labels_findings_frontal_train.jsonl"
#val_file = "image_labels_findings_frontal_val.jsonl"
#test_file = "image_labels_findings_frontal_test.jsonl"

## Import Required Modules

In [6]:
from textBert_utils import set_seed
from MMBT.image import ImageEncoderDenseNet
from MMBT.mmbt_config import MMBTConfig
from MMBT.mmbt import MMBTForClassification

In [7]:
from MMBT.mmbt_utils import JsonlDataset, get_image_transforms, get_labels, load_examples, collate_fn, get_multiclass_labels, collate_fn_mask_all_text, get_multiclass_criterion

In [8]:
import argparse

In [9]:
import glob
import logging
import random
import json
import os
from collections import Counter
import numpy as np
from matplotlib.pyplot import imshow

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

In [10]:
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm, trange

from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

In [11]:
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

# Set-up Experiment Hyperparameters and Arguments

Specify the training, validation, and test files to run the experiment on. The default here is running the model on 'impression' texts.  

To re-make the training, validation, and test data, please refer to the information in the **data/** directory.  

Change the default values in the parser.add_argument function for the hyperparameters that you want to specify in the following cell or use the default option.  

For multiple experiment runs, please make sure to change the `output_dir` argument so that new results don't overwrit existing ones.

The arguments specified here are the same as in the `run_mmimdb.py` file 
in the [Huggingface example implementation of MMBT.](https://github.com/huggingface/transformers/blob/8ea412a86faa8e9edeeb6b5c46b08def06aa03ea/examples/research_projects/mm-imdb/run_mmimdb.py#L305)

In [12]:
parser = argparse.ArgumentParser(f'Project Hyperparameters and Other Configurations Argument Parser')

parser = argparse.ArgumentParser()

# Required parameters
parser.add_argument(
    "--data_dir",
    default="data/json",
    type=str,
    help="The input data dir. Should contain the .jsonl files.",
)
parser.add_argument(
    "--model_name",
    default="bert-base-uncased",
    type=str,
    help="model identifier from huggingface.co/models",
)
parser.add_argument(
    "--output_dir",
    default="mmbt_output_findings_multi_major",
    type=str,
    help="The output directory where the model predictions and checkpoints will be written.",
)

    
parser.add_argument(
    "--config_name", default="bert-base-uncased", type=str, help="Pretrained config name if not the same as model_name"
)
parser.add_argument(
    "--tokenizer_name",
    default="bert-base-uncased",
    type=str,
    help="Pretrained tokenizer name or path if not the same as model_name",
)

parser.add_argument("--train_batch_size", default=16, type=int, help="Batch size for training.")
parser.add_argument(
    "--eval_batch_size", default=16, type=int, help="Batch size for evaluation."
)
parser.add_argument(
    "--max_seq_length",
    default=300,
    type=int,
    help="The maximum total input sequence length after tokenization. Sequences longer "
    "than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
    "--num_image_embeds", default=3, type=int, help="Number of Image Embeddings from the Image Encoder"
)
parser.add_argument("--do_train", default=True, type=bool, help="Whether to run training.")
parser.add_argument("--do_eval", default=True, type=bool, help="Whether to run eval on the dev set.")
parser.add_argument(
    "--evaluate_during_training", default=True, type=bool, help="Run evaluation during training at each logging step."
)


parser.add_argument(
    "--gradient_accumulation_steps",
    type=int,
    default=1,
    help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.1, type=float, help="Weight deay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
    "--num_train_epochs", default=6.0, type=float, help="Total number of training epochs to perform."
)
parser.add_argument("--patience", default=5, type=int, help="Patience for Early Stopping.")
parser.add_argument(
    "--max_steps",
    default=-1,
    type=int,
    help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

parser.add_argument("--logging_steps", type=int, default=25, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=25, help="Save checkpoint every X updates steps.")
parser.add_argument(
    "--eval_all_checkpoints",
    default=True, type=bool,
    help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)

parser.add_argument("--num_workers", type=int, default=8, help="number of worker threads for dataloading")

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")


args = parser.parse_args("")

# Setup CUDA, GPU & distributed training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count() if torch.cuda.is_available() else 0
args.device = device

# for multiclass labeling
args.multiclass = True

In [13]:
# Setup Train/Val/Test filenames
args.train_file = train_file
args.val_file = val_file
args.test_file = test_file

## Showing a sample from JsonDataset
i.e. calling "\_\_getitem\_\_"

Note:   
image_end_token is the BERT token id for [SEP].   
image_start_token is the BERT token id for [CLS]. 


In [14]:
tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name,
        do_lower_case=True,
        cache_dir=None,
    )
train_dataset = load_examples(tokenizer, args)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




In [15]:
train_dataset[0]

{'image': tensor([[[-0.7650, -0.7479, -0.7308,  ..., -0.3541, -0.3369, -0.3198],
          [-0.7137, -0.7137, -0.6794,  ..., -0.2171, -0.1828, -0.1999],
          [-0.6109, -0.6109, -0.6109,  ..., -0.1143, -0.0801, -0.0801],
          ...,
          [ 1.8722,  1.9064,  1.9064,  ...,  1.6324,  1.6667,  1.7523],
          [ 1.8893,  1.9064,  1.9407,  ...,  1.6153,  1.6838,  1.7523],
          [ 1.8722,  1.9064,  1.9407,  ...,  1.6324,  1.7180,  1.7694]],
 
         [[-0.6527, -0.6352, -0.6176,  ..., -0.2325, -0.2150, -0.1975],
          [-0.6001, -0.6001, -0.5651,  ..., -0.0924, -0.0574, -0.0749],
          [-0.4951, -0.4951, -0.4951,  ...,  0.0126,  0.0476,  0.0476],
          ...,
          [ 2.0434,  2.0784,  2.0784,  ...,  1.7983,  1.8333,  1.9209],
          [ 2.0609,  2.0784,  2.1134,  ...,  1.7808,  1.8508,  1.9209],
          [ 2.0434,  2.0784,  2.1134,  ...,  1.7983,  1.8859,  1.9384]],
 
         [[-0.4275, -0.4101, -0.3927,  ..., -0.0092,  0.0082,  0.0256],
          [-0.3753,


### Evaluating Functions.

Experimenting with masking the text input during testing.

In [16]:
def evaluate(args, model, tokenizer, evaluate=True, test=False, prefix=""):
    
    if test:
        # start a separate tensorboard to track testing eval result
        comment = f"masked_text_test_{args.output_dir}_{args.eval_batch_size}"
        tb_writer = SummaryWriter(comment=comment)

    eval_output_dir = args.output_dir
    eval_dataset = load_examples(tokenizer, args, evaluate=evaluate, test=test)

    if not os.path.exists(eval_output_dir):
        os.makedirs(eval_output_dir)

    # mask all text input
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn_mask_all_text
    )

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = []
    out_label_ids = []
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)

        with torch.no_grad():
            batch = tuple(t.to(args.device) for t in batch)
            labels = batch[5]
            input_ids = batch[0]
            input_modal = batch[2]
            attention_mask = batch[1]
            modal_start_tokens = batch[3]
            modal_end_tokens = batch[4]
            
            if args.multiclass:
                outputs = model(
                    input_modal,
                    input_ids=input_ids,
                    modal_start_tokens=modal_start_tokens,
                    modal_end_tokens=modal_end_tokens,
                    attention_mask=attention_mask,
                    token_type_ids=None,
                    modal_token_type_ids=None,
                    position_ids=None,
                    modal_position_ids=None,
                    head_mask=None,
                    inputs_embeds=None,
                    labels=None,
                    return_dict=True
                )
            else:
                outputs = model(
                    input_modal,
                    input_ids=input_ids,
                    modal_start_tokens=modal_start_tokens,
                    modal_end_tokens=modal_end_tokens,
                    attention_mask=attention_mask,
                    token_type_ids=None,
                    modal_token_type_ids=None,
                    position_ids=None,
                    modal_position_ids=None,
                    head_mask=None,
                    inputs_embeds=None,
                    labels=labels,
                    return_dict=True
                )
            #logits = outputs[0]  # model outputs are always tuple in transformers (see doc)
            #tmp_eval_loss = criterion(logits, labels)
            logits = outputs.logits
            if args.multiclass:
                criterion = get_multiclass_criterion(eval_dataset)
                tmp_eval_loss = criterion(logits, labels)
            else:
                tmp_eval_loss = outputs.loss
            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        # Move logits and labels to CPU
        if args.multiclass:
            pred = torch.sigmoid(logits).cpu().detach().numpy() > 0.5
        else:            
            pred = torch.nn.functional.softmax(logits, dim=1).argmax(dim=1).cpu().detach().numpy()
        out_label_id = labels.detach().cpu().numpy()
        preds.append(pred)
        out_label_ids.append(out_label_id)

    eval_loss = eval_loss / nb_eval_steps

    result = {"loss": eval_loss}

    if args.multiclass:
        tgts = np.vstack(out_label_ids)
        preds = np.vstack(preds)
        result["macro_f1"] = f1_score(tgts, preds, average="macro")
        result["micro_f1"] = f1_score(tgts, preds, average="micro")
    else:
        preds = [l for sl in preds for l in sl]
        out_label_ids = [l for sl in out_label_ids for l in sl]
        result["accuracy"] = accuracy_score(out_label_ids, preds)

    output_eval_file = os.path.join(eval_output_dir, prefix, "masked_text_eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))
            if test:
                tb_writer.add_scalar(f'eval_{key}', result[key], nb_eval_steps)
    
    if test:
        tb_writer.close()


    return result


## Setting up the MMBT Model 

Set up logging and the MMBT Model. Similar to the text-only model, check points 
are saved during a similar customizable interval.



In [17]:
# Setup logging
logger = logging.getLogger(__name__)
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
                    datefmt="%m/%d/%Y %H:%M:%S",
                    filename=os.path.join(args.output_dir, f"masked_text_{os.path.splitext(args.test_file)[0]}_logging.txt"),
                    level=logging.INFO)
logger.warning("device: %s, n_gpu: %s",
        args.device,
        args.n_gpu
)
# Set the verbosity to info of the Transformers logger (on main process only):

# Set seed
set_seed(args)

In [18]:
# Setup model
if args.multiclass:
    labels = get_multiclass_labels()
    num_labels = len(labels)
else:
    labels = get_labels()
    num_labels = len(labels)
transformer_config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
transformer = AutoModel.from_pretrained(args.model_name, config=transformer_config, cache_dir=None)
img_encoder = ImageEncoderDenseNet(num_image_embeds=args.num_image_embeds)
multimodal_config = MMBTConfig(transformer, img_encoder, num_labels=num_labels, modal_hidden_size=1024)

logger.info(f"Evaluation parameters: {args}")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




## Evaluating saved model checkpoints on the Test Set

The test set has been modified to only provide the image embeddings tokens and all the text inputs are masked.


In [19]:
%pdb on
# Evaluation
results = {}
if args.do_eval:
    checkpoints = [args.output_dir]
    if args.eval_all_checkpoints:
        checkpoints = list(os.path.dirname(c) 
        for c in sorted(glob.glob(args.output_dir + "/**/" + 
                                  WEIGHTS_NAME, recursive=False)))
        # recursive=False because otherwise the parent diretory gets included
        # which is not what we want; only subdirectories

    logger.info("Evaluate the following checkpoints: %s", checkpoints)

    for checkpoint in checkpoints:
        global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
        prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
        model = MMBTForClassification(transformer_config, multimodal_config)
        checkpoint = os.path.join(checkpoint, 'pytorch_model.bin')
        model.load_state_dict(torch.load(checkpoint))
        model.to(args.device)
        result = evaluate(args, model, tokenizer, evaluate=True, test=True, prefix=prefix)
        result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
        results.update(result)

results.keys()

Automatic pdb calling has been turned ON


Evaluating: 100%|██████████| 36/36 [02:40<00:00,  4.45s/it]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.55it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.49it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.50it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.45it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.46it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.38it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.43it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.48it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.44it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.42it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.39it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.45it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.45it/s]
Evaluating: 100%|██████████| 36/36 [00:10<00:00,  3.45it/s]
Evaluating: 100%|██████████| 36/36 [00:1

dict_keys(['loss_100', 'macro_f1_100', 'micro_f1_100', 'loss_125', 'macro_f1_125', 'micro_f1_125', 'loss_150', 'macro_f1_150', 'micro_f1_150', 'loss_175', 'macro_f1_175', 'micro_f1_175', 'loss_200', 'macro_f1_200', 'micro_f1_200', 'loss_225', 'macro_f1_225', 'micro_f1_225', 'loss_25', 'macro_f1_25', 'micro_f1_25', 'loss_250', 'macro_f1_250', 'micro_f1_250', 'loss_275', 'macro_f1_275', 'micro_f1_275', 'loss_300', 'macro_f1_300', 'micro_f1_300', 'loss_325', 'macro_f1_325', 'micro_f1_325', 'loss_350', 'macro_f1_350', 'micro_f1_350', 'loss_375', 'macro_f1_375', 'micro_f1_375', 'loss_400', 'macro_f1_400', 'micro_f1_400', 'loss_425', 'macro_f1_425', 'micro_f1_425', 'loss_450', 'macro_f1_450', 'micro_f1_450', 'loss_475', 'macro_f1_475', 'micro_f1_475', 'loss_50', 'macro_f1_50', 'micro_f1_50', 'loss_500', 'macro_f1_500', 'micro_f1_500', 'loss_525', 'macro_f1_525', 'micro_f1_525', 'loss_550', 'macro_f1_550', 'micro_f1_550', 'loss_575', 'macro_f1_575', 'micro_f1_575', 'loss_600', 'macro_f1_600',

In [20]:
results

{'loss_100': 1.658436440759235,
 'loss_125': 1.7200494209925334,
 'loss_150': 1.6319334639443293,
 'loss_175': 1.91707839568456,
 'loss_200': 2.07403427362442,
 'loss_225': 2.3141628238889904,
 'loss_25': 1.1056627333164215,
 'loss_250': 2.4645030399163566,
 'loss_275': 2.7903796103265552,
 'loss_300': 2.0998321142461567,
 'loss_325': 2.337274544768863,
 'loss_350': 2.8250158958964877,
 'loss_375': 2.8545848296748266,
 'loss_400': 2.4531858232286243,
 'loss_425': 2.341753406657113,
 'loss_450': 2.415829267766741,
 'loss_475': 3.101353704929352,
 'loss_50': 1.1591884659396277,
 'loss_500': 2.997980647616916,
 'loss_525': 3.041702992386288,
 'loss_550': 3.06148824095726,
 'loss_575': 3.090206411149767,
 'loss_600': 3.1117879615889654,
 'loss_625': 3.13253300719791,
 'loss_75': 1.1982646253373888,
 'macro_f1_100': 0.20829822623667935,
 'macro_f1_125': 0.20422464638160143,
 'macro_f1_150': 0.39636227492773307,
 'macro_f1_175': 0.2538169308633725,
 'macro_f1_200': 0.2401805037528204,
 'macr

## Saving Test Eval Results

The code automatically saved evaluation result from each checkpoint in its respective folder. This next cell simply saves all of them in one place.

In [21]:
with open(os.path.join(args.output_dir, f"{os.path.splitext(args.test_file)[0]}_masked_text_eval_results.txt"), mode='w', encoding='utf-8') as out_f:
    print(results, file=out_f)