In [1]:
%load_ext autoreload
%autoreload 2

import json
import os
import pickle
from datetime import datetime

import evaluate
import torch
from tqdm import tqdm

from eval import *
from superposed.llama.metrics import *
from superposed.llama.generation import Llama
from superposed.llama.superposed_generation import SuperposedLlama
from superposed.llama.tokenizer import Tokenizer
from superposed.ngrams.ngram_models import make_models

  from .autonotebook import tqdm as notebook_tqdm
2024-05-27 20:15:02.245095: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-27 20:15:02.294117: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Setup

In [2]:
# load data
with open("../gpt-2-output-dataset/data/webtext.test.jsonl", "r") as f:
    dataset = [json.loads(line)["text"] for line in f]

In [3]:
# Params (default parameters for all cases)
param_file = "./params/p15_d3_mixed.json"
with open(param_file, "r") as f:
    params = json.load(f)
    print(f"Parameters: {params}")
alpha = params["alpha"]
temp = params["temp"]
n_drafts = params["n_drafts"]
prompt_len = params["prompt_len"]
n_token_sample = params["n_token_sample"]
i_weights = params["i_weights"]
i_length = params["i_length"]

Parameters: {'alpha': 0.54, 'temp': 0.06, 'n_drafts': 3, 'prompt_len': 15, 'n_token_sample': 9, 'n_token_consider': 32000, 'mixing_method': 'sample_new_weights_with_score', 'smoothing': 'geom', 'sample_tokens': 0, 'sample_beams': 0, 'i_weights': [0.01, 0.04, 0.15, 0.18, 0.12], 'i_length': [1, 2, 3, 4, 5]}


In [4]:
# Create ngram models
ngrams = make_models("./ckpts-200k", bigram=True, trigram=True, fourgram=True, fivegram=True, sixgram=True, sevengram=False)

Making bigram...
1310800
Making trigram...
671088728
Making fourgram...
2684354648
Making fivegram...
5368709200
Making sixgram...
5368709200


In [5]:
sup_device = torch.device("cuda:0")
reg_device = torch.device("cuda:1")
tokenizer = Tokenizer('./7B/tokenizer.model')

# Mixed

In [6]:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "10302"

In [7]:
weight_path = "./7B/"
model = SuperposedLlama.build(ckpt_dir=weight_path, 
                         tokenizer_path='./7B/tokenizer.model', 
                         max_seq_len=100, 
                         max_batch_size=32,
                         device=sup_device,
                         model_parallel_size=1)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


Loaded in 18.50 seconds
cuda:0


In [8]:
start_time = datetime.now()
sup_sequences, sup_ppl = evaluate_mixed_losses(data=dataset,
                                                   model=model,
                                                   smoothing="geom",
                                                   tokenizer=tokenizer,
                                                   prompt_len=prompt_len,
                                                   max_gen_len=10,
                                                   alpha=alpha,
                                                   temp=temp,
                                                   n_drafts=n_drafts,
                                                   n_token_sample=n_token_sample,
                                                   bsz=32,
                                                   i_weights=i_weights,
                                                   i_length=i_length,
                                                   ngrams=ngrams,
                                                   get_time=False,
                                                   penalty=200,
                                                   marker=True)
finish_time = datetime.now()

100%|████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:35<00:00,  2.20s/it]


In [59]:
duration = finish_time - start_time
print(f"Time: {duration}, Average Time: {duration / len(dataset)}")

Time: 0:03:48.817127, Average Time: 0:00:00.045763


In [9]:
# Save results into file. 
# grader.py and diversity_grader.py use this file for perplexity evaluation.
file_name = ""
with open(file_name, "wb") as f:
    pickle.dump(sup_sequences, f)    

# Nucleus

In [29]:
reg_model = Llama.build(ckpt_dir="./7B/", 
                    tokenizer_path='./7B/tokenizer.model', 
                    max_seq_len=100, 
                    max_batch_size=32,
                    device=reg_device,
                    model_parallel_size=1)

0
Loaded in 7.42 seconds


In [33]:
start_time = datetime.now()
nucleus_sequences, nucleus_ppl = evaluate_nucleus_losses(data=dataset,
                                       model=reg_model,
                                       tokenizer=tokenizer,
                                       prompt_len=prompt_len,
                                       max_gen_len=10,
                                       temp=0.6,
                                       bsz=32)
