In [None]:
# When using a colab notebook:
#!pip install aleph-alpha-client langchain python-dotenv

In [None]:
from aleph_alpha_client import Client, Prompt, CompletionRequest,ExplanationRequest, TextControl, EvaluationRequest, TargetGranularity
from scipy import spatial
import numpy as np
import os
from dotenv import load_dotenv

In [None]:
load_dotenv()

client = Client(token=os.getenv("AA_TOKEN"))

#### Lets learn about Attention Manipulation

In [None]:
text = "The quick brown fox jumps over the lazy dog.\nThe color of the fox is"
# Here we define a TextControl that will be used to control the attention on the prompt.
# Change the factor to 0.0 to see what happens.
control = # TODO define a TextControl that will suppress the word "brown" in the prompt (documentation: https://aleph-alpha-client.readthedocs.io/en/latest/aleph_alpha_client.html#aleph_alpha_client.TextControl)
prompt = Prompt.from_text(text, controls=[control])

request = CompletionRequest(prompt=prompt, maximum_tokens=10, stop_sequences=["."])
result = client.complete(request = request, model="luminous-extended")
print(result.completions[0].completion)

We see, that changing the attention changes the output of the model. Let's see how that can help us in different scenarios.

In [None]:
# Let's try that again. but this time, we want to get traceable explanations.
text = "The quick brown fox jumps over the lazy dog.\nThe color of the fox is"

controls = []
# create a control for each word in the prompt
for i, word in enumerate(text.split()):
    # we want to control the attention on each word
    # so we set the factor to 1.0
    #get the starting index of the word
    start = text.find(word)
    control = TextControl(start=start, length=len(word), factor=0.1)
    controls.append(control)
    
eval_scores = []
for control in controls:
    prompt = Prompt.from_text(text, controls=[control])
    
    request = EvaluationRequest(prompt=prompt, completion_expected=" brown")
    score = client.evaluate(request = request, model="luminous-extended")
    print(f"The control of '{text[control.start:control.start+control.length]}' is: {score.result['log_perplexity']}")
    


### Let's use an explantation request to find out what the model is looking at

In [None]:
exp_req = # TODO create an ExplanationRequest with the prompt and the controls https://docs.aleph-alpha.com/docs/tasks/explain/ 
response_explain = client.explain(exp_req, model="luminous-extended")

explanations = response_explain[1][0].items[0][0]

for item in explanations:
    start = item.start
    end = item.start + item.length
    print(f"""EXPLAINED TEXT: {text[start:end]}
Score: {np.round(item.score, decimals=3)}""")

In [None]:
# Read the data in the data.md file
with open("data.md", "r") as f:
    data = f.read()
    
# Split the data into a list of texts
texts = data.split("#####")

print(f"data: {data[:100]}")
print(f"texts: {texts[10][:100]}")

In [None]:
answers_prompt = f"""### Instructions: Solve the task based on the text below".

### Input:
{texts[4]}

### Task: Give me a list of countries that include social elements.

### Reponse:"""

response = client.complete(CompletionRequest(prompt=Prompt.from_text(answers_prompt), maximum_tokens=100, stop_sequences=["###"]), model="luminous-base-control-beta")
answer = response.completions[0].completion

exp_req = ExplanationRequest(Prompt.from_text(answers_prompt), answer, control_factor=0.1, prompt_granularity="paragraph", target_granularity=TargetGranularity.Complete)
response_explain = client.explain(exp_req, model="luminous-extended")

explanations = response_explain[1][0].items[0][0]

for item in explanations:
    start = item.start
    end = item.start + item.length
    print(f"""EXPLAINED TEXT: {answers_prompt[start:end]}
Score: {np.round(item.score, decimals=3)}""")

In [None]:
exp_req = ExplanationRequest(Prompt.from_text(answers_prompt), " The USA have social elements.", control_factor=0.1, prompt_granularity="paragraph", target_granularity=TargetGranularity.Complete)
response_explain = client.explain(exp_req, model="luminous-extended")

explanations = response_explain[1][0].items[0][0]

for item in explanations:
    start = item.start
    end = item.start + item.length
    print(f"""EXPLAINED TEXT: {answers_prompt[start:end]}
Score: {np.round(item.score, decimals=3)}""")