# Fast Detect GPT Perturb Project

In [None]:
!git clone https://github.com/baoguangsheng/fast-detect-gpt

Cloning into 'fast-detect-gpt'...
remote: Enumerating objects: 175, done.[K
remote: Counting objects: 100% (175/175), done.[K
remote: Compressing objects: 100% (84/84), done.[K
remote: Total 175 (delta 119), reused 139 (delta 86), pack-reused 0[K
Receiving objects: 100% (175/175), 3.37 MiB | 10.54 MiB/s, done.
Resolving deltas: 100% (119/119), done.


In [None]:
%cd fast-detect-gpt

/content/fast-detect-gpt


In [None]:
!pip install datasets  stable_baselines3 transformers

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting stable_baselines3
  Downloading stable_baselines3-2.2.1-py3-none-any.whl (181 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m181.7/181.7 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
Collecting gymnasium<0.30,>=0.28.

## Fast Detect GPT

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
import os

def from_pretrained(cls, model_name, kwargs, cache_dir):
    local_path = os.path.join(cache_dir, 'local.' + model_name.replace("/", "_"))
    try:
        obj = cls.from_pretrained(local_path, **kwargs)
    except Exception as ex:
        print(ex)
        obj = cls.from_pretrained(model_name, **kwargs, cache_dir=cache_dir)
        obj.save_pretrained(local_path)
    return obj

# predefined models
model_fullnames = {  'gpt2': 'gpt2',
                     'gpt2-xl': 'gpt2-xl',
                     'opt-2.7b': 'facebook/opt-2.7b',
                     'gpt-neo-2.7B': 'EleutherAI/gpt-neo-2.7B',
                     'gpt-j-6B': 'EleutherAI/gpt-j-6B',
                     'gpt-neox-20b': 'EleutherAI/gpt-neox-20b',
                     'mgpt': 'sberbank-ai/mGPT',
                     'pubmedgpt': 'stanford-crfm/pubmedgpt',
                     'mt5-xl': 'google/mt5-xl',
                     'llama-13b': 'huggyllama/llama-13b',
                     'llama2-13b': 'TheBloke/Llama-2-13B-fp16',
                     'bloom-7b1': 'bigscience/bloom-7b1',
                     'opt-13b': 'facebook/opt-13b',
                     }
float16_models = ['gpt-j-6B', 'gpt-neox-20b', 'llama-13b', 'llama2-13b', 'bloom-7b1', 'opt-13b']

def get_model_fullname(model_name):
    return model_fullnames[model_name] if model_name in model_fullnames else model_name

def load_model(model_name, device, cache_dir):
    model_fullname = get_model_fullname(model_name)
    print(f'Loading model {model_fullname}...')
    model_kwargs = {}
    if model_name in float16_models:
        model_kwargs.update(dict(torch_dtype=torch.float16))
    if 'gpt-j' in model_name:
        model_kwargs.update(dict(revision='float16'))
    model = from_pretrained(AutoModelForCausalLM, model_fullname, model_kwargs, cache_dir)
    print('Moving model to GPU...', end='', flush=True)
    start = time.time()
    model.to(device)
    print(f'DONE ({time.time() - start:.2f}s)')
    return model

def load_tokenizer(model_name, for_dataset, cache_dir):
    model_fullname = get_model_fullname(model_name)
    optional_tok_kwargs = {}
    if "facebook/opt-" in model_fullname:
        print("Using non-fast tokenizer for OPT")
        optional_tok_kwargs['fast'] = False
    if for_dataset in ['pubmed']:
        optional_tok_kwargs['padding_side'] = 'left'
    else:
        optional_tok_kwargs['padding_side'] = 'right'
    base_tokenizer = from_pretrained(AutoTokenizer, model_fullname, optional_tok_kwargs, cache_dir=cache_dir)
    if base_tokenizer.pad_token_id is None:
        base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
        if '13b' in model_fullname:
            base_tokenizer.pad_token_id = 0
    return base_tokenizer

In [None]:

import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, precision_recall_curve, auc

# 15 colorblind-friendly colors
COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442",
            "#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73",
            "#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"]


def get_roc_metrics(real_preds, sample_preds):
    fpr, tpr, _ = roc_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)
    roc_auc = auc(fpr, tpr)
    return fpr.tolist(), tpr.tolist(), float(roc_auc)


def get_precision_recall_metrics(real_preds, sample_preds):
    precision, recall, _ = precision_recall_curve([0] * len(real_preds) + [1] * len(sample_preds),
                                                  real_preds + sample_preds)
    pr_auc = auc(recall, precision)
    return precision.tolist(), recall.tolist(), float(pr_auc)


In [None]:
import random

import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import argparse
import json

def get_samples(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1
    nsamples = 10000
    lprobs = torch.log_softmax(logits, dim=-1)
    distrib = torch.distributions.categorical.Categorical(logits=lprobs)
    samples = distrib.sample([nsamples]).permute([1, 2, 0])
    return samples

def get_likelihood(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1
    labels = labels.unsqueeze(-1) if labels.ndim == logits.ndim - 1 else labels
    lprobs = torch.log_softmax(logits, dim=-1)
    log_likelihood = lprobs.gather(dim=-1, index=labels)
    return log_likelihood.mean(dim=1)

def get_sampling_discrepancy(logits_ref, logits_score, labels):
    assert logits_ref.shape[0] == 1
    assert logits_score.shape[0] == 1
    assert labels.shape[0] == 1
    if logits_ref.size(-1) != logits_score.size(-1):
        # print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
        vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
        logits_ref = logits_ref[:, :, :vocab_size]
        logits_score = logits_score[:, :, :vocab_size]

    samples = get_samples(logits_ref, labels)
    log_likelihood_x = get_likelihood(logits_score, labels)
    log_likelihood_x_tilde = get_likelihood(logits_score, samples)
    miu_tilde = log_likelihood_x_tilde.mean(dim=-1)
    sigma_tilde = log_likelihood_x_tilde.std(dim=-1)
    discrepancy = (log_likelihood_x.squeeze(-1) - miu_tilde) / sigma_tilde
    return discrepancy.item()

def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
    # assert logits_ref.shape[0] == 1
    # assert logits_score.shape[0] == 1
    # assert labels.shape[0] == 1
    if logits_ref.size(-1) != logits_score.size(-1):
        # print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
        vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
        logits_ref = logits_ref[:, :, :vocab_size]
        logits_score = logits_score[:, :, :vocab_size]
    # print(logits_ref.shape, logits_score.shape, labels.shape)
    # (1, 95, 50257), (1, 95, 50257), (1, 95)
    labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
    lprobs_score = torch.log_softmax(logits_score, dim=-1)
    probs_ref = torch.softmax(logits_ref, dim=-1)
    log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
    mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
    var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
    # print(log_likelihood.shape, mean_ref.shape, var_ref.shape)
    discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).sqrt()
    # print(discrepancy.shape)
    return discrepancy

In [None]:

import random

import numpy as np
import torch
import os
import glob
import argparse
import json
import transformers
import datasets



# reference_model_name = "gpt-j-6B"
# scoring_model_name = "gpt-neo-2.7B"

reference_model_name = "gpt2"
scoring_model_name = "gpt2"


dataset = "xsum"
ref_path = "./local_infer_ref"
device = "cuda"
cache_dir = "../cache"

class ProbEstimator:
    def __init__(self):
        self.real_crits = []
        self.fake_crits = []
        for result_file in glob.glob(os.path.join(ref_path, '*.json')):
            with open(result_file, 'r') as fin:
                res = json.load(fin)
                self.real_crits.extend(res['predictions']['real'])
                self.fake_crits.extend(res['predictions']['samples'])
        print(f'ProbEstimator: total {len(self.real_crits) * 2} samples.')
        self.real_crits_tensor = torch.tensor(self.real_crits).to(device)
        self.fake_crits_tensor = torch.tensor(self.fake_crits).to(device)


    def crit_to_prob(self, crit):
        real_crits_tensor = self.real_crits_tensor.unsqueeze(dim=1)
        fake_crits_tensor = self.fake_crits_tensor.unsqueeze(dim=1)

        real_diffs = torch.abs(torch.cat([real_crits_tensor, fake_crits_tensor]) - crit.unsqueeze(dim=0))
        # print(crit.shape, real_diffs.shape)

        # Calculate offset
        offset, _ = torch.sort(real_diffs, dim=0)
        offset = offset[100, :]

        # Count occurrences using PyTorch operations
        lower_bound = (crit - offset).unsqueeze(dim=0)
        upper_bound = (crit + offset).unsqueeze(dim=0)
        cnt_real = torch.sum((real_crits_tensor > lower_bound) & (real_crits_tensor < upper_bound), dim=0)
        cnt_fake = torch.sum((fake_crits_tensor > lower_bound) & (fake_crits_tensor < upper_bound), dim=0)
        # Convert to float for division
        cnt_real = cnt_real.float()
        cnt_fake = cnt_fake.float()

        # Calculate and return the probability
        return cnt_fake / (cnt_real + cnt_fake)

        # offset = np.sort(np.abs(np.array(self.real_crits + self.fake_crits) - crit))[100]
        # cnt_real = np.sum((np.array(self.real_crits) > crit - offset) & (np.array(self.real_crits) < crit + offset))
        # cnt_fake = np.sum((np.array(self.fake_crits) > crit - offset) & (np.array(self.fake_crits) < crit + offset))
        # return cnt_fake / (cnt_real + cnt_fake)




In [None]:


class FastDetectGPT:
    def __init__(self):
        self.device = device
        # load model
        self.scoring_tokenizer = load_tokenizer(scoring_model_name, dataset, cache_dir)
        self.scoring_model = load_model(scoring_model_name, device, cache_dir)
        self.scoring_model.eval()
        self.reference_model_name = reference_model_name
        self.scoring_model_name = scoring_model_name
        if self.reference_model_name != self.scoring_model_name:
            self.reference_tokenizer = load_tokenizer(self.reference_model_name, dataset, cache_dir)
            self.reference_model = load_model(self.reference_model_name, device, cache_dir)
            self.reference_model.eval()
        # evaluate criterion
        self.criterion_name = "sampling_discrepancy_analytic"
        self.criterion_fn = get_sampling_discrepancy_analytic
        self.prob_estimator = ProbEstimator()
        # input text
        print('Local demo for Fast-DetectGPT, where the longer text has more reliable result.')
        print('')

    def infer(self, text):
        # evaluate text     # (1, 112)
        tokenized = self.scoring_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
        labels = tokenized.input_ids[:, 1:]
        with torch.no_grad():
            logits_score = self.scoring_model(**tokenized).logits[:, :-1]
            if self.reference_model_name == self.scoring_model_name:
                logits_ref = logits_score
            else:
                tokenized = self.reference_tokenizer(text, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.device)
                assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
                logits_ref = self.reference_model(**tokenized).logits[:, :-1]
            crit = self.criterion_fn(logits_ref, logits_score, labels)
        # estimate the probability of machine generated text
        prob = self.prob_estimator.crit_to_prob(crit)
        # print(f'Fast-DetectGPT criterion is {crit:.4f}, suggesting that the text has a probability of {prob * 100:.0f}% to be fake.')
        return prob


detector = FastDetectGPT()



from typing import List, Set

def model2hfname(model: str) -> str:
    return {
        "bert-tiny": "prajjwal1/bert-tiny",
        "bert-med": "prajjwal1/bert-medium",
        "small": "gpt2",
        "med": "gpt2-medium",
        "large": "gpt2-large",
        "full": "gpt2-xl",
        "gpt2-sm": "gpt2",
        "gpt2-med": "gpt2-medium",
        "gpt2-lg": "gpt2-large",
        "gpt2": "gpt2-xl",
        "neo": "EleutherAI/gpt-neo-2.7B",
    }[model]

def get_model_and_tokenizer(model: str, Cls = transformers.AutoModelForCausalLM, **model_kwargs):
    hf_model_name = model2hfname(model)

    m = Cls.from_pretrained(hf_model_name, **model_kwargs)
    if isinstance(m, transformers.GPT2LMHeadModel):
        m.transformer.gradient_checkpointing_enable()

    tok = transformers.AutoTokenizer.from_pretrained(hf_model_name)

    if tok.pad_token_id is None:
        if Cls == transformers.AutoModelForCausalLM:
            tok.pad_token = tok.eos_token
        else:
            print("Adding pad token to tokenizer")
            tok.add_special_tokens({"pad_token": "[PAD]"})
            tok.pad_token = "[PAD]"
    return m, tok


def stop_tokens(tokenizer, stop_strings: Set[str] = set([])) -> List[int]:
    tokens = []
    for idx in range(len(tokenizer)):
        if tokenizer.decode(idx) in stop_strings:
            tokens.append(idx)
    print("Stop tokens:", tokens)
    return tokens

def ignore_tokens(tokenizer, stop_strings: Set[str] = set("\n")) -> List[int]:
    tokens = []
    for idx in range(len(tokenizer)):
        if tokenizer.decode(idx) in stop_strings:
            tokens.append(idx)
    print("Ignore tokens:", tokens)
    return tokens

def ignore_tokens_replace(tokenizer, stop_strings: Set[str] = set(" ")) -> List[int]:
    tokens = []
    for idx in range(len(tokenizer)):
        if tokenizer.decode(idx) in stop_strings:
            tokens.append(idx)
    print("Ignore tokens replaced by:", tokens)
    return tokens[0]

def top_k_logits(logits, k):
    if k == 0:
        return logits
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1]
    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)


Repo id must be in the form 'repo_name' or 'namespace/repo_name': '../cache/local.gpt2'. Use `repo_type` argument if needed.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loading model gpt2...
../cache/local.gpt2 does not appear to have a file named config.json. Checkout 'https://huggingface.co/../cache/local.gpt2/None' for available files.


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Moving model to GPU...DONE (0.51s)
ProbEstimator: total 1800 samples.
Local demo for Fast-DetectGPT, where the longer text has more reliable result.



In [None]:

from stable_baselines3.common.monitor import Monitor

import gymnasium as gym
import tqdm


## Our LMEnv / DataSets

In [None]:

class LMEnv(gym.Env):
    ### NOTE: [CHANGE!!!] change the n_train from 8 to 1
    ### NOTE: [CHANGE!!!] change the sampling_mode from "likelihood" to "argmax"
    def __init__(self, args):

        self.args=args

        # Batch Size
        self.batch_size = self.args.batch_size

        # Dataset
        self.data_items = self.args.data_items.split(',')
        self.data_items = torch.tensor(list(map(lambda x:int(x), self.data_items))).to(self.args.env_device)
        assert(len(self.data_items) == self.batch_size)
        self._seed = self.args.random_seed
        self.dataset = self.args.dataset
        self.n_train = self.args.n_train
        self._load_datasets()

        ## LLM
        self.max_sample_tokens = self.args.max_sample_tokens
        self.model, self.tok = get_model_and_tokenizer(self.args.model_name)
        assert isinstance(self.model, transformers.GPT2LMHeadModel)
        self.model.to(self.args.env_device)
        self.stop_tokens = stop_tokens(self.tok)
        self.ignore_tokens = ignore_tokens(self.tok)
        self.ignore_tokens_replace = ignore_tokens_replace(self.tok)
        self.vocab_size = len(self.tok)
        # Current inputs and logits

        self.topK_logistics = self.args.topK_logistics


        self.sampling_mode = self.args.sampling_mode  # "likelihood" or "argmax"
        self.num_perturb = None
        self.past_obs = None

        self.input_ids = None
        self.attention_mask = None
        self.output_mask = None
        self.input_mask = None
        self.past_kvs = None
        self.last_logits = None
        self.last_logits_unperturbed = None
        self.input_ids_unperturbed = None
        self.past_kvs_unperturbed = None
        self.sample_done = None

        ## RL: Basic Action Space and Obs Space
        # Whether perturb or not.
        # If not perturb: sample by multinomial
        # If perturb: sample by equal probability
        self.obs_dim = self.args.obs_dim
        self.action_space = gym.spaces.Discrete(2)

        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.obs_dim, self.topK_logistics), dtype=np.float32)

        # from torch.utils.tensorboard import SummaryWriter
        # self.writer = SummaryWriter(f"CS330_FastGPT_{model_name}_{env_device}/{algorithm}/OLD_Action_MLogits")
        self.zero_tensor = torch.zeros(self.batch_size).to(self.args.env_device)
        self.one_tensor = torch.ones(self.batch_size).to(self.args.env_device)

        self.reset(random=self.args.random)

    def _load_datasets(self):
        print("Dataset:", self.dataset)
        if self.dataset == "xsum":
            d = datasets.load_dataset(self.dataset, split="train").shuffle(seed=self._seed)
            filter_fn = lambda rows: [
                len(a.split(" ")) < 100 for a in rows["document"]
            ]
            d = d.filter(filter_fn, batched=True, batch_size=None)
            d = d["document"][:self.n_train]
            self.data = d
        else:
            raise NotImplementedError

    def _get_new_input(self, items):
        ret = []
        for item in items:
            ret.append(self.data[item].replace('\n', ' '))
        return ret


  and should_run_async(code)


