In [1]:
import json
import numpy as np
import os
import random
import torch
import datasets
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.util import ngrams
from transformers import BitsAndBytesConfig
from transformers import GPT2Tokenizer ,  GPT2Model, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, LogitsProcessorList, T5ForConditionalGeneration, T5Tokenizer, MT5ForConditionalGeneration, M2M100ForConditionalGeneration, M2M100Tokenizer
import os
import pandas as pd
import json
import matplotlib.pyplot as plt
import time
from typing import List


In [2]:
def load_hf_data_set(split,dataset_name, dataset_subname):
        data = {}
        data[split] = datasets.load_dataset(dataset_name,dataset_subname, split="validation",trust_remote_code=True, streaming=True)
        return data[split]


def ele_dist_k_from_idx(lst, start_index, k):
    return lst[start_index::k]

In [3]:
samplesize = 1000
batch = 10
random.seed(41)
data = list(load_hf_data_set('validation','wmt19','de-en').take(samplesize))
data = [x["translation"] for x in data]

In [4]:
data[0]

{'de': 'München 1856: Vier Karten, die Ihren Blick auf die Stadt verändern',
 'en': 'Munich 1856: Four maps that will change your view of the city'}

In [5]:
batch

10

In [6]:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# add the EOS token as PAD token to avoid warnings
model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id).to(torch_device)




In [35]:
cls_tokenizer = AutoTokenizer
cls_model = AutoModelForCausalLM
tokenizer_args = {}
device_map="auto"
torch_dtype=torch.float16
load_in_8bit = False
model_name = "gpt2"


if model_name=="flan-t5":
    model = "google/flan-t5-large"
    cls_model = T5ForConditionalGeneration
    cls_tokenizer = T5Tokenizer

elif model_name == "m2m":
    model = "facebook/m2m100_418M"
    cls_model = M2M100ForConditionalGeneration
    cls_tokenizer = M2M100Tokenizer
    
elif model_name == "mt0":
    model = "bigscience/mt0-large"
    cls_model = AutoModelForSeq2SeqLM
    cls_tokenizer = AutoTokenizer


    
tokenizer = cls_tokenizer.from_pretrained(model, **tokenizer_args)

if  model_name == "m2m":
    tokenizer.src_lang = "fr"

if load_in_8bit:
    # breakpoint()
    bnb_config= BitsAndBytesConfig(load_in_8bit=True,)
    model = cls_model.from_pretrained(model,
                                        torch_dtype=torch.bfloat16,
                                        device_map=device_map,
                                        quantization_config=bnb_config,
                                        # low_cpu_mem_usage=low_cpu_mem_usage,
                                        cache_dir = '/work/pi_dhruveshpate_umass_edu/aamballa_umass_edu/models/.cache',
                                        trust_remote_code=True,
                                        )

else:
    model = cls_model.from_pretrained(model,
                                        torch_dtype=torch_dtype,
                                        device_map=device_map,
                                        # low_cpu_mem_usage=low_cpu_mem_usage,
                                        cache_dir = '/work/pi_dhruveshpate_umass_edu/aamballa_umass_edu/models/.cache',
                                        trust_remote_code=True,
                                        load_in_8bit=load_in_8bit)

# tokenizer.pad_token =  tokenizer.eos_token
# model.config.pad_token_id = model.config.eos_token_id
# model.eval()

    


In [7]:
from serialize import serialize_tree, deserialize_tree

In [9]:
log_softmax = torch.nn.LogSoftmax()
def modelrun(model, input_ids, max_len, k):
        outputs = model(input_ids)
        next_token_logits = outputs.logits[:, -1, :]   # (batch, seq len, vocab size)
    
        # select tok k 
        topk_probs, token_indices = torch.topk(next_token_logits,k, dim= -1)
        log_topk_probs = log_softmax(topk_probs)

        # stoing node, log probs for leaf nodes
        out = []
        if max_len == 1:
            for i in range(k):
                out.append((log_topk_probs[0][i].item(),token_indices[0][i].item(),1))
            return out 
                
        # preorder travesal such that top tokens are visited first
        for i in range(k):
            out.append((log_topk_probs[0][i].item(),token_indices[0][i].item(),0))
            input1 = torch.cat([input_ids[0], torch.tensor([token_indices[0][i]], device = "cuda")], dim=-1).unsqueeze(0)
            temp = modelrun(model,input1 , max_len-1, k)
            for j in temp:
                out.append(j)

        return out 
    
            

In [None]:
default_fwd_instruction = "Translate the following German sentence to an English sentence."
default_fwd_input_prefix = "German sentence: "
default_fwd_target_prefix = ". English sentence: "

for idx, d in enumerate(tqdm(data, desc="Predicting")):

    prompt_arr = [default_fwd_instruction,default_fwd_input_prefix]
    prompt_arr.append(d['de'])
    prompt_arr.append(default_fwd_target_prefix)
    input_prompt = (' ').join(prompt_arr)  # join the sentences
    
    input_ids = tokenizer(input_prompt, return_tensors='pt').input_ids.to('cuda')
    
    max_len = 5
    k = 5 
    
    nodes = modelrun(model, input_ids, max_len, k)
    #  adding the prefix to the tree
    nodes.insert(0, (1,1,0))

    # serialize the tree
    serialize_tree(k, max_len, nodes, f"data/sample{idx}.bin")



Predicting:   2%|███▍                                                                                                                                                       | 22/1000 [02:24<1:47:33,  6.60s/it]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Predicting:  17%|██████████████████████████▏                                                                                                                               | 170/1000 [18:42<1:30:59,  6.58s/it]

In [125]:
# Deserialize 

In [127]:
for i in range(1000):
    deserialize_tree(k, max_len, f"data/sample{i}.bin")

[(1.0, 1, 0),
 (-0.53369140625, 11, 0),
 (-0.67529296875, 475, 0),
 (-0.1954345703125, 314, 0),
 (-0.53125, 1101, 0),
 (-0.33251953125, 407, 1),
 (-1.2626953125, 635, 1),
 (-0.88623046875, 836, 0),
 (-0.0001493692398071289, 470, 1),
 (-8.8125, 6, 1),
 (-1.728515625, 340, 0),
 (-0.1771240234375, 338, 0),
 (-0.424560546875, 407, 1),
 (-1.0615234375, 1327, 1),
 (-1.8173828125, 318, 0),
 (-0.6884765625, 407, 1),
 (-0.69775390625, 257, 1),
 (-0.71142578125, 290, 0),
 (-0.1597900390625, 314, 0),
 (-0.53955078125, 1842, 0),
 (-0.346435546875, 284, 1),
 (-1.228515625, 262, 1),
 (-0.875, 1101, 0),
 (-0.57373046875, 1464, 1),
 (-0.82861328125, 1654, 1),
 (-1.9130859375, 356, 0),
 (-0.55419921875, 423, 0),
 (-0.265869140625, 257, 1),
 (-1.4541015625, 587, 1),
 (-0.8544921875, 821, 0),
 (-0.5908203125, 1464, 1),
 (-0.80712890625, 1111, 1),
 (-0.8828125, 13, 0),
 (-0.44140625, 314, 0),
 (-0.46728515625, 1842, 0),
 (-0.338623046875, 284, 0),
 (-0.250732421875, 711, 1),
 (-1.5068359375, 2342, 1),
 (-