In [1]:
%%capture
!pip install datasets detoxify

In [12]:
from datasets import load_dataset, Dataset
import pandas as pd

dataset = load_dataset("heegyu/toxic_conversations_balanced", split = "train")
dataset = dataset.shuffle(seed=42)
dataset = dataset.select(range(100000))




Repo card metadata block was not found. Setting CardData to empty.


In [13]:
dataset

Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 100000
})

In [16]:
dataset[90]

{'text': 'is it going to be a marijuana grow op, homeless shelter, low income housing or maintenance facility for EMX?',
 'label': 0,
 'label_text': 'not toxic'}

In [17]:
dataset = dataset.filter(lambda example: example["text"] is not None)
dataset

Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 100000
})

In [18]:
from detoxify import Detoxify
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"


toxicity_model = Detoxify("original-small", device = device)

def compute_reward(examples):
    text = examples["text"]
    results = toxicity_model.predict(text)
    toxicity_score = results["toxicity"]

    reward = abs(toxicity_score)
    examples["reward"] = reward
    return examples

dataset = dataset.map(compute_reward, batched=False)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

Downloading: "https://github.com/unitaryai/detoxify/releases/download/v0.1.2/original-albert-0e1d6498.ckpt" to /root/.cache/torch/hub/checkpoints/original-albert-0e1d6498.ckpt
100%|██████████| 44.6M/44.6M [00:00<00:00, 115MB/s] 


config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [19]:
dataset = dataset.remove_columns(["label_text"])

In [20]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def tokenize_fn(examples):
    return tokenizer(examples["text"], truncation=True, padding = "max_length", max_length=128,)

data = dataset.map(tokenize_fn, batched=True)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [21]:
# Add RTG as a separate feature with proper shape
def add_rtg(examples):
    batch_size = len(examples["input_ids"])
    examples["rtg"] = []

    for i in range(batch_size):
        # Create a tensor of the same reward value for each token position
        seq_length = len(examples["input_ids"][i])
        reward_value = examples["reward"][i]
        examples["rtg"].append([reward_value] * seq_length)

    return examples

data = data.map(add_rtg, batched=True)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [26]:
from transformers import GPT2LMHeadModel, GPT2Config
import torch

class DecisionTransformer(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # RTG embedding layer to convert scalar rewards to embedding dimension
        self.rtg_embedding = torch.nn.Linear(1, config.n_embd)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        rtg=None,
        labels=None,
        return_dict=None,
        **kwargs
    ):
        if rtg is None:
            raise ValueError("RTG values must be provided")

        # Remove unexpected keyword arguments before passing to the transformer
        kwargs.pop("num_items_in_batch", None)

        # Reshape RTG for embedding: (batch_size, seq_len) -> (batch_size, seq_len, 1)
        rtg = rtg.unsqueeze(-1).float()

        # Embed RTG values
        rtg_emb = self.rtg_embedding(rtg)

        # Get token embeddings
        token_emb = self.transformer.wte(input_ids)

        # Combine token embeddings with RTG embeddings
        inputs_embeds = token_emb + rtg_emb

        # Use the combined embeddings for the transformer
        outputs = self.transformer(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=return_dict,
            **kwargs
        )

        hidden_states = outputs[0]
        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            # Calculate loss using cross entropy
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

        if return_dict:
            return {"loss": loss, "logits": lm_logits}
        else:
            return (loss, lm_logits) if loss is not None else lm_logits



In [27]:
class RTGDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        input_ids = torch.stack([torch.tensor(x["input_ids"]) for x in batch])
        attention_mask = torch.stack([torch.tensor(x["attention_mask"]) for x in batch])
        rtg = torch.stack([torch.tensor(x["rtg"]) for x in batch])
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "rtg": rtg,
            "labels": labels,
        }

In [24]:
#train-test split
data = data.train_test_split(test_size=0.2)

In [28]:
from transformers import TrainingArguments, Trainer

