# Using attention to produce relevant subsets

This is a version that uses `bert-base-uncased` to build embeddings.

In [31]:
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

def get_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze() 

object_names = ["car", "train", "airplane", "bus"]
object_descriptions = [
    "A vehicle that runs on roads",
    "A vehicle that runs on tracks",
    "A vehicle that flies in the sky",
    "A large road vehicle that carries many passengers"
]

query_text = "tracks?"

name_embeddings = torch.stack([get_embedding(name) for name in object_names])
description_embeddings = torch.stack([get_embedding(desc) for desc in object_descriptions])
query_embedding = get_embedding(query_text).unsqueeze(0) 

values = torch.cat((name_embeddings, description_embeddings), dim=1)

scores = query_embedding @ name_embeddings.T
attention_weights = F.softmax(scores, dim=-1)

output = attention_weights @ values

sorted_indices = torch.argsort(attention_weights[0], descending=True)

print("Sorted object indices based on attention:", sorted_indices.tolist())
print("Most relevant object to the query:", object_names[sorted_indices[0]])


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Sorted object indices based on attention: [3, 2, 0, 1]
Most relevant object to the query: bus


This is a version that uses `text-embedding-ada-002` to build embeddings.

In [6]:
import torch
import torch.nn.functional as F
import openai

from dotenv import load_dotenv

load_dotenv('.env')
openai.api_key = os.environ.get('OPENAI_API_KEY')

def get_embedding(text):
    return torch.as_tensor(openai.Embedding.create(input=text, engine="text-embedding-ada-002")['data'][0]['embedding'])

object_names = ["car", "train", "airplane", "bus"]
object_descriptions = [
    "A vehicle that runs on roads",
    "A vehicle that runs on tracks",
    "A vehicle that flies in the sky",
    "A large road vehicle that carries many passengers"
]

query_text = "tracks?"

name_embeddings = torch.stack([get_embedding(name) for name in object_names])
description_embeddings = torch.stack([get_embedding(desc) for desc in object_descriptions])
query_embedding = get_embedding(query_text).unsqueeze(0) 

values = torch.cat((name_embeddings, description_embeddings), dim=1)

scores = query_embedding @ name_embeddings.T
attention_weights = F.softmax(scores, dim=-1)

output = attention_weights @ values

sorted_indices = torch.argsort(attention_weights[0], descending=True)

sorted_objects = [object_names[i] for i in sorted_indices.tolist()]
sorted_descriptions = [object_descriptions[i] for i in sorted_indices.tolist()]

print("Query: ", query_text)
print("Most relevant object to the query:", object_names[sorted_indices[0]])
print("Description of the most relevant object:", sorted_descriptions[0])
print()
print("Sorted object indices based on attention:", sorted_indices.tolist())
print("Sorted objects based on attention:", sorted_objects)
print("Sorted descriptions based on attention:", sorted_descriptions)
print("Attention weights: ", attention_weights[0])

Query:  tracks?
Most relevant object to the query: train
Description of the most relevant object: A vehicle that runs on tracks

Sorted object indices based on attention: [1, 3, 0, 2]
Sorted objects based on attention: ['train', 'bus', 'car', 'airplane']
Sorted descriptions based on attention: ['A vehicle that runs on tracks', 'A large road vehicle that carries many passengers', 'A vehicle that runs on roads', 'A vehicle that flies in the sky']
Attention weights:  tensor([0.2495, 0.2566, 0.2442, 0.2497])


Longer example

In [7]:
col_names =[
    'Report Number',
    'Supplemental Number',
    'Accident Year',
    'Accident Date/Time',
    'Operator ID',
    'Operator Name',
    'Pipeline/Facility Name',
    'Pipeline Location',
    'Pipeline Type',
    'Liquid Type',
    'Liquid Subtype',
    'Liquid Name',
    'Accident City',
    'Accident County',
    'Accident State',
    'Accident Latitude',
    'Accident Longitude',
    'Cause Category',
    'Cause Subcategory',
    'Unintentional Release (Barrels)',
    'Intentional Release (Barrels)',
    'Liquid Recovery (Barrels)',
    'Net Loss (Barrels)',
    'Liquid Ignition',
    'Liquid Explosion',
    'Pipeline Shutdown',
    'Shutdown Date/Time',
    'Restart Date/Time',
    'Public Evacuations',
    'Operator Employee Injuries',
    'Operator Contractor Injuries',
    'Emergency Responder Injuries',
    'Other Injuries',
    'Public Injuries',
    'All Injuries',
    'Operator Employee Fatalities',
    'Operator Contractor Fatalities',
    'Emergency Responder Fatalities',
    'Other Fatalities',
    'Public Fatalities',
    'All Fatalities',
    'Property Damage Costs',
    'Lost Commodity Costs',
    'Public/Private Property Damage Costs',
    'Emergency Response Costs',
    'Environmental Remediation Costs',
    'Other Costs',
    'All Costs'
 ]

