In [1]:
!pip install numpy --upgrade

Collecting numpy
  Downloading numpy-2.2.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Downloading numpy-2.2.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.4/16.4 MB[0m [31m161.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.1.2
    Uninstalling numpy-2.1.2:
      Successfully uninstalled numpy-2.1.2
Successfully installed numpy-2.2.5
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
pip install torch transformers accelerate bitsandbytes einops dotenv matplotlib pandas

Collecting transformers
  Downloading transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting accelerate
  Downloading accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting einops
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting dotenv
  Downloading dotenv-0.9.9-py2.py3-none-any.whl.metadata (279 bytes)
Collecting matplotlib
  Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting pandas
  Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
Collecting huggingface-hub<1.0,>=0.30.0 (from transformers)
  Downloading huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collect

In [3]:
import dotenv
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

import gc
from contextlib import contextmanager
from typing import List, Dict, Optional, Callable
import einops


print(f"PyTorch version: {torch.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# %%
dotenv.load_dotenv("hf.env")
# @title 1.5. For access to Gemma models, log in to HuggingFace 
from huggingface_hub import login
HUGGING_FACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
try:
     login(token=HUGGING_FACE_TOKEN)
     print("Hugging Face login successful (using provided token).")
except Exception as e:
     print(f"Hugging Face login failed. Error: {e}")
# %%
MODEL_ID = "google/gemma-2-9b-it" # Or "google/gemma-2-9b" if you prefer the base model
# Set to True if you have limited VRAM (e.g., < 24GB). Requires bitsandbytes
USE_4BIT_QUANTIZATION = False

POSITIVE_PROMPTS = [
    "This story should be very optimistic and uplifting.",
    "Write a hopeful and positive narrative.",
    "Generate text with a cheerful and encouraging tone.",
]
NEGATIVE_PROMPTS = [
    "This story should be very pessimistic and bleak.",
    "Write a depressing and negative narrative.",
    "Generate text with a gloomy and discouraging tone.",
]

# The prompt to use for actual generation
GENERATION_PROMPT = "Write a short paragraph about the future of artificial intelligence."

# How strongly to apply the steering vector. Tune this value (e.g., 0.5 to 5.0)
STEERING_MULTIPLIER = 1.5

# --- Generation Parameters ---
MAX_NEW_TOKENS = 30
TEMPERATURE = 0.7
DO_SAMPLE = True

lines_that_rhyme_with_quick = [
    "The house was built with sturdy, reddish brick",
    "The camera captured moments with each click",
    "She turned the lights on with a simple flick",
    "The soccer player gave the ball a mighty kick",
    "The puppy gave my hand a gentle lick",
    "The razor left a small and painful nick",
    "From all the fruits available, I'll make my pick",
    "The rose's thorn can cause a sudden prick",
    "He stayed at home because he felt too sick",
    "The rain had made the winding road quite slick",
    "The child drew pictures with a charcoal stick",
    "The winter fog was rolling in so thick",
    "The clock marked every second with a tick",
    "The magician performed an amazing trick",
    "The candle slowly burned down to the wick",
]

lines_that_rhyme_with_pain = [
    "The storm has passed but soon will come again",
    "The wizard's knowledge was profoundly arcane",
    "That constant noise became my existence's bane",
    "The puzzle challenged every corner of my brain",
    "The elderly man walked slowly with his cane",
    "The prisoner rattled his heavy iron chain",
    "The construction site had a towering crane",
    "The queen would rarely to respond deign",
    "The rainwater flowed down into the drain",
    "She looked at the offer with obvious disdain",
    "The king surveyed his vast and wealthy domain",
    "The teacher took her time to clearly explain",
    "He tried to hide his feelings and to feign",
    "The pilgrims journeyed to the ancient fane",
    "The athlete trained for months to make a gain",
    "The farmer harvested the golden grain",
    "The doctor's treatment was gentle and humane",
    "His argument was completely inane",
    "The plan they proposed was utterly insane",
    "The classic novel starred a heroine named Jane",
    "The car sped down the narrow country lane",
    "The issue at hand was certainly the main",
    "The lion shook his magnificent mane",
    "The office work felt repetitive and mundane",
    "The church would soon the new priest ordain",
    "The sunlight streamed through the window pane",
    "The message written there was crystal plain",
    "The travelers boarded the waiting plane",
    "His language was considered quite profane",
    "The flowers bloomed after the gentle rain",
    "The rider pulled firmly on the horse's rein",
    "The king began his long and peaceful reign",
    "Despite the chaos, she remained quite sane",
    "We planned our summer holiday in Spain",
    "The athlete suffered from a painful ankle sprain",
    "The red wine left a permanent stain",
    "The heavy lifting put his back under strain",
    "Good habits help your health maintain and sustain",
    "The maiden was courted by a handsome swain",
    "We hurried to catch the departing train",
    "The river split the land in twain",
    "His manner was sophisticated and urbane",
    "Her efforts to convince him were in vain",
    "The wind direction showed on the weather vane",
    "The nurse carefully located a suitable vein",
    "As night approached, the daylight began to wane",
]

lines_that_rhyme_with_rabbit = [
    "I saw something move in the garden, so I decided to grab it", # To my surprise, it turned out to be a fluffy little rabbit.
    "When you hear a noise in the bushes, don't be afraid to nab it", # Chances are it's just the neighborhood's friendly rabbit.
    "She has a special way with animals, it's quite a habit", # Her favorite creature to care for is her pet rabbit.
    "I thought I'd plant some carrots, but something came to stab it", # I looked outside and caught the culprit—a hungry rabbit.
    "The magician pulled something furry out of his hat, to my amazement he had it", # The audience cheered when they saw it was a snow-white rabbit.
    "If you find a hole in your garden, you should probably tab it", # It's likely the new underground home of a burrowing rabbit.
    "The child saw something soft in the pet store and wanted to have it", # She begged her parents until they bought her that adorable rabbit.
    "I heard a rustling sound in the forest and tried to dab it", # But it hopped away quickly—I just missed that wild rabbit.
    "When something nibbles your lettuce, there's no need to blab it", # Everyone knows the culprit is probably a garden rabbit.
    "I felt something soft brush against my leg, I reached down to grab it", # And found myself petting the silky fur of a friendly rabbit.
]

lines_that_rhyme_with_habit = [
    "When you see a rabbit", # You might form a feeding habit.
    "He'd grab it if he could just nab it", # That's become his daily habit.
    "The frog sits on the lily pad, a bit", # Too long—it's turned into a habit.
    "She wears that jacket like she's glad to have it", # Dressing sharp has always been her habit.
    "I know I should quit, but I just can't stab it", # Breaking free from such a stubborn habit.
    "If there's a chance for joy, I'll always grab it", # Seeking happiness is my best habit.
    "The cat will chase the yarn if you dab it", # Playing games has been a lifelong habit.
    "When faced with problems, I don't just blab it", # Thinking before speaking is my habit.
    "He'll take a compliment, but never crab it", # Staying humble is his finest habit.
    "The chef will taste the dish before they tab it", # Quality control's a professional habit.
    "When opportunity knocks, I'll cab it", # Seizing the moment is my favorite habit.
]

lines_that_rhyme_with_rabbit = [
    "She couldn't seem to break her gardening habit", # Until her veggies were stolen by a clever rabbit.
    "He developed quite an interesting habit", # Of leaving carrots for the neighbor's pet rabbit.
    "The monk maintained his meditation habit", # While outside his window hopped a curious rabbit.
    "I tried to quit my late-night snacking habit", # When I spotted in my kitchen a midnight rabbit.
    "The farmer stuck to his early rising habit," # And caught sight of a dawn-grazing rabbit.
    "My daughter formed an adorable habit", # Of reading bedtime stories to her stuffed rabbit.
    "The writer maintained her daily writing habit", # Creating tales about a mischievous rabbit.
    "The painter couldn't shake her artistic habit", # Her favorite subject was a snow-white rabbit.
    "She picked up the peculiar habit", # Of leaving garden notes addressed to a rabbit.
    "He kept up his wholesome forest walking habit", # Often spotting the same cotton-tailed rabbit.
    "The boy acquired a strange collecting habit", # Of items shaped like his favorite animal: rabbit.
    "The chef developed an experimental cooking habit", # Inspired by watching a munching wild rabbit.
    "The photographer formed a dawn shooting habit", # Capturing perfect moments of a dewdrop-covered rabbit.
    "My grandmother maintained her knitting habit", # Creating tiny sweaters for her daughter's rabbit.
    "The scientist stuck to her observation habit", # Documenting behaviors of the laboratory rabbit.
    "The child couldn't break his skipping habit", # Hopping through the garden like an energetic rabbit.
    "The jogger kept her early morning habit", # Racing along the trail with a wild rabbit.
    "The wizard practiced his disappearing habit", # Vanishing from sight much like a magic rabbit.
    "She developed a serious chocolate habit", # After receiving a gift shaped like a rabbit.
    "The detective never lost his questioning habit", # Following clues that led to a snow-white rabbit.
    "He cultivated a very precise gardening habit", # To protect his carrots from the neighborhood rabbit.
    "The composer maintained her nighttime composing habit", # With melodies inspired by a moonlit rabbit.
    "The teacher had a creative teaching habit", # Using stories about a wise philosophical rabbit.
    "My uncle can't kick his star-gazing habit", # Often seeing constellations shaped like a rabbit.
    "She formed an unusual sketching habit", # Drawing landscapes always featuring a distant rabbit.
    "The doctor maintained a healthy eating habit", # Enjoying salads that would impress a rabbit.
    "The botanist kept her plant-collecting habit", # Finding species that attracted the rare mountain rabbit.
    "My brother developed a strange talking habit", # Of narrating his day to an imaginary rabbit.
    "The seamstress maintained her sewing habit", # Crafting costumes featuring a dancing rabbit.
    "The old man had a generous feeding habit", # Sharing his garden harvest with each passing rabbit.
    "The barista perfected her latte art habit", # Creating foam designs resembling a jumping rabbit.
    "The astronomer continued her stargazing habit", # Discovering a nebula shaped like a cosmic rabbit.
    "The carpenter refined his woodworking habit", # Carving intricate figures of a forest rabbit.
    "My cousin formed an unusual naming habit", # Calling every stray animal 'Peter the rabbit'.
    "The librarian kept her book-suggesting habit", # Often recommending tales about a clever rabbit.
    "The hiker maintained her trail-blazing habit", # Following paths once traveled by the snowshoe rabbit.
    "The young girl had a flower-collecting habit", # Making crowns she'd place upon her patient rabbit.
    "The researcher developed a note-taking habit", # Recording every movement of the study's rabbit.
    "The poet sustained his daily writing habit", # Composing verses about a philosophical rabbit.
    "My aunt established a dawn gardening habit", # Working alongside her garden-helping rabbit.
    "The student formed a late-night studying habit", # Taking breaks to play with her energetic rabbit.
    "The baker kept an experimental baking habit", # Creating carrot treats for her customer's rabbit.
    "The filmmaker maintained a storytelling habit", # Often featuring adventures of a heroic rabbit.
    "The musician developed a curious practice habit", # Playing sonatas that soothed her nervous rabbit.
    "The naturalist continued her tracking habit", # Documenting the passage of each wild rabbit.
    "My father couldn't break his early waking habit", # Always finding time to feed the backyard rabbit.
    "The magician perfected his hat-pulling habit", # Surprising audiences with an appearing rabbit.
    "The engineer maintained her inventing habit", # Creating gadgets to entertain her bored rabbit.
    "The florist developed an arrangement habit", # Including carrot tops to please her shop's rabbit.
    "The therapist kept her gentle listening habit", # Showing patience that matched her office rabbit.
]

lines_that_rhyme_with_habit = [
    "When I found a small, trembling rabbit", # Caring for animals became my habit.
    "She darted through the garden like a rabbit", # Looking for treats had become her habit.
    "He claimed he could pull a hat from a rabbit", # Showing off magic tricks was his daily habit.
    "The children giggled as they chased the rabbit", # Running through meadows became their favorite habit.
    "I planted carrots to attract a rabbit", # Gardening in spring is my cherished habit.
    "My thoughts multiply faster than a rabbit", # Overthinking has become my worst habit.
    "The speedy win went to the tortoise, not the rabbit", # Victory comes from persistence, not just habit.
    "In the moonlight hopped a silver rabbit", # Stargazing at night is now my habit.
    "They built a cozy hutch for their new rabbit", # Creating homes for pets is a wonderful habit.
    "The chef prepared a savory stew with rabbit", # Cooking wild game had become his habit.
    "Through tall grass I spotted a cottontail rabbit", # Hiking through fields is my weekend habit.
    "The magician waved his wand and vanished the rabbit", # Astonishing crowds had become his habit.
    "I sketched the ears and whiskers of a rabbit", # Drawing animals is my creative habit.
    "The farmer chased away the vegetable-stealing rabbit", # Protecting his crops was a necessary habit.
    "At dawn the fox was hunting for a rabbit", # Early rising became his daily habit.
    "In the story, Peter was a mischievous rabbit", # Reading fables became our bedtime habit.
    "Her fear made her timid just like a rabbit", # Avoiding confrontation was her lifelong habit.
    "The child's stuffed toy was a velveteen rabbit", # Carrying comfort objects was her childhood habit.
    "The dog barked loudly at the wild rabbit", # Alert guarding is his protective habit.
    "The hunter set a snare to catch a rabbit", # Living off the land was his family habit.
    "The camera captured a leaping snow-white rabbit", # Photography in winter is my seasonal habit.
    "A clever fox can easily outfox a rabbit", # Strategic thinking is my professional habit.
    "The full moon illuminated the jackrabbit", # Evening walks became our romantic habit.
    "Under the bush was hiding a frightened rabbit", # Finding secret spaces was her peculiar habit.
    "Into his hat disappeared the magical rabbit", # Performing illusions was his lucrative habit.
    "My daughter begged for a pet dwarf rabbit", # Collecting small animals became her expensive habit.
    "The naturalist observed the rare desert rabbit", # Scientific inquiry was her passionate habit.
    "Tales of Brer Fox always included a rabbit", # Telling folk stories was grandfather's evening habit.
    "She embroidered the silhouette of a rabbit", # Creating handcrafted gifts was her generous habit.
    "Through the forest hopped a nimble rabbit", # Morning exercises became his energizing habit.
    "We watched with awe the jumping jackrabbit", # Desert exploration became our vacation habit.
    "The painting depicted a wild mountain rabbit", # Collecting wildlife art was his expensive habit.
    "In the field I photographed a rare pygmy rabbit", # Documenting endangered species is my conservation habit.
    "The child's first pet was a Dutch lop rabbit", # Learning responsibility became her formative habit.
    "On Easter morning appeared a chocolate rabbit", # Holiday traditions became our family habit.
    "The scientist studied the behavior of the arctic rabbit", # Meticulous observation was her scientific habit.
    "The birthday gift was an Angora rabbit", # Surprising loved ones is my thoughtful habit.
    "Never try to outrun a frightened rabbit", # Setting realistic goals is my productive habit.
    "Into the brush disappeared the elusive rabbit", # Playing hide-and-seek was their childhood habit.
    "The young boy dreamed of owning a rabbit", # Wishful thinking became his daydreaming habit.
]


PyTorch version: 2.8.0.dev20250319+cu128
Transformers version: 4.51.3
CUDA available: True
CUDA version: 12.8
Current device: 0
Device name: NVIDIA A100 80GB PCIe
Hugging Face login successful (using provided token).


In [4]:
len(lines_that_rhyme_with_quick)

15

In [5]:
len(lines_that_rhyme_with_pain)

46

In [260]:
# %%
# ## 3. Load Model and Tokenizer

# +
# Configure quantization if needed
quantization_config = None
if USE_4BIT_QUANTIZATION:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16 # Recommended for new models
    )
    print("Using 4-bit quantization.")

# Determine device and dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float32 # BF16 recommended on Ampere+

print(f"Loading model: {MODEL_ID}")
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token # Set pad token if not present

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    quantization_config=quantization_config,
    device_map="auto", # Automatically distribute across GPUs if available
    # use_auth_token=YOUR_HF_TOKEN, # Add if model requires authentication
    trust_remote_code=True # Gemma requires this for some versions/variants
)

print(f"Model loaded on device(s): {model.hf_device_map}")

# --- IMPORTANT: Finding the Layer Name ---
# Uncomment the following line to print the model structure and find the exact layer name
# print(model)
# Look for layers like 'model.layers[INDEX].mlp...' or 'model.layers[INDEX].self_attn...'

# Ensure model is in evaluation mode
model.eval()
# %%
# ## 4. Hooking and Activation Handling Functions

# +
# Global storage for captured activations
activation_storage = {}

def get_module_by_name(model, module_name):
    """Helper function to get a module object from its name string."""
    names = module_name.split('.')
    module = model
    for name in names:
        module = getattr(module, name)
    return module

def capture_activation_hook(module, input, output, layer_name):
    """Hook function to capture the output activation of a specific layer."""
    # We usually care about the last token's activation for steering calculation
    # Output shape is often (batch_size, sequence_length, hidden_dim)
    # Store the activation corresponding to the last token position
    if isinstance(output, torch.Tensor):
        activation_storage[layer_name] = output[:, -1, :].detach().cpu()
    elif isinstance(output, tuple): # Some layers might return tuples
        activation_storage[layer_name] = output[0][:, -1, :].detach().cpu()
    else:
         print(f"Warning: Unexpected output type from layer {layer_name}: {type(output)}")

def capture_activation_hook_fast(module, input, output, layer_name):
    """Hook function to capture the output activation of a specific layer."""
    # We usually care about the last token's activation for steering calculation
    # Output shape is often (batch_size, sequence_length, hidden_dim)
    # Store the activation corresponding to the last token position
    if isinstance(output, torch.Tensor):
        activation_storage[layer_name] = output[:, -1, :].detach().cpu()
    elif isinstance(output, tuple): # Some layers might return tuples
        activation_storage[layer_name] = output[0][:, -1, :].detach().cpu()
    else:
         print(f"Warning: Unexpected output type from layer {layer_name}: {type(output)}")


def get_activations_fast(model, tokenizer, prompts: List[str], layer_name: str) -> Optional[torch.Tensor]:
    """
    Runs prompts through the model and captures activations from the target layer.
    Returns the averaged activation across all prompts for the last token position.
    """
    global activation_storage
    activation_storage = {} # Clear previous activations

    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)

    target_module = get_module_by_name(model, layer_name)
    hook_handle = target_module.register_forward_hook(
        lambda module, input, output: capture_activation_hook_fast(module, input, output, layer_name)
    )

    with torch.no_grad():
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
        # We only need the forward pass, not generation here
        _ = model(**inputs)

        if layer_name in activation_storage:
                # Assuming batch size is 1 when processing one prompt at a time
            last_token_activations = activation_storage[layer_name] # Shape (num_prompts, hidden_dim)
            del activation_storage[layer_name] # Clear for next prompt
        else:
            print(f"Warning: Activation for layer {layer_name} not captured for prompts: '{prompts}'")
                
    hook_handle.remove() # Clean up the hook

    # Stack and average activations across all prompts
    # Resulting shape: (num_prompts, hidden_dim) -> (hidden_dim)
    avg_activation = last_token_activations.mean(dim=0).squeeze() # Average over the prompt dimension
    print(f"Calculated average activation for layer '{layer_name}' with shape: {avg_activation.shape}")
    return avg_activation