finish_time = datetime.now() 
duration = finish_time - start_time

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:43<00:00,  3.63it/s]

Time: 0:00:43.230412, Average Time: 0:00:00.008646





In [36]:
nucleus_sequences = nucleus_sequences.reshape(len(dataset), 1, -1)
print(f"Time: {duration}, Average Time: {duration / len(dataset)}")  

Time: 0:00:43.230412, Average Time: 0:00:00.008646


# Evaluation

In [37]:
torch.set_default_dtype(torch.float32)

In [9]:
def decode(tokenizer, encoding):
    """
    Args:
        tokenizer (Any): Tokenizer
        encoding (torch.Tensor): Encoding
    Returns:
        decoding (str)
    """
    eos_locs = (encoding == tokenizer.eos_id).nonzero()
    if len(eos_locs > 0):
        encoding = encoding[:eos_locs[0]]
    return tokenizer.decode(encoding.to(torch.int32).tolist())
    
def print_results(tokenizer, predictions, n_drafts=n_drafts):
    """
    Args:
        tokenizer (Any): Tokenizer
        predictions (torch.Tensor): Tokens of predicted sequences, flattened to (n_prompts * n_drafts, gen_len)
    Returns:
        Mauve score
    """
    count = 0
    for i in tqdm(range(len(predictions))):
        d = decode(tokenizer, predictions[i])
        if i <= 15:
            # first draft of this prompt
            if i % n_drafts == 0:
                count = 0
                print("---------------")
                prompt = decode(tokenizer, predictions[i][:prompt_len])
                print(f"prompt: {prompt}")
            print(f"{count}: {d}")
            count += 1
        else: 
            break

In [10]:
print_results(tokenizer, predictions=sup_sequences.reshape(len(dataset) * n_drafts, -1), n_drafts=n_drafts)

  1%|▉                                                                                           | 16/1500 [00:00<00:00, 4323.47it/s]

---------------
prompt: Is this restaurant family-friendly ? Yes No Unsure

0: Is this restaurant family-friendly ? Yes No Unsure
I'm a big fan of the food and
1: Is this restaurant family-friendly ? Yes No Unsure
I'm a big fan of the food,
2: Is this restaurant family-friendly ? Yes No Unsure
I'm a big fan of the food here
---------------
prompt: Clinton talks about her time of 'reflection' during sick
0: Clinton talks about her time of 'reflection' during sick leave
Clinton talks about her time of
1: Clinton talks about her time of 'reflection' during sickness
Clinton talks about her time of
2: Clinton talks about her time of 'reflection' during sick leave
Clinton talks about her health of
---------------
prompt: House Majority Whip Steve Scalise has been discharged
0: House Majority Whip Steve Scalise has been discharged from the hospital after being shot at a congression
1: House Majority Whip Steve Scalise has been discharged from the hospital after being shot in a congression
2: 




In [43]:
print_results(tokenizer, predictions=nucleus_sequences.reshape(len(dataset) * 1, -1), n_drafts=1)

  0%|▎                                                                                           | 16/5000 [00:00<00:01, 2567.29it/s]

---------------
prompt: Is this restaurant family-friendly ? Yes No Unsure

0: Is this restaurant family-friendly ? Yes No Unsure
10160 Cedar Ave
---------------
prompt: Clinton talks about her time of 'reflection' during sick
0: Clinton talks about her time of 'reflection' during sickness
Clinton talks about her time of
---------------
prompt: House Majority Whip Steve Scalise has been discharged
0: House Majority Whip Steve Scalise has been discharged from the hospital, his office announced Wednesday
---------------
prompt: Insight Course: Lesson 14

Control of
0: Insight Course: Lesson 14

Control of the Body

<p align="right">
---------------
prompt: BY JENNIE MCNULTY

Lesbian.
0: BY JENNIE MCNULTY

Lesbian. Bisexual. Queer. Transgender.
---------------
prompt: The Buddha's Teaching As It Is

In
0: The Buddha's Teaching As It Is

In the year 1962, I had
---------------
prompt: As part of a broad initiative to combat sexual harassment and
0: As part of a broad initiative to combat s


