#### GAN training loop

In [2]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions.categorical import Categorical
from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
DISCRIMINATOR_PATH = "../models/winobias-discriminator/checkpoint-429"
GENERATOR_MODEL_NAME = "gpt2-large"

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


#### Loading Discriminator

In [13]:
print("Loading Discriminator")
discriminator = AutoModelForSequenceClassification.from_pretrained(DISCRIMINATOR_PATH)
disc_tokenizer = AutoTokenizer.from_pretrained(DISCRIMINATOR_PATH)
discriminator.to(device)
discriminator.eval()

Loading Discriminator


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

#### Load the GPT-2-large generator

In [15]:
generator = GPT2LMHeadModel.from_pretrained(GENERATOR_MODEL_NAME)
gen_tokenizer = GPT2Tokenizer.from_pretrained(GENERATOR_MODEL_NAME)
gen_tokenizer.pad_token = gen_tokenizer.eos_token
generator.to(device)
generator.train()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3840, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=1280)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=5120, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=5120)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [16]:
optimizer = Adam(generator.parameters(), lr=1e-5)

In [17]:
@torch.no_grad()
def get_bias_probability(text: str) -> float:
    """
        Returns the probability that the text is biased.
    """
    disc_inputs = disc_tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=64,
    ).to(device)
    
    outputs = discriminator(**disc_inputs)
    logits = outputs.logits
    probs = F.softmax(logits, dim=-1)
    prob_biased = probs[0, 1].item()
    return prob_biased

In [18]:
def sample_and_calculate_loss(prompt: str, max_length: int = 20):
    """
    1) Autoregressively sample tokens from GPT-2 given a prompt.
    2) Compute the sum of log probabilities of the sampled tokens.
    3) Compute the bias reward = (1 - prob_biased).
    4) Return REINFORCE loss = - (log_prob_of_sequence * reward).
    """
    input_ids = gen_tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    sampled_log_prob = torch.tensor(0.0, device=device)
    generator.eval()
    with torch.no_grad():
        for _ in range(max_length):
            outputs = generator(input_ids=input_ids)
            logits = outputs.logits[:, -1, :]    # shape [1, vocab_size]
            probs = F.softmax(logits, dim=-1)    # shape [1, vocab_size]
            dist = Categorical(probs)
            next_token = dist.sample()           # shape [1]
            log_prob = dist.log_prob(next_token) # shape [1]
            sampled_log_prob += log_prob[0]
            input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=1)

    generated_text = gen_tokenizer.decode(input_ids[0], skip_special_tokens=True)

    prob_biased = get_bias_probability(generated_text)
    reward = 1.0 - prob_biased  # High if text is not biased

    generator.train()
    input_ids_2 = gen_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    all_log_probs = torch.tensor(0.0, device=device, requires_grad=True)

    for _ in range(max_length):
        outputs_2 = generator(input_ids=input_ids_2)
        logits_2 = outputs_2.logits[:, -1, :]
        probs_2 = F.softmax(logits_2, dim=-1)
        dist_2 = Categorical(probs_2)
        next_token_2 = dist_2.sample()
        log_prob_2 = dist_2.log_prob(next_token_2)
        all_log_probs = all_log_probs + log_prob_2[0]

        input_ids_2 = torch.cat([input_ids_2, next_token_2.unsqueeze(-1)], dim=1)

    loss = -all_log_probs * reward

    return loss, generated_text, reward
    

In [19]:
def train_generator_rl(
    prompts,
    num_iterations=30,
    max_length=20
):
    """
    Basic RL training loop for GPT-2 to reduce bias.
    Each iteration:
      - Pick a prompt
      - Sample a sequence
      - Compute reward (1 - prob_biased)
      - REINFORCE update
    """
    print("[INFO] Starting RL training...")
    for step in range(num_iterations):
        prompt = prompts[step % len(prompts)]

        optimizer.zero_grad()
        loss, gen_text, reward = sample_and_calculate_loss(prompt, max_length)
        loss.backward()
        optimizer.step()

        if (step + 1) % 5 == 0:
            print(
                f"[STEP {step+1}] Loss={loss.item():.4f} | Reward={reward:.4f}")
            print(f"Prompt: {prompt}")
            print(f"Generated: {gen_text}")
            print("-" * 70)

In [None]:
if __name__ == "__main__":
    prompts = [
        "The doctor told the nurse",
        "The nurse told the doctor",
        "The engineer and his coworker",
        "The developer argued with the designer",
    ]
    
    train_generator_rl(prompts, num_iterations=20, max_length=15)

    test_prompt = "The manager decided to"
    print("\n[INFO] Testing generator with prompt:", test_prompt)
    final_text = gen_tokenizer.decode(
        generator.generate(
            gen_tokenizer(test_prompt, return_tensors="pt").input_ids.to(device),
            max_length=50,
            do_sample=True,
            top_k=50,
            top_p=0.9
        )[0],
        skip_special_tokens=True
    )
    print("[Generated Text]:", final_text)

[INFO] Starting RL training...
[STEP 5] Loss=0.4744 | Reward=0.0087
Prompt: The doctor told the nurse
Generated: The doctor told the nurse he was having heart treatment, and apologized for the pain he was in.
----------------------------------------------------------------------
[STEP 10] Loss=31.6655 | Reward=0.8827
Prompt: The nurse told the doctor
Generated: The nurse told the doctor that she had questions and he then started to walk around the room based on
----------------------------------------------------------------------


KeyboardInterrupt: 

In [None]:
"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"