In [None]:
pip install evaluate


In [None]:
torch.cuda.empty_cache()

In [1]:
pip install rouge_score

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install evaluate

Note: you may need to restart the kernel to use updated packages.


In [None]:
#!/usr/bin/env python
# R3-QuARC+ : Memory-Efficient Reinforcement Learning–Enhanced QA (Counterspeech) Model

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
import evaluate
from datasets import Features, Value, load_dataset, concatenate_datasets, ClassLabel

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -------------------------------
# 1. Load and Combine Data
# -------------------------------
expected_features = Features({
    "hatespeech": Value("string"),
    "csType": Value("string"),
    "counterspeech": Value("string"),
    "Suggest": Value("string"),
    "Relevance": Value("float64"),
    "Aggressive": Value("float64"),
    "Complexity": Value("float64"),
    "Comments": Value("float64"),
    "source": Value("string"),
    "claim": Value("string"),
    "centralTopic": Value("string"),
    "speakerIntent": Value("string"),
    "targetGroup": Value("string"),
    "relevantPowerDynamics": Value("string"),
    "hatespeechImplication": Value("string"),
    "targetGroupEmotionalReaction": Value("string"),
    "targetGroupCognitiveReaction": Value("string"),
    "hatespeechOffensiveness": Value("string"),
    "id": Value("int64"),
    "is_high_quality": Value("string"),
    "hs_id": Value("int64"),
    "hatespeechTarget": Value("string"),
    "powerDynamics": Value("string"),
    "prompt_offensiveness": Value("string"),
    "prompt_target_group": Value("string"),
    "prompt_speaker_intent": Value("string"),
    "prompt_power_dynamics": Value("string"),
    "prompt_implication": Value("string"),
    "prompt_emotional_reaction": Value("string"),
    "prompt_cognitive_reaction": Value("string"),
    "prompt_cs_generation": Value("string")
})

data_files = {
    "train": "/kaggle/input/nlp-midsem-novelty/train.csv",
    "validation": "/kaggle/input/nlp-midsem-novelty/validation.csv",
    "test": "/kaggle/input/nlp-midsem-novelty/test.csv"
}

raw_datasets = load_dataset("csv", data_files=data_files, features=expected_features)

# Combine and shuffle
combined_ds = concatenate_datasets([
    raw_datasets["train"],
    raw_datasets["validation"],
    raw_datasets["test"]
])
combined_ds = combined_ds.shuffle(seed=42)

# -------------------------------
# 2. Cast and Stratified Split
# -------------------------------
is_high_quality_label = ClassLabel(names=["no", "yes"])
combined_ds = combined_ds.cast_column("is_high_quality", is_high_quality_label)

split1 = combined_ds.train_test_split(test_size=0.4, stratify_by_column="is_high_quality", seed=42)
new_train = split1["train"]
split2 = split1["test"].train_test_split(test_size=0.5, stratify_by_column="is_high_quality", seed=42)
new_val = split2["train"]
new_test = split2["test"]

raw_train = new_train
raw_val = new_val
raw_test = new_test

print("New Train distribution:", raw_train.features["is_high_quality"])
print("New Val distribution:", raw_val.features["is_high_quality"])
print("New Test distribution:", raw_test.features["is_high_quality"])

# -------------------------------
# 3. Initialize Tokenizers and Models (Lighter Models)
# -------------------------------
# Using t5-small for the actor reduces memory footprint.
actor_tokenizer = AutoTokenizer.from_pretrained("t5-small")
actor_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)

# Retain a lightweight critic.
critic_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
critic_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2).to(device)

max_seq_len = 128  # Reduce sequence length for memory savings

