In [22]:
import random
import torch
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = tokenizer.pad_token_id


In [None]:

# Initialize population
def initialize_population(pop_size, prompt):
    return [prompt for _ in range(pop_size)]

# Evaluate fitness
# def evaluate(candidate, model, tokenizer):
#     input_ids = tokenizer.encode(candidate, return_tensors='pt')
#     with torch.no_grad():
#         outputs = model.generate(input_ids, max_length=50)
#         generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     # Example fitness based on text length (you can use more complex criteria)
#     fitness = len(generated_text)
#     return fitness, generated_text
def evaluate(individual, model, tokenizer, prompt):
    candidate = prompt + ''.join(individual)
    inputs = tokenizer(candidate, return_tensors='pt', padding=True, truncation=True)
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=50,
            attention_mask=attention_mask,
            pad_token_id=tokenizer.eos_token_id,
            temperature=0.7,  # Adjust temperature
            top_p=0.9,         # Use top-p sampling
            do_sample=True
        )
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "bank" in generated_text:
        fitness = 50
    elif "Federal Reserve" in generated_text:
            fitness = 45
    elif "money laundering" in generated_text: 
        fitness = 40  
    else:
        fitness = 20
    return (fitness,generated_text)


# Selection
def select_parents(population, fitness_scores, num_parents):
    sorted_indices = sorted(range(len(fitness_scores)), key=lambda i: fitness_scores[i], reverse=True)
    selected_indices = sorted_indices[:num_parents]
    return [population[i] for i in selected_indices]

# Crossover
def crossover(parent1, parent2):
    mid_point = len(parent1) // 2
    child = parent1[:mid_point] + parent2[mid_point:]
    return child

# Mutation
def mutate(sequence, mutation_rate=0.1):
    sequence = list(sequence)
    for i in range(len(sequence)):
        if random.random() < mutation_rate:
            sequence[i] = random.choice(list("abcdefghijklmnopqrstuvwxyz "))
    return ''.join(sequence)

# Genetic Algorithm
def genetic_algorithm(prompt, model, tokenizer, pop_size=10, generations=5, mutation_rate=0.1):
    population = initialize_population(pop_size, prompt)
    best_candidate = None
    best_fitness = 0
    best_text = ""

    for generation in range(generations):
        fitness_scores = []
        generated_texts = []
        
        for candidate in population:
            fitness, generated_text = evaluate(candidate, model, tokenizer, prompt)
            fitness_scores.append(fitness)
            generated_texts.append(generated_text)
        
        # Track the best candidate
        for i, fitness in enumerate(fitness_scores):
            if fitness > best_fitness:
                best_fitness = fitness
                best_candidate = population[i]
                best_text = generated_texts[i]
        
        # Select parents
        parents = select_parents(population, fitness_scores, pop_size // 2)
        
        # Generate next population
        next_population = []
        while len(next_population) < pop_size:
            parent1, parent2 = random.sample(parents, 2)
            child = crossover(parent1, parent2)
            child = mutate(child, mutation_rate)
            next_population.append(child)
        
        population = next_population
    
    return best_text

prompt = "One of the financial scandals is the "
pruned_actions = []
for _ in range(20):
    final_text = genetic_algorithm(prompt, model, tokenizer)
    print(f"Final Generated Text: {final_text}")
    hof_actions = np.array(tokenizer.encode(final_text, max_length=50))
    pruned_actions.extend(list(hof_actions))

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Final Generated Text: One of the financial scandals is the One of the financial scandals is the  Fannie and Freddie bailout. The Federal Reserve is now trying to bail out these banks. The next bailout is a bailout of the U.S. and its allies.
Final Generated Text: One of the financial scandals is the One of the financial scandals is the  sophisticated scandal that took place last year, when a group of the biggest banks (such as JPMorgan Chase, Citigroup and Bank of America) announced that
Final Generated Text: One of the financial scandals is the Oie of theooiqancial scandels is tie  to a very successful one-day deal with a bank that is a huge risk to the country. The Oie of theooiqancial is
Final Generated Text: One of the financial scandals is the One of the financial scandals is the  Financial scandal with the big banks and the big banks and the big banks that are taking advantage of the credit crunch that is coming on from the crisis in the US. The
Final Generated Text: One of the f

In [24]:
import gym
import numpy as np
def calculate_perplexity(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors='pt')        
    input_ids = inputs.input_ids                      
    attention_mask = inputs.attention_mask     
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        #outputs = model(**inputs, labels=inputs["input_ids"])
        log_likelihood =  -outputs.loss * inputs.input_ids.size(1)
    perplexity = torch.exp(log_likelihood / inputs.input_ids.size(1))
    
    return perplexity.item()

def calculate_coherence(text, model, tokenizer):
    perplexity = calculate_perplexity(text, model, tokenizer)
    
    # Normalize perplexity to a range [0, 1]
    norm_perplexity = 1 / (1 + perplexity)
    
    # Combine with other metrics (if available)
    coherence_score = norm_perplexity  
    
    return coherence_score
def _calculate_reward(generated_text, model, tokenizer):
    generated_text_str = ' '.join(generated_text)
    if "bank" in generated_text or "Bank" in generated_text or "Federal Reserve" in generated_text or "money laundering" in generated_text:
        coherence = calculate_coherence(generated_text, model, tokenizer)
        reward = coherence
    else:
        reward = 0

    return reward



In [25]:
import torch 
import time
from transformers import GPT2LMHeadModel, GPT2Tokenizer
prompt = "One of the financial scandal is the "

def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """
    assert logits.dim() == 2  # logits should be [batch_size, vocab_size]
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]

        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]

        logits[0,indices_to_remove] = filter_value
    return logits


def generate_custom_text( inputs,input_ids, max_length=50, top_k=50, top_p=0.95):

    # Initialize generated tokens list
    done = False
    while not done:
        generated_text = ""
        generated = input_ids
        for _ in range(max_length):
            outputs = model(input_ids=generated)
    
    
            next_token_logits = outputs.logits[:, -1, :]
            
            # Mask tokens not in custom action space by setting their logits to a very low value
            #mask = torch.full(next_token_logits.shape, float('-inf'))
            
            #for token_id in pruned_action_space:
            #    mask[:, token_id] = next_token_logits[:, token_id]
            #next_token_logits = mask
           
            # Apply sampling techniques
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
    
            next_token = torch.multinomial(torch.nn.functional.softmax(next_token_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            # Stop generating if the end-of-sequence token is generated
            if next_token in tokenizer.encode(tokenizer.eos_token):
                break

        generated_text = tokenizer.decode(generated.squeeze(), skip_special_tokens=True)
        print(generated_text)
        reward = _calculate_reward(generated_text,model, tokenizer)
        if reward > 0.9:
            done = True

    return generated_text



inputs = tokenizer(prompt, return_tensors='pt')
input_ids = inputs['input_ids']

tot_time = 0   
start_time = time.time()
for ind in range(10):
    print(f"Iteration: {ind}/10") 
    generated_text = generate_custom_text(inputs, input_ids)
end_time = time.time()
tot_time += (end_time - start_time)
avg_time = tot_time/10
print(f"Not Pruned: Total Time: {tot_time}; Average Time {avg_time}")

print(generated_text)



Iteration: 0/10
One of the financial scandal is the ills of the stock-market. In my case, my financial system is in a state of disrepair. I'm living with an employer who is bankrupt and can't pay the bills. I have to work seven hours a day, and I can
Iteration: 1/10
One of the financial scandal is the urn of a man who once said: 'I don't know if I want to work, but if I were to do so, then my work would be pointless'.

But the rest of us say: 'We are all just sheep,
One of the financial scandal is the othraison of large-scale debtors who take a risk that a large chunk of their holdings will never be repaid . . . And so, one of the financial othraison is that of large-scale creditors with
One of the financial scandal is the ichthyosin, or iron-clad proof that you don't need a prescription.
One of the financial scandal is the  failure by Washington Post reporters who didn't report the real story.
The truth is, the press didn't report on it.
So let's be clear, we're not about to let the

In [26]:
import torch 
from transformers import GPT2LMHeadModel, GPT2Tokenizer

pruned_action_space = pruned_actions

def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """
    assert logits.dim() == 2  # logits should be [batch_size, vocab_size]
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]

        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]

        logits[0,indices_to_remove] = filter_value
    return logits
    
