In [8]:
import os 
import gc
import json
import pandas as pd
from IPython.display import display, HTML
from Roberta_SA import *
from GPT2_SA import *
from T5_SA import *
from LLaMA2_SA import *

def underline_words_in_red(text, words_to_underline):
    """ Underline words that are present in words_to_underline with a red underline. """
    for word in words_to_underline:
        text = text.replace(word, f'<u style="color: black;">{word}</u>')
    return text

def visualize_importance(input_items, normalized_importance, words_to_underline=[]):
    """ General function to visualize importance for any granularity - word, token, sentence. """
    max_alpha = 0.5
    highlighted_text = []

    for i in range(len(input_items)):
        item = input_items[i]
        weight = normalized_importance[i]
        item = item.replace('Ġ', '').replace('▁', '') # 'Ġ' roberta; '▁' T5
        if weight is not None:
            highlighted_item = f'<span style="background-color:rgba(135,206,250,{weight / max_alpha});">{item}</span>'
        else:
            highlighted_item = item
        highlighted_text.append(highlighted_item)

    combined_text = ' '.join(highlighted_text)
    combined_text = underline_words_in_red(combined_text, words_to_underline)
    display(HTML(combined_text))
    #return combined_text
    
def get_rlf_index(w_list, rlf_word):
    rlf_word_index = -1
    for i in range(len(w_list)):
        if rlf_word in w_list[i]:
            rlf_word_index = i
    return rlf_word_index

def get_Sexp(w_list, rlf_word, wis_list):
    rlf_word_index = get_rlf_index(w_list, rlf_word)
    if rlf_word_index == -1:
        return 1/len(w_list[0])
    else:
        return wis_list[rlf_word_index]

In [9]:
df_sample = pd.read_csv('../data/sample_data.csv')
index = 101
rlf_sent = df_sample.iloc[index]['rlf_sent']
label = int(df_sample.iloc[index]['label'])
rlf_word = df_sample.iloc[index]['rlf']

## Zero-shot RoBERTa

In [10]:
# you can load fine-tuned roberta by setting output_dir to checkpoint folder path and load_best = True
zeroshot_roberta = Roberta_SA(
                output_dir = '',
                load_best = False
            )
pred_y_list = zeroshot_roberta.get_sentiment([rlf_sent])
w_list, wis_list = zeroshot_roberta.get_text_list_w_imp([rlf_sent], [label])
print('zero_shot RoBERTa: ')
print('predict sentiment label: ', pred_y_list[0])
Sexp = get_Sexp(w_list[0], rlf_word, wis_list[0])
print('Sexp = {}'.format(Sexp))
visualize_importance(w_list[0], wis_list[0])
del zeroshot_roberta
gc.collect();

Loading pre-trained model:  siebert/sentiment-roberta-large-english
get_sentiment: 0/0
zero_shot RoBERTa: 
predict sentiment label:  1
Sexp = 0.0


## Zero-shot GPT2

In [11]:
# you can load fine-tuned roberta by setting output_dir to checkpoint folder path and load_best = True
zeroshot_gpt2 = GPT2_SA(
                output_dir = '',
                load_best = False
            )
pred_y_list = zeroshot_gpt2.get_sentiment([rlf_sent])
w_list, wis_list = zeroshot_gpt2.get_text_list_w_imp([rlf_sent], [label])
print('zero_shot GPT2: ')
print('predict sentiment label: ', pred_y_list[0])
Sexp = get_Sexp(w_list[0], rlf_word, wis_list[0])
print('Sexp = {}'.format(Sexp))
visualize_importance(w_list[0], wis_list[0])
del zeroshot_gpt2
gc.collect();

Loading pre-trained model:  michelecafagna26/gpt2-medium-finetuned-sst2-sentiment
zero_shot GPT2: 
predict sentiment label:  1
Sexp = 0.04761912425369596


## Zero-shot T5

In [12]:
# you can load fine-tuned roberta by setting output_dir to checkpoint folder path and load_best = True
zeroshot_t5 = T5_SA(
                output_dir = '',
                load_best = False
            )
pred_y_list = zeroshot_t5.get_sentiment([rlf_sent])
w_list, wis_list = zeroshot_t5.get_text_list_w_imp([rlf_sent], [label])
print('zero_shot T5: ')
print('predict sentiment label: ', pred_y_list[0])
Sexp = get_Sexp(w_list[0], rlf_word, wis_list[0])
print('Sexp = {}'.format(Sexp))
visualize_importance(w_list[0], wis_list[0])
del zeroshot_t5
gc.collect();

Loading pre-trained model:  mrm8488/t5-base-finetuned-imdb-sentiment
zero_shot T5: 
predict sentiment label:  1
Sexp = 0.0718627542656454




## Zero-shot LLaMA2

In [13]:
zeroshot_llama2 = LLaMA2_SA(
                lora_model_path = '',
                load_best = False
            )
pred_y_list = zeroshot_llama2.get_sentiment([rlf_sent])
w_list, wis_list = zeroshot_llama2.get_text_list_w_imp([rlf_sent])
print('zero_shot LLaMA2: ')
print('predict sentiment label: ', pred_y_list[0])
Sexp = get_Sexp(w_list[0], rlf_word, wis_list[0])
print('Sexp = {}'.format(Sexp))
visualize_importance(w_list[0], wis_list[0])
del zeroshot_llama2
gc.collect();

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

zero_shot LLaMA2: 
predict sentiment label:  1
Sexp = 0.38709677419354843


## ExpInstruct

In [14]:
ExpInstruct_llama2 = LLaMA2_SA(
                lora_model_path = '../ft_model/llama2/folder_0/checkpoint-6000',
                load_best = True
            )
pred_y_list = ExpInstruct_llama2.get_sentiment([rlf_sent])
w_list, wis_list = ExpInstruct_llama2.get_text_list_w_imp([rlf_sent])
print('ExpInstruct LLaMA2: ')
print('predict sentiment label: ', pred_y_list[0])
Sexp = get_Sexp(w_list[0], rlf_word, wis_list[0])
print('Sexp = {}'.format(Sexp))
visualize_importance(w_list[0], wis_list[0])
del ExpInstruct_llama2
gc.collect();  

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

ExpInstruct LLaMA2: 
predict sentiment label:  1
Sexp = 0.34285714285714286