# -------------------------------
# 4. Define Helper Functions
# -------------------------------
def build_rationale(imp, tg, pd, off):
    rationale_parts = []
    if imp:
        text = str(imp).strip().strip('"')
        if not text.endswith(('.', '?', '!', '"', "'")):
            text = text.rstrip() + "."
        rationale_parts.append(text)
    if tg:
        tg_text = str(tg).strip().strip('"')
        pd_text = str(pd).strip().strip('"') if pd else ""
        sentence = f"It targets {tg_text}, reflecting {pd_text}." if pd_text else f"It targets {tg_text}."
        rationale_parts.append(sentence)
    if off:
        off_text = str(off).strip().strip('"')
        if not off_text.endswith(('.', '?', '!', '"', "'")):
            off_text = off_text.rstrip() + "."
        if off_text and off_text[0].islower():
            off_text = off_text.capitalize()
        off_sentence = f"This is {off_text}" if not off_text.startswith("It is") else off_text
        if not off_sentence.endswith('.'):
            off_sentence += '.'
        rationale_parts.append(off_sentence)
    return " ".join(rationale_parts)

def prepare_actor(example):
    hs = str(example.get("hatespeech", ""))
    cs_type = str(example.get("csType", ""))
    cs = str(example.get("counterspeech", ""))
    imp = example.get("hatespeechImplication", "") or ""
    tg = example.get("targetGroup", "") or ""
    pd = example.get("relevantPowerDynamics", "") or ""
    off = example.get("hatespeechOffensiveness", "") or ""
    input_text = f"{cs_type}: {hs}"
    output_text = cs.strip()
    if output_text and output_text[-1] not in ".!?\"'":
        output_text += "."
    rationale_text = build_rationale(imp, tg, pd, off)
    if rationale_text:
        output_text += f" Rationale: {rationale_text.strip()}"
        if output_text and output_text[-1] not in ".!?\"'":
            output_text += "."
    model_inputs = actor_tokenizer(input_text, max_length=max_seq_len, truncation=True, padding="max_length")
    with actor_tokenizer.as_target_tokenizer():
        labels = actor_tokenizer(output_text, max_length=max_seq_len, truncation=True, padding="max_length")["input_ids"]
    labels = [l if l != actor_tokenizer.pad_token_id else -100 for l in labels]
    return {
        "input_ids": model_inputs.get("input_ids", []),
        "attention_mask": model_inputs.get("attention_mask", []),
        "labels": labels
    }

def prepare_critic(example):
    hs = str(example.get("hatespeech", ""))
    cs = str(example.get("counterspeech", ""))
    encodings = critic_tokenizer(hs, cs, max_length=max_seq_len, truncation=True, padding="max_length")
    label = example["is_high_quality"]
    return {**encodings, "label": label}

# -------------------------------
# 5. Prepare Datasets and Set Format
# -------------------------------
actor_train_ds = raw_train.map(prepare_actor, batched=False, remove_columns=raw_train.column_names)
actor_val_ds   = raw_val.map(prepare_actor, batched=False, remove_columns=raw_val.column_names)
actor_test_ds  = raw_test.map(prepare_actor, batched=False, remove_columns=raw_test.column_names)

critic_train_ds = raw_train.map(prepare_critic, batched=False, remove_columns=raw_train.column_names)
critic_val_ds   = raw_val.map(prepare_critic, batched=False, remove_columns=raw_val.column_names)

print("Actor Train Columns:", actor_train_ds.column_names)
print("Actor Val Columns:", actor_val_ds.column_names)

actor_train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
actor_val_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
actor_test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
critic_train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
critic_val_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# -------------------------------
# 6. Create DataLoaders
# -------------------------------
batch_size = 8
actor_train_loader = DataLoader(actor_train_ds, batch_size=batch_size, shuffle=True)
actor_val_loader = DataLoader(actor_val_ds, batch_size=batch_size)
critic_train_loader = DataLoader(critic_train_ds, batch_size=8, shuffle=True)
critic_val_loader = DataLoader(critic_val_ds, batch_size=8)

# -------------------------------
# 7. Initialize Optimizers
# -------------------------------
actor_optimizer = torch.optim.Adam(actor_model.parameters(), lr=5e-5)
critic_optimizer = torch.optim.Adam(critic_model.parameters(), lr=1e-5)

# -------------------------------
# 8. Supervised Training for Critic
# -------------------------------
print("Training critic (quality classifier)...")
epochs_critic = 2
critic_model.train()
for epoch in range(epochs_critic):
    total_loss = 0.0
    for batch in critic_train_loader:
        critic_optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attn = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        outputs = critic_model(input_ids=input_ids, attention_mask=attn, labels=labels)
        loss = outputs.loss
        loss.backward()
        critic_optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(critic_train_loader)
    print(f"[Critic] Epoch {epoch+1}/{epochs_critic} - Average Loss: {avg_loss:.4f}")

