In [1]:
%pip install datasets detoxify

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import pandas as pd
from datasets import load_dataset, concatenate_datasets


dataset = load_dataset("OxAISH-AL-LLM/wiki_toxic", split = "train")

ds = load_dataset("textdetox/multilingual_toxic_spans", split="en")

dataset = dataset.rename_columns({"comment_text":"text"}).remove_columns(["id"])

# Filter the dataset by label
pos_samples = dataset.filter(lambda x: x["label"] == 1)
neg_samples = dataset.filter(lambda x: x["label"] == 0)

# Determine the minimum count between the two classes
min_count = min(len(pos_samples), len(neg_samples))

# Shuffle and select an equal number of examples from each subset
balanced_pos = pos_samples.shuffle(seed=42).select(range(min_count))
balanced_neg = neg_samples.shuffle(seed=42).select(range(min_count))

# Concatenate the two balanced subsets and shuffle the final dataset
balanced_dataset = concatenate_datasets([balanced_pos, balanced_neg, ds]).shuffle(seed=42)

dataset = balanced_dataset.shuffle(seed=42)




In [18]:
dataset = dataset.remove_columns(["Sentence", "Negative Connotations"])

In [20]:
dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 25868
})

In [21]:
dataset = dataset.filter(lambda example: example["text"] is not None)
dataset

Filter:   0%|          | 0/25868 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'label'],
    num_rows: 25868
})

In [22]:
import re
import unicodedata
from datasets import load_dataset

# Define the cleaning function to apply on each example
def clean_text(example):
    # Adjust "text" to match the key in your dataset
    text = example["text"]
    
    # Remove extra whitespace and newlines
    text = re.sub(r'\s+', ' ', text)
    
    # Normalize Unicode characters (decomposes combined characters)
    text = unicodedata.normalize('NFKD', text)
    
    # Optionally remove non-ASCII characters (e.g., unusual symbols)
    text = text.encode('ascii', 'ignore').decode('utf-8')
    
    # Trim any leading/trailing whitespace
    example["text"] = text.strip()
    return example

dataset = dataset.map(clean_text)

# Check a cleaned example
(dataset[90])


Map:   0%|          | 0/25868 [00:00<?, ? examples/s]

{'text': 'WTF???? WHAT DID I DO??? FEMALE EJACULATION IS FUCKING GROSS!',
 'label': 1}

In [23]:
from detoxify import Detoxify
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"


toxicity_model = Detoxify("original-small", device = device)

def compute_reward(examples):
    text = examples["text"]
    results = toxicity_model.predict(text)
    toxicity_score = results["toxicity"]

    reward = abs(toxicity_score)
    examples["reward"] = reward
    return examples

dataset = dataset.map(compute_reward, batched=False)

Map:   0%|          | 0/25868 [00:00<?, ? examples/s]

In [24]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def tokenize_fn(examples):
    return tokenizer(examples["text"], truncation=True, padding = "max_length", max_length=128,)

data = dataset.map(tokenize_fn, batched=True)

Map:   0%|          | 0/25868 [00:00<?, ? examples/s]

In [25]:
# Add RTG as a separate feature with proper shape
def add_rtg(examples):
    batch_size = len(examples["input_ids"])
    examples["rtg"] = []

    for i in range(batch_size):
        # Create a tensor of the same reward value for each token position
        seq_length = len(examples["input_ids"][i])
        reward_value = examples["reward"][i]
        examples["rtg"].append([reward_value] * seq_length)

    return examples

data = data.map(add_rtg, batched=True)

Map:   0%|          | 0/25868 [00:00<?, ? examples/s]

In [26]:
from transformers import GPT2LMHeadModel, GPT2Config
import torch

class DecisionTransformer(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # RTG embedding layer to convert scalar rewards to embedding dimension
        self.rtg_embedding = torch.nn.Linear(1, config.n_embd)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        rtg=None,
        labels=None,
        return_dict=None,
        **kwargs
    ):
        if rtg is None:
            raise ValueError("RTG values must be provided")

        # Remove unexpected keyword arguments before passing to the transformer
        kwargs.pop("num_items_in_batch", None)

        # Reshape RTG for embedding: (batch_size, seq_len) -> (batch_size, seq_len, 1)
        rtg = rtg.unsqueeze(-1).float()

        # Embed RTG values
        rtg_emb = self.rtg_embedding(rtg)

        # Get token embeddings
        token_emb = self.transformer.wte(input_ids)

        # Combine token embeddings with RTG embeddings
        inputs_embeds = token_emb + rtg_emb

        # Use the combined embeddings for the transformer
        outputs = self.transformer(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=return_dict,
            **kwargs
        )

        hidden_states = outputs[0]
        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            # Calculate loss using cross entropy
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )

        if return_dict:
            return {"loss": loss, "logits": lm_logits}
        else:
            return (loss, lm_logits) if loss is not None else lm_logits



In [27]:
class RTGDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        input_ids = torch.stack([torch.tensor(x["input_ids"]) for x in batch])
        attention_mask = torch.stack([torch.tensor(x["attention_mask"]) for x in batch])
        rtg = torch.stack([torch.tensor(x["rtg"]) for x in batch])
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "rtg": rtg,
            "labels": labels,
        }

In [28]:
#train-test split
data = data.train_test_split(test_size=0.1)

In [71]:
from transformers import TrainingArguments, Trainer