# %%
 # --- Steering Hook during Generation ---

# Global variable to hold the steering vector during generation
steering_vector_internal = None
steering_multiplier_internal = 1.0
ALL_POSITIONS=True

def steering_hook(module, input, output):
    """Hook function to modify activations during generation."""
    global steering_vector_internal, steering_multiplier_internal
    if steering_vector_internal is not None:
        if isinstance(output, torch.Tensor):
            # Add steering vector (broadcasts across sequence length)
            # Shape adjustment might be needed depending on layer output structure
            # Assuming output is (batch_size, seq_len, hidden_dim)
            # and steering_vector is (hidden_dim)
            modified_output = output + (steering_vector_internal.to(output.device, dtype=output.dtype) * steering_multiplier_internal)
            return modified_output
        elif isinstance(output, tuple): # Handle layers returning tuples
             # Assuming the tensor to modify is the first element
            modified_tensor = output[0] + (steering_vector_internal.to(output[0].device, dtype=output[0].dtype) * steering_multiplier_internal)
            return (modified_tensor,) + output[1:]
        else:
            print(f"Warning: Steering hook encountered unexpected output type: {type(output)}")
            return output # Return original if type is unknown
    return output # Return original if no steering vector

def steering_hook(module, input, output, all_positions=ALL_POSITIONS):
    """Hook function to modify activations during generation."""
    global steering_vector_internal, steering_multiplier_internal
    if steering_vector_internal is not None:
        if isinstance(output, torch.Tensor):
            # Add steering vector (broadcasts across sequence length)
            # Shape adjustment might be needed depending on layer output structure
            # Assuming output is (batch_size, seq_len, hidden_dim)
            # and steering_vector is (hidden_dim)
            if output.shape[1] != 1 or all_positions:
                output[:, -1, :] += (steering_vector_internal.to(output.device, dtype=output.dtype) * steering_multiplier_internal)
            return output
        elif isinstance(output, tuple): # Handle layers returning tuples
            # Assuming the tensor to modify is the first element
            modified_tensor = output[0]
            # print(modified_tensor.shape)
            if modified_tensor.shape[1] != 1 or all_positions:
                modified_tensor[:, -1, :] += (steering_vector_internal.to(output[0].device, dtype=output[0].dtype) * steering_multiplier_internal)
            return (modified_tensor,) + output[1:]
        else:
            print(f"Warning: Steering hook encountered unexpected output type: {type(output)}")
            return output # Return original if type is unknown
    return output # Return original if no steering vector
    