# Evaluate critic on validation set
critic_model.eval()
correct = 0
total = 0
for batch in critic_val_loader:
    input_ids = batch["input_ids"].to(device)
    attn = batch["attention_mask"].to(device)
    labels = batch["label"].to(device)
    with torch.no_grad():
        logits = critic_model(input_ids=input_ids, attention_mask=attn).logits
        preds = logits.argmax(dim=-1)
    correct += (preds == labels).sum().item()
    total += labels.size(0)
val_acc = correct / total if total > 0 else 0
print(f"[Critic] Validation Accuracy: {val_acc*100:.2f}%")

# -------------------------------
# 9. Supervised Training for Actor
# -------------------------------
print("\nTraining actor model (supervised fine-tuning on reference counterspeech)...")
epochs_actor_sup = 2
actor_model.train()
rouge_metric = evaluate.load("rouge")
for epoch in range(epochs_actor_sup):
    total_loss = 0.0
    for batch in actor_train_loader:
        actor_optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attn = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        outputs = actor_model(input_ids=input_ids, attention_mask=attn, labels=labels)
        loss = outputs.loss
        loss.backward()
        actor_optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(actor_train_loader)
    print(f"[Actor Sup] Epoch {epoch+1}/{epochs_actor_sup} - Average Loss: {avg_loss:.4f}")
    
    # Evaluate on validation after each epoch
    actor_model.eval()
    val_preds = []
    val_refs = []
    for batch in actor_val_loader:
        input_ids = batch["input_ids"].to(device)
        attn = batch["attention_mask"].to(device)
        with torch.no_grad():
            gen_outputs = actor_model.generate(
                input_ids=input_ids, 
                attention_mask=attn,
                max_length=128, 
                num_beams=4, 
                num_return_sequences=1
            )
        for i, output_ids in enumerate(gen_outputs):
            text = actor_tokenizer.decode(output_ids, skip_special_tokens=True)
            if "Rationale:" in text:
                text = text.split("Rationale:")[0].strip()
            val_preds.append(text)
            ref_cs = actor_tokenizer.decode(
                batch["labels"][i][batch["labels"][i] != -100],
                skip_special_tokens=True
            )
            if "Rationale:" in ref_cs:
                ref_cs = ref_cs.split("Rationale:")[0].strip()
            val_refs.append(ref_cs)
    rouge_scores = rouge_metric.compute(
        predictions=[p.strip() for p in val_preds],
        references=[r.strip() for r in val_refs],
        use_stemmer=True
    )
    print(f"[Actor Sup] Epoch {epoch+1} Validation Metrics:")
    for key, value in rouge_scores.items():
        if isinstance(value, dict) and "mid" in value:
            score = value["mid"].fmeasure * 100
        else:
            score = value * 100
        print(f"  {key}: {score:.2f}")
    actor_model.train()

# -------------------------------
# 10. Reinforcement Learning Phase (Actor-Critic with Mixed Precision)
# -------------------------------
print("\nStarting reinforcement learning refinement...")
rl_epochs = 2
num_candidates = 2  # Reduced candidate count
rl_batch_size = 8   # Reduced batch size for RL loop

# Set up automatic mixed precision (AMP)
scaler = torch.cuda.amp.GradScaler()

actor_rl_optimizer = torch.optim.Adam(actor_model.parameters(), lr=1e-5)
actor_model.train()
critic_model.eval()  # Freeze critic during RL

# For RL, reduce memory by using only needed columns from raw_train.
rl_columns = ["hatespeech", "csType", "counterspeech", "is_high_quality"]
raw_train_for_rl = raw_train.remove_columns(
    [col for col in raw_train.column_names if col not in rl_columns]
)
train_loader_full = DataLoader(raw_train_for_rl, batch_size=rl_batch_size, shuffle=True)