### Sampling & Perturb

In [None]:
class LMEnv(LMEnv):
    def _feedforward(self, cur_input, attention_mask, past_kvs=None):
        # TODO: Speed up feedforward by utilizing past_kvs
        """
        :param cur_input: When past_kvs = None, tensor shape [batch_size, seq_len]. When past_kvs is not None, tensor shape [batch_size, 1]
        :param past_kvs: a cache to speed up model inference
        :return local_logits: tensor shape [batch_size, vocab_size] local logits at the last point
        :return new_past_kvs: the new model state cache
        """
        with torch.inference_mode():
            outputs = self.model(cur_input,
                                #  past_key_values=past_kvs,
                                 attention_mask=attention_mask,
                                 # use_cache=True
                                 use_cache=False)
            all_logits = outputs.logits
            B, S, V = all_logits.shape
            returned_logits = torch.ones(B, self.obs_dim, V).float().to(self.args.env_device)
            if S < self.obs_dim:
                returned_logits[:, self.obs_dim - S:, :] = all_logits
            else:
                returned_logits = all_logits[:, S - self.obs_dim:, :]
            # new_past_kvs = outputs.past_key_values
            new_past_kvs = None
            return returned_logits, new_past_kvs

    def _cat_new_word(self, sampled_token, input_ids):
        token_len = sampled_token.shape[0]
        return torch.cat((input_ids, sampled_token.clone().detach().long().view(-1, 1)), dim=1)

    def _sample_tokens(self, local_logits, input_ids, attention_mask):
        """
        :param local_logits: tensor shape [batch_size, vocab_size] local logits
         at the last point
        :param input_ids: tensor shape [batch_size, seq_len] input ids at latest
         point
        :param attention_mask: tensor shape [batch_size, seq_len] attention
         mask at latest point
        :return new_token: tensor shape [batch_size, 1]
        works together with past_kvs returned from get_logits() to feed in the
         next round of get_logits().
        :return new_input_ids: when past_kvs = None, this would return the
         complete input concat with output up to this point
        :return new_attention_mask: attention mask extended
        """
        if self.sampling_mode == "argmax":
            sampled_token = torch.argmax(local_logits, dim=-1)
        elif self.sampling_mode == "likelihood":
            sampled_token = torch.multinomial(F.softmax(local_logits, dim=-1), num_samples=1).squeeze(dim=1)
        else:
            raise NotImplementedError

        # Replace tokens such as new line with spaces

        mask = torch.any(torch.eq(sampled_token, torch.tensor(self.ignore_tokens).to(self.args.env_device)), dim=-1)
        sampled_token[mask] = self.ignore_tokens_replace

        new_token = sampled_token.view(-1, 1)
        new_input_ids = self._cat_new_word(new_token, input_ids)
        new_attention_mask = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
                )
        return new_token, new_input_ids, new_attention_mask

    def _perturb_tokens(self, local_logits, perturb_mode="chosen", perturb_ranking=-1):
        """
        :param local_logits: tensor shape [batch_size, vocab_size] local logits at the last point
        :param perturb_ranking: perturb selection of the last word
        :return new_token: the selected token to generate
        :return new_input_ids: the new input ids after the perturbation
        """
        # Get the top k predictions （1-10）
        if perturb_mode == "chosen":
            _, topk_indices = torch.topk(local_logits, perturb_ranking, dim=1)
            # Select the last item
            new_token = topk_indices[:, -1]
            new_input_ids = self._cat_new_word(new_token, self.input_ids)
            return new_token, new_input_ids
        else:
            _, topk_indices = torch.topk(local_logits, 10, dim=1)
            # Select random item
            new_token = topk_indices[:, random.randint(0, 9)]
            new_input_ids = self._cat_new_word(new_token, self.input_ids)
            return new_token, new_input_ids


