In [10]:
from datasets import load_dataset, Dataset
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pickle
import torch

In [2]:
dataset = load_dataset('wikipedia', "20220301.en", split='train', streaming=True)
shuffled_dataset = dataset.shuffle(seed=42, buffer_size=10_000)

In [3]:
nd = list(dataset.take(30000))
ndata = Dataset.from_list(nd)

In [None]:
print(dset[6]['text'])

In [181]:
quotesdata = load_dataset("jstet/quotes-500k")

Found cached dataset csv (/u/prasanns/.cache/huggingface/datasets/jstet___csv/jstet--quotes-500k-ede96e03d28fbb72/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 220.75it/s]


In [184]:
TOTDATA = 60000
quotesdata = quotesdata['train'].shuffle(seed=0)
quotesdata = quotesdata.select(range(TOTDATA))

Loading cached shuffled indices for dataset at /u/prasanns/.cache/huggingface/datasets/jstet___csv/jstet--quotes-500k-ede96e03d28fbb72/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-602235f45088ee8c.arrow


In [11]:
quotesdata[1]

{'quote': "The rose, however, made us girls somewhat fainthearted, because it really was something we felt mattered, the white bridal dream with the wedding bouquet and the kiss from the man who was to be ours forever. But then Laura said that the lady who had given it to us had gotten divorced only five years later. And since many of our parents were also divorce, if indeed they had ever been married at all, that dream clearly wasn't worth our time.",
 'author': 'Janne Teller',
 'category': 'marriage, nothing'}

In [4]:
modelname = "facebook/opt-350m"

In [5]:
model = AutoModelForCausalLM.from_pretrained(modelname, device_map=2)
model.eval()
toker = AutoTokenizer.from_pretrained(modelname, padding_side='left')

toker.max_length=512
toker.padding_size='left'
toker.pad_token = toker.eos_token

In [6]:
def proc_quotes(exs):
    inps = []
    for e in exs:
        s = e['author']+ ": " + e['quote']
        if len(s)<200:
            inps.append(s)
    return inps

def proc_wiki(exs):
    inps = []
    for e in exs:
        svals = e['text'].split(".")
        pots = []
        for s in svals: 
            if len(s)>20 and len(s)<200:
                pots.append(s.strip())
        inps.extend(pots[:3])
    print(len(inps))
    return inps
        
def generate_trunc(inputs, trunc, model, mbatch_size=4, top_p=0.9, temp=0.4):
    newinps = []
    corrgens = []
    for inp in inputs: 
        newinps.append(toker.decode(toker(inp).input_ids[:-(trunc+1)], skip_special_tokens=True))
        corrgens.append(toker.decode(toker(inp).input_ids[-(trunc+1):], skip_special_tokens=True))
    newgens = []
    for i in tqdm(range(0, len(newinps), mbatch_size)):
        inps = toker(newinps[i:i+mbatch_size], padding=True, truncation=True, return_tensors="pt").to(model.device)
        newgens.extend(model.generate(**inps, max_new_tokens=trunc+1, do_sample=True, top_p=top_p, temperature=temp))
    return toker.batch_decode(newgens, skip_special_tokens=True), corrgens

In [7]:
procd = proc_wiki(ndata.select(range(len(ndata))))
# procd = proc_quotes(quotesdata.select(range(100)))

84757


In [26]:
gts, gols = [], []
for i in range(0, len(procd), interv):
    with open("distilgen_tmp/"+str(i), "rb") as f: 
        a, b = pickle.load(f)
        gts.extend(a)
        gols.extend(b)


In [29]:
newdset = Dataset.from_pandas(pd.DataFrame({'question':[""]*len(gts), 'response_j':gts, 'response_k':procd}))

In [31]:
newdset[0]

{'question': '',
 'response_j': 'Anarchism is a political philosophy and movement that is sceptical of authority and is concerned with the preservation of individual rights',
 'response_k': 'Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy'}

In [11]:
interv =1000
for i in range(0, len(procd), interv):
    with torch.no_grad():
        gtrunc, golds = generate_trunc(procd[i:i+interv], 8, model, 64, 0.9, 0.4)
    with open("distilgen_tmp/"+str(i), "wb") as f: 
        pickle.dump((gtrunc, golds), f)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:04<00:00,  3.85it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:04<00:00,  3.49it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:04<00:00,  3.51it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:04<00:00,  3.50it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:04<00:00,  3.75it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:0

In [22]:
procd

['Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy',
 'Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful',
 'Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires',
 'Autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior',
 "Parents often notice signs during the first three years of their child's life",
 'These signs often develop gradually, though some autistic children experience regression in their communication and social skills after reaching developmental milestones at a normal pace',
 'Surface albedo is defined as the ratio of radiosity Je to the irradiance Ee (flux per unit area) received by a surface',
 "The proportion reflected is not only determined by properties of th

In [8]:
datadf = pd.DataFrame({'outputs':procd})

In [9]:
wdata = Dataset.from_pandas(datadf)

In [13]:
wdata.save_to_disk("../../data/wikidatasft")

                                                                                                                                                                         

In [187]:
quotesdata

Dataset({
    features: ['quote', 'author', 'category'],
    num_rows: 60000
})

In [198]:
ind = 9
print("ORIGINAL: "+procd[ind])
print('GOLD: '+golds[ind])
print("GEN: "+gtrunc[ind])

ORIGINAL:  a poet arranges meaning in the sounds.: A versifier arranges sounds
GOLD: ifier arranges sounds
GEN:  a poet arranges meaning in the sounds.: A versification of the poetry


In [17]:
ds = load_dataset("imdb",download_mode="force_redownload")

Downloading readme: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 7.81k/7.81k [00:00<00:00, 13.8MB/s]


Downloading and preparing dataset None/plain_text to /u/prasanns/.cache/huggingface/datasets/parquet/plain_text-730f38e7e31e7fd3/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7...


Downloading data files:   0%|                                                                                                                      | 0/2 [00:00<?, ?it/s]
Downloading data:   0%|                                                                                                                      | 0.00/21.0M [00:00<?, ?B/s][A
Downloading data:  31%|█████████████████████████████████▉                                                                           | 6.53M/21.0M [00:00<00:00, 65.2MB/s][A
Downloading data: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21.0M/21.0M [00:00<00:00, 96.2MB/s][A
Downloading data files:  50%|███████████████████████████████████████████████████████                                                       | 1/2 [00:00<00:00,  1.36it/s]
Downloading data:   0%|                                                                                                                      

ExpectedMoreSplits: {'unsupervised'}