for epoch in range(rl_epochs):
    epoch_loss = 0.0
    count = 0
    for batch in train_loader_full:
        hs_list = [str(x) for x in batch["hatespeech"]]
        cs_type_list = [str(x) for x in batch["csType"]]
        input_texts = [f"{cs_type}: {hs}" for cs_type, hs in zip(cs_type_list, hs_list)]
        enc = actor_tokenizer(
            input_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        )
        input_ids = enc["input_ids"].to(device)
        attn = enc["attention_mask"].to(device)

        with torch.no_grad():
            gen_outputs = actor_model.generate(
                input_ids=input_ids,
                attention_mask=attn,
                max_length=128,
                do_sample=True,
                top_p=0.9,
                top_k=50,
                num_return_sequences=num_candidates
            )
        gen_outputs = gen_outputs.cpu()

        all_hs = []
        all_cs_gen = []
        for i, output_ids in enumerate(gen_outputs):
            input_index = i // num_candidates
            hs_text = hs_list[input_index]
            output_text = actor_tokenizer.decode(output_ids, skip_special_tokens=True)
            if "Rationale:" in output_text:
                output_text = output_text.split("Rationale:")[0].strip()
            cs_text = output_text if output_text else " "
            all_hs.append(hs_text)
            all_cs_gen.append(cs_text)

        critic_enc = critic_tokenizer(
            all_hs,
            all_cs_gen,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=128
        )
        critic_in_ids = critic_enc["input_ids"].to(device)
        critic_attn = critic_enc["attention_mask"].to(device)

        with torch.no_grad():
            logits = critic_model(critic_in_ids, attention_mask=critic_attn).logits
            probs = torch.softmax(logits, dim=-1)
            high_quality_probs = probs[:, 1]

        repeat_input_ids = input_ids.repeat_interleave(num_candidates, dim=0)
        repeat_attn = attn.repeat_interleave(num_candidates, dim=0)

        labels_gen = gen_outputs.clone()
        labels_gen[labels_gen == actor_tokenizer.pad_token_id] = -100
        labels_gen = labels_gen.to(device)

        with torch.cuda.amp.autocast():
            actor_outputs = actor_model(
                input_ids=repeat_input_ids,
                attention_mask=repeat_attn,
                labels=labels_gen
            )
            logits_seq = actor_outputs.logits
            vocab_size = logits_seq.size(-1)
            logits_flat = logits_seq.view(-1, vocab_size)
            labels_flat = labels_gen.view(-1)
            token_loss = F.cross_entropy(
                logits_flat,
                labels_flat,
                reduction='none',
                ignore_index=-100
            )
            token_loss = token_loss.view(len(gen_outputs), -1)
            seq_nll = token_loss.sum(dim=1)
            rewards = high_quality_probs.detach().cpu().view(-1)
            seq_nll = seq_nll.to(device)

            policy_loss = torch.tensor(0.0, device=device)
            for bi in range(len(hs_list)):
                start = bi * num_candidates
                end = start + num_candidates
                group_rewards = rewards[start:end]
                group_nll = seq_nll[start:end]
                baseline = group_rewards.mean()
                group_loss = (baseline - group_rewards.to(device)) * group_nll
                policy_loss += group_loss.sum()
            policy_loss = policy_loss / (len(hs_list) * num_candidates)

        scaler.scale(policy_loss).backward()
        scaler.step(actor_rl_optimizer)
        scaler.update()
        epoch_loss += policy_loss.item()
        count += 1

    avg_epoch_loss = epoch_loss / max(count, 1)
    print(f"[RL] Epoch {epoch+1}/{rl_epochs} - Average Policy Loss: {avg_epoch_loss:.4f}")

    # Evaluate RL performance on validation after each RL epoch
    actor_model.eval()
    val_preds = []
    val_refs = []
    hq_scores = []
    for batch in actor_val_loader:
        input_ids = batch["input_ids"].to(device)
        attn = batch["attention_mask"].to(device)
        with torch.no_grad():
            gen_outputs = actor_model.generate(
                input_ids=input_ids,
                attention_mask=attn,
                max_length=128,
                num_beams=4,
                num_return_sequences=1
            )
        for i, output_ids in enumerate(gen_outputs):
            gen_text = actor_tokenizer.decode(output_ids, skip_special_tokens=True)
            cs_text = gen_text.split("Rationale:")[0].strip() if "Rationale:" in gen_text else gen_text
            val_preds.append(cs_text)
            ref_text = actor_tokenizer.decode(
                batch["labels"][i][batch["labels"][i] != -100],
                skip_special_tokens=True
            )
            ref_text = ref_text.split("Rationale:")[0].strip() if "Rationale:" in ref_text else ref_text
            val_refs.append(ref_text)
            hs_text = actor_tokenizer.decode(batch["input_ids"][i], skip_special_tokens=True)
            if ": " in hs_text:
                _, hs_only = hs_text.split(": ", 1)
            else:
                hs_only = hs_text
            with torch.no_grad():
                enc = critic_tokenizer(
                    hs_only,
                    cs_text,
                    truncation=True,
                    padding=True,
                    return_tensors="pt",
                    max_length=128
                )
                enc_input_ids = enc["input_ids"].to(device)
                enc_attn = enc["attention_mask"].to(device)
                logits = critic_model(enc_input_ids, attention_mask=enc_attn).logits
                prob = F.softmax(logits, dim=-1)[0, 1].item()
            hq_scores.append(prob)
    rouge_scores = rouge_metric.compute(
        predictions=[p.strip() for p in val_preds],
        references=[r.strip() for r in val_refs],
        use_stemmer=True
    )
    print(f"[RL] Epoch {epoch+1} Validation Metrics:")
    for key, value in rouge_scores.items():
        if isinstance(value, dict) and 'mid' in value:
            score = value['mid'].fmeasure * 100
        else:
            score = value * 100
        print(f"  {key}: {score:.2f}")
    avg_hq = sum(hq_scores) / len(hq_scores) if hq_scores else 0.0
    perc_hq = sum(1 for s in hq_scores if s > 0.5) / len(hq_scores) * 100 if hq_scores else 0.0
    print(f"  Average critic high-quality probability: {avg_hq:.3f}")
    print(f"  Percentage high-quality (prob>0.5): {perc_hq:.2f}%")
    actor_model.train()