# Initialize model
config = GPT2Config.from_pretrained("gpt2")
model = DecisionTransformer(config)
model.resize_token_embeddings(len(tokenizer))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training arguments
training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=4,
    logging_steps=100,
    save_steps=10000,
    eval_strategy="steps",
    eval_steps=1000,
    load_best_model_at_end=True,
    remove_unused_columns=False,  # Important to keep 'rtg'
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data["train"],
    eval_dataset=data["test"],
    data_collator=RTGDataCollator(tokenizer),
)


# Train!
trainer.train()

Step,Training Loss,Validation Loss
1000,3.1673,3.223577
2000,3.2312,3.152618
3000,3.0977,3.089689
4000,2.9192,3.047583
5000,2.7822,3.006578
6000,2.9421,2.960296
7000,2.7154,2.933096
8000,2.7195,2.90502
9000,2.7612,2.887602
10000,2.7831,2.865292


Could not locate the best model at ./output/checkpoint-23000/pytorch_model.bin, if you are running a distributed training on multiple nodes, you should activate `--save_on_each_node`.


TrainOutput(global_step=23284, training_loss=2.83184582202874, metrics={'train_runtime': 3138.7269, 'train_samples_per_second': 29.669, 'train_steps_per_second': 7.418, 'total_flos': 6083249650532352.0, 'train_loss': 2.83184582202874, 'epoch': 4.0})

In [74]:
def generate_conditioned_text(model, tokenizer, prompt, target_rtg, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # Create RTG tensor with the target value
    rtg = torch.tensor([[target_rtg] * input_ids.shape[1]], dtype=torch.float).to(device)

    # Generate text
    seq_length = input_ids.shape[1]
    for _ in range(max_length):
        # Forward pass
        with torch.no_grad():
            rtg_current = rtg[:, :seq_length]
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                rtg=rtg_current,
                return_dict=True
            )

        # Get next token prediction
        next_token_logits = outputs["logits"][:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

        # Append prediction to input
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)

        new_rtg = torch.tensor([[target_rtg]], dtype=torch.float).to(device)
        rtg = torch.cat([rtg, new_rtg], dim=1)

        # Stop if EOS token
        if next_token.item() == tokenizer.eos_token_id:
            break

        seq_length += 1

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


prompt = "You are"
less_toxic_text = generate_conditioned_text(model, tokenizer, prompt, target_rtg=0.9)
more_toxic_text = generate_conditioned_text(model, tokenizer, prompt, target_rtg=0.0)

print("More Toxic Text:", less_toxic_text)
print("Less Toxic Text:", more_toxic_text)

More Toxic Text: You are a life. I'm a idiot.
Less Toxic Text: You are a article. I'm a article to be a article.


In [75]:
import torch.nn.functional as F

def generate_conditioned_text2(model, tokenizer, prompt, target_rtg, max_length=50, temperature=1.0, top_k=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # Create RTG tensor with the target value for each token in the prompt
    rtg = torch.tensor([[target_rtg] * input_ids.shape[1]], dtype=torch.float).to(device)

    seq_length = input_ids.shape[1]
    for _ in range(max_length):
        with torch.no_grad():
            # Slice rtg to match current sequence length
            rtg_current = rtg[:, :seq_length]
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                rtg=rtg_current,
                return_dict=True
            )

        # Get next token logits and apply temperature scaling
        next_token_logits = outputs["logits"][:, -1, :] / temperature

        # Apply top-k filtering
        top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
        probabilities = F.softmax(top_k_logits, dim=-1)
        next_token = top_k_indices[0, torch.multinomial(probabilities, num_samples=1)]

        # Append the predicted token to input_ids and update attention mask

        input_ids = torch.cat([input_ids, next_token], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)

        # Append the target reward for the new token
        new_rtg = torch.tensor([[target_rtg]], dtype=torch.float).to(device)
        rtg = torch.cat([rtg, new_rtg], dim=1)

        # Stop if EOS token is generated
        if next_token.item() == tokenizer.eos_token_id:
            break

        seq_length += 1

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

less_toxic_text = generate_conditioned_text2(model, tokenizer, prompt, target_rtg=1)
more_toxic_text = generate_conditioned_text2(model, tokenizer, prompt, target_rtg=0.0)
avg_toxic = generate_conditioned_text2(model,tokenizer, prompt, target_rtg=0.5 )

print("More Toxic Text:", less_toxic_text)
print("Less Toxic Text:", more_toxic_text)
print("Avg Toxic Text:", avg_toxic)


More Toxic Text: You are an fucking fucking asshole. Go I will going off that you are a dumb to, you like shit, you are my dumb to get a life.
Less Toxic Text: You are. So an problem, I have a page with good article was not be a page in a edit (for source is a problem of being the talk. I hope I would be not find the lot of the encyclopedia, I see I hope the link
Avg Toxic Text: You are a page about I am about me now so the article is my edits on a user, it isnipedes on me.


In [76]:
model.save_pretrained("Ashed00/Controlled_toxicity")

In [78]:
tokenizer.save_pretrained("Ashed00/Controlled_toxicity")

('Ashed00/Controlled_toxicity/tokenizer_config.json',
 'Ashed00/Controlled_toxicity/special_tokens_map.json',
 'Ashed00/Controlled_toxicity/vocab.json',
 'Ashed00/Controlled_toxicity/merges.txt',
 'Ashed00/Controlled_toxicity/added_tokens.json',
 'Ashed00/Controlled_toxicity/tokenizer.json')