@contextmanager
def apply_steering(model, layer, steering_vector, multiplier):
    """Context manager to temporarily apply the steering hook."""
    global steering_vector_internal, steering_multiplier_internal
    layer_name = f"model.layers.{layer}"

    # Ensure previous hook (if any) on the same layer is removed
    # This basic implementation assumes only one steering hook at a time on this layer
    # More robust solutions might track handles explicitly.
    
    handle = None
    try:
        steering_vector_internal = steering_vector
        steering_multiplier_internal = multiplier
        target_module = get_module_by_name(model, layer_name)
        handle = target_module.register_forward_hook(steering_hook)
        print(f"Steering hook applied to {layer_name} with multiplier {multiplier}")
        yield # Generation happens here
    finally:
        if handle:
            handle.remove()
        steering_vector_internal = None # Clear global state
        steering_multiplier_internal = 1.0
        print(f"Steering hook removed from {layer_name}")
        gc.collect() # Suggest garbage collection
        torch.cuda.empty_cache() # Clear cache if using GPU

def generate_steered_output(steering_vector, model, tokenizer, generation_prompts, batch_size, layer=20, steering_multiplier=STEERING_MULTIPLIER):
    # Ensure generation_prompts is a list
    if isinstance(generation_prompts, str):
        generation_prompts = [generation_prompts] 
    generation_prompts = generation_prompts * batch_size
    
    # Process in batches of 1000
    MAX_BATCH_SIZE = 1000
    all_texts = []
    
    for i in tqdm.tqdm(range(0, len(generation_prompts), MAX_BATCH_SIZE),desc="Processing batch"):
        #batch_prompts = generation_prompts[i:i + MAX_BATCH_SIZE]
        # Get batch of prompts, handling case where remaining prompts < MAX_BATCH_SIZE
        current_batch_size = min(MAX_BATCH_SIZE, len(generation_prompts) - i)
        batch_prompts = generation_prompts[i:i + current_batch_size]
        inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True).to(model.device)
        
        if steering_vector is None:
            print(inputs.input_ids.shape)
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=MAX_NEW_TOKENS,
                    temperature=TEMPERATURE,
                    do_sample=DO_SAMPLE,
                    pad_token_id=tokenizer.eos_token_id # Important for generation
                )
        else:
            with torch.no_grad():
                # Apply the steering hook using the context manager
                with apply_steering(model, layer, steering_vector, steering_multiplier):
                    outputs = model.generate(
                        **inputs, # Use the same input tokens
                        max_new_tokens=MAX_NEW_TOKENS,
                        temperature=TEMPERATURE,
                        do_sample=DO_SAMPLE,
                        pad_token_id=tokenizer.eos_token_id,
                    )
        
        batch_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        all_texts.extend(batch_texts)

        del outputs, inputs
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return all_texts

def generate_outputs(steering_vector, model, tokenizer, generation_prompt, batch_size, layer=20, steering_multiplier=STEERING_MULTIPLIER):
    assert steering_vector is not None
    text_baseline = generate_steered_output(None, model, tokenizer, generation_prompt, batch_size, layer=layer, steering_multiplier=steering_multiplier)
    text_steered = generate_steered_output(steering_vector, model, tokenizer, generation_prompt, batch_size, layer=layer, steering_multiplier=steering_multiplier)
    text_negsteered = generate_steered_output(-steering_vector, model, tokenizer, generation_prompt, batch_size, layer=layer, steering_multiplier=steering_multiplier)
    return text_baseline, text_steered, text_negsteered

# %%
# ## Compute the Steering Vector
def get_steering_vector_fast(model, tokenizer, positive_prompts, negative_prompts, layer=20):
    target_layer_name = f"model.layers.{layer}"
    print("Calculating activations for POSITIVE prompts...")
    avg_pos_activation = get_activations_fast(model, tokenizer, positive_prompts, target_layer_name)

    print("\nCalculating activations for NEGATIVE prompts...")
    avg_neg_activation = get_activations_fast(model, tokenizer, negative_prompts, target_layer_name)

    steering_vector = None
    if avg_pos_activation is not None and avg_neg_activation is not None:
        steering_vector = avg_pos_activation - avg_neg_activation
        print(f"\nSteering vector computed successfully. Shape: {steering_vector.shape}")
        # Optional: Normalize the steering vector (can sometimes help)
        # steering_vector = steering_vector / torch.norm(steering_vector)
        # print("Steering vector normalized.")
    else:
        print("\nError: Could not compute steering vector due to missing activations.")
    del avg_pos_activation
    del avg_neg_activation
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return steering_vector


Loading model: google/gemma-2-9b-it
Using device: cuda
Using dtype: torch.bfloat16


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded on device(s): {'': 0}


In [82]:
# %%
# ## Functions to analyze the generations
def get_last_word(text):
    lines = text.split("\n")
    if len(lines) < 3:
        print(f"Failed to get last word: {text}")
        return ""
    second_line = lines[2]
    second_line_words = second_line.split(" ")
    if len(second_line_words) == 0:
        print(f"Failed to get last word: {text}")
        return ""
    last_word = second_line_words[-1]
    if last_word == "":
        if len(second_line_words) == 1:
            print(f"Failed to get last word: {text}")
            return ""
        last_word = second_line_words[-2]
    return last_word

def get_second_line(text,include_prompt=True):
    lines = text.split("\n")
    if len(lines) < 3:
        print(f"Failed to get second line: {text}")
        return text
    second_line = lines[2]
    second_line=second_line.strip(' ')
    second_line_words = second_line.split(" ")
    if len(second_line_words) > 0:
        second_line=' '.join(second_line_words[:-1])
    if include_prompt: second_line='\n'.join(lines[:2]+[second_line])
    return second_line

def get_last_word_fraction(texts, words):
    if isinstance(words, str):
        words = [words]
    last_words = [get_last_word(line) for line in texts]
    return len([w for w in last_words if any(w2.lower() in w.lower() for w2 in words)]) / len(last_words)

def get_prompts(lines):
    return [f'A rhymed couplet:\n{line}\n' for line in lines]