### Reset

In [None]:
class LMEnv(LMEnv):
    def _obs_wrapper(self, all_logits):
        """
        :param all_logits: tensor shape [batch_size, seq_len, vocab_size]
        :return topk_values: numpy array shape [batch_size, seq_len, topk]
        :return topk_indices: numpy array shape [batch_size, seq_len, topk]
        """
        topk_values, topk_indices = torch.topk(all_logits, self.topK_logistics, dim=-1)
        # Normalize the topk_values
        topk_values = F.softmax(topk_values, dim=-1)

        return topk_values.detach().cpu().numpy(), topk_indices.detach().cpu().numpy()

    def _sample_done(self):
        """
        ERROR PRONE: Do not use unless validated.
        """
        input_ids = self.input_ids[:,-1].unsqueeze(dim=-1)
        stop_tokens = torch.tensor(self.stop_tokens).view(1, -1).to(self.args.env_device)
        # print("Token:", token.shape, "Stop Tokens:", stop_tokens.shape)
        # (batch_size=1, topk=10) & (num_stop_tokens=2)

        a = torch.any(torch.eq(input_ids, stop_tokens))

        b = self.input_ids.shape[1] >= self.max_sample_tokens
        b = torch.tensor(b).repeat((self.batch_size)).to(self.args.env_device)
        # print("Done:", a, input_ids[0][0], stop_tokens, b)
        return a | b

    def _masked(self, a, b, mask):
        """
        ERROR PRONE: Do not use unless validated.
        """
        if a is None:
            return b
        else:
            if a.shape[-1] > b.shape[-1]:
                b_zeros = torch.zeros_like(a)
                b_zeros[...,:b.shape[-1]] = b
                b = b_zeros
            else:
                a_zeros = torch.zeros_like(b)
                a_zeros[...,:a.shape[-1]] = a
                a = a_zeros
            a[mask] = b[mask]
            return a

    def _reset(self, random=True, mask=None):
        """
        :param random: whether to sample data randomly
         if random == False, choose self.data_items in the dataset
        :param mask: only reset these rows in the batch
        ERROR PRONE: Do not use unless debugged
        """
        print("Reset begins...")
        if mask is None:
            mask = self.one_tensor.bool().detach()
        if not torch.any(mask):
            return self.past_obs, None
        if random:
            data_items = torch.randint(low=0, high=self.n_train, size=(self.batch_size,)).to(env_device)
            self.data_items = self._masked(self.data_items, data_items, mask)
        print("Data Items:", self.data_items, "Mask:", mask)

        ## Get a new generate starting point
        initial_texts = self._get_new_input(self.data_items)

        be = self.tok(initial_texts,
                              return_tensors="pt",
                              padding=True, return_attention_mask=True)

        self.input_ids = self._masked(self.input_ids,
                     be["input_ids"]
                     .to(self.args.env_device),
                     mask)
        self.attention_mask = self._masked(self.attention_mask,
                     be["attention_mask"]
                     .to(self.args.env_device),
                     mask)
        self.num_perturb = self._masked(self.num_perturb, self.zero_tensor.clone().detach(), mask)
        while random and self.input_ids.shape[-1] ==0:
            self.data_items = self._masked(self.data_items,
                         np.random.randint(self.n_train, size=self.batch_size),
                         mask)
            initial_texts = self._get_new_input(self.data_items)
            be = self.tok(initial_texts,
                          return_tensors="pt",
                          padding=True,
                          return_attention_mask=True)
            self.input_ids = self._masked(self.input_ids,
                        be["input_ids"]
                        .to(self.args.env_device),
                        mask)
            self.attention_mask = self._masked(self.attention_mask,
                        be["attention_mask"]
                        .to(self.args.env_device),
                        mask)
        ## First 1 step
        all_logits, new_past_kvs = self._feedforward(self.input_ids, self.attention_mask)
        local_logits = all_logits[:, -1, :]
        self.last_logits = local_logits
        self.past_kvs = new_past_kvs
        self.sample_done = self.zero_tensor.clone().detach().bool()

        _, new_input_ids, new_attention_mask = self._sample_tokens(local_logits, self.input_ids, self.attention_mask)
        self.input_ids = new_input_ids
        self.attention_mask = new_attention_mask

        self.max_sample_tokens = max_sample_tokens
        if self.input_ids.shape[-1] + 20 > self.max_sample_tokens:
          self.max_sample_tokens = self.input_ids.shape[-1] + 20

        self.last_logits_unperturbed = self.last_logits
        self.past_kvs_unperturbed = self.past_kvs
        self.input_ids_unperturbed = self.input_ids
        self.attention_mask_unperturbed = self.attention_mask
        obs, _ = self._obs_wrapper(all_logits)

        ## NOTE: save the past obs
        self.past_obs = obs

        reset_info = {"TimeLimit.truncated": self.zero_tensor.clone().detach().bool(),
                      "DataItem": self.data_items,
                      "F_GPT_Score_drop": self.zero_tensor.clone().detach(),
                      "RL_num_perturb": self.zero_tensor.clone().detach(),
                      "last_reward": self.zero_tensor.clone().detach(),
                      }
        print("Reset ends!")
        # print(obs, reset_info)
        return obs, reset_info

    def _reset_all(self, random=True):
        """
        :param random: whether to sample data randomly
         if random == False, choose self.data_items in the dataset
        :param mask: only reset these rows in the batch
        """

        if random:
            data_items = torch.randint(low=0, high=self.n_train, size=(self.batch_size,)).to(env_device)
            self.data_items = data_items
        print("Reset All begins...", random, self.data_items)

        ## Get a new generate starting point
        initial_texts = self._get_new_input(self.data_items)

        batch_encoding = self.tok(initial_texts,
                      return_tensors="pt",
                      padding=True)

        self.input_ids = batch_encoding["input_ids"].to(self.args.env_device)
        self.attention_mask = batch_encoding["attention_mask"].to(self.args.env_device)
        self.num_perturb = self.zero_tensor.clone().detach()
        while random and self.input_ids.shape[-1] ==0:
            self.data_items = torch.randint(low=0, high=self.n_train, size=(self.batch_size,)).to(env_device)
            initial_texts = self._get_new_input(self.data_items)
            batch_encoding = self.tok(initial_texts,
                      return_tensors="pt",
                      padding=True)
            self.input_ids = batch_encoding["input_ids"].to(self.args.env_device)
            self.attention_mask = batch_encoding["attention_mask"].to(self.args.env_device)
        ## First 1 step
        all_logits, new_past_kvs = self._feedforward(self.input_ids, self.attention_mask)
        local_logits = all_logits[:, -1, :]
        self.last_logits = local_logits
        self.past_kvs = new_past_kvs

        self.sample_done = self.zero_tensor.clone().detach().bool()

        _, new_input_ids, new_attention_mask = self._sample_tokens(local_logits, self.input_ids, self.attention_mask)
        self.input_ids = new_input_ids
        self.attention_mask = new_attention_mask
        self.output_mask = self.attention_mask
        self.input_mask = self.attention_mask

        self.last_logits_unperturbed = self.last_logits
        self.past_kvs_unperturbed = self.past_kvs
        self.input_ids_unperturbed = self.input_ids
        self.attention_mask_unperturbed = self.attention_mask

        self.max_sample_tokens = self.args.max_sample_tokens
        if self.input_ids.shape[-1] + 20 > self.max_sample_tokens:
          self.max_sample_tokens = self.input_ids.shape[-1] + 20

        print("Max tokens:", self.max_sample_tokens)

        obs, _ = self._obs_wrapper(all_logits)

        ## NOTE: save the past obs
        self.past_obs = obs

        reset_info = {"TimeLimit.truncated": self.zero_tensor.clone().detach().bool(),
                      "DataItem": self.data_items,
                      "F_GPT_Score_drop": self.zero_tensor.clone().detach(),
                      "RL_num_perturb": self.zero_tensor.clone().detach(),
                      "last_reward": self.zero_tensor.clone().detach(),
                      }
        print("Reset All ends!")
        # print(obs, reset_info)
        return obs, reset_info

    def reset(self, seed: int = None, random=True, mask = None):
        # print("Resetting environment=============")
        if mask == None:
            return self._reset_all(random=random)
        else:
            return self._reset(random=random, mask=mask)
        # return obs

