In [None]:
!venv/bin/pip install peft

In [None]:
!venv/bin/pip install bitsandbytes datasets accelerate

In [None]:
from huggingface_hub import login
login("XXX")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from transformers import BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import os
import torch
os.makedirs("offload", exist_ok=True)
base_model_name = "CohereLabs/aya-expanse-8b"
folder = "./drive/MyDrive/debias/"
checkpoint_dir = "iproskurina/AYA-8b-expanse-tuned-all-data"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto"
)


Loading checkpoint shards: 100%|██████████████████| 4/4 [00:30<00:00,  7.58s/it]


In [None]:
# tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# base_model = AutoModelForCausalLM.from_pretrained(base_model_name,  device_map="auto",
#     offload_folder="offload")

model = PeftModel.from_pretrained(base_model, checkpoint_dir)
model = torch.compile(model)
model.eval()


OptimizedModule(
  (_orig_mod): PeftModelForCausalLM(
    (base_model): LoraModel(
      (model): CohereForCausalLM(
        (model): CohereModel(
          (embed_tokens): Embedding(256000, 4096, padding_idx=0)
          (layers): ModuleList(
            (0-31): 32 x CohereDecoderLayer(
              (self_attn): CohereAttention(
                (q_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=4096, out_features=8, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=8, out_features=4096, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
        

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# base_model = AutoModelForCausalLM.from_pretrained(base_model_name,  device_map="auto",
#     offload_folder="offload")

In [None]:
def detoxify(input_text, max_new_tokens=100):
    prompt =f"""
    You are a text detoxification assistant. Your task is to rewrite toxic, offensive, or harmful text to make it non-toxic, respectful, and safe for all audiences.\n
    Instructions:\n
    Keep the original meaning and intent of the message.\n
    Maintain the original language (e.g., English, French, etc.).\n
    Make only the minimal necessary changes to remove any toxic, abusive, offensive, or inappropriate language.\n
    You cannot hallucinate\n
    Stop generating once the detoxified version is complete and neutral — do not add extra commentary, explanation, or continuation.\n
    Toxic Text:\n
    {input_text}\n
    Detoxified Version:\n
    """

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id,temperature=1.2,num_beams=5,early_stopping=True)
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)

    detoxified = result.split("## Detoxified text:")[-1].strip()
    return detoxified


In [None]:
import torch

def batch_encode(tokenizer, texts, batch_size=256, max_length=128):
    input_ids = []
    attention_mask = []

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer.batch_encode_plus(
            batch,
            max_length=max_length,
            padding='longest',
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=False,
            return_tensors='pt'
        ).to("cuda")
        input_ids.append(inputs['input_ids'])
        attention_mask.append(inputs['attention_mask'])


    input_ids = torch.cat(input_ids, dim=0)
    attention_mask = torch.cat(attention_mask, dim=0)

    return input_ids, attention_mask


In [None]:
import pandas as pd
from tqdm import tqdm
import torch

def detoxify_file(tsv_input_path, tsv_output_path, model, tokenizer, max_length=512, batch_size=64):
    df = pd.read_csv(tsv_input_path, sep='\t')
    #print(df['lang'].unique())

    #df = df[df['lang'] == "fr"].reset_index(drop=True)

    detox_outputs = []

    for i in tqdm(range(0, len(df), batch_size)):
        batch_df = df.iloc[i:i+batch_size]
        prompts = [f"""
You are a text detoxification assistant. Your task is to rewrite toxic, offensive, or harmful text to make it non-toxic, respectful, and safe for all audiences.\n
Instructions:\n
Keep the original meaning and intent of the message.\n
Maintain the original language (e.g., English, French, etc.).\n
Make only the minimal necessary changes to remove any toxic, abusive, offensive, or inappropriate language.\n
You cannot hallucinate\n
Stop generating once the detoxified version is complete and neutral — do not add extra commentary, explanation, or continuation.\n
Toxic Text:\n
{toxic_text}\n
Detoxified Version:\n
""" for toxic_text in batch_df["toxic_sentence"]]

        # Tokenize in batch
        inputs = tokenizer.batch_encode_plus(
            prompts,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=max_length
        )

        input_ids = inputs["input_ids"].to(model.device)
        attention_mask = inputs["attention_mask"].to(model.device)

        # Generate in batch
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=100,
                pad_token_id=tokenizer.eos_token_id,
                temperature=1.2,
                do_sample=True,
                top_p=0.9
            )

        # Decode in batch
        decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)


        for result in decoded_outputs:
            detox_outputs.append(result)

    df["neutral_sentence"] = detox_outputs
    df.to_csv(tsv_output_path, sep='\t', index=False)
    print(f"Fichier sauvegardé à : {tsv_output_path}")


In [None]:
# import gc
# import torch

# gc.collect()                  # Python garbage collector
# torch.cuda.empty_cache()      # Clear unused memory from PyTorch
# torch.cuda.ipc_collect()      # Free inter-process cache


In [None]:
token_lengths = df["toxic_sentence"].apply(lambda x: len(tokenizer.encode(x, truncation=False)))
max_length = token_lengths.max()
print(f"Maximum tokenized length: {max_length}")

In [None]:
device="cuda"
model.to(device)
model.eval()
df = detoxify_file("data/test_inputs_upd.tsv","data/detoxified.tsv",
              model,tokenizer,max_length=320)

  0%|                                                   | 0/141 [00:00<?, ?it/s]

In [None]:
df

Unnamed: 0.1,Unnamed: 0,toxic_sentence,neutral_sentence,lang
0,6190,"Fillon est un fils de pute, mais possède un pr...",\nYou are a text detoxification assistant. You...,fr
1,6339,Vitellius fils de tepu qui t a donné le droit ...,\nYou are a text detoxification assistant. You...,fr
2,6169,que ce torche cul de Figaro pour parler des év...,\nYou are a text detoxification assistant. You...,fr
3,6243,Quand j'ai pété mon câble au collège et que j'...,\nYou are a text detoxification assistant. You...,fr
4,6061,Le truc c'est que la critique étrangère vient ...,\nYou are a text detoxification assistant. You...,fr


In [None]:
df = pd.read_csv("data/detoxified_mini.tsv",sep="\t")
df

Unnamed: 0.1,Unnamed: 0,toxic_sentence,neutral_sentence,lang
0,6190,"Fillon est un fils de pute, mais possède un pr...",\nYou are a text detoxification assistant. You...,fr
1,6339,Vitellius fils de tepu qui t a donné le droit ...,\nYou are a text detoxification assistant. You...,fr
2,6169,que ce torche cul de Figaro pour parler des év...,\nYou are a text detoxification assistant. You...,fr
3,6243,Quand j'ai pété mon câble au collège et que j'...,\nYou are a text detoxification assistant. You...,fr
4,6061,Le truc c'est que la critique étrangère vient ...,\nYou are a text detoxification assistant. You...,fr


In [None]:
import pandas as pd
import re

def extract_detoxified_version(df, col_name, new_col_name="detoxified_text"):

    def extract(text):
        match = re.search(r"Detoxified Version:\s*\n(.*)", text)
        if match:
            return match.group(1).strip()
        return None

    df[new_col_name] = df[col_name].apply(extract)
    return df

df = extract_detoxified_version(df,"neutral_sentence")