<a href="https://colab.research.google.com/github/STKalinowski/T5LogitSteering/blob/main/T5LogitSteering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers
!pip install accelerate

Collecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m58.0 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m74.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m56.8 MB/s[0m eta [36m0:00:0

In [2]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
import torch
from IPython.display import display, HTML
import ipywidgets as widgets
import math

if torch.cuda.is_available():
  device='cuda'
else:
  device = 'cpu'
modelName = 'google/flan-t5-large'

In [3]:
model = AutoModelForSeq2SeqLM.from_pretrained(modelName).to(device)
tokenizer = AutoTokenizer.from_pretrained(modelName)

Downloading (…)lve/main/config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

In [4]:
def sigmoid(x):
  return 1 / (1 + math.exp(-x))

In [14]:
def influencialGeneration(input_prompt, influencerPrompt, alpha, max_length=50,topP=0.95, temperature=0.5, repetition_penalty=1):
  # Encode both input_prompt & influence_dimensions
  enc_inputs = tokenizer([input_prompt, influencerPrompt], padding=True, return_tensors='pt').to(device)

  encOutput = model.encoder(input_ids=enc_inputs['input_ids'], attention_mask=enc_inputs['attention_mask'])

  # Create empty decoded inputs
  genOut = tokenizer.encode('<pad>', return_tensors='pt', add_special_tokens=False)[0].to(device)

  # Loop
  while(genOut.shape[0]< max_length+1 and int(genOut[-1]) != tokenizer.eos_token_id):
    logits = model(encoder_outputs=encOutput, decoder_input_ids=genOut.repeat(2,1).to(device)).logits
    logits = logits[:,-1]
    logits = torch.add(logits[0, :],sigmoid(genOut.shape[0]/10)*alpha*logits[1, :])

    # Apply temperature
    logits = logits / temperature

    # Apply repetition penalty
    for token in set(genOut):
      logits[token] /= repetition_penalty

    # Apply Top-P (nucleus) sampling and get the next token
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > topP
    # Shift the indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = float('-inf')

    # Sample
    next_token = torch.multinomial(torch.nn.functional.softmax(logits, dim=-1), num_samples=1)

    # Add the next token to the generated output
    genOut = torch.cat([genOut, next_token], dim=0)
  return genOut


In [15]:
slider = widgets.FloatSlider(
    value=0.5,
    min=0,
    max=0.8,
    step=0.01,
    description='Alpha:',
    orientation='horizontal'
)
# Create a variable to store the slider value
alpha = 0.5
def update_variable(change):
    global alpha
    alpha = change.new
slider.observe(update_variable, 'value')
display(slider)

FloatSlider(value=0.5, description='Alpha:', max=0.8, step=0.01)

In [22]:
max_length = 50
topP=0.98
temperature = 0.4
repetition_penalty = 1.5

input_prompt = '''Write a fairy tale about a lost library in the desert. '''
influencer_prompt = '''Use poetic language'''

res = influencialGeneration(input_prompt, influencer_prompt, alpha, max_length,topP, temperature, repetition_penalty)
output_text = tokenizer.decode(res, skip_special_tokens=True)
display(HTML(f'<p>{output_text}</p>'))