# # Text Embedding Vector

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from getpass import getpass
import gc

LLM_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_PATH = "target_embeddings_grid.pt"

try:
    token = getpass("Hugging Face Access Token: ")
    login(token=token)
    print("Hugging Face login successful")
except Exception as e:
    print(f"{e}")

print(f"Model to load: {LLM_MODEL_NAME}")
print(f"Device: {DEVICE}")

llm_model = None
llm_tokenizer = None

try:
    llm_model = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME, 
        torch_dtype=torch.bfloat16
    ).to(DEVICE)
    
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    llm_model.eval()
except Exception as e:
    print(f"{e}")

In [None]:
if llm_model is not None and llm_tokenizer is not None:   
    text_normal = "a photo of a normal region on a grid"
    text_anomaly = "a photo of an anomaly region on a grid"
    
    print(f"Target text (Normal): '{text_normal}'")
    print(f"Target text (Anomaly): '{text_anomaly}'")

    try:
        with torch.no_grad():
            normal_tokens = llm_tokenizer(text_normal, return_tensors="pt").input_ids.to(DEVICE)
            embedding_normal = llm_model.get_input_embeddings()(normal_tokens).mean(dim=1).squeeze(0)

            anomaly_tokens = llm_tokenizer(text_anomaly, return_tensors="pt").input_ids.to(DEVICE)
            embedding_anomaly = llm_model.get_input_embeddings()(anomaly_tokens).mean(dim=1).squeeze(0)

        target_embeddings = {
            "normal": embedding_normal.cpu(),
            "anomaly": embedding_anomaly.cpu()
        }
        torch.save(target_embeddings, SAVE_PATH)

    except Exception as e:
        print(f"{e}")
    
    del llm_model
    del llm_tokenizer
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print("\nLLM has been unloaded from memory")
else:
    print("\nModel was not loaded successfully")