In [14]:
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_metric

import torch
from transformers import GPT2Tokenizer, GPTNeoForCausalLM, AutoModelForCausalLM
from datasets import load_dataset
# from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, pipeline

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler

import torch
import wandb
import time
import os
import statistics
from tqdm import tqdm
import numpy as np
import pandas as pd
tqdm.pandas()
from multiprocessing import Pool

from datasets import load_dataset, Dataset

from transformers import AutoTokenizer, pipeline
import argparse
from evaluate import load

from tqdm import tqdm
from tqdm.notebook import tqdm_notebook
tqdm_notebook.pandas()
from torchmetrics.functional import accuracy
from accelerate.tracking import GeneralTracker, on_main_process
from typing import Optional

import wandb

In [15]:
##editing wandb api so we can specify parameters such as group, run_name, etc
class WandbEd(GeneralTracker):
    name = "wandb"
    requires_logging_directory = False

    @on_main_process
    def __init__(self, project: str, run_name, group: str, entity: str, config: dict):
        self.group = group
        self.run_name = run_name
        self.project = project
        self.config = config
        self.entity = entity
        run = wandb.init(group = self.group, name = self.run_name, entity = self.entity, project = self.project, config=config)

    @property
    def tracker(self):
        return self.run.run

    @on_main_process
    def store_init_configuration(self, values: dict):
        wandb.config = values

    @on_main_process
    def log(self, values: dict, step: Optional[int] = None):
        wandb.log(values)

In [3]:
config_sc = {
    "lm_name": 'facebook/opt-2.7b', ##model name
    "alpha_bleu": 0.5, ##bleu weight if you want to use multi-objective optimization
    "beta_ppl": 0.5,   ##perplexity weight if you want to use multi-objective optimization
    "reward_type": 'bert',   ##reward function type
    "ref_lm_name": 'facebook/opt-2.7b',  ##ref model name, same as model name
    "cls_model_name": "null", 
    "tk_name": "lm_extraction",
    "reward_fn": 'bert', ##reward function type
    "steps": 2, 
    "batch_size": 32,
    "forward_batch_size": 8,
    "ppo_epochs": 2,   
    "lr": 1.41e-5,
    "init_kl_coef":0.2,
    "target": 6,
    "horizon":10000,
    "gamma":1,
    "lam":0.95,
    "cliprange": .2,
    "cliprange_value":.2,
    "vf_coef":.1, 
}

In [5]:
import random
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
mode = 'train'

if mode == 'inference':
    run_name = config_sc['lm_name']+'_'+config_sc['dataset_type']
else:
    run_name = config_sc['lm_name']+'_'+'training'

wandbed_obj = WandbEd(group=config_sc['lm_name'], project='lm_extraction_defence_exps', entity = "thesis_projects", run_name=run_name, config = config_sc)

In [None]:
config = PPOConfig(
    model_name='/models--facebook--opt-2.7b/snapshots/397f71a473a150c00f0fe3fc4a2f78ff3ccaf82d',
    learning_rate=1.41e-5,
    batch_size=32,
    forward_batch_size=8,
    ppo_epochs = 8,
    optimize_cuda_cache=True,
    remove_unused_columns = False,
    log_with=wandbed_obj,
)

In [None]:
sacrebleu = load_metric('sacrebleu')
bertscore = load("bertscore")

In [None]:
def sacrebleu_fn(label, response):
    score = sacrebleu.compute(predictions=[response], references=[[label]])['score']
    return 100-score 

def calculatePerplexity(sentence):
     """
     exp(loss)
     """
     input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
     input_ids = input_ids.to('cuda')
     with torch.no_grad():
         outputs = ppl_model(input_ids, labels=input_ids)
     loss, logits = outputs[:2]
     return torch.exp(loss)

def perplexity_fn(text):
    ppl_lst = []
    for i in text:
        ppl_lst.append(calculatePerplexity(i).unsqueeze(0))
    ppl_tns = torch.cat(ppl_lst)
    return ppl_tns, torch.mean(ppl_tns)

In [None]:
def reward_fn_comp(bleu_score, label, response):
    if config_sc['reward_type'] == 'bleu':
        return 'bleu'
    
    if config_sc['reward_type'] == 'bert':
        score = bertscore.compute(predictions=response, references=label, 
                            model_type="microsoft/deberta-large", device='cuda')['f1']
        score = [-abs(number) for number in score]
        
    return score

In [None]:
def reward_sacrbleu(response, label):
    pool = Pool()
    result = [pool.apply(sacrebleu_fn, args=(true, pred)) for true, pred in zip(label, response)]
    return result

In [None]:

