In [1]:
import math
import random
import pickle
import json
import re
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer

In [2]:
ul2_device_map={'shared': 0,
 'lm_head': 0,
 'encoder': 0,
 'decoder.embed_tokens': 0,
 'decoder.block.0': 0,
 'decoder.block.1': 0,
 'decoder.block.2': 1,
 'decoder.block.3': 1,
 'decoder.block.4': 1,
 'decoder.block.5': 1,
 'decoder.block.6': 1,
 'decoder.block.7': 1,
 'decoder.block.8': 1,
 'decoder.block.9': 1,
 'decoder.block.10': 1,
 'decoder.block.11': 1,
 'decoder.block.12': 1,
 'decoder.block.13': 1,
 'decoder.block.14': 1,
 'decoder.block.15': 1,
 'decoder.block.16': 1,
 'decoder.block.17': 1,
 'decoder.block.18': 1,
 'decoder.block.19': 1,
 'decoder.block.20': 1,
 'decoder.block.21': 1,
 'decoder.block.22': 1,
 'decoder.block.23': 1,
 'decoder.block.24': 1,
 'decoder.block.25': 1,
 'decoder.block.26': 1,
 'decoder.block.27': 1,
 'decoder.block.28': 1,
 'decoder.block.29': 1,
 'decoder.block.30': 1,
 'decoder.block.31': 1,
 'decoder.final_layer_norm': 1,
 'decoder.dropout': 1}

In [3]:
# Load the UL2 20B model. Should output some text if it is working right.

model = T5ForConditionalGeneration.from_pretrained("google/ul2", device_map=ul2_device_map, torch_dtype=torch.bfloat16, load_in_8bit=False)                                                                                       
tokenizer = AutoTokenizer.from_pretrained("google/ul2/")

input_string = "[NLG] Q: Who was the king of england when the United States revolted? A:<extra_id_0>"                                           
inputs = tokenizer(input_string, return_tensors="pt").input_ids.to("cuda")
outputs = model.generate(inputs, max_length=35)

print(tokenizer.decode(outputs[0]))

<pad><extra_id_0> King George III. He was the king of England when the United States revolted. Q: Who was the king of England when the United States revolt


In [4]:
# Text to exclude from replacement. Select things you beleive are very likely to be correct in original dataset.

excluded=['1','2','3','4','5','6','7','8','9','10',' one',' two',' three',' four',' five',' six',' seven',' eight',' nine',' ten']

# Function returns list of non-overlapping, unique, three word replacement candidates for a string in a list.

def get_replace(text,excluded):
    good_threes=[]
    words=text.split(" ")
    threes=[]
    for i in range(0,(len(words)-5),3):
        threes.append(words[i]+" "+words[i+1]+" "+words[i+2])
    for three in threes:
        try:
            failed=False
            matches = re.findall(three, text)
            if len(matches) > 1:
                failed=True
            if any(substring in three for substring in excluded):
                failed=True
            if not failed:
                good_threes.append(three)
        except:
            pass
    return good_threes

# Sort_Tuple sorts a list of tuples
# by the second element.

def Sort_Tuple(tup):
    tup.sort(key=lambda x: x[0], reverse=False)
    return tup

# Basic LM deterministic call. Returns logit and answer.

def ask_flan_T5D(input_text):
    inputs = tokenizer([sentence for sentence in input_text], return_tensors="pt", padding=True)
    outputs = model.generate(
        input_ids=inputs["input_ids"].to(0),
        attention_mask=inputs["attention_mask"].to(0),
        do_sample=False,
        eos_token_id=0,
        max_new_tokens=20,
        bos_token_id=0,
        return_dict_in_generate=True,
        output_scores=True,
    )
    output_tuple=[]
    probs = torch.stack(outputs.scores, dim=1).softmax(-1)
    for z,i in enumerate(outputs.sequences):
        out_text = tokenizer.decode(i, skip_special_tokens=False)
        logprobs = 0
        counter = 0
        for k in i[1:]:
            word_prob = (round(probs[z][counter][k.item()].item(), 2)) + 0.001
            logprobs = logprobs + math.log(word_prob)
            counter += 1
        out_tuple = (out_text, round(logprobs, 2))
        output_tuple.append(out_tuple)
    return output_tuple

# Check if entry was properly replaced. If not properly replaced then return original sequence.

def check_replace(original,replacement):
    matches = re.findall(r"extra_id_0", original)
    if len(matches) > 1:
        return original
    else:
        if "extra_id_0" in replacement[0] and "extra_id_1" in replacement[0]:
            replace=replacement[0].split("<extra_id_0>")[1].split("<extra_id_1")[0]
            if any(substring in replace for substring in excluded):
                return original
            else:
                print("replacing: "+replace)
                updated=original.replace("<extra_id_0>",replace).replace("  "," ").replace(" .", ".")
                return updated
        else:
            return original

In [5]:
prefix="""The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Rosey, and a human user, called User.
In the following interactions, User and Rosey will converse in natural language, and Rosey will do its best to answer User’s questions.
The conversation begins: User: Would a bunker be useful if a nuclear war started? Rosey: Yes, a bunker would be useful if a nuclear war started. A bunker is a a reinforced underground shelter. Nuclear weapon detonation would cause large explosions and release deadly radiation that could instantly kill anyone in the blast radius. Persons inside an underground bunker would be at least partially protected from the harmful radiation and explosion. User: Thank you, that was really helpful. Now for something entirely different. """

print(prefix)