### Step

In [None]:
class LMEnv(LMEnv):
    def get_texts(self, mask=None):
        """
        :return texts: str list [batch_size]
        """
        input_ids = self.input_ids.clone().detach()
        if mask is not None:
          input_ids[mask == 1] = self.tok.pad_token_id
        return self.tok.batch_decode(input_ids, skip_special_tokens=True)

    def get_texts_unperturbed(self, mask=None):
        """
        :return texts: str list [batch_size]
        """
        input_ids_unperturbed = self.input_ids_unperturbed.clone().detach()
        if mask is not None:
          input_ids_unperturbed[mask == 1] = self.tok.pad_token_id
        return self.tok.batch_decode(input_ids_unperturbed, skip_special_tokens=True)

    def _step_sample(self, perturb):
        """
        :param perturb: boolean tensor of shape [batch_size]
        :return obs: tensor of shape [batch_size, obs_dim, topk]
        :return done: bool tensor of shape [batch_size]
        """
        sampled_token, sampled_output, sampled_attention_mask = self._sample_tokens(self.last_logits, self.input_ids, self.attention_mask)

        _, perturbed_output = self._perturb_tokens(self.last_logits, perturb_mode="chosen", perturb_ranking=self.args.perturb_ranking)

        self.input_ids = torch.where(perturb.unsqueeze(dim=-1), perturbed_output, sampled_output)
        self.attention_mask = sampled_attention_mask

        cur_input = self.input_ids
        self.past_kvs = None

        ## GET NEW OBS
        all_logits, new_past_kvs = self._feedforward(cur_input, self.attention_mask, self.past_kvs)
        local_logits = all_logits[:, -1, :]
        self.last_logits = local_logits
        self.past_kvs = new_past_kvs

        obs, token = self._obs_wrapper(all_logits)
        token = torch.tensor(token)[:, -1, :].unsqueeze(dim=-1).to(self.args.env_device)
        stop_tokens = torch.tensor(self.stop_tokens).view(1, 1, -1).to(self.args.env_device)

        done = torch.any(torch.eq(token, stop_tokens), dim=-1)
        done = torch.any(done, dim=-1)

        return obs, done

    def _step_sample_unperturbed(self):
        """
        Parallel also doing sampling of the unperturbed version
        """
        sampled_token, sampled_output, sampled_attention_mask = self._sample_tokens(self.last_logits_unperturbed, self.input_ids_unperturbed, self.attention_mask_unperturbed)
        # cur_input = sampled_token
        cur_input = sampled_output
        self.input_ids_unperturbed = sampled_output
        self.attention_mask_unperturbed = sampled_attention_mask

        ## GET NEW OBS
        all_logits, new_past_kvs = self._feedforward(cur_input, self.attention_mask_unperturbed,self.past_kvs_unperturbed)
        local_logits = all_logits[:, -1, :]
        self.last_logits_unperturbed = local_logits
        self.past_kvs_unperturbed = new_past_kvs

    def step(self, action):
        """
        :param action: bool tensor of shape [batch_size]
        """

        reward = self.zero_tensor.clone().detach()
        F_GPT_Score_drop = self.zero_tensor.clone().detach().float()
        perturbed_score = self.zero_tensor.clone().detach().float()
        unperturbed_score = self.zero_tensor.clone().detach().float()
        RL_num_perturb = self.zero_tensor.clone().detach().long()

        # Parse Action
        obs, done = self._step_sample(perturb=action)
        # Also parallelly performing unperturbed samples
        self._step_sample_unperturbed()

        print("Step:", self.input_ids.shape[-1])

        if self.args.rule_based_penalty:
          rule_based_penalty = 1
        else:
          rule_based_penalty = 0

        not_done = torch.logical_not(done)

        self.num_perturb = self.num_perturb + torch.where(action & not_done, 1, 0)
        penalized_low = (action & torch.tensor(obs[:,-1, 0] <= 0.55).to(self.args.env_device)).bool()
        penalized_high = (torch.logical_not(action) & torch.tensor(obs[:,-1, 0] > 0.55).to(self.args.env_device)).bool()
        reward[penalized_low] -= rule_based_penalty
        reward[penalized_high] -= rule_based_penalty
        # print(action, not_done, self.num_perturb)

        ## NOTE: save the past obs
        self.past_obs = obs


        self.sample_done = self.sample_done | done
        if self.input_ids.shape[1] >= self.max_sample_tokens:
          self.sample_done = self.one_tensor.clone().detach()
        print("Done:", self.sample_done)

        self.output_mask = torch.cat(
                    [self.output_mask,
                     torch.logical_not(self.sample_done).unsqueeze(dim=1)],
                    dim=-1)
        self.input_mask = torch.cat(
                    [self.input_mask,
                     torch.zeros_like(self.sample_done).int().unsqueeze(dim=1)],
                    dim=-1)

        if torch.all(self.sample_done):

            mask = self.input_mask | torch.logical_not(self.output_mask)
            perturbed_score = detector.infer(self.get_texts(mask))

            RL_num_perturb = self.num_perturb.clone().detach()

            unperturbed_score = detector.infer(self.get_texts_unperturbed(mask))

            F_GPT_Score_drop = 100. * (unperturbed_score - perturbed_score)

            # Reward
            reward += 100 * F_GPT_Score_drop
            reward -= 0.01 * RL_num_perturb * RL_num_perturb / 2

        info = {"TimeLimit.truncated": self.zero_tensor.clone().detach().bool().to(self.args.env_device),
                "F_GPT_Score_drop": F_GPT_Score_drop,
                "last_perturbed_score": perturbed_score,
                "last_unperturbed_score": unperturbed_score,
                "RL_num_perturb": RL_num_perturb,
                "last_reward": reward,
                }

        # If your environment does not have a concept of truncation, you can set truncated to the same value as done
        truncated = self.sample_done.bool()
        return obs, reward, self.sample_done, truncated, info
        # return obs, reward, done, info


    def seed(self, seed=None):
        self._seed = seed


