In [3]:
from transformers import pipeline, set_seed, AutoConfig
from sklearn.feature_extraction.text import CountVectorizer
from scipy.stats import binned_statistic
from scipy.optimize import curve_fit
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
config = AutoConfig.from_pretrained('gpt2')
generator = pipeline('text-generation', model='gpt2', config=config)
tokenizer = generator.tokenizer
model = generator.model
set_seed(42)
T=1.0

In [5]:
def get_avg(M, h, thr=50, bins = 35):
    bin_means, _, _ = binned_statistic(M, h, statistic='mean', bins=bins)
    bin_std, _, _ = binned_statistic(M, h, statistic='std', bins=bins)
    bin_counts, _, _ = binned_statistic(M, h, statistic='count', bins=bins)
    return bin_means[bin_counts>thr], (bin_std**2)[bin_counts>thr]

def data_to_obs(data, doc_param = {"min":10, "step":5}):
    vec = CountVectorizer()
    vec.fit(data)
    X = pd.DataFrame(data=vec.fit_transform(data).toarray().T, index=vec.get_feature_names_out())
    f = X.divide(X.sum(0),1).mean(1).sort_values(ascending=False)
    M = X.sum(0)
    h = X.apply(lambda x: (x!=0).sum(), axis=0)
    N = len(f)
    doc_M = list(M.values)
    doc_h = list(h.values)
    for doc in data:
        _doc = doc.split(" ")
        for _M in range(doc_param["min"],len(_doc), doc_param["step"]):
            doc_M.append(_M)
            _h = (vec.transform([" ".join(_doc[:_M])])>0).sum()
            doc_h.append(_h)
    return N, f, M, h, doc_M, doc_h

def get_input_from_f(f):
    tokens = np.array([x[0] for x in tokenizer([str(w) for w in f.index.values])["input_ids"]])
    return ' '.join([tokenizer.decode(token) for token in np.random.choice(tokens, size=50, p=f.values/f.values.sum())])

In [6]:
from nltk.corpus import gutenberg
import string
gutenberg.raw('austen-emma.txt')

def get_prompt(N):
    text = gutenberg.raw('austen-emma.txt')
    text = ''.join(ch for ch in text if ch not in string.punctuation)
    text = text.lower().split()
    M = N
    i = np.random.randint(0, len(text)-M)
    return ' '.join(text[i:i+M])

get_prompt(5)

'i must i willi will'

In [8]:
data = [text['generated_text'] for text in generator(get_prompt(5), max_length=250, truncation=False, num_return_sequences=25, temperature = T)]
N, f, M, h, doc_M, doc_h = data_to_obs(data)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [7]:
Ns = np.logspace(np.log10(100), np.log10(2000), 25, dtype=int)
dN = []
Np = []
io=open("prompts.txt", "w")
for N in Ns:
    print(N)
    prompt = get_input_from_f(f.sort_values(ascending=False)[:N])
    # prompt = get_prompt(N)
    io.write(prompt+"\n")
io.close()

100
113
128
145
164
186
211
239
271
307
348
394
447
506
574
650
736
834
945
1071
1213
1375
1558
1765
2000


In [10]:
dN = []
Np = []
M = 20
for _ in range(25):
    prompt = get_input_from_f(f.sort_values(ascending=False)[:N])
    #prompt = get_prompt(M)
    text = generator(prompt, max_new_tokens=250, num_return_sequences=15, temperature = float(T))
    dN.append([len(np.unique(t['generated_text'].split(" "))) for t in text])
    Np.append(len(np.unique(prompt.split(" "))))
    break

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [None]:
Np = []
dN = []
with open("prompts.txt", "r") as file:
    for prompt in file.readlines():
        Np.append(len(np.unique(prompt.split(" "))))

with open("results.txt", "r") as file:
    for result in file.readline().split("=========\n"):
        dN.append([len(np.unique(result.split(" ")))])

In [20]:
# with open("NDN.pkl", "wb") as file:
#     pickle.dump([Ns, dN, Np], file)

with open("NDN.pkl", "rb") as file:
    Ns, dN, Np = pickle.load(file)

In [42]:
fig=go.Figure()


for M in [20, 100]:
    with open(f"NDN_M{M}.pkl", "rb") as file:
        Ns, dN, Np = pickle.load(file)
        
    # with open(f"NDN.pkl", "rb") as file:
    #     Ns, dN, Np = pickle.load(file)
        
    data = []
    for _x, _y in zip(Np, dN):
        for i in _y:
            # data.append((_x, (i-_x)/(1e-20+250.))) #(N_text-N_prompt)/DM
            data.append((_x, i))
            
    x = np.array([_x for _x,_y in data])
    y = np.array([_y for _x,_y in data])
    fit_func =lambda x, A, B: A+B*x
    popt, pcov = curve_fit(fit_func, x, y)
    print(popt[1]*250)
    bin_means, bin_edges, binnumber = binned_statistic(x, y, statistic="median")
    df = pd.DataFrame(data=data, columns=["Np", "dN"])
    x = np.array([_x for _x,_y in data])
    y = np.array([_y for _x,_y in data])

    fig.add_scatter(
        x=(bin_edges[1:] + bin_edges[:-1])/2,
        y=bin_means,
        mode="lines",
        name="medians",
        line_width=8,
        line_color="red"
    )

    fig.add_scatter(
        x=x,
        y=fit_func(x, *popt),
        mode="lines",
        name="{:.2f}+{:.2f}x".format(*popt),
        line_width=8,
        line_color="orange"
    )

    for x in df["Np"].unique():
        y = df[df["Np"]==x]["dN"].values
        fig.add_trace(go.Scatter(x=np.repeat(x, len(y)), 
                                y=y, 
                                name=str(x), 
                                mode="markers", 
                                marker={"color":"gray", "opacity": 0.2, "size": 10})
                    )
        # fig.add_trace(go.Scatter(x=[x], 
        #                         y=[np.median(y)], 
        #                         mode="markers", 
        #                         marker_color="red",
        #                         error_y={
        #                             "type": "data",
        #                             "array": [np.std(y)],
        #                             "thickness": 5,
        #                             "width": 5
        #                         },
        #                         marker_size=10,
        #                         line_width=3)
        #             )
    


fig.update_layout(
    {
        "showlegend": False,
        "legend": {"font":{"size":20}},
        "xaxis":{
            "title": "N_prompt",
            "type": "linear",
            "tickfont": {"size": 20},
            "titlefont": {"size": 25},
            "linewidth": 1,
            "linecolor": "gray"
        },
        "yaxis":{
            # "title": "P_new [(N_gen-N_prompt)/M]",
            "title": "N_gen",
            "type": "linear",
            "tickfont": {"size": 20},
            "titlefont": {"size": 25},
            "linewidth": 1,
            "linecolor": "gray"
        },
        "plot_bgcolor": "rgba(0,0,0,0)",
        "width":800,
        "height":700
    }
)


fig.show()
fig.write_image("NDN_austin.pdf")
fig.write_image("NDN_austin.png")
fig.write_image("NDN_austin.svg")

277.7903469080571
488.03349530082505