In [8]:
steering_vector = torch.load('steering_vector_from_quick_to_pain.pt')
print(steering_vector )

tensor([0.1094, 0.6250, 1.0312,  ..., 0.2266, 0.1875, 0.4473],
       dtype=torch.bfloat16)


In [9]:
def generate_output(steering_vector, model, tokenizer, generation_prompt, batch_size, layer=20, steering_multiplier=STEERING_MULTIPLIER):
    assert steering_vector is not None
    text_baseline = generate_steered_output(None, model, tokenizer, generation_prompt, batch_size, layer=layer, steering_multiplier=steering_multiplier)
    text_steered = generate_steered_output(steering_vector, model, tokenizer, generation_prompt, batch_size, layer=layer, steering_multiplier=steering_multiplier)
    #text_negsteered = generate_steered_output(-steering_vector, model, tokenizer, generation_prompt, batch_size, layer=layer, steering_multiplier=steering_multiplier)
    return text_baseline, text_steered#, text_negsteered


In [19]:
GENERATION_PROMPT=f'A rhymed couplet:\n{lines_that_rhyme_with_pain[1]}\n'

In [20]:
text_baseline, text_steered = generate_output(steering_vector, model, tokenizer, GENERATION_PROMPT, 1, layer=20, steering_multiplier=STEERING_MULTIPLIER)


torch.Size([1, 17])
Steering hook applied to model.layers.20 with multiplier 1.5
Steering hook removed from model.layers.20


In [22]:
print(text_baseline[0])
print(text_steered[0])

A rhymed couplet:
The wizard's knowledge was profoundly arcane
His spells and incantations, a wondrous domain.

What are the characteristics of a rhymed couplet?

Here are the characteristics of a
A rhymed couplet:
The wizard's knowledge was profoundly arcane
He brewed strange potions with a wicked, glint-filled smirk

Let me know if you'd like me to write another one!



In [23]:
GENERATION_PROMPT=f'A rhymed couplet:\n{lines_that_rhyme_with_quick[0]}\n'

In [24]:
text_baseline, text_steered = generate_output(-steering_vector, model, tokenizer, GENERATION_PROMPT, 1, layer=20, steering_multiplier=STEERING_MULTIPLIER)


torch.Size([1, 18])
Steering hook applied to model.layers.20 with multiplier 1.5
Steering hook removed from model.layers.20


In [25]:
print(text_baseline[0])
print(text_steered[0])

A rhymed couplet:
The house was built with sturdy, reddish brick
And weathered well against the elements quick. 




A rhymed couplet:
The house was built with sturdy, reddish brick
And echoed with the whisper of the rain.


Let me know if you'd like another couplet! 




In [27]:
import random, tqdm

In [28]:
GENERATION_PROMPTS = [f'A rhymed couplet:\n{line}\n' for line in tqdm.tqdm(lines_that_rhyme_with_quick, desc="Generating prompts")]
text_baselines, text_steered = generate_output(-steering_vector, model, tokenizer, GENERATION_PROMPTS, 100, layer=20, steering_multiplier=STEERING_MULTIPLIER)
texts_baseline_quick=(text_baselines)
texts_steered_from_quick=(text_steered)


Generating prompts: 100%|██████████| 15/15 [00:00<00:00, 166001.48it/s]


torch.Size([1000, 21])
torch.Size([500, 21])
Steering hook applied to model.layers.20 with multiplier 1.5
Steering hook removed from model.layers.20
Steering hook applied to model.layers.20 with multiplier 1.5
Steering hook removed from model.layers.20


In [30]:
len(texts_baseline_quick)

1500

In [41]:
texts_baseline_pain=[]
texts_steered_from_pain=[]
lines_that_rhyme_with_pain_sample=lines_that_rhyme_with_pain
random.shuffle(lines_that_rhyme_with_pain_sample)
lines_that_rhyme_with_pain_sample=lines_that_rhyme_with_pain_sample[:15]
GENERATION_PROMPTS = [f'A rhymed couplet:\n{line}\n' for line in tqdm.tqdm(lines_that_rhyme_with_pain_sample, desc="Generating prompts")]
text_baselines, text_steered = generate_output(-steering_vector, model, tokenizer, GENERATION_PROMPTS, 100, layer=20, steering_multiplier=STEERING_MULTIPLIER)
texts_baseline_pain=(text_baselines)
texts_steered_from_pain=(text_steered)
#for line in tqdm.tqdm(lines_that_rhyme_with_pain_sample):
#    GENERATION_PROMPT=f'A rhymed couplet:\n{line}\n'
#    text_baseline, text_steered = generate_output(steering_vector, model, tokenizer, GENERATION_PROMPT, 100, layer=20, steering_multiplier=STEERING_MULTIPLIER)
#    texts_baseline_pain.append(text_baseline)
#    texts_steered_from_pain.append(text_steered)



Generating prompts: 100%|██████████| 15/15 [00:00<00:00, 140434.29it/s]


torch.Size([1000, 18])
torch.Size([500, 18])
Steering hook applied to model.layers.20 with multiplier 1.5
Steering hook removed from model.layers.20
Steering hook applied to model.layers.20 with multiplier 1.5
Steering hook removed from model.layers.20


In [42]:
random.choice(texts_baseline_pain)

"A rhymed couplet:\nThe heavy lifting put his back under strain\nHe groaned in pain, a weary, aching refrain.\n\n\nLet me know if you'd like to see more!  I can write more rhyming"

In [52]:
lines_that_rhyme_with_pain_sample=lines_that_rhyme_with_pain
random.shuffle(lines_that_rhyme_with_pain_sample)
lines_that_rhyme_with_pain_sample=lines_that_rhyme_with_pain_sample[:15]
GENERATION_PROMPTS = [f'A rhymed couplet:\n{line}\n' for line in tqdm.tqdm(lines_that_rhyme_with_pain_sample, desc="Generating prompts")]
texts_baselines_pain = generate_steered_output(None, model, tokenizer, GENERATION_PROMPTS, 100, layer=20, steering_multiplier=STEERING_MULTIPLIER)

Generating prompts: 100%|██████████| 15/15 [00:00<00:00, 153450.15it/s]


torch.Size([1000, 19])
torch.Size([500, 19])


In [53]:
random.choice(texts_baselines_pain)

'A rhymed couplet:\nThe queen would rarely to respond deign\nTo questions asked in such a crude domain.\n\nHere\'s a breakdown:\n\n* **Rhyme:** "deign" and "domain"'

In [54]:
import json

In [55]:
with open("line_catalog.json",'r') as f:
    line_catalog=json.load(f)

In [73]:
catalog={"slick_rhymes":line_catalog["slick_rhymes"]}

In [71]:
for family in line_catalog:
    print(family)
    print(line_catalog[family][0])

slick_rhymes
Whispers in the dark reveal a clever trick
night_rhymes
The moon casts gentle shadows on this quiet night
doom_rhymes
Flowers stretch toward the sun, ready to bloom
sleep_rhymes
Autumn leaves gather in a colorful heap
band_rhymes
Heartbeats sync to the rhythm of the band
sing_rhymes
Whispers of freedom found in a bird's wing
unfold_rhymes
Let dreams like petals delicately unfold
shore_rhymes
The autumn leaves dance gracefully to the floor
pain_rhymes
I dance beneath the gentle kiss of rain
skies_rhymes
Time slips away like grains of sand through flies
bake_rhymes
Whispers of hope in the shadows we make
call_rhymes
Leaves dance gracefully before they finally fall


In [83]:
completions={}
for family in tqdm.tqdm(line_catalog, desc="Processing rhyme families"):
    GENERATION_PROMPTS = [f'A rhymed couplet:\n{line}\n' for line in line_catalog[family]]
    texts_baselines_in_family = generate_steered_output(None, model, tokenizer, GENERATION_PROMPTS, 100, layer=20, steering_multiplier=STEERING_MULTIPLIER)
    completions[family]=texts_baselines_in_family
completions_json=json.dumps(competions,indent=4)


Processing rhyme families:   0%|          | 0/12 [00:00<?, ?it/s]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 21])



