In [1]:
from argparse import Namespace
import pandas as pd
from datasets import load_dataset, Dataset
from pipeline_v2_utils import load_model, add_idx, parse_tokenize_prompts, parse_input_check, parse_output_check, parse_gen_completions, parse_result_ppl_eval, perform_detection
import yaml
import torch
from statistics import mean

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'transformers'

In [2]:
arg_dict = {
    'result_dir': '',

    'wm_model_name': 'facebook/opt-1.3B',                           # watermark model name
    'seeding_scheme': 'simple_1',
    'gamma': 0.5,
    'delta': 2.0,
    'generation_seed': 123,
    'use_sampling': True,                                           
    'sampling_temp': 0.7,                                           # only for sampling
    'n_beams': 1,                                                   # only for not sampling
    'no_repeat_ngram_size': 0,                                      # only for n_neams > 1
    'early_stopping': False,                                        # only for beam search
    'normalizers': '',
    'z_threshold': 4.0,
    'ignore_repeated_bigrams': False,  # not used by now

    'oracle_model_name': 'facebook/opt-2.7b',                       # oracle model name

    'dataset_name': 'c4',                                           # dataset related parameters
    'dataset_config_name': 'realnewslike',
    
    'limit_indices': 20,
    'input_truncation_strategy': 'completion_length',
    'input_filtering_strategy': 'prompt_and_completion_length',
    'min_input_encoded_length': 50,
    'min_untruncated_input_encoded_length': 0,
    'output_filtering_strategy': 'max_new_tokens',
    'prompt_max_length': None,
    'max_new_tokens': 200,
}
args = Namespace()
args.__dict__.update(arg_dict)
args.normalizers = args.normalizers.split(",") if args.normalizers else []
print(args)

Namespace(result_dir='', wm_model_name='facebook/opt-1.3B', seeding_scheme='simple_1', gamma=0.5, delta=2.0, generation_seed=123, use_sampling=True, sampling_temp=0.7, n_beams=1, no_repeat_ngram_size=0, early_stopping=False, normalizers=[], z_threshold=4.0, ignore_repeated_bigrams=False, oracle_model_name='facebook/opt-2.7b', dataset_name='c4', dataset_config_name='realnewslike', limit_indices=20, input_truncation_strategy='completion_length', input_filtering_strategy='prompt_and_completion_length', min_input_encoded_length=50, min_untruncated_input_encoded_length=0, output_filtering_strategy='max_new_tokens', prompt_max_length=None, max_new_tokens=200)


In [3]:
model, tokenizer, device = load_model(args.wm_model_name)
print(f'wm_model {args.wm_model_name} loaded on {device}')

wm_model facebook/opt-1.3B loaded on mps


In [4]:
dataset = load_dataset(args.dataset_name, args.dataset_config_name, split="train", streaming=True)
dataset = dataset.map(add_idx, batched=False, with_indices=True)

tokenize_prompts = parse_tokenize_prompts(args, model, tokenizer)
dataset = dataset.map(tokenize_prompts, batched=False, with_indices=True)

input_check = parse_input_check(args)
dataset = dataset.filter(input_check, batched=False, with_indices=True)

gen_completions = parse_gen_completions(args, model, tokenizer)
columns_to_remove = ["input_encoded", "untruncated_input_encoded"]
if "c4" in args.dataset_name:
    columns_to_remove += ["text","timestamp","url"]
dataset = dataset.map(gen_completions, batched=False, with_indices=True, remove_columns=columns_to_remove)



In [5]:
output_check = parse_output_check(args)

result_list = []
ds_iterator = iter(dataset)
i = 0
while i < args.limit_indices:
    ex = next(ds_iterator)
    yaml_output = yaml.dump(ex, sort_keys=False, default_flow_style=False)
    print(yaml_output)
    result_list.append(ex)
    if output_check(ex) == True:
        i += 1
    else:
        print(
            f"\nGeneration too short, saving outputs, but not incrementing counter...\n",
            f"{i} of {len(result_list)} rows were satisfactory so far",
            f"current generation overhead ratio: {round(len(result_list)/(i+1), 3)}",
            f"completed {round(i/args.limit_indices, 2)} of total"
        )
        
print(
    f"#"*80,
    f"\nGeneration output length check overhead was num rows processed={len(result_list)}",
    f"for {args.limit_indices} samples. Ratio: {round(len(result_list)/args.limit_indices, 3)}"
)

model = model.to(torch.device("cpu"))
del model

  test_elements = torch.tensor(test_elements)


