In [None]:
import requests, re
import torch
import pathlib
from transformers import GPTNeoXTokenizerFast
import os

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from refuge._vendor.mkultra.tuning import GPTNeoXPromptTuningLM
from refuge._vendor.dolly.instruct_pipeline import InstructionTextGenerationPipeline
from refuge._vendor.mkultra.soft_prompt import SoftPrompt

In [None]:
model_name = "databricks/dolly-v2-3b"

In [None]:
tokenizer = GPTNeoXTokenizerFast.from_pretrained(model_name, padding_side="left")
model = GPTNeoXPromptTuningLM.from_pretrained(model_name, device_map="auto")

In [None]:
prompt = "Alice sipped her tea as the white rabbit gloated about his vast collection of pocket watches"

In [None]:
model_inputs = tokenizer(prompt, return_tensors="pt")

input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]

generated_sequence_tensor = model.generate(
    input_ids=input_ids.to(model.device),
    attention_mask=attention_mask.to(model.device),
    pad_token_id=tokenizer.pad_token_id,
    max_new_tokens=256,
    top_p=0.92,
    do_sample=True
)

generated_sequence = generated_sequence_tensor.cpu().numpy().tolist()
print(tokenizer.decode(generated_sequence[0]))

In [None]:
sp_name = 'alice-cyclic-dropout-2'
model_base_name = model_name.split("/")[-1]
project_dir = f"/home/simon/.mkultra/soft_prompts/{sp_name}-{model_base_name}/"
filename_for_checkpoint = lambda step: f"{sp_name}-{model_base_name}-step-{step}.json"

In [None]:
sp = SoftPrompt.from_file( os.path.join(project_dir, filename_for_checkpoint(620)) )
model.set_soft_prompt(sp)

call = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

basic_output = model.generate(
    input_ids=call,
    pad_token_id=tokenizer.pad_token_id,
    max_new_tokens=256,
    top_p=0.92,
    do_sample=True
)
print(tokenizer.decode(basic_output[0]))

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cbook as cbook
import numpy as np

loss_log_path = os.path.join(project_dir,"loss_log.csv")
fname2 = cbook.get_sample_data(loss_log_path, asfileobj=False)
with cbook.get_sample_data(loss_log_path) as file:
    array = np.loadtxt(file, delimiter=",")

fig = plt.figure()
plt.plot(array[:, 0], array[:, 1])