In [6]:
from transformers import LlamaTokenizer, LlamaModel
import torch
from tqdm import tqdm
import numpy as np

In [2]:
tokenizer = LlamaTokenizer.from_pretrained('../llama-2-7b')
model = LlamaModel.from_pretrained('../llama-2-7b')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at ../llama-2-7b were not used when initializing LlamaModel: ['lm_head.weight']
- This IS expected if you are initializing LlamaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlamaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LlamaModel were not initialized from the model checkpoint at ../llama-2-7b and are newly initialized: ['model.layers.26.self_attn.rotary_emb.inv_freq', 'model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.15

In [3]:
def get_llama_embedding(sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    # Extract the hidden states (last layer)
    last_hidden_state = outputs.last_hidden_state
    # Average the hidden states to get sentence embedding
    sentence_embedding = torch.mean(last_hidden_state, dim=1).squeeze().numpy()

    return sentence_embedding


def assign_prompt_embedding(graph, node_type):
    prompted_embedding = []
    for prompt in tqdm(graph[node_type]['prompt']):
        prompted_embedding.append(get_llama_embedding(prompt))
    graph[node_type].prompt_embedding = torch.tensor(np.array(prompted_embedding), dtype=torch.float32)

In [10]:
graph = torch.load('../processed_data/heterogeneous_graph_768_no_med_with_prompt_10_imbalanced.pt')

#### Be careful of running the following cells, it takes a long time to run.

In [None]:
assign_prompt_embedding(graph, 'user')

In [13]:
assign_prompt_embedding(graph, 'food')

  4%|▎         | 372/10503 [19:50<9:00:24,  3.20s/it] 


KeyboardInterrupt: 

In [8]:
assign_prompt_embedding(graph, 'ingredient')

100%|██████████| 3458/3458 [46:52<00:00,  1.23it/s] 


In [7]:
assign_prompt_embedding(graph, 'category')

100%|██████████| 174/174 [01:54<00:00,  1.52it/s]


In [5]:
assign_prompt_embedding(graph, 'habit')

100%|██████████| 64/64 [00:42<00:00,  1.51it/s]
  graph[node_type].prompt_embedding = torch.tensor(prompted_embedding, dtype=torch.float32)


In [9]:
torch.save(graph, '../processed_data/heterogeneous_graph_768_no_med_with_prompt_10_imbalanced.pt')