### Manual RL Policy

In [None]:
def manual_policy(env: LMEnv, threshold = 0.45, num_samples = 1):
    rewards = []

    pbar = tqdm.tqdm(range(num_samples))
    for _ in pbar:

        done = False
        num_perturb = 0
        tot = 0
        reward = 0.
        obs, _ = env.reset(random=False)
        while not done:
            mask = obs[:, 0, 0] > threshold
            # mask = env.zero_tensor.clone().detach().bool().to(env_device)
            mask = torch.tensor(mask).bool().to(args.env_device)
            action = torch.where(mask, 1, 0).bool().to(args.env_device)
            num_perturb += torch.where(mask, 1, 0)
            obs, local_reward, local_done, _, _ = env.step(action)
            done = torch.all(local_done)
            reward += local_reward

        pbar.set_description(f"Reward: {reward}")
        rewards.append(reward.mean().cpu())
    print("Rewards Mean: ", np.mean(rewards), "Std: ", np.std(rewards))



In [None]:
# import sys
# sys.argv[0] = "first_arg"

In [None]:
import argparse
parser = argparse.ArgumentParser()


parser.add_argument("--max_sample_tokens", default=150, type=int)

parser.add_argument("--total_timesteps", default=1.5E5, type=int)
#1.5E5

parser.add_argument("--data_items", default="127,733,55,953,469,628,793,511", type=str)

