## Imports

In [None]:
from model import ReVisionProcessor, ReVisionForConditionalGeneration

from datautils import RevisionRewriteDataset
# from datautils import RevisionRewriteDatasetWithMetadata
from transformers import BitsAndBytesConfig
from torch.utils.data import DataLoader
import pandas as pd
import torch
import os
from tqdm import tqdm

import evaluate
import nltk

nltk.download("wordnet")
nltk.download("punkt")
nltk.download("omw-1.4")

from dotenv import load_dotenv

load_dotenv(".env")

from prettytable import PrettyTable

## Model/Data Initialization

Below are the evaluation parameters/variables.

- `DEVICE` determines the compute device the tensors are sent to.
- `MODEL_ID` is the repository name for the model that you are evaluating
- `HF_TOKEN` obtains the HuggingFace authentication token from the environment. If you are getting an error on this line, either add it to the environment or create a `.env` file with `HF_TOKEN`. See the imports for `load_dotenv`.
- `BATCH_SIZE` is the number of examples run at a time. Higher number for speed, but setting it to 1 is the most accurate (used for the paper)
- `DATASET_SUFFIX` is the suffix in the dataset file name on our HuggingFace repo.
- `save_location` is the local file location that the paired evaluated completions and their corresponding reference text are stored.

Note that you may or may not get an error with the `images` folder. You may need to go into the Dataset object defined in `datautils.py` and modify the images folder. Not sure why it's sometimes `images`, `images/images`, or `images/images/images`...

In [None]:
DEVICE = "cuda"
MODEL_ID = "anonymoususerrevision/ReVision-250M-256-16-baseline"
HF_TOKEN = os.getenv("HF_TOKEN")
BATCH_SIZE = 4

# Quantization
USE_8BIT = True
USE_16BIT = False

# Can only use 1 type of quantization
assert not (USE_8BIT and USE_16BIT)

# Prefix already set as "test"
# DATASET_SUFFIX = "_with_metadata"
save_location = "results_baseline_8bit.tsv"

In [None]:
def set_seed(seed):
    """Set seed for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)


def collate_fn(examples, processor, to_bf16=False):
    # Unpack the tuples into individual components
    images = [example[0].convert("RGB") for example in examples]
    texts = [example[1].replace("\n", "") for example in examples]  # Prompts
    labels = [example[2].replace("\n", "") for example in examples]  # Responses

    # Process with the processor
    tokens = processor(
        text=texts,
        images=images,
        # suffix=labels, #! Make sure we don't supply the answer to the model
        return_tensors="pt",
        padding="max_length",
        max_length=1024,  # hmm.
        tokenize_newline_separately=False,
    )

    if to_bf16:
        tokens = tokens.to(torch.bfloat16)

    tokens["labels"] = labels

    return tokens

In [None]:
processor = ReVisionProcessor.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)

if USE_8BIT:
    print("Using 8-bit quantization")
    config = BitsAndBytesConfig(
        load_in_8bit=True, llm_int8_skip_modules=["vision_tower.vision_model"]
    )
    model = ReVisionForConditionalGeneration.from_pretrained(
        MODEL_ID,
        use_auth_token=HF_TOKEN,
        quantization_config=config,
        device_map=DEVICE,
    )
    model = model.eval()

elif USE_16BIT:
    model = ReVisionForConditionalGeneration.from_pretrained(
        MODEL_ID, use_auth_token=HF_TOKEN, torch_dtype=torch.float16, device_map=DEVICE
    ).eval()
else:
    model = ReVisionForConditionalGeneration.from_pretrained(
        MODEL_ID, use_auth_token=HF_TOKEN, device_map=DEVICE
    ).eval()

print(f"Model is running on FPType: {model.dtype}")

# for name, param in model.named_parameters():
#     print(f"Layer: {name}, dtype: {param.dtype}")

# ! silences warning:"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation"
model.generation_config.pad_token_id = processor.tokenizer.pad_token_id

# test_dataset = RevisionRewriteDatasetWithMetadata(
#     split="test",
#     filename_suffix=DATASET_SUFFIX,
#     use_auth_token=HF_TOKEN,
#     processor=processor,
# )
test_dataset = RevisionRewriteDataset(split="test",
                                                  use_auth_token=HF_TOKEN,
                                                  processor=processor)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=lambda x: collate_fn(x, processor),
)

## Predicted Completions

### Calculate Completions From Model

In [None]:
def generate_predictions(test_dataloader, model, processor, device):
    predictions, references = [], []
    for batch in tqdm(test_dataloader):
        batch_encoding = {k: v.to(device) for k, v in batch.items() if k != "labels"}
        input_len = batch_encoding["input_ids"].shape[-1]
        with torch.no_grad():
            output = model.generate(
                **batch_encoding,
                max_new_tokens=256,
                do_sample=False,
                repetition_penalty=1.5,
            )
        generated_texts = processor.batch_decode(
            output[:, input_len:], skip_special_tokens=True
        )

        # remove "assistant\n"
        cleaned_texts = [text.removeprefix("assistant\n") for text in generated_texts]

        predictions.extend(cleaned_texts)
        references.extend(batch["labels"])

    return predictions, references


def save_to_tsv(predictions, references, file_name="results.tsv"):
    # Create a DataFrame
    data = {"Prediction": predictions, "Reference": references}
    df = pd.DataFrame(data)

    # Save DataFrame to a TSV file
    df.to_csv(file_name, sep="\t", index=False)
    print(f"Results saved to: {file_name}")

In [None]:
predictions, references = generate_predictions(
    test_dataloader, model, processor, DEVICE
)
save_to_tsv(predictions, references, save_location)

### OR Load Completions From Local Storage

If you already obtained the predicted data, then you can run from this cell and below. Make sure you at least run the import section first.

In [None]:
def get_predictions_references_from_tsv(file_name: str):
    df = pd.read_csv(file_name, delimiter="\t")
    return df["Prediction"].to_list(), df["Reference"].to_list()

In [None]:
predictions, references = get_predictions_references_from_tsv(save_location)

## Calculate Evaluation Metrics

In [None]:
def get_evaluations(predictions, references):
    bleu_metric = evaluate.load("bleu")
    rouge_metric = evaluate.load("rouge")
    meteor_metric = evaluate.load("meteor")

    bleu_score = bleu_metric.compute(predictions=predictions, references=references)
    rouge_score = rouge_metric.compute(predictions=predictions, references=references)
    meteor_score = meteor_metric.compute(predictions=predictions, references=references)

    print("BLEU Score:", bleu_score)
    print("ROUGE Score:", rouge_score)
    print("METEOR Score:", meteor_score)

    table = PrettyTable()
    table.field_names = ["Metric", "Score Details"]
    table.add_row(["BLEU", f"{bleu_score['bleu']:.6f}"])

    rouge_details = (
        f"ROUGE-1: {rouge_score['rouge1']:.6f}, "
        f"ROUGE-2: {rouge_score['rouge2']:.6f}, "
        f"ROUGE-L: {rouge_score['rougeL']:.6f}"
    )
    table.add_row(["ROUGE", rouge_details])
    table.add_row(["METEOR", f"{meteor_score['meteor']:.6f}"])

    print("Evaluation Metrics Report")
    print(table)

In [None]:
get_evaluations(predictions, references)