In [None]:
from utils import *

In [None]:
# specify the model, task, and prompt
pretrains = ['google/flan-t5-large']
datasets = ['cola']
prompt_types = ['standard_a']
seeds = [2266]
alphas = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

for pretrained in pretrains:
    tokenizer = AutoTokenizer.from_pretrained(pretrained, padding_side='left')
    tokenizer.add_special_tokens({'pad_token': pad_tokens[pretrained]})
    if pretrained in ['google/flan-t5-large', 'google/flan-t5-xxl']:
        model = T5ForConditionalGeneration.from_pretrained(pretrained).to(device)
    else:
        model = AutoModelForCausalLM.from_pretrained(pretrained).to(device)
    for dataset in datasets:
        for prompt_type in prompt_types:
            for n_o_s in number_of_shots[dataset]:
                for seed in seeds:
                    for a in alphas:
                        
                        config = {
                            'experiment_id': time.strftime('%Y%m%d_%H%M%S', time.localtime()),
                            'dataset': dataset,
                            'number_of_data': None,
                            'model': pretrained,
                            'prompt_type': prompt_type,
                            'number_of_shots': n_o_s,
                            'a': a,
                            'max_tokens': max_tokens[prompt_type],
                            'batch_size': 16,
                            'pad_token_id': pad_ids[pretrained],
                            'eos_token_id': eos_ids[pretrained],
                            'seed': seed,
                            'device': torch.cuda.get_device_name(torch.cuda.current_device()),
                            'note': 'sad'
                        }

                        random.seed(seed)
                        torch.cuda.manual_seed_all(seed)

                        # read in the prompt
                        with open(f'prompts/{dataset}/{prompt_type}-{n_o_s}.txt') as file:
                            prefix = file.read()
                            
                        # read in the data
                        with open(f'datasets/{dataset}.json') as file:
                            data = json.loads(file.read())
                        config['number_of_data'] = len(data)

                        # inference
                        results = defaultdict(dict)
                        with tqdm(total=config['number_of_data']) as t:
                            for i, item in data.items():
                                inputs = []
                                inputs.append(concatenate(dataset, prompt_type, prefix, item['original'], item))
                                for synthetic in list(item['synthetic'].values())[:4]:
                                    inputs.append(concatenate(dataset, prompt_type, prefix, synthetic, item))
                                input_ids = tokenizer(inputs, padding=True, return_tensors='pt').input_ids.to(device)
                                loader = DataLoader(input_ids, batch_size=config['batch_size'], shuffle=False)
                                outputs = []
                                for input in loader:
                                    with torch.no_grad():
                                        output = model.generate(
                                            input,
                                            max_new_tokens=config['max_tokens'],
                                            pad_token_id=config['pad_token_id'],
                                            return_dict_in_generate=True,
                                            output_scores=True
                                        )
                                    if pretrained in ['google/flan-t5-large', 'google/flan-t5-xxl']:
                                        outputs.extend(tokenizer.batch_decode(output.sequences))
                                    else:
                                        outputs.extend(tokenizer.batch_decode(output.sequences[:, input_ids.shape[1]:]))
                                results[i]['original'] = item['original']
                                results[i]['label'] = item['label']
                                results[i]['input'] = inputs[0]
                                results[i]['prediction_original'] = outputs[:1]
                                results[i]['prediction_synthetic'] = outputs[1:]
                                results[i]['prediction_sad'] = tokenizer.decode(torch.argmax(config['a'] * output.scores[0][0] - (1-config['a']) * torch.var(output.scores[0][1:], dim=0))) ## the prediction obtained using sensitivity-aware decoding
                                if results[i]['prediction_sad'] in ['', ' ', '\n']:
                                    results[i]['prediction_sad'] = tokenizer.decode(torch.argmax(config['a'] * output.scores[1][0] - (1-config['a']) * torch.var(output.scores[1][1:], dim=0)))
                                results[i]['synthetic'] = list(item['synthetic'].values())[:4]
                                results[i]['input_synthetic'] = inputs[1:5]
                                t.update(1)

                        # save the results
                        with open(f'config_reference_sad.txt', 'a') as file:
                            file.write(('\t'.join(['{' + i + '}' for i in config.keys()]) + '\n').format(**config))
                            file.close()
                        with open('results_sad/results-{}.json'.format(config['experiment_id']), 'w') as file:
                            json.dump(results, file, indent=4, ensure_ascii=False)
    
    del model
    torch.cuda.empty_cache()