# Initialize model
config = GPT2Config.from_pretrained("gpt2")
model = DecisionTransformer(config)
model.resize_token_embeddings(len(tokenizer))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training arguments
training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    #num_train_epochs=1,
    max_steps = 7000, #kaggle out of storage
    logging_steps=100,
    save_steps=5000,
    eval_strategy="steps",
    eval_steps=1000,
    load_best_model_at_end=True,
    remove_unused_columns=False,  # Important to keep 'rtg'
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data["train"],
    eval_dataset=data["test"],
    data_collator=RTGDataCollator(tokenizer),
)


# Train!
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbt22ece049[0m ([33mbt21ece003-nit-nagpur[0m). Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss
700,3.2713,3.26656
1400,3.2597,3.164512
2100,3.143,3.092297
2800,3.1424,3.051057
3500,3.0424,3.02236
4200,2.9958,2.989415
4900,3.0125,2.956045
5600,3.0257,2.932941
6300,2.9529,2.919227
7000,2.9419,2.914364


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


TrainOutput(global_step=7000, training_loss=3.096241934640067, metrics={'train_runtime': 4530.3703, 'train_samples_per_second': 12.361, 'train_steps_per_second': 1.545, 'total_flos': 3658154508288000.0, 'train_loss': 3.096241934640067, 'epoch': 0.7})

In [42]:
def generate_conditioned_text(model, tokenizer, prompt, target_rtg, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # Create RTG tensor with the target value
    rtg = torch.tensor([[target_rtg] * input_ids.shape[1]], dtype=torch.float).to(device)

    # Generate text
    seq_length = input_ids.shape[1]
    for _ in range(max_length):
        # Forward pass
        with torch.no_grad():
            rtg_current = rtg[:, :seq_length]
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                rtg=rtg_current,
                return_dict=True
            )

        # Get next token prediction
        next_token_logits = outputs["logits"][:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

        # Append prediction to input
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)

        new_rtg = torch.tensor([[target_rtg]], dtype=torch.float).to(device)
        rtg = torch.cat([rtg, new_rtg], dim=1)

        # Stop if EOS token
        if next_token.item() == tokenizer.eos_token_id:
            break

        seq_length += 1

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


prompt = "You are a "
less_toxic_text = generate_conditioned_text(model, tokenizer, prompt, target_rtg=0.9)
more_toxic_text = generate_conditioned_text(model, tokenizer, prompt, target_rtg=0.0)

print("More Toxic Text:", less_toxic_text)
print("Less Toxic Text:", more_toxic_text)

More Toxic Text: You are a                                                   
Less Toxic Text: You are a                                                   


In [None]:
import torch.nn.functional as F

def generate_conditioned_text2(model, tokenizer, prompt, target_rtg, max_length=50, temperature=1.0, top_k=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # Create RTG tensor with the target value for each token in the prompt
    rtg = torch.tensor([[target_rtg] * input_ids.shape[1]], dtype=torch.float).to(device)

    seq_length = input_ids.shape[1]
    for _ in range(max_length):
        with torch.no_grad():
            # Slice rtg to match current sequence length
            rtg_current = rtg[:, :seq_length]
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                rtg=rtg_current,
                return_dict=True
            )

        # Get next token logits and apply temperature scaling
        next_token_logits = outputs["logits"][:, -1, :] / temperature

        # Apply top-k filtering
        top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
        probabilities = F.softmax(top_k_logits, dim=-1)
        next_token = top_k_indices[0, torch.multinomial(probabilities, num_samples=1)]

        # Append the predicted token to input_ids and update attention mask

        input_ids = torch.cat([input_ids, next_token], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)

        # Append the target reward for the new token
        new_rtg = torch.tensor([[target_rtg]], dtype=torch.float).to(device)
        rtg = torch.cat([rtg, new_rtg], dim=1)

        # Stop if EOS token is generated
        if next_token.item() == tokenizer.eos_token_id:
            break

        seq_length += 1

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

less_toxic_text = generate_conditioned_text2(model, tokenizer, prompt, target_rtg=1)
more_toxic_text = generate_conditioned_text2(model, tokenizer, prompt, target_rtg=0.0)

print("More Toxic Text:", less_toxic_text)
print("Less Toxic Text:", more_toxic_text)