In [None]:
import json
import bigbench.models.huggingface_models as huggingface_models
from captum.attr import Lime
from captum.attr import visualization as viz
import torch
from transformers import GPT2Tokenizer
from IPython.core.display import HTML, display

def show_text_attr(tokens, attrs):
    def rgb(x): return '255,0,0' if x < 0 else '0,255,0'
    def alpha(x): return abs(x) ** 0.5
    token_marks = [
        f'<mark style="background-color:rgba({rgb(attr)},{alpha(attr)})">{token}</mark>'
        for token, attr in zip(tokens, attrs.tolist())
    ]

    display(HTML('<p>' + ' '.join(token_marks) + '</p>'))

# def show_text_attr(tokens, attrs):
#     def rgb(x): return 'red' if x < 0 else 'green'
#     def alpha(x): return abs(x) ** 0.5
#     token_marks = [
#         f'\\textcolor[rgb]{{{rgb(attr)},{alpha(attr),{alpha(attr)}}{{{token}}}'
#         for token, attr in zip(tokens, attrs.tolist())
#     ]

    latex_output = ' '.join(token_marks)
    return latex_output

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions


def load_data(path):
    with open(path, 'r') as f:
        return json.load(f)
    

def predict(inputs):
    text = [tokenizer.decode(input) for input in inputs]
    return torch.tensor([model.cond_log_prob(t, "He's recovering") for t in text])

input = """CHAPTER XXIX \n\nFrona had gone at once to her father's side, but he was already recovering. Courbertin was brought forward with a scratched face, sprained wrist, and an insubordinate tongue. To prevent discussion and to save time, Bill Brown claimed the floor. \n\n\"Mr. Chairman, while we condemn the attempt on the part of Jacob Welse, Frona Welse, and Baron Courbertin to rescue the prisoner and thwart justice, we cannot, under the circumstances, but sympathize with them. There is no need that I should go further into this matter. You all know, and doubtless, under a like situation, would have done the same. And so, in order that we may expeditiously finish the business, I make a motion to disarm the three prisoners and let them go.\" \n\nThe motion was carried, and the two men searched for weapons. Frona was saved this by giving her word that she was no longer armed. The meeting then resolved itself into a hanging committee, and began to file out of the cabin. \n\n\"Sorry I had to do it,\" the chairman said, half-apologetically, half-defiantly. \n\nJacob Welse smiled. \"You took your chance,\" he answered, \"and I can't blame you. I only wish I'd got you, though.\" \n\nExcited voices arose from across the cabin. \"Here, you! Leggo!\" \"Step on his fingers, Tim!\" \"Break that grip!\" \"Ouch! Ow!\" \"Pry his mouth open!\" \n\nFrona saw a knot of struggling men about St. Vincent, and ran over. He had thrown himself down on the floor and, tooth and nail, was fighting like a madman. Tim Dugan, a stalwart Celt, had come to close quarters with him, and St. Vincent's teeth were sunk in the man's arm.
How is Frona's father doing?
"""

model = huggingface_models.BIGBenchHFModel('gpt2-large')

# Tokenize the input
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
input_ids = tokenizer.encode(input, return_tensors='pt')

lime = Lime(predict)

explanation = lime.attribute(input_ids.squeeze(), target=0, show_progress=True)

# Decode the entire sequence of token IDs into a string
text = tokenizer.decode(input_ids.squeeze())

# Split the string into individual tokens
tokens = text.split()

# Call show_text_attr with the list of tokens and the attributions
show_text_attr(tokens, explanation)