parser.add_argument("--batch_size", default=8, type=int)

parser.add_argument("--random_seed", default=42, type=int)

parser.add_argument("--dataset", default="xsum", type=str)

parser.add_argument("--n_train", default=1000, type=int)

parser.add_argument("--topK_logistics", default=10, type=int)
parser.add_argument("--perturb_ranking", default=3, type=int)
parser.add_argument("--sampling_mode", default="likelihood", type=str)
parser.add_argument("--obs_dim", default=1, type=int)
parser.add_argument("--random", action='store_true', default=False)

parser.add_argument("--model_name", default="med", type=str)
parser.add_argument("--env_device", default="cuda", type=str)

parser.add_argument("--algorithm", default="PPO", type=str)

parser.add_argument("--tb_folder", default="./tensorboard_log", type=str)


parser.add_argument("--inference", default=False, type=bool)
parser.add_argument("--save", default=True, type=bool)


parser.add_argument("--rule_based_penalty", action='store_true', default=False)

parser.add_argument("--RL_model_name", default="raw", type=str)

parser.add_argument("--retrain_from", default="", type=str)
parser.add_argument("--cross_Sentence", action='store_true', default=False)
parser.add_argument('-f')

args = parser.parse_args()



In [None]:
# env = LMEnv(args=args)


In [None]:
# env.input_mask.shape
# env.input_mask[0]

In [None]:
# env.get_texts_unperturbed()

In [None]:
# manual_policy(env)


In [None]:
# env.num_perturb

In [None]:
# mask = env.input_mask | torch.logical_not(env.output_mask)
# mask.shape

In [None]:
# input_ids = env.input_ids.clone().detach()
# input_ids.shape

In [None]:
# mask

In [None]:
# env.data_items

In [None]:
# input_ids[mask==1] = env.tok.pad_token_id
# input_ids

In [None]:
# env.tok.batch_decode(env.input_ids, skip_special_tokens=True)

In [None]:
# sampled_texts = env.get_texts(mask=mask)

In [None]:
# perturbed_texts = env.get_texts_unperturbed(mask=env.input_mask | torch.logical_not(env.output_mask))

In [None]:
# sampled_texts

In [None]:
# detector.infer(perturbed_texts[3])

In [None]:
# perturbed_texts

In [None]:
# env.attention_mask[0]

In [None]:
# env.input_mask[0]

In [None]:
# env.output_mask[0]

In [None]:
# env.tok.batch_decode(env.input_ids, skip_special_tokens=True)[0]

In [None]:
# sampled_texts[0]

In [None]:
# perturbed_texts[0]

In [None]:
# env.input_ids_unperturbed[0]

In [None]:
# env.attention_mask_unperturbed[0]

In [None]:
# env.sample_done

## RL Vectorized Environment

In [None]:
import warnings
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Type