# -------------------------------
# 11. Final Evaluation After RL Training
# -------------------------------
print("\nFinal evaluation after RL refinement...")
actor_model.eval()
rouge_metric = evaluate.load("rouge")
val_preds = []
val_refs = []
hq_scores = []
for batch in actor_val_loader:
    input_ids = batch["input_ids"].to(device)
    attn = batch["attention_mask"].to(device)
    with torch.no_grad():
        gen_outputs = actor_model.generate(
            input_ids=input_ids,
            attention_mask=attn,
            max_length=128,
            num_beams=4,
            num_return_sequences=1
        )
    for i, output_ids in enumerate(gen_outputs):
        gen_text = actor_tokenizer.decode(output_ids, skip_special_tokens=True)
        cs_text = gen_text.split("Rationale:")[0].strip() if "Rationale:" in gen_text else gen_text
        val_preds.append(cs_text)
        ref_text = actor_tokenizer.decode(
            batch["labels"][i][batch["labels"][i] != -100],
            skip_special_tokens=True
        )
        ref_text = ref_text.split("Rationale:")[0].strip() if "Rationale:" in ref_text else ref_text
        val_refs.append(ref_text)
        hs_text = actor_tokenizer.decode(batch["input_ids"][i], skip_special_tokens=True)
        if ": " in hs_text:
            _, hs_only = hs_text.split(": ", 1)
        else:
            hs_only = hs_text
        with torch.no_grad():
            enc = critic_tokenizer(
                hs_only,
                cs_text,
                truncation=True,
                padding=True,
                return_tensors="pt",
                max_length=128
            )
            enc_input_ids = enc["input_ids"].to(device)
            enc_attn = enc["attention_mask"].to(device)
            logits = critic_model(enc_input_ids, attention_mask=enc_attn).logits
            prob = F.softmax(logits, dim=-1)[0, 1].item()
        hq_scores.append(prob)
