In [None]:
import random
import matplotlib

from scipy import stats
from IPython.display import display, HTML

from utils import *
from lm_saliency import *

## calculate saliency scores

In [None]:
# specify the model, task, and prompt
pretrains = ['google/flan-t5-large']
datasets = ['cola']
prompt_types = ['standard_b']
seeds = [2266]

for dataset in datasets:
    for prompt_type in prompt_types:
        for n_o_s in number_of_shots[dataset]:
            for seed in seeds:
                for pretrained in pretrains:
                    
                    tokenizer = AutoTokenizer.from_pretrained(pretrained, padding_side='left')
                    tokenizer.add_special_tokens({'pad_token': pad_tokens[pretrained]})
                    if pretrained in ['google/flan-t5-large']:
                        model = T5ForConditionalGeneration.from_pretrained(pretrained).to(device)
                    else:
                        model = AutoModelForCausalLM.from_pretrained(pretrained).to(device)
                    
                    config = {
                        'experiment_id': time.strftime('%Y%m%d_%H%M%S', time.localtime()),
                        'dataset': dataset,
                        'number_of_data': None,
                        'model': pretrained,
                        'prompt_type': prompt_type,
                        'number_of_shots': n_o_s,
                        'temperature': 0.8,
                        'max_tokens': max_tokens[prompt_type],
                        'batch_size': 16,
                        'eos_token_id': eos_ids[pretrained],
                        'seed': seed,
                        'device': torch.cuda.get_device_name(torch.cuda.current_device())
                    }

                    torch.cuda.manual_seed_all(config['seed'])

                    # read in the prompt
                    with open(f'prompts/{dataset}/{prompt_type}-{n_o_s}.txt') as file:
                        prefix = file.read()
                    
                    # read in the data
                    with open(f'datasets/{dataset}.json') as file:
                        data = json.loads(file.read())
                    config['number_of_data'] = len(data)

                    # calculate the saliency scores
                    l1_norms, input_x_gradients = defaultdict(list), defaultdict(list)
                    with tqdm(total=config['number_of_data']) as t:
                        for i, item in data.items():
                            inputs = []
                            inputs.append(concatenate(dataset, prompt_type, prefix, item['original'], item))
                            for synthetic in list(item['synthetic'].values())[:4]:
                                inputs.append(concatenate(dataset, prompt_type, prefix, synthetic, item))
                            tokenized = tokenizer(inputs)
                            all_input_tokens, all_attention_ids = tokenized['input_ids'], tokenized['attention_mask']
                            for _ in range(len(inputs)):
                                input_tokens = all_input_tokens[_]
                                attention_ids = all_attention_ids[_]
                                base_saliency_matrix, base_embd_matrix = saliency(model, input_tokens, attention_ids, pretrained)
                                l1_norms[i].append(l1_grad_norm(base_saliency_matrix, normalize=True))
                                input_x_gradients[i].append(input_x_gradient(base_saliency_matrix, base_embd_matrix, normalize=True))
                            t.update(1)
                    
                    for t in [l1_norms, input_x_gradients]:
                        for _, v in t.items():
                            for i in range(len(v)):
                                if pretrained not in ['google/flan-t5-large']:
                                    v[i] = list(v[i])
                                else:
                                    v[i] = list(v[i][1])
                                for j in range(len(v[i])):
                                    v[i][j] = float(v[i][j])
                    
                    # save the results
                    with open('config_reference_saliency.txt', 'a') as file:
                        file.write(('\t'.join(['{' + i + '}' for i in config.keys()]) + '\n').format(**config))
                        file.close()
                    with open('results_saliency/l1_norms-{}.json'.format(config['experiment_id']), 'w') as file:
                        json.dump(l1_norms, file, indent=4, ensure_ascii=False)
                    with open('results_saliency/input_x_gradients-{}.json'.format(config['experiment_id']), 'w') as file:
                        json.dump(input_x_gradients, file, indent=4, ensure_ascii=False)

                    del model
                    torch.cuda.empty_cache()

## inspection

In [None]:
# specify the model, task, prompt, and seed
model = 'togethercomputer/GPT-JT-6B-v1'
dataset = 'cola'
prompt_type = 'standard_b'
seed = 2266

# read in the data
with open(f'datasets/{dataset}.json') as file:
    data = json.loads(file.read())

# specify the number of instances to show
instances = [str(_) for _ in range(5)]
    
tokenizer = AutoTokenizer.from_pretrained(model, padding_side='left')
tokenizer.add_special_tokens({'pad_token': pad_tokens[model]})

for experiment_id in config_df_saliency.experiment_id.tolist():
    config = {k: v[0] for k, v in config_df_saliency[config_df_saliency.experiment_id == experiment_id].to_dict(orient='list').items()}
    if config['dataset'] == dataset and config['model'] == model and int(config['seed']) == seed and config['prompt_type'] == prompt_type:
        
        # read in the saliency scores
        with open(f'results_saliency/l1_norms-{experiment_id}.json') as file:
            target = json.loads(file.read())
        
        # read in the prompt
        n_o_s = number_of_shots[dataset][0]
        with open(f'prompts/{dataset}/{prompt_type}-{n_o_s}.txt') as file:
            prefix = file.read()
            
        random.seed(seed)
        for i in instances:
            item = data[i]
            inputs = [concatenate(dataset, prompt_type, prefix, item['original'], item)]
            for synthetic in list(item['synthetic'].values())[:4]:
                inputs.append(concatenate(dataset, prompt_type, prefix, synthetic, item))
            tokenized = tokenizer(inputs)
            all_input_tokens, all_attention_ids = tokenized['input_ids'], tokenized['attention_mask']

            for j in range(0, 1):
                if model not in ['google/flan-t5-large']:
                    sep = 198
                    if prompt_type in ['standard_a']:
                        start_id = [_ for _, x in enumerate(all_input_tokens[j]) if x == sep][-2] + 1
                    if prompt_type in ['standard_b', 'CoT', 'context faithful prompting']:
                        start_id = [_ for _, x in enumerate(all_input_tokens[j]) if x == sep][-3] + 1
                    if prompt_type in ['zero_sensitivity_b']:
                        start_id = [_ for _, x in enumerate(all_input_tokens[j]) if x == sep][-3] + 2
                    if prompt_type in ['gkp']:
                        start_id = [_ for _, x in enumerate(all_input_tokens[j]) if x == sep][-4] + 1
                    target[i][j] = [_*10 for _ in target[i][j]]
                else:
                    sep = 10
                    if prompt_type in ['zero_sensitivity_b']:
                        if dataset in ['cola']:
                            start_id = [_ for _, x in enumerate(all_input_tokens[j]) if x == sep][-3] + 2
                        if dataset in ['sst2']:
                            start_id = [_ for _, x in enumerate(all_input_tokens[j]) if x == sep][-3] + 3
                        if dataset in ['csqa', 'mnli', 'rte']:
                            start_id = [_ for _, x in enumerate(all_input_tokens[j]) if x == sep][-4] + 2
                
                # normalize the scores
                norm = mpl.colors.Normalize(vmin=-0.98, vmax=1)
                
                # display a heatmap
                tokens = [tokenizer.decode(_) for _ in all_input_tokens[j][start_id:]]
                attention = np.array(target[i][j][len(target[i][j])-(len(all_input_tokens[j])-start_id):])
                attention = norm(attention)
                s = colorize(tokens, attention)
                display(HTML(s))