import gymnasium as gym
import numpy as np

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from stable_baselines3.common.vec_env.patch_gym import _patch_env
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info


class MyVecEnv(VecEnv):
    """
    Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
    Python process. This is useful for computationally simple environment such as ``Cartpole-v1``,
    as the overhead of multiprocess or multithread outweighs the environment computation time.
    This can also be used for RL methods that
    require a vectorized environment, but that you want a single environments to train with.

    :param env_fns: a list of functions
        that return environments to vectorize
    :raises ValueError: If the same environment instance is passed as the output of two or more different env_fn.
    """

    actions: np.ndarray

    def __init__(self, lm_env:LMEnv, random):
        self.env = lm_env
        super().__init__(self.env.batch_size,
                         self.env.observation_space,
                         self.env.action_space)
        self.num_envs = self.env.batch_size
        obs_space = self.env.observation_space
        self.keys, shapes, dtypes = obs_space_info(obs_space)

        self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
        self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
        self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
        self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)]
        self.metadata = self.env.metadata
        self.counter = 0
        self.random = random

    def step_async(self, actions: np.ndarray) -> None:
        self.actions = actions

    def step_without_reset(self):
        obs, self.buf_rews, terminated, truncated, buf_infos = self.env.step(
            torch.tensor(self.actions).bool().to(self.env.args.env_device)
        )
        # print(self.actions, obs)
        # convert to SB3 VecEnv api
        # if type(terminated) is not float:
        #   print(type(terminated), type(truncated))
        #   self.buf_dones = terminated | truncated
        # else:
        self.buf_dones = terminated
        # See https://github.com/openai/gym/issues/3102
        # Gym 0.26 introduces a breaking change
        for i in range(self.env.batch_size):
          buf_infos_i = {}
          for k, v in buf_infos.items():
            buf_infos_i[k] = v[i]
          self.buf_infos[i] = buf_infos_i

        self._save_obs(obs)
        res = (self._obs_from_buf(),
               np.copy(self.buf_rews.cpu()),
               np.copy(self.buf_dones.bool().cpu()),
               deepcopy(self.buf_infos))
        # print("Action:", self.actions, "Reward:", res[1], "Dones:", res[2], "Infos:", res[3], "Buf Info:", buf_infos)
        return res

    def step_wait(self) -> VecEnvStepReturn:
        self.counter += 1
        print("VecEnv Step: ", self.counter)
        obs, self.buf_rews, terminated, truncated, buf_infos = self.env.step(
            torch.tensor(self.actions).bool().to(self.env.args.env_device)
        )

        self.buf_dones = terminated
        for i in range(self.env.batch_size):
          buf_infos_i = {}
          for k, v in buf_infos.items():
            buf_infos_i[k] = v[i]
          self.buf_infos[i] = buf_infos_i

        if torch.all(self.buf_dones):
            # save final observation where user can get it, then reset
            print("Resetting 1")
            for i in range(self.env.batch_size):
              self.buf_infos[i]["terminal_observation"] = obs[i]
            obs, self.reset_infos = self.env.reset(random=False)

            print(np.copy(self.buf_dones.bool().cpu()))
        self._save_obs(obs)
        res = (self._obs_from_buf(),
               np.copy(self.buf_rews.cpu()),
               np.copy(self.buf_dones.bool().cpu()),
               deepcopy(self.buf_infos))

        return res

    def reset(self) -> VecEnvObs:

        print("Resetting 2")
        obs, self.reset_infos = self.env.reset(seed=self._seeds,
                                               random=self.random)
        # obs, self.reset_infos = self.env.reset(seed=self._seeds,
        #                                        random=True)
        print(obs.shape)
        self._save_obs(obs)

        # Seeds and options are only used once
        self._reset_seeds()
        self._reset_options()
        return self._obs_from_buf()

    def close(self) -> None:
        print("Close")
        self.env.close()

    def get_images(self) -> Sequence[Optional[np.ndarray]]:
        print("Get images")
        if self.render_mode != "rgb_array":
            warnings.warn(
                f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images."
            )
            return [None for _ in self.env.batch_size]
        return [env.render()]  # type: ignore[misc]

    def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
        """
        Gym environment rendering. If there are multiple environments then
        they are tiled together in one image via ``BaseVecEnv.render()``.

        :param mode: The rendering type.
        """
        print("Render")
        return super().render(mode=mode)

    def _save_obs(self, obs: VecEnvObs) -> None:
        for key in self.keys:
          for dim in range(self.env.batch_size):
            if key is None:
                self.buf_obs[key][dim] = obs[dim]
            else:
                self.buf_obs[key][dim] = obs[key][dim]  # type: ignore[call-overload]

    def _obs_from_buf(self) -> VecEnvObs:
        return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))

    def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
        """Return attribute from vectorized environment (see base class)."""
        print("get_attr ", attr_name, indices)
        return [getattr(self.env, attr_name) for _ in self._get_indices(indices)]

    def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
        """Set attribute inside vectorized environments (see base class)."""
        print("set_attr ", attr_name, value, indices)
        setattr(self.env, attr_name, value)

    def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
        """Call instance methods of vectorized environments."""
        print("env_method ", method_name, indices)
        return [getattr(self.env, method_name)(*method_args, **method_kwargs) for _ in self._get_indices(indices)]

    def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
        """Check if worker environments are wrapped with a given wrapper"""
        # target_envs = self._get_target_envs(indices)
        # Import here to avoid a circular import
        from stable_baselines3.common import env_util
        return [env_util.is_wrapped(self.env, wrapper_class) for _ in self._get_indices(indices)]


In [None]:

from stable_baselines3 import PPO, DQN
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, BaseCallback
from stable_baselines3.common.utils import obs_as_tensor, safe_mean, set_random_seed
from stable_baselines3.common.monitor import Monitor

from stable_baselines3.common.vec_env.subproc_vec_env import  SubprocVecEnv, _flatten_obs
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv

from stable_baselines3.common.env_checker import check_env

def init_env_for_agent_training(args):
    env = LMEnv(args=args)
    return MyVecEnv(env, random=args.random)

############################################