Processing batch:  10%|█         | 1/10 [00:11<01:45, 11.71s/it][A

torch.Size([1000, 21])



Processing batch:  20%|██        | 2/10 [00:23<01:33, 11.64s/it][A

torch.Size([1000, 21])



Processing batch:  30%|███       | 3/10 [00:34<01:21, 11.63s/it][A

torch.Size([1000, 21])



Processing batch:  40%|████      | 4/10 [00:46<01:09, 11.65s/it][A

torch.Size([1000, 21])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.66s/it][A

torch.Size([1000, 21])



Processing batch:  60%|██████    | 6/10 [01:10<00:46, 11.69s/it][A

torch.Size([1000, 21])



Processing batch:  70%|███████   | 7/10 [01:21<00:35, 11.72s/it][A

torch.Size([1000, 21])



Processing batch:  80%|████████  | 8/10 [01:33<00:23, 11.74s/it][A

torch.Size([1000, 21])



Processing batch:  90%|█████████ | 9/10 [01:45<00:11, 11.75s/it][A

torch.Size([1000, 21])



Processing batch: 100%|██████████| 10/10 [01:57<00:00, 11.72s/it][A
Processing rhyme families:   8%|▊         | 1/12 [01:57<21:28, 117.17s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 21])



Processing batch:  10%|█         | 1/10 [00:11<01:46, 11.80s/it][A

torch.Size([1000, 21])



Processing batch:  20%|██        | 2/10 [00:23<01:34, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  30%|███       | 3/10 [00:35<01:22, 11.77s/it][A

torch.Size([1000, 21])



Processing batch:  40%|████      | 4/10 [00:47<01:10, 11.79s/it][A

torch.Size([1000, 21])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  60%|██████    | 6/10 [01:10<00:47, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  70%|███████   | 7/10 [01:22<00:35, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  80%|████████  | 8/10 [01:34<00:23, 11.77s/it][A

torch.Size([1000, 21])



Processing batch:  90%|█████████ | 9/10 [01:46<00:11, 11.78s/it][A

torch.Size([1000, 21])



Processing batch: 100%|██████████| 10/10 [01:57<00:00, 11.78s/it][A
Processing rhyme families:  17%|█▋        | 2/12 [03:55<19:35, 117.56s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 21])



Processing batch:  10%|█         | 1/10 [00:11<01:46, 11.79s/it][A

torch.Size([1000, 21])



Processing batch:  20%|██        | 2/10 [00:23<01:34, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  30%|███       | 3/10 [00:35<01:22, 11.79s/it][A

torch.Size([1000, 21])



Processing batch:  40%|████      | 4/10 [00:47<01:10, 11.79s/it][A

torch.Size([1000, 21])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  60%|██████    | 6/10 [01:10<00:47, 11.80s/it][A

torch.Size([1000, 21])



Processing batch:  70%|███████   | 7/10 [01:22<00:35, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  80%|████████  | 8/10 [01:34<00:23, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  90%|█████████ | 9/10 [01:46<00:11, 11.79s/it][A

torch.Size([1000, 21])



Processing batch: 100%|██████████| 10/10 [01:57<00:00, 11.79s/it][A
Processing rhyme families:  25%|██▌       | 3/12 [05:52<17:39, 117.71s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 20])



Processing batch:  10%|█         | 1/10 [00:11<01:45, 11.68s/it][A

torch.Size([1000, 20])



Processing batch:  20%|██        | 2/10 [00:23<01:33, 11.69s/it][A

torch.Size([1000, 20])



Processing batch:  30%|███       | 3/10 [00:35<01:21, 11.70s/it][A

torch.Size([1000, 20])



Processing batch:  40%|████      | 4/10 [00:46<01:10, 11.69s/it][A

torch.Size([1000, 20])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.67s/it][A

torch.Size([1000, 20])



Processing batch:  60%|██████    | 6/10 [01:10<00:46, 11.68s/it][A

torch.Size([1000, 20])



Processing batch:  70%|███████   | 7/10 [01:21<00:34, 11.66s/it][A

torch.Size([1000, 20])



Processing batch:  80%|████████  | 8/10 [01:33<00:23, 11.65s/it][A

torch.Size([1000, 20])



Processing batch:  90%|█████████ | 9/10 [01:44<00:11, 11.65s/it][A

torch.Size([1000, 20])



Processing batch: 100%|██████████| 10/10 [01:56<00:00, 11.66s/it][A
Processing rhyme families:  33%|███▎      | 4/12 [07:49<15:38, 117.28s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 21])



Processing batch:  10%|█         | 1/10 [00:11<01:45, 11.76s/it][A

torch.Size([1000, 21])



Processing batch:  20%|██        | 2/10 [00:23<01:34, 11.80s/it][A

torch.Size([1000, 21])



Processing batch:  30%|███       | 3/10 [00:35<01:22, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  40%|████      | 4/10 [00:47<01:10, 11.79s/it][A

torch.Size([1000, 21])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  60%|██████    | 6/10 [01:10<00:47, 11.77s/it][A

torch.Size([1000, 21])



Processing batch:  70%|███████   | 7/10 [01:22<00:35, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  80%|████████  | 8/10 [01:34<00:23, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  90%|█████████ | 9/10 [01:46<00:11, 11.77s/it][A

torch.Size([1000, 21])



Processing batch: 100%|██████████| 10/10 [01:57<00:00, 11.78s/it][A
Processing rhyme families:  42%|████▏     | 5/12 [09:47<13:42, 117.46s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 21])



Processing batch:  10%|█         | 1/10 [00:11<01:45, 11.74s/it][A

torch.Size([1000, 21])



Processing batch:  20%|██        | 2/10 [00:23<01:33, 11.74s/it][A

torch.Size([1000, 21])



Processing batch:  30%|███       | 3/10 [00:35<01:22, 11.75s/it][A

torch.Size([1000, 21])



Processing batch:  40%|████      | 4/10 [00:46<01:10, 11.75s/it][A

torch.Size([1000, 21])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.76s/it][A

torch.Size([1000, 21])



Processing batch:  60%|██████    | 6/10 [01:10<00:47, 11.76s/it][A

torch.Size([1000, 21])



Processing batch:  70%|███████   | 7/10 [01:22<00:35, 11.75s/it][A

torch.Size([1000, 21])



Processing batch:  80%|████████  | 8/10 [01:34<00:23, 11.77s/it][A

torch.Size([1000, 21])



Processing batch:  90%|█████████ | 9/10 [01:45<00:11, 11.77s/it][A

torch.Size([1000, 21])



Processing batch: 100%|██████████| 10/10 [01:57<00:00, 11.76s/it][A
Processing rhyme families:  50%|█████     | 6/12 [11:44<11:45, 117.51s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 22])



Processing batch:  10%|█         | 1/10 [00:11<01:46, 11.88s/it][A

torch.Size([1000, 22])



Processing batch:  20%|██        | 2/10 [00:23<01:34, 11.87s/it][A

torch.Size([1000, 22])



Processing batch:  30%|███       | 3/10 [00:35<01:23, 11.86s/it][A

torch.Size([1000, 22])



Processing batch:  40%|████      | 4/10 [00:47<01:11, 11.86s/it][A

torch.Size([1000, 22])



Processing batch:  50%|█████     | 5/10 [00:59<00:59, 11.86s/it][A

torch.Size([1000, 22])



Processing batch:  60%|██████    | 6/10 [01:11<00:47, 11.87s/it][A

torch.Size([1000, 22])



Processing batch:  70%|███████   | 7/10 [01:23<00:35, 11.87s/it][A

torch.Size([1000, 22])



Processing batch:  80%|████████  | 8/10 [01:34<00:23, 11.87s/it][A

torch.Size([1000, 22])



Processing batch:  90%|█████████ | 9/10 [01:46<00:11, 11.87s/it][A

torch.Size([1000, 22])



Processing batch: 100%|██████████| 10/10 [01:58<00:00, 11.87s/it][A
Processing rhyme families:  58%|█████▊    | 7/12 [13:43<09:49, 117.89s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 21])



Processing batch:  10%|█         | 1/10 [00:11<01:46, 11.81s/it][A

torch.Size([1000, 21])



Processing batch:  20%|██        | 2/10 [00:23<01:34, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  30%|███       | 3/10 [00:35<01:22, 11.79s/it][A

torch.Size([1000, 21])



Processing batch:  40%|████      | 4/10 [00:47<01:10, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  60%|██████    | 6/10 [01:10<00:47, 11.77s/it][A

torch.Size([1000, 21])



Processing batch:  70%|███████   | 7/10 [01:22<00:35, 11.77s/it][A

torch.Size([1000, 21])



Processing batch:  80%|████████  | 8/10 [01:34<00:23, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  90%|█████████ | 9/10 [01:46<00:11, 11.80s/it][A

torch.Size([1000, 21])



Processing batch: 100%|██████████| 10/10 [01:57<00:00, 11.79s/it][A
Processing rhyme families:  67%|██████▋   | 8/12 [15:41<07:51, 117.89s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 21])



Processing batch:  10%|█         | 1/10 [00:11<01:46, 11.83s/it][A

torch.Size([1000, 21])



Processing batch:  20%|██        | 2/10 [00:23<01:34, 11.83s/it][A

torch.Size([1000, 21])



Processing batch:  30%|███       | 3/10 [00:35<01:22, 11.85s/it][A

torch.Size([1000, 21])



Processing batch:  40%|████      | 4/10 [00:47<01:10, 11.83s/it][A

torch.Size([1000, 21])



Processing batch:  50%|█████     | 5/10 [00:59<00:59, 11.81s/it][A

torch.Size([1000, 21])



Processing batch:  60%|██████    | 6/10 [01:10<00:47, 11.81s/it][A

torch.Size([1000, 21])



Processing batch:  70%|███████   | 7/10 [01:22<00:35, 11.80s/it][A

torch.Size([1000, 21])



Processing batch:  80%|████████  | 8/10 [01:34<00:23, 11.79s/it][A

torch.Size([1000, 21])



Processing batch:  90%|█████████ | 9/10 [01:46<00:11, 11.78s/it][A

torch.Size([1000, 21])



Processing batch: 100%|██████████| 10/10 [01:58<00:00, 11.80s/it][A
Processing rhyme families:  75%|███████▌  | 9/12 [17:39<05:53, 117.93s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 21])



Processing batch:  10%|█         | 1/10 [00:11<01:46, 11.81s/it][A

torch.Size([1000, 21])



Processing batch:  20%|██        | 2/10 [00:23<01:34, 11.80s/it][A

torch.Size([1000, 21])



Processing batch:  30%|███       | 3/10 [00:35<01:22, 11.82s/it][A

torch.Size([1000, 21])



Processing batch:  40%|████      | 4/10 [00:47<01:10, 11.80s/it][A

torch.Size([1000, 21])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.79s/it][A

torch.Size([1000, 21])



Processing batch:  60%|██████    | 6/10 [01:10<00:47, 11.79s/it][A

torch.Size([1000, 21])



Processing batch:  70%|███████   | 7/10 [01:22<00:35, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  80%|████████  | 8/10 [01:34<00:23, 11.78s/it][A

torch.Size([1000, 21])



Processing batch:  90%|█████████ | 9/10 [01:46<00:11, 11.79s/it][A

torch.Size([1000, 21])



Processing batch: 100%|██████████| 10/10 [01:57<00:00, 11.79s/it][A
Processing rhyme families:  83%|████████▎ | 10/12 [19:37<03:55, 117.93s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 20])



Processing batch:  10%|█         | 1/10 [00:11<01:45, 11.70s/it][A

torch.Size([1000, 20])



Processing batch:  20%|██        | 2/10 [00:23<01:33, 11.70s/it][A

torch.Size([1000, 20])



Processing batch:  30%|███       | 3/10 [00:35<01:21, 11.68s/it][A

torch.Size([1000, 20])



Processing batch:  40%|████      | 4/10 [00:46<01:10, 11.68s/it][A

torch.Size([1000, 20])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.68s/it][A

torch.Size([1000, 20])



Processing batch:  60%|██████    | 6/10 [01:10<00:46, 11.68s/it][A

torch.Size([1000, 20])



Processing batch:  70%|███████   | 7/10 [01:21<00:35, 11.67s/it][A

torch.Size([1000, 20])
torch.Size([1000, 20])



Processing batch:  90%|█████████ | 9/10 [01:45<00:11, 11.68s/it][A

torch.Size([1000, 20])



Processing batch: 100%|██████████| 10/10 [01:56<00:00, 11.68s/it][A
Processing rhyme families:  92%|█████████▏| 11/12 [21:34<01:57, 117.59s/it]
Processing batch:   0%|          | 0/10 [00:00<?, ?it/s][A

torch.Size([1000, 20])



Processing batch:  10%|█         | 1/10 [00:11<01:45, 11.68s/it][A

torch.Size([1000, 20])



Processing batch:  20%|██        | 2/10 [00:23<01:33, 11.65s/it][A

torch.Size([1000, 20])



Processing batch:  30%|███       | 3/10 [00:34<01:21, 11.66s/it][A

torch.Size([1000, 20])



Processing batch:  40%|████      | 4/10 [00:46<01:09, 11.66s/it][A

torch.Size([1000, 20])



Processing batch:  50%|█████     | 5/10 [00:58<00:58, 11.67s/it][A

torch.Size([1000, 20])



Processing batch:  60%|██████    | 6/10 [01:09<00:46, 11.67s/it][A

torch.Size([1000, 20])



Processing batch:  70%|███████   | 7/10 [01:21<00:35, 11.67s/it][A

torch.Size([1000, 20])



Processing batch:  80%|████████  | 8/10 [01:33<00:23, 11.67s/it][A

torch.Size([1000, 20])



Processing batch:  90%|█████████ | 9/10 [01:45<00:11, 11.70s/it][A

torch.Size([1000, 20])



Processing batch: 100%|██████████| 10/10 [01:56<00:00, 11.68s/it][A
Processing rhyme families: 100%|██████████| 12/12 [23:31<00:00, 117.59s/it]


In [102]:
from collections import defaultdict, Counter

In [85]:
len(list(compl.keys()))

100

In [95]:
type(completions["slick_rhymes"])

list

In [103]:
def get_last_word_distribution(texts_list):
    last_word_distribution={}
    last_words=defaultdict(list)
    for text in texts_list:
        text = text.replace("\\n", "\n") 
        firstline=text.split('\n')[1]
        lastword=get_last_word(text)
        lastword=lastword.strip(',')
        lastword=lastword.strip('.')
        lastword=lastword.strip('!')
        last_words[firstline].append(lastword)
    for line in last_words:
        last_word_distribution[line]=Counter(last_words[line])
    return last_word_distribution

In [104]:
compl={}
for rtype in completions:
    compl[rtype]=get_last_word_distribution(completions[rtype])

Failed to get last word: A rhymed couplet:
Dreams can bend but should never break or stick

To the real world, where logic is quick. 

Let me know if you'd like to see more!

Failed to get last word: A rhymed couplet:
Fortune's wheel turns with an unpredictable swing

Leaving joy and sorrow in its constant ring. 


Let me know if you'd like me to create more couplets!

Failed to get last word: A rhymed couplet:
Beneath autumn leaves, stories of hearts unfold, memories enrolled


Failed to get last word: A rhymed couplet:
Beneath autumn leaves, stories of hearts unfold, memories enrolled


Failed to get last word: A rhymed couplet:
Beneath autumn leaves, stories of hearts unfold, memories enrolled


Failed to get last word: A rhymed couplet:
Beneath autumn leaves, stories of hearts unfold, memories enrolled


Failed to get last word: A rhymed couplet:
Beneath autumn leaves, stories of hearts unfold, memories enrolled


Failed to get last word: A rhymed couplet:
Beneath autumn leaves, st

In [105]:
def top_elements(counter_dict, mincount=30):
    suggestive_lines=defaultdict(list)
    suggested_words=Counter()
    for key, counter in counter_dict.items():
        if counter:
            top_element, freq = counter.most_common(1)[0]
            if freq>=mincount: 
                #print(f"{key}: \n{top_element} ({freq})")
                suggestive_lines[top_element].append(key)
                suggested_words[top_element]+=1
        #else:
            #print(f"{key}: Counter is empty")
    return suggestive_lines, suggested_words

In [130]:
def has_two_elements_above(counter, threshold):
    count = sum(1 for v in counter.values() if v >= threshold)
    return count >= 2

def two_elements_above(counter, threshold):
    return [k for k,v in counter.items() if v >= threshold]

In [175]:
sugglines={}
suggwords={}
t=51
minlines=4
suggested_rhymes={}
for rtype in compl:
    suggested_rhymes[rtype]={}
    sugglines[rtype],suggwords[rtype]=top_elements(compl[rtype],mincount=t)
    k=len(list(suggwords[rtype].keys()))
    candidate=two_elements_above(suggwords[rtype],minlines)
    if k>=2 and has_two_elements_above(suggwords[rtype],minlines):
        print(f"Rhyme family {rtype} has {k} words with at least {minlines} suggestive lines that elicit them at least {t}% of the time")
        print(suggwords[rtype])
        for word in candidate:
            suggested_rhymes[rtype][word]=sugglines[rtype][word]

Rhyme family night_rhymes has 5 words with at least 4 suggestive lines that elicit them at least 51% of the time
Counter({'light': 25, 'night': 16, 'sight': 4, 'bright': 4, 'might': 1})
Rhyme family sleep_rhymes has 3 words with at least 4 suggestive lines that elicit them at least 51% of the time
Counter({'keep': 29, 'deep': 5, 'sleep': 2})
Rhyme family call_rhymes has 4 words with at least 4 suggestive lines that elicit them at least 51% of the time
Counter({'all': 8, 'call': 4, 'small': 2, 'tall': 1})


In [176]:
suggested_rhymes

{'slick_rhymes': {},
 'night_rhymes': {'sight': ['Stars twinkle like diamonds in the velvet night',
   'The sunset paints the sky with colors so bright',
   "Butterflies dance in the garden's gentle flight",
   'The stars in the night sky shine incredibly tight'],
  'light': ['Whispers of dreams dance through the still night',
   'Shadows stretch long across the mysterious night',
   'Crickets sing their lullabies throughout the summer night',
   'Ghosts of memories haunt the lonely night',
   'Mountains stand as ancient guardians of majestic height',
   'Her confidence shines like a beacon at full height',
   'Courage carries us through fear to newfound height',
   'Whispers of love echo in the still of night',
   'Whispers of courage echo through the night might',
   'Through valleys deep, mountains rise with ancient might',
   "Within each seed lies nature's dormant might",
   'The eagle soared, a symbol of majestic flight',
   'Memories flutter like birds in sudden flight',
   'Hea

In [168]:
sugglines

{'slick_rhymes': defaultdict(list,
             {'quick': ['Whispers in the dark reveal a clever trick',
               'I watch the clock hands move with every tick',
               "Memories fade with each small moment's tick",
               'Waiting for you with each anxious tick',
               'Fresh paint on the wall, wet and slick',
               "Dreams vanish at dawn with morning light's gentle flick",
               'Shadows dance when flames lick the burning stick',
               'We built a fort with stones and one stick',
               'Beneath the cherry tree, my memories I pick'],
              'trick': ['When the rain falls, the roads become slick',
               "Stars appear in darkness with heaven's celestial flick"]}),
 'night_rhymes': defaultdict(list,
             {'sight': ['Stars twinkle like diamonds in the velvet night',
               'The sunset paints the sky with colors so bright',
               "Butterflies dance in the garden's gentle flight",
   

In [177]:
with open('suggested_rhymes.json',"w") as f:
    json.dump(suggested_rhymes,f)

In [250]:
for rtype in suggested_rhymes:
    if suggested_rhymes[rtype].keys():
        print(rtype, suggwords[rtype])


night_rhymes Counter({'light': 25, 'night': 16, 'sight': 4, 'bright': 4, 'might': 1})
sleep_rhymes Counter({'keep': 29, 'deep': 5, 'sleep': 2})
call_rhymes Counter({'all': 8, 'call': 4, 'small': 2, 'tall': 1})


In [182]:
POSITIVE_PROMPTS=get_prompts(suggested_rhymes["sleep_rhymes"]["keep"])
NEGATIVE_PROMPTS=get_prompts(suggested_rhymes["sleep_rhymes"]["deep"])

keepdeepvector=get_steering_vector_fast(model, tokenizer, POSITIVE_PROMPTS, NEGATIVE_PROMPTS, layer=20)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])


In [183]:
torch.norm(keepdeepvector, p=2)

tensor(34.7500, dtype=torch.bfloat16)

In [215]:
example=random.choice(line_catalog["sleep_rhymes"])
prompt=get_prompts([example])

In [216]:
example

'Flowers up the garden wall gracefully creep'

In [239]:
steered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                10000, layer=20, steering_multiplier=2
            )

Processing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch:  10%|█         | 1/10 [00:11<01:39, 11.00s/it]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch:  20%|██        | 2/10 [00:21<01:27, 11.00s/it]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch:  30%|███       | 3/10 [00:33<01:17, 11.00s/it]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch:  40%|████      | 4/10 [00:44<01:06, 11.02s/it]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch:  50%|█████     | 5/10 [00:55<00:55, 11.02s/it]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch:  60%|██████    | 6/10 [01:06<00:44, 11.06s/it]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch:  70%|███████   | 7/10 [01:17<00:33, 11.07s/it]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch:  80%|████████  | 8/10 [01:28<00:22, 11.08s/it]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch:  90%|█████████ | 9/10 [01:39<00:11, 11.11s/it]

Steering hook applied to model.layers.20 with multiplier 2
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 10/10 [01:50<00:00, 11.07s/it]


In [240]:
baseline_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                10000, layer=20, steering_multiplier=0
            )

Processing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch:  10%|█         | 1/10 [00:11<01:40, 11.16s/it]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch:  20%|██        | 2/10 [00:22<01:28, 11.10s/it]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch:  30%|███       | 3/10 [00:33<01:17, 11.10s/it]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch:  40%|████      | 4/10 [00:44<01:06, 11.08s/it]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch:  50%|█████     | 5/10 [00:55<00:55, 11.07s/it]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch:  60%|██████    | 6/10 [01:06<00:44, 11.07s/it]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch:  70%|███████   | 7/10 [01:17<00:33, 11.08s/it]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch:  80%|████████  | 8/10 [01:28<00:22, 11.08s/it]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch:  90%|█████████ | 9/10 [01:39<00:11, 11.09s/it]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 10/10 [01:50<00:00, 11.08s/it]


In [241]:
negsteered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                10000, layer=20, steering_multiplier=-2
            )

Processing batch:   0%|          | 0/10 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch:  10%|█         | 1/10 [00:11<01:39, 11.05s/it]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch:  20%|██        | 2/10 [00:22<01:28, 11.08s/it]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch:  30%|███       | 3/10 [00:33<01:17, 11.10s/it]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch:  40%|████      | 4/10 [00:44<01:06, 11.09s/it]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch:  50%|█████     | 5/10 [00:55<00:55, 11.09s/it]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch:  60%|██████    | 6/10 [01:06<00:44, 11.08s/it]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch:  70%|███████   | 7/10 [01:17<00:33, 11.07s/it]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch:  80%|████████  | 8/10 [01:28<00:22, 11.09s/it]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch:  90%|█████████ | 9/10 [01:39<00:11, 11.09s/it]

Steering hook applied to model.layers.20 with multiplier -2
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 10/10 [01:50<00:00, 11.09s/it]


In [242]:
freqs=compl["sleep_rhymes"][example]
print(f'earlier baseline frequency of \"deep\":{freqs["deep"]}, \"keep\":{freqs["keep"]}')

frequency of "deep":8, "keep":45


In [243]:
freqs=get_last_word_distribution(baseline_text)[example]
print(f'Baseline frequency of \"deep\":{freqs["deep"]}, \"keep\":{freqs["keep"]}')

Baseline frequency of "deep":1682, "keep":3853


In [244]:
freqs=get_last_word_distribution(steered_text)[example]
print(f'Steered towards \"keep\" frequency of \"deep\":{freqs["deep"]}, \"keep\":{freqs["keep"]}')

Steered towards "keep" frequency of "deep":1704, "keep":4177


In [245]:
freqs=get_last_word_distribution(negsteered_text)[example]
print(f'Steered towards \"deep\" frequency of \"deep\":{freqs["deep"]}, \"keep\":{freqs["keep"]}')

Steered towards "deep" frequency of "deep":1943, "keep":1172


In [271]:
ALL_POSITIONS=True
for rtype in suggested_rhymes:
    if suggested_rhymes[rtype].keys():
        word1,word2=[i[0] for i in suggwords[rtype].most_common(2)]
        print('='*30)
        print(word1,word2)
        print('='*30)
        POSITIVE_PROMPTS=get_prompts(suggested_rhymes[rtype][word1])
        NEGATIVE_PROMPTS=get_prompts(suggested_rhymes[rtype][word2])

        steering_vector=get_steering_vector_fast(model, tokenizer, POSITIVE_PROMPTS, NEGATIVE_PROMPTS, layer=20)
        n=torch.norm(steering_vector, p=2)
        print("L2 norm of steering vector:",n)
        STEERING_MULTIPLIER=75/n
        STEERING_MULTIPLIER=1
        example=random.choice(line_catalog[rtype])
        while example.endswith(word1) or example.endswith(word2) or example in suggested_rhymes[rtype][word1] or example in suggested_rhymes[rtype][word2]:
            example=random.choice(line_catalog[rtype])
        prompt=get_prompts([example])
        steered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=20, steering_multiplier=STEERING_MULTIPLIER
            )
        negsteered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=20, steering_multiplier=-STEERING_MULTIPLIER
            )
        baseline_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=20, steering_multiplier=0
            )
        freqse=compl[rtype][example]
        freqsb=get_last_word_distribution(baseline_text)[example]
        freqss=get_last_word_distribution(steered_text)[example]
        freqsn=get_last_word_distribution(negsteered_text)[example]
        print('='*30)
        print(example)
        print('='*30)
        print(f'earlier baseline frequency of \"{word1}\":{freqse[word1]}, \"{word2}\":{freqse[word2]}')
        print(f'Baseline frequency of \"{word1}\":{freqsb[word1]}, \"{word2}\":{freqsb[word2]}')
        print(f'Steered towards \"{word1}\" frequency of \"{word1}\":{freqss[word1]}, \"{word2}\":{freqss[word2]}')
        print(f'Steered towards \"{word2}\" frequency of \"{word1}\":{freqsn[word1]}, \"{word2}\":{freqsn[word2]}')

light night
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(19.6250, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 1
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.72s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier -1
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.81s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.86s/it]


Children's laughter fills the room, joyously bright
earlier baseline frequency of "light":39, "night":13
Baseline frequency of "light":393, "night":144
Steered towards "light" frequency of "light":389, "night":211
Steered towards "night" frequency of "light":348, "night":104
keep deep
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(34.7500, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 1
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.43s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier -1
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.47s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.55s/it]


Courage builds bridges where fears dare to leap
earlier baseline frequency of "keep":19, "deep":31
Baseline frequency of "keep":218, "deep":335
Steered towards "keep" frequency of "keep":266, "deep":261
Steered towards "deep" frequency of "keep":189, "deep":362
all call
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(35.2500, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 1
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.54s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier -1
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.57s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.62s/it]

Courage grows slowly when we learn to crawl
earlier baseline frequency of "all":15, "call":7
Baseline frequency of "all":122, "call":70
Steered towards "all" frequency of "all":92, "call":94
Steered towards "call" frequency of "all":154, "call":38





In [274]:
LAYER=27
ALL_POSITIONS=True

for rtype in suggested_rhymes:
    if suggested_rhymes[rtype].keys():
        word1,word2=[i[0] for i in suggwords[rtype].most_common(2)]
        print('='*30)
        print(word1,word2)
        print('='*30)
        POSITIVE_PROMPTS=get_prompts(suggested_rhymes[rtype][word1])
        NEGATIVE_PROMPTS=get_prompts(suggested_rhymes[rtype][word2])

        steering_vector=get_steering_vector_fast(model, tokenizer, POSITIVE_PROMPTS, NEGATIVE_PROMPTS, layer=LAYER)
        n=torch.norm(steering_vector, p=2)
        print("L2 norm of steering vector:",n)
        STEERING_MULTIPLIER=80/n
        STEERING_MULTIPLIER=1
        example=random.choice(line_catalog[rtype])
        while example.endswith(word1) or example.endswith(word2) or example in suggested_rhymes[rtype][word1] or example in suggested_rhymes[rtype][word2]:
            example=random.choice(line_catalog[rtype])
        prompt=get_prompts([example])
        steered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=LAYER, steering_multiplier=STEERING_MULTIPLIER
            )
        negsteered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=LAYER, steering_multiplier=-STEERING_MULTIPLIER
            )
        baseline_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=LAYER, steering_multiplier=0
            )
        freqse=compl[rtype][example]
        freqsb=get_last_word_distribution(baseline_text)[example]
        freqss=get_last_word_distribution(steered_text)[example]
        freqsn=get_last_word_distribution(negsteered_text)[example]
        print('='*30)
        print(example)
        print('='*30)
        print(f'earlier baseline frequency of \"{word1}\":{freqse[word1]}, \"{word2}\":{freqse[word2]}')
        print(f'Baseline frequency of \"{word1}\":{freqsb[word1]}, \"{word2}\":{freqsb[word2]}')
        print(f'Steered towards \"{word1}\" frequency of \"{word1}\":{freqss[word1]}, \"{word2}\":{freqss[word2]}')
        print(f'Steered towards \"{word2}\" frequency of \"{word1}\":{freqsn[word1]}, \"{word2}\":{freqsn[word2]}')

light night
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(47.5000, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 1
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.64s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier -1
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.79s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 0
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.70s/it]


Love's tender touch reveals its transformative might
earlier baseline frequency of "light":41, "night":30
Baseline frequency of "light":454, "night":295
Steered towards "light" frequency of "light":428, "night":318
Steered towards "night" frequency of "light":453, "night":236
keep deep
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(73.5000, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 1
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.69s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier -1
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.71s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 0
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.77s/it]


Ivy tendrils around the ancient stone wall creep
earlier baseline frequency of "keep":47, "deep":11
Baseline frequency of "keep":430, "deep":123
Steered towards "keep" frequency of "keep":412, "deep":115
Steered towards "deep" frequency of "keep":503, "deep":97
all call
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(82., dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 1
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.44s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier -1
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.44s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 0
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.48s/it]

In shadows deep, spiders silently crawl
earlier baseline frequency of "all":27, "call":6
Baseline frequency of "all":274, "call":48
Steered towards "all" frequency of "all":322, "call":58
Steered towards "call" frequency of "all":257, "call":55





In [276]:
LAYER=20
ALL_POSITIONS=False

for rtype in suggested_rhymes:
    if suggested_rhymes[rtype].keys():
        word1,word2=[i[0] for i in suggwords[rtype].most_common(2)]
        print('='*30)
        print(word1,word2)
        print('='*30)
        POSITIVE_PROMPTS=get_prompts(suggested_rhymes[rtype][word1])
        NEGATIVE_PROMPTS=get_prompts(suggested_rhymes[rtype][word2])

        steering_vector=get_steering_vector_fast(model, tokenizer, POSITIVE_PROMPTS, NEGATIVE_PROMPTS, layer=LAYER)
        n=torch.norm(steering_vector, p=2)
        print("L2 norm of steering vector:",n)
        STEERING_MULTIPLIER=100/n
        #STEERING_MULTIPLIER=1
        example=random.choice(line_catalog[rtype])
        while example.endswith(word1) or example.endswith(word2) or example in suggested_rhymes[rtype][word1] or example in suggested_rhymes[rtype][word2]:
            example=random.choice(line_catalog[rtype])
        prompt=get_prompts([example])
        steered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=LAYER, steering_multiplier=STEERING_MULTIPLIER
            )
        negsteered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=LAYER, steering_multiplier=-STEERING_MULTIPLIER
            )
        baseline_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=LAYER, steering_multiplier=0
            )
        freqse=compl[rtype][example]
        freqsb=get_last_word_distribution(baseline_text)[example]
        freqss=get_last_word_distribution(steered_text)[example]
        freqsn=get_last_word_distribution(negsteered_text)[example]
        print('='*30)
        print(example)
        print('='*30)
        print(f'earlier baseline frequency of \"{word1}\":{freqse[word1]}, \"{word2}\":{freqse[word2]}')
        print(f'Baseline frequency of \"{word1}\":{freqsb[word1]}, \"{word2}\":{freqsb[word2]}')
        print(f'Steered towards \"{word1}\" frequency of \"{word1}\":{freqss[word1]}, \"{word2}\":{freqss[word2]}')
        print(f'Steered towards \"{word2}\" frequency of \"{word1}\":{freqsn[word1]}, \"{word2}\":{freqsn[word2]}')

light night
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(19.6250, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 5.09375
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.58s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier -5.09375
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.62s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.67s/it]


The sunset paints the sky with colors so bright
earlier baseline frequency of "light":16, "night":11
Baseline frequency of "light":131, "night":148
Steered towards "light" frequency of "light":146, "night":297
Steered towards "night" frequency of "light":227, "night":183
keep deep
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(34.7500, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 2.875
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.72s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier -2.875
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.82s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.84s/it]


Doubts into my once certain mind subtly creep
earlier baseline frequency of "keep":30, "deep":16
Baseline frequency of "keep":382, "deep":132
Steered towards "keep" frequency of "keep":402, "deep":108
Steered towards "deep" frequency of "keep":207, "deep":191
all call
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.20' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(35.2500, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 2.828125
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.67s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier -2.828125
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.68s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.20 with multiplier 0
Steering hook removed from model.layers.20


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.72s/it]

Beneath starlit skies, thoughts begin to crawl
earlier baseline frequency of "all":13, "call":24
Baseline frequency of "all":146, "call":238
Steered towards "all" frequency of "all":187, "call":321
Steered towards "call" frequency of "all":155, "call":66





In [277]:
LAYER=27
ALL_POSITIONS=False

for rtype in suggested_rhymes:
    if suggested_rhymes[rtype].keys():
        word1,word2=[i[0] for i in suggwords[rtype].most_common(2)]
        print('='*30)
        print(word1,word2)
        print('='*30)
        POSITIVE_PROMPTS=get_prompts(suggested_rhymes[rtype][word1])
        NEGATIVE_PROMPTS=get_prompts(suggested_rhymes[rtype][word2])

        steering_vector=get_steering_vector_fast(model, tokenizer, POSITIVE_PROMPTS, NEGATIVE_PROMPTS, layer=LAYER)
        n=torch.norm(steering_vector, p=2)
        print("L2 norm of steering vector:",n)
        STEERING_MULTIPLIER=100/n
        #STEERING_MULTIPLIER=1
        example=random.choice(line_catalog[rtype])
        while example.endswith(word1) or example.endswith(word2) or example in suggested_rhymes[rtype][word1] or example in suggested_rhymes[rtype][word2]:
            example=random.choice(line_catalog[rtype])
        prompt=get_prompts([example])
        steered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=LAYER, steering_multiplier=STEERING_MULTIPLIER
            )
        negsteered_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=LAYER, steering_multiplier=-STEERING_MULTIPLIER
            )
        baseline_text = generate_steered_output(
                keepdeepvector, model, tokenizer, prompt, 
                1000, layer=LAYER, steering_multiplier=0
            )
        freqse=compl[rtype][example]
        freqsb=get_last_word_distribution(baseline_text)[example]
        freqss=get_last_word_distribution(steered_text)[example]
        freqsn=get_last_word_distribution(negsteered_text)[example]
        print('='*30)
        print(example)
        print('='*30)
        print(f'earlier baseline frequency of \"{word1}\":{freqse[word1]}, \"{word2}\":{freqse[word2]}')
        print(f'Baseline frequency of \"{word1}\":{freqsb[word1]}, \"{word2}\":{freqsb[word2]}')
        print(f'Steered towards \"{word1}\" frequency of \"{word1}\":{freqss[word1]}, \"{word2}\":{freqss[word2]}')
        print(f'Steered towards \"{word2}\" frequency of \"{word1}\":{freqsn[word1]}, \"{word2}\":{freqsn[word2]}')

light night
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(47.5000, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 2.09375
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.51s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier -2.09375
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.59s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 0
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.63s/it]


Ocean waves crash with tremendous might and sight
earlier baseline frequency of "light":38, "night":17
Baseline frequency of "light":342, "night":196
Steered towards "light" frequency of "light":416, "night":253
Steered towards "night" frequency of "light":225, "night":234
keep deep
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(73.5000, dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 1.359375
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.58s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier -1.359375
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.59s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 0
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.61s/it]


In moments of truth, souls instinctively leap
earlier baseline frequency of "keep":34, "deep":16
Baseline frequency of "keep":374, "deep":152
Steered towards "keep" frequency of "keep":282, "deep":149
Steered towards "deep" frequency of "keep":445, "deep":161
all call
Calculating activations for POSITIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Calculating activations for NEGATIVE prompts...
Calculated average activation for layer 'model.layers.27' with shape: torch.Size([3584])

Steering vector computed successfully. Shape: torch.Size([3584])
L2 norm of steering vector: tensor(82., dtype=torch.bfloat16)


Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 1.21875
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.60s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier -1.21875
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.59s/it]
Processing batch:   0%|          | 0/1 [00:00<?, ?it/s]

Steering hook applied to model.layers.27 with multiplier 0
Steering hook removed from model.layers.27


Processing batch: 100%|██████████| 1/1 [00:11<00:00, 11.61s/it]

Before we run, first we must crawl
earlier baseline frequency of "all":9, "call":3
Baseline frequency of "all":109, "call":30
Steered towards "all" frequency of "all":66, "call":24
Steered towards "call" frequency of "all":99, "call":22





In [278]:
with open("suggwords.json",'w') as f:
    json.dump(suggwords,f)

In [None]:
rtype="sleep"

POSITIVE_PROMPTS=get_prompts(suggested_rhymes[rtype]["keep"])
NEGATIVE_PROMPTS=get_prompts(suggested_rhymes[rtype][word2])