In [30]:
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from pplm_classification_head import ClassificationHead  # From PPLM repo

model_discrim = "sentiment"

if model_discrim == "sentiment":
    CLASS_SIZE = 5
else:
    CLASS_SIZE = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
def load_gpt2():
    model = GPT2LMHeadModel.from_pretrained("gpt2-medium").to(device)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
    return model, tokenizer

def load_discriminator(discrim_path):
    discrim = ClassificationHead(class_size=CLASS_SIZE, embed_size=1024).to(device)  # Adjust class_size as needed
    discrim.load_state_dict(torch.load(discrim_path, map_location=device))
    discrim.eval()
    return discrim

def generate_text_with_steering(model, tokenizer, discriminator, prompt, steps=100, alpha=2, kl_factor=1):
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    output = input_ids.clone().to(device)
    
    for _ in range(steps):
        outputs = model(output, return_dict=True, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1][:, -1, :].detach().requires_grad_(True)
        logits = outputs.logits[:, -1, :]
        
        pred = discriminator(hidden_states)
        
        if model_discrim=="sentiment":
            pred = pred[:, [2, 3]]  # Keep only logits for classes 2 and 3
        
        target_class = torch.tensor([1], device=device)
        
        original_probs = F.softmax(logits, dim=-1)

        # Compute KL divergence: D_KL(P_model || P_steered)
        new_logits = model.lm_head(hidden_states)
        new_probs = F.softmax(new_logits, dim=-1)
        kl_loss = torch.nn.KLDivLoss(reduction="batchmean")(new_probs.log(), original_probs)
    
        loss = torch.nn.CrossEntropyLoss()(pred, target_class) + kl_factor*kl_loss
        loss.backward()
        
        # Modify hidden state
        with torch.no_grad():
            hidden_states += alpha * hidden_states.grad
            new_logits = model.lm_head(hidden_states)
        
        # Get next token
        #next_token = torch.argmax(logits, dim=-1)[:, None]
        
        top_k_values, top_k_indices = torch.topk(new_logits[0,:], 100)
        top_k_probs = F.softmax(top_k_values)
        token_k_id = torch.multinomial(top_k_probs, num_samples=1)
        next_token = top_k_indices[token_k_id].unsqueeze(0)
        
        output = torch.cat((output, next_token), dim=1)
        
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Example Usage
model, tokenizer = load_gpt2()
discriminator = load_discriminator("discrim_models/"+model_discrim+"_classifierhead.pt")
prompt = "The chicken is"
"""
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
input_ids = input_ids.to(device)
outputs = model(input_ids, return_dict=True, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1][:, -1, :]
logits_clone = hidden_states.clone().detach().requires_grad_(True)
print(discriminator(logits_clone))
"""
generated_text = generate_text_with_steering(model, tokenizer, discriminator, prompt)
print(generated_text)

  top_k_probs = F.softmax(top_k_values)


RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

In [None]:
input_ids = tokenizer(generated_text, return_tensors='pt').input_ids
input_ids = input_ids.to(device)
outputs = model(input_ids, return_dict=True, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1][:, -1, :]
logits_clone = hidden_states.clone().detach().requires_grad_(True)
print(discriminator(logits_clone))

tensor([[12.7256,  1.9373, 21.4120,  4.2105,  0.3421]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