class TensorboardCallback(BaseCallback):
    """
    Custom callback for plotting additional values in tensorboard.
    """

    def __init__(self, verbose=0):
        super().__init__(verbose)

    def _on_step(self) -> bool:
        # Log scalar value (here a random variable)
        # success_rate = self.training_env.get_success_rate(window_size=100)
        if len(self.model.ep_info_buffer) > 0 and len(self.model.ep_info_buffer[0]) > 0:
            # import pdb; pdb.set_trace()
            F_GPT_Score_drop = safe_mean([ep_info["F_GPT_Score_drop"] for ep_info in self.model.ep_info_buffer])
            # self.logger.record("rollout/F_GPT_Score_drop", F_GPT_Score_drop)
            self.logger.record("rollout/F_GPT_Score_drop", F_GPT_Score_drop)

            RL_num_perturb = safe_mean([ep_info["RL_num_perturb"] for ep_info in self.model.ep_info_buffer])
            self.logger.record("rollout/RL_num_perturb", RL_num_perturb)

            last_reward = safe_mean([ep_info["last_reward"] for ep_info in self.model.ep_info_buffer])
            self.logger.record("rollout/last_reward", last_reward)

        return True

cust_callback = TensorboardCallback()


In [None]:


import datetime

############################################
if __name__ == "__main__":
    timestamp = datetime.datetime.now().strftime("%m%d_%H%M%S")

    reward_describe = f"RP_{args.rule_based_penalty}"
    prefix = ""
    if args.cross_Sentence:
        prefix += "Cross_CT_"
    else:
        prefix += "CT_"

    args.RL_model_name = f"{prefix}S_{args.data_items}_{reward_describe}_{timestamp}"

    tb_log_name = f"{args.algorithm}/{args.RL_model_name}"

    cpt_save_path = os.path.join(args.tb_folder, tb_log_name+"_1", "model_checkpoints/")

    checkpoint_callback = CheckpointCallback(save_freq=1E3, save_path=cpt_save_path)

    cust_callback = TensorboardCallback()

    ###########################################

    vec_env = init_env_for_agent_training(args=args)


    if args.algorithm=="PPO":
        model = PPO("MlpPolicy", vec_env, verbose=1,
                    tensorboard_log=args.tb_folder)

        if args.retrain_from != "":
            print(args.retrain_from)
            model = model.load(args.retrain_from)
            model.set_env(env=vec_env)
            print("Reload model success")

        # model.learn(total_timesteps=args.total_timesteps,
        #             tb_log_name=tb_log_name)
        model.learn(total_timesteps=args.total_timesteps,
                    tb_log_name=tb_log_name,
                    callback=[cust_callback, checkpoint_callback])
        if args.save:
            model.save(f"{args.algorithm}/{args.RL_model_name}_T_{args.total_timesteps}.pt")
    else:
        raise NotImplementedError

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
VecEnv Step:  19286
Step: 140
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
VecEnv Step:  19287
Step: 141
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
VecEnv Step:  19288
Step: 142
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
VecEnv Step:  19289
Step: 143
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
VecEnv Step:  19290
Step: 144
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
VecEnv Step:  19291
Step: 145
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
VecEnv Step:  19292
Step: 146
Done: tensor([False, False, False, False, False, False, Fals

In [None]:
model.save(f"{args.algorithm}/{args.RL_model_name}")

In [None]:
model = PPO.load(f"{args.algorithm}/{args.RL_model_name}")
obs = vec_env.reset()
done = False
while not done:
    action, _ = model.predict(obs)
    vec_env.step_async(action)
    obs, reward, done, info = vec_env.step_without_reset()
    print(done)
    done = done.all()


Resetting 2
Reset All begins... False tensor([127, 733,  55, 953, 469, 628, 793, 511], device='cuda:0')
Max tokens: 150
Reset All ends!
(8, 1, 10)
Step: 127
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
[False False False False False False False False]
Step: 128
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
[False False False False False False False False]
Step: 129
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
[False False False False False False False False]
Step: 130
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
[False False False False False False False False]
Step: 131
Done: tensor([False, False, False, False, False, False, False, False],
       device='cuda:0')
[False False False False False False False False]
Step: 132
Done: tensor([False, False, False, False, False, False, False, False

In [None]:

texts = vec_env.env.get_texts(mask = vec_env.env.input_mask | torch.logical_not(vec_env.env.output_mask))
texts_unperturbed = vec_env.env.get_texts_unperturbed(mask = vec_env.env.input_mask | torch.logical_not(vec_env.env.output_mask))
for i, text in enumerate(texts):
  print(i, text)
print("======")
for i, text in enumerate(texts_unperturbed):
  print(i, text)

0 This photo was taken just outside Glascoil in Belfast's West Belfast on Sunday night. The underneath photo is from
1  The children's father - Darren, a dad for 54 child years has died + Neither failed to put on a
2 This road, near the 4tents on suburb of Middlegate, also iced west-to west between Fel
3 Udia Estebans was picked as replacement. London Irish suffered a major scalp yesterday, losing 23 of the
4 S United strangled City for two goals in the first 10-odd mins!!!!Please enable javascript in the
5 I wouldn't be here if I walked away from all this today …… *******************If you are interested in
6 Mr Harrington got his green card and now he lives with his ex girlfriend Sandra  street. Her attempts to persuade
7 This enquired into the issue and was unable or not interested to comment and Ã¢¬©"'s
0 AvenueGreater Belfast are working under tight security banning anyone under 17 from seeing videos of planned events. The 
1  Wiltshire Fire and Rescue Service Loads of evidence fo

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!mv -r PPO /content/drive/MyDrive/PPO
!mv -r tensorboard_log /content/drive/MyDrive/tensorboard_log

Mounted at /content/drive
mv: invalid option -- 'r'
Try 'mv --help' for more information.
mv: invalid option -- 'r'
Try 'mv --help' for more information.


In [None]:
print(detector.infer(texts), detector.infer(texts_unperturbed))

tensor([0.0000, 0.0495, 0.0000, 0.0400, 0.4300, 0.0200, 0.0000, 0.0000],
       device='cuda:0') tensor([0.0300, 0.0400, 0.0700, 0.0000, 0.0000, 0.0200, 0.0400, 0.0000],
       device='cuda:0')


In [None]:
# mask = env.input_mask | torch.logical_not(env.output_mask)
# input_ids = env.input_ids_unperturbed.clone().detach()
# # print(input_ids)
# if mask is not None:
#   input_ids[mask == 1] = env.tok.pad_token_id
# input_ids[0]

In [None]:
# env.tok("Hello,I have looked very much forward, Do you hand me the gun this evening? I have no money to hand.They tried and tried to have him arrested for a murder so presumably the PCSO-ELMUC have")

In [None]:
# val = [220,  4841,   834,   220,
#          4841,  1427,  5211]
# env.tok.decode(val[6])

In [None]:
action = torch.tensor([True, False, False]).bool()
obs = torch.randn(3, 1, 5)
print(obs[:, -1, 0])
action & (obs[:, -1, 0] >= 0.55)

tensor([ 0.7782,  1.5660, -1.0054])


tensor([ True, False, False])