#Seq impl
def reward_fn(response, label, label_texts, generated_texts):
    
    bleu_score = reward_sacrbleu(response, label)
        
    ppl_lb_tns, ppl_lb_mean = perplexity_fn(label_texts)

    ppl_gen_tns, ppl_gen_mean = perplexity_fn(generated_texts)
    
    score = reward_fn_comp(bleu_score, label, response)
    
    if score == 'bleu':
        score = bleu_score

    return score, bleu_score, ppl_gen_mean, ppl_lb_mean, ppl_lb_tns, ppl_gen_tns

In [None]:
lm_data = pd.read_csv('all_train.csv')
ds = Dataset.from_pandas(lm_data)

In [9]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def tokenize(sample):
    sample["tokens"] = tokenizer.encode(sample["prefix"])
    sample["query"] = tokenizer.decode(sample["tokens"])
    return sample

ds = ds.map(tokenize, batched=False)

In [None]:
gen_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id
}

In [11]:
def collater(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [12]:
if mode == 'train':
    model = AutoModelForCausalLMWithValueHead.from_pretrained(
        config.model_name,
        device_map="auto")

    model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name, device_map="auto")
elif mode == 'inference':
    model = AutoModelForCausalLMWithValueHead.from_pretrained(
        'saved_models/',
        device_map="auto")
    
    tokenizer = AutoTokenizer.from_pretrained('saved_models/')
    tokenizer.pad_token = tokenizer.eos_token
    
ppl_model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto")    

In [13]:
if mode == 'train':
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
    # We then build the PPOTrainer, passing the model, the reference model, the tokenizer
    ppo_trainer = PPOTrainer(
        config, model, ref_model=model_ref, tokenizer=tokenizer, dataset=ds, data_collator=collater, optimizer=optimizer
    )
else:
    pass

In [15]:
def train(total_ppo_epochs=2):
    for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(ppo_trainer.dataloader))):

        logs, timing = dict(), dict()
        t0 = time.time()
        query_tensors = [torch.tensor(t).long().cuda() for t in batch["tokens"]]

        #### Get response from lm
        t = time.time()
        response_tensors = []
        for i in range(ppo_trainer.config.batch_size):
            gen_len = 55
            query_tensor_sq = query_tensors[i].unsqueeze(dim=0)
            gen_kwargs["max_new_tokens"] = gen_len
            response = ppo_trainer.generate(query_tensors[i], **gen_kwargs)

            response_tensors.append(response.squeeze()[-gen_len:])

        batch['response'] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
        timing['time/get_response'] = time.time()-t
        #### Compute reward score
        t = time.time()
        respones_batch = batch['response']
        label_batch = batch['suffix']

        label_texts = [q + r for q,r in zip(batch['query'], batch['suffix'])]
        generated_texts = [q + r for q,r in zip(batch['query'], batch['response'])]

        reward_scores, bleu_score, mean_ppl_gen, mean_ppl_label, perplexity_scores_label, perplexity_scores_generated = reward_fn(respones_batch, label_batch, label_texts, generated_texts)
        rewards = torch.tensor(reward_scores, dtype=float).cuda()
        rewards = [torch.tensor(output) for output in rewards]

        timing['time/get_sentiment_preds'] = time.time()-t
        print('finished reward', rewards)
        #### Run PPO step 
        t = time.time()

        model.gradient_checkpointing_enable()
        model.pretrained_model.config.use_cache = False


        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)


        timing['time/optimization'] = time.time()-t

        #### Log everything
        timing['time/epoch'] = time.time()-t0
        rewards = torch.tensor(reward_scores, dtype=float)
        table_rows = [list(r) for r in zip(batch['query'], batch['response'], rewards.cpu().tolist(), batch['suffix'], 
                                           perplexity_scores_label.cpu().tolist(), 
                                           perplexity_scores_generated.cpu().tolist(), bleu_score)]

        logs.update({'game_log': wandb.Table(columns=['query', 'pred', 'reward','label','perplexity_label',
                                                     'perplexity_response', 'bleu_score'], rows=table_rows)})
        logs.update(timing)
    #     logs.update(stats)
        logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy()
        logs['env/reward_std'] = torch.std(rewards).cpu().numpy()
        logs['env/perplexity_gen'] = mean_ppl_gen.cpu().numpy()
        logs['env/perplexity_lab'] = mean_ppl_label.cpu().numpy()
        logs['env/bleu'] = statistics.mean(bleu_score)
        logs['env/reward_dist'] = rewards.cpu().numpy()

        ppo_trainer.accelerator.log(logs)

In [None]:
train(total_ppo_epochs=1)

In [None]:
model.save_pretrained('saved_models/', max_shard_size='20GB')