col_descriptions = [
    "Unique ID for the accident report.",
    "Additional reference numbers, if any.",
    "Year the accident occurred.",
    "Date and time of the accident.",
    "Unique identifier for the operating entity.",
    "Name of the entity operating the pipeline.",
    "Name of the affected pipeline or facility.",
    "Geographic location of the pipeline.",
    "Type of pipeline (e.g., gas, oil).",
    "Type of liquid involved in the accident.",
    "More specific classification of the liquid.",
    "Brand or specific name of the liquid.",
    "City where the accident occurred.",
    "County where the accident occurred.",
    "State where the accident occurred.",
    "Latitude coordinate of the accident.",
    "Longitude coordinate of the accident.",
    "Broad category of accident cause.",
    "More specific cause details.",
    "Volume of unintentional liquid release in barrels.",
    "Volume of intentional liquid release in barrels.",
    "Volume of liquid recovered in barrels.",
    "Net liquid lost, calculated as released minus recovered.",
    "Whether the liquid caught fire.",
    "Whether an explosion occurred.",
    "Whether the pipeline was shut down.",
    "Date and time of the pipeline shutdown.",
    "Date and time the pipeline was restarted.",
    "Number of public evacuations, if any.",
    "Number of operator employee injuries.",
    "Number of contractor injuries.",
    "Number of emergency responder injuries.",
    "Injuries to others not in above categories.",
    "Number of injuries to the public.",
    "Total number of injuries from the accident.",
    "Number of fatalities among operator employees.",
    "Number of fatalities among contractors.",
    "Number of fatalities among emergency responders.",
    "Number of other fatalities not in above categories.",
    "Number of fatalities among the public.",
    "Total number of fatalities from the accident.",
    "Costs incurred due to property damage.",
    "Costs of lost commodities, e.g., oil, gas.",
    "Costs for public or private property damage.",
    "Costs of emergency response efforts.",
    "Costs for environmental cleanup.",
    "Miscellaneous additional costs.",
    "Total costs incurred due to the accident."
]


In [8]:
import torch
import torch.nn.functional as F
import openai

def get_embedding(text):
    return torch.as_tensor(openai.Embedding.create(input=text, engine="text-embedding-ada-002")['data'][0]['embedding'])

object_names = col_names
object_descriptions = col_descriptions

query_text = "when and where did the accident happen?"

name_embeddings = torch.stack([get_embedding(name) for name in object_names])
description_embeddings = torch.stack([get_embedding(desc) for desc in object_descriptions])
query_embedding = get_embedding(query_text).unsqueeze(0) 

values = torch.cat((name_embeddings, description_embeddings), dim=1)

scores = query_embedding @ name_embeddings.T
attention_weights = F.softmax(scores, dim=-1)

output = attention_weights @ values

sorted_indices = torch.argsort(attention_weights[0], descending=True)

sorted_objects = [object_names[i] for i in sorted_indices.tolist()]
sorted_descriptions = [object_descriptions[i] for i in sorted_indices.tolist()]

print("Query: ", query_text)
print()
print("Sorted object indices based on attention:", sorted_indices.tolist())
print("Sorted objects based on attention:", sorted_objects)
print("Sorted descriptions based on attention:", sorted_descriptions)
print("Attention weights: ", attention_weights[0].sort(reverse=True))

Query:  when and where did the accident happen?

Sorted object indices based on attention: [3, 2, 14, 15, 12, 16, 13, 38, 29, 35, 36, 37, 30, 40, 31, 39, 34, 32, 33, 27, 26, 24, 19, 0, 41, 28, 7, 23, 20, 25, 44, 4, 6, 21, 47, 17, 5, 9, 43, 22, 42, 11, 46, 8, 18, 1, 45, 10]
Sorted objects based on attention: ['Accident Date/Time', 'Accident Year', 'Accident State', 'Accident Latitude', 'Accident City', 'Accident Longitude', 'Accident County', 'Other Fatalities', 'Operator Employee Injuries', 'Operator Employee Fatalities', 'Operator Contractor Fatalities', 'Emergency Responder Fatalities', 'Operator Contractor Injuries', 'All Fatalities', 'Emergency Responder Injuries', 'Public Fatalities', 'All Injuries', 'Other Injuries', 'Public Injuries', 'Restart Date/Time', 'Shutdown Date/Time', 'Liquid Explosion', 'Unintentional Release (Barrels)', 'Report Number', 'Property Damage Costs', 'Public Evacuations', 'Pipeline Location', 'Liquid Ignition', 'Intentional Release (Barrels)', 'Pipeline Shu

TypeError: sort() received an invalid combination of arguments - got (reverse=bool, ), but expected one of:
 * (*, bool stable, int dim, bool descending)
 * (int dim, bool descending)
 * (*, bool stable, name dim, bool descending)
 * (name dim, bool descending)


Asking GPT4

https://chat.openai.com/share/42e28e9f-20d9-4614-acfe-3de4097e8c9c