The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Rosey, and a human user, called User.
In the following interactions, User and Rosey will converse in natural language, and Rosey will do its best to answer User’s questions.
The conversation begins: User: Would a bunker be useful if a nuclear war started? Rosey: Yes, a bunker would be useful if a nuclear war started. A bunker is a a reinforced underground shelter. Nuclear weapon detonation would cause large explosions and release deadly radiation that could instantly kill anyone in the blast radius. Persons inside an underground bunker would be at least partially protected from the harmful radiation and explosion. User: Thank you, that was really helpful. Now for something entirely different. 


In [6]:
#load pickle binary file. List of tuples in format of ("instruction","response")

objects = []
with (open("tuple_list.pkl", "rb")) as openfile:
    while True:
        try:
            objects.append(pickle.load(openfile))
        except EOFError:
            break
tuple_list=objects[0]
print(len(list(tuple_list)))

26584


In [7]:
# Import list of tuples in (instruction, response) format

filt_new=[]
for i in tuple_list:
    combined="User: "+i[0]+" Rosey: "+i[1]
    dlen=len(tokenizer.encode(combined))
    q=i[0]
    a=i[1]
    if dlen > 350:
        pass
    else:
        filt_new.append((dlen,q,a,combined))

In [8]:
# Format into qa_batches format for use in UL2 repair

qa_batches=[]
for i in filt_new:
    question=i[1]
    answer=i[2]
    qa_string=i[3]
    replace_list=get_replace(qa_string,excluded)
    if len(replace_list) > 4:
        replace_list.pop(0)
        replace_list.pop(0)
        list_len=len(replace_list)
        qa_batches.append((list_len,question,answer,qa_string,replace_list))
sorted_qa_batches=Sort_Tuple(qa_batches)
qa_batches=sorted_qa_batches
print(qa_batches[1])

(3, 'What is the Fibonacci sequence?', "That's a good question, and one I'm not entirely sure how to answer.", "User: What is the Fibonacci sequence? Rosey: That's a good question, and one I'm not entirely sure how to answer.", ["Rosey: That's a", 'good question, and', "one I'm not"])


In [10]:
qa_batches[6000]

(9,
 'What is keratin?',
 'Keratin is a fibrous structural protein that is found in the hair, nails, and skin. Keratin is also used in the manufacture of cosmetics, as a food additive, and as a food supplement.',
 'User: What is keratin? Rosey: Keratin is a fibrous structural protein that is found in the hair, nails, and skin. Keratin is also used in the manufacture of cosmetics, as a food additive, and as a food supplement.',
 ['is a fibrous',
  'structural protein that',
  'is found in',
  'the hair, nails,',
  'and skin. Keratin',
  'is also used',
  'in the manufacture',
  'of cosmetics, as',
  'a food additive,'])

In [9]:
# create a binary pickle file to save your dictionary
f = open("instructs_to_repair.pkl", "wb")
pickle.dump(qa_batches, f)
f.close()

In [None]:
# Code to repair entries in qa_batches.
batch_size=6
finished=set()
num_replacements=10
for z in range(6008,6024,batch_size):
    first=True
    new_minibatch=[]
    originals=[]
    for y in range(0,num_replacements):
        minibatch_queries=[]
        if first:
            first=False
            for k in qa_batches[z:z+batch_size]:
                num_replacements=k[0]+1
                original=k[3]
                originals.append(original)
                new_entry=k[3].replace(k[4][y],"<extra_id_0>")
                query="[NLU] "+prefix+new_entry+"<extra_id_1>"
                minibatch_queries.append(query)
            outputs=ask_flan_T5D(minibatch_queries)
            for q,r in enumerate(outputs):
                update=check_replace(minibatch_queries[q],r)
                if "<extra_id_0>" in update:
                    new_minibatch.append(originals[q])
                else:
                    new_minibatch.append(update)
        else:
            minibatch_queries=[]
            for pos,k in enumerate(qa_batches[z:z+batch_size]):
                try:
                    query=new_minibatch[pos].replace(k[4][y],"<extra_id_0>")
                    minibatch_queries.append(query)
                except:
                    finish=(new_minibatch[pos].replace(("[NLU] "+prefix),""))
                    print(finish)
                    print("Finished\n\n")
                    try:
                        finished.add((originals[pos],finish,"pass"))
                        minibatch_queries.append(new_minibatch[pos])
                    except:
                        finished.add((originals[pos],new_minibatch[pos],"fails"))
                        minibatch_queries.append(new_minibatch[pos])
            outputs=ask_flan_T5D(minibatch_queries)
            for q,r in enumerate(outputs):
                update=check_replace(minibatch_queries[q],r)
                if "<extra_id_0>" in update:
                    pass
                else:
                    new_minibatch[q]=update

replacing:  can help you
replacing:  when buying fruit
replacing:  in stores?
replacing:  turkey as the meat that puts people
replacing:  you in the
replacing:  take private lessons?
replacing:  set up a
replacing:  and vegetables? Rosey:
replacing:  Rosey: That’s a
replacing:  to sleep? Rosey:
replacing:  face? Rosey: I
replacing:  Rosey: It depends
replacing:  timer to help
replacing:  It depends on
replacing:  good question. I
replacing:  People refer to
replacing:  think you have
replacing: ,
replacing:  you focus on
replacing:  what you’re
replacing:  think that many
replacing:  a toothache.
replacing:  but you should
replacing:  your work.
replacing:  buying. For
replacing:  items have been
replacing:  meat that “puts
replacing:  You should
replacing:  or the other.
replacing:  Rosey: I’m not sure
replacing:  most fruits and
replacing:  removed from stores
replacing:  people to sleep”
replacing:  go to the
replacing:  Do you want
replacing:  what you mean
replacing:  vegetables, 

In [None]:
for i in finished:
    print(i)