results = rouge_metric.compute(
    predictions=[p.strip() for p in val_preds],
    references=[r.strip() for r in val_refs],
    use_stemmer=True
)
print("Final Validation ROUGE scores (after RL):")
for key, value in results.items():
    if isinstance(value, dict) and 'mid' in value:
        score = value['mid'].fmeasure * 100
    else:
        score = value * 100
    print(f"  {key}: {score:.2f}")
avg_hq = sum(hq_scores) / len(hq_scores) if hq_scores else 0.0
perc_hq = sum(1 for s in hq_scores if s > 0.5) / len(hq_scores) * 100 if hq_scores else 0.0
print(f"Final Average critic high-quality probability: {avg_hq:.3f}")
print(f"Final Percentage of high-quality outputs: {perc_hq:.2f}%")

# -------------------------------
# 12. Sample Outputs on Test Set
# -------------------------------
if raw_test:
    print("\nSample outputs on test set:")
    actor_model.eval()
    test_loader = DataLoader(actor_test_ds, batch_size=1, shuffle=True)
    for i, batch in enumerate(test_loader):
        if i >= 5:
            break
        hs_input_ids = batch["input_ids"].to(device)
        hs_attn = batch["attention_mask"].to(device)
        hs_text = actor_tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
        if ": " in hs_text:
            _, hs_only = hs_text.split(": ", 1)
        else:
            hs_only = hs_text
        with torch.no_grad():
            output_ids = actor_model.generate(
                input_ids=hs_input_ids,
                attention_mask=hs_attn,
                max_length=128,
                num_beams=4,
                num_return_sequences=1
            )[0]
        output_text = actor_tokenizer.decode(output_ids, skip_special_tokens=True)
        print(f"Hate speech: {hs_only.strip()}")
        print(f"Model response: {output_text.strip()}\n")


2025-04-15 15:45:17.709671: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744731917.736095     290 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744731917.743200     290 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda
New Train distribution: ClassLabel(names=['no', 'yes'], id=None)
New Val distribution: ClassLabel(names=['no', 'yes'], id=None)
New Test distribution: ClassLabel(names=['no', 'yes'], id=None)


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/8383 [00:00<?, ? examples/s]



Map:   0%|          | 0/2795 [00:00<?, ? examples/s]

Map:   0%|          | 0/2795 [00:00<?, ? examples/s]

Map:   0%|          | 0/8383 [00:00<?, ? examples/s]

Map:   0%|          | 0/2795 [00:00<?, ? examples/s]

Actor Train Columns: ['input_ids', 'attention_mask', 'labels']
Actor Val Columns: ['input_ids', 'attention_mask', 'labels']
Training critic (quality classifier)...
[Critic] Epoch 1/2 - Average Loss: 0.4653
[Critic] Epoch 2/2 - Average Loss: 0.3089
[Critic] Validation Accuracy: 87.91%

Training actor model (supervised fine-tuning on reference counterspeech)...


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


[Actor Sup] Epoch 1/2 - Average Loss: 2.2668
[Actor Sup] Epoch 1 Validation Metrics:
  rouge1: 21.62
  rouge2: 5.50
  rougeL: 16.96
  rougeLsum: 16.96
[Actor Sup] Epoch 2/2 - Average Loss: 1.6994
[Actor Sup] Epoch 2 Validation Metrics:
  rouge1: 24.96
  rouge2: 7.20
  rougeL: 18.61
  rougeLsum: 18.61

Starting reinforcement learning refinement...


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


[RL] Epoch 1/2 - Average Policy Loss: -0.3663
[RL] Epoch 1 Validation Metrics:
  rouge1: 24.96
  rouge2: 7.20
  rougeL: 18.61
  rougeLsum: 18.61
  Average critic high-quality probability: 0.710
  Percentage high-quality (prob>0.5): 71.16%
[RL] Epoch 2/2 - Average Policy Loss: -0.3764
[RL] Epoch 2 Validation Metrics:
  rouge1: 24.96
  rouge2: 7.20
  rougeL: 18.61
  rougeLsum: 18.61
  Average critic high-quality probability: 0.710
  Percentage high-quality (prob>0.5): 71.16%

Final evaluation after RL refinement...