def generate_custom_text_pruned( inputs,input_ids, max_length=50, top_k=50, top_p=0.95):

    # Initialize generated tokens list
    done = False
    while not done:
        generated_text = ""
        generated = input_ids
        for _ in range(max_length):
            outputs = model(input_ids=generated)
        
            next_token_logits = outputs.logits[:, -1, :]
            
            # Mask tokens not in custom action space by setting their logits to a very low value
            mask = torch.full(next_token_logits.shape, float('-inf'))
            
            for token_id in pruned_action_space:
                mask[:, token_id] = next_token_logits[:, token_id]
            next_token_logits = mask
           
            # Apply sampling techniques
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
    
            next_token = torch.multinomial(torch.nn.functional.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            # Stop generating if the end-of-sequence token is generated
            if next_token in tokenizer.encode(tokenizer.eos_token):
                break
    
        generated_text = tokenizer.decode(generated.squeeze(), skip_special_tokens=True)
        print(generated_text)
        if "bank" in generated_text or "Bank" in generated_text:
            done = True
            
    return generated_text


        

inputs = tokenizer(prompt, return_tensors='pt')
input_ids = inputs['input_ids']


tot_time = 0   
start_time = time.time()
for ind in range(10):
    print(f"Iteration: {ind+1}/10") 
    generated_text = generate_custom_text_pruned(inputs, input_ids)
end_time = time.time()
tot_time += (end_time - start_time)
avg_time = tot_time/10
print(f"Pruned: Total Time: {tot_time}; Average Time {avg_time}")

print(generated_text)



Iteration: 1/10
One of the financial scandal is the  sophisticated  investigation of  the Goldman Sachs  investigation to its very very last, the  investigation of the  Wall Street bailout  investigation,  including the  investigation of 
One of the financial scandal is the ersd, also used in the books of all world financial and financial "banking scandals. The ersd is the financial risk that all of a major financial crisis is taking place, and is a risk for financial profits and profits for the world
Iteration: 2/10
One of the financial scandal is the  in-and-of-the-day financial investigation that took place in the Bank of England in 2008. This is the last of the financial investigation's discovery.
The investigation is 
Iteration: 3/10
One of the financial scandal is the  tort-fraud and lack of due-dort-of-tort charges that were reported to the Federal Reserve by the last year of the debt crisis. It is the first time the central bank has been indicted by a US Federal
Iteration: 4/10