idx: 1
untruncated_input_encoded_length: 509
input_encoded_length: 309
input_decoded: "\"Whoever gets him, they'll be getting a good one,\" David Montgomery\
  \ said.\nINDIANAPOLIS \u2014 Hakeem Butler has been surrounded by some of the best\
  \ wide receivers on the planet this week at the NFL Scouting Combine.\nIt\u2019\
  s an experience that might humble some. But for Butler, it has only enhanced his\
  \ confidence.\nAs it stands, 22-year-old Butler is not regarded as the best wide\
  \ receiver in this year\u2019s NFL Draft. He\u2019s projected by some experts to\
  \ go as late as the third round. But when wide receivers were measured Thursday,\
  \ Butler gained some attention: He led all receivers in height (6-foot-5 3/8), arm\
  \ length (35 1/4 inches) and wingspan (83 7/8 inches).\nOn Thursday, running back\
  \ David Montgomery, who played with Butler at Iowa State, captured the general vibe\
  \ surrounding Butler here.\nButler says he\u2019s met with every NFL team on 

In [6]:
result_dataset = Dataset.from_list(result_list)
oracle_model_name = 'facebook/opt-2.7b'
oracle_model, oracle_tokenizer, oracle_device = load_model(oracle_model_name)
print(f'oracle_model {oracle_model_name} loaded on {oracle_device}')

result_ppl_eval = parse_result_ppl_eval(oracle_model, oracle_tokenizer)
result_dataset = result_dataset.map(result_ppl_eval, batched=False, with_indices=True)

oracle_model = oracle_model.to(torch.device("cpu"))
del oracle_model

print(f"#"*80)
print(f"real avg PPL: {mean(result_dataset['real_ppl'])}")
print(f"real avg loss: {mean(result_dataset['real_loss'])}")
print(f"wo_wm avg PPL: {mean(result_dataset['wo_wm_ppl'])}")
print(f"wo_wm avg loss: {mean(result_dataset['wo_wm_loss'])}")
print(f"w_wm avg PPL: {mean(result_dataset['w_wm_ppl'])}")
print(f"w_wm avg loss: {mean(result_dataset['w_wm_loss'])}")

oracle_model facebook/opt-2.7b loaded on mps


Map: 100%|██████████| 32/32 [01:52<00:00,  3.51s/ examples]


################################################################################
real avg PPL: 10.649542339146137
real avg loss: 2.257085919380188
wo_wm avg PPL: 4.578849408775568
wo_wm avg loss: 1.44712308421731
w_wm avg PPL: 5.725208610296249
w_wm avg loss: 1.713696587830782


In [7]:
result_df = result_dataset.to_pandas()
real_completion_df = result_df[['idx', 'input_encoded_length', 'input_decoded', 'real_completion_encoded_length', 'real_completion_decoded', 'real_loss', 'real_ppl']]
real_completion_df = real_completion_df.rename(columns={
    'real_completion_encoded_length': 'completion_encoded_length', 
    'real_completion_decoded': 'completion_decoded', 
    'real_loss': 'loss', 
    'real_ppl': 'ppl'
})
real_completion_df['class'] = 'real'

completion_wo_wm_df = result_df[['idx', 'input_encoded_length', 'input_decoded', 'completion_wo_wm_encoded_length', 'completion_wo_wm_decoded', 'wo_wm_loss', 'wo_wm_ppl']]
completion_wo_wm_df = completion_wo_wm_df.rename(columns={
    'completion_wo_wm_encoded_length': 'completion_encoded_length', 
    'completion_wo_wm_decoded': 'completion_decoded', 
    'wo_wm_loss': 'loss', 
    'wo_wm_ppl': 'ppl'
})
completion_wo_wm_df['class'] = 'wo_wm'

completion_w_wm_df = result_df[['idx', 'input_encoded_length', 'input_decoded', 'completion_w_wm_encoded_length', 'completion_w_wm_decoded', 'w_wm_loss', 'w_wm_ppl']]
completion_w_wm_df = completion_w_wm_df.rename(columns={
    'completion_w_wm_encoded_length': 'completion_encoded_length', 
    'completion_w_wm_decoded': 'completion_decoded', 
    'w_wm_loss': 'loss', 
    'w_wm_ppl': 'ppl'
})
completion_w_wm_df['class'] = 'w_wm'

result_df = pd.concat([real_completion_df, completion_wo_wm_df, completion_w_wm_df], axis=0, ignore_index=True)
result_df = result_df.sort_values(by=['idx', 'class'])
result_df = result_df.assign(
    prediction=None, 
    confidence=None, 
    p_value=None, 
    z_score=None, 
    green_fraction=None, 
    num_green_tokens=None, 
    num_tokens_scored=None
)
result_df = perform_detection(args, tokenizer, device, result_df)
result_df.to_csv("result_dataset.csv", index=False)