# Compute hidden representations for a single input sequence

In [None]:
import os
from pathlib import Path

import numpy as np
import torch 

from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
available_gpus

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
# model_name_or_path = "princeton-nlp/Sheared-LLaMA-1.3B"
# model_name_or_path = "meta-llama/Llama-2-7b-hf"
# model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
model_name_or_path = "meta-llama/Meta-Llama-3-8B"
# model_name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_name_or_path = "mistralai/Mistral-7B-v0.1"
# model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1"
# model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.2"


In [None]:
# specify input text
text_id, text = 'text-1', "Montreal is the second most populous city in Canada, the tenth most populous city in North America, and the most populous city in the province of Quebec. Founded in 1642 as Ville-Marie, or 'City of Mary', it is named after Mount Royal, the triple-peaked hill around which the early city of Ville-Marie was built. The city is centred on the Island of Montreal, which obtained its name from the same origin as the city, and a few much smaller peripheral islands, the largest of which is Île Bizard. The city is 196 km (122 mi) east of the national capital, Ottawa, and 258 km (160 mi) southwest of the provincial capital, Quebec City."

In [None]:
for attention_type in ['causal', 'bidirectional']:
    # load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    attn_implementation = 'eager'
    lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, attn_implementation=attn_implementation)

    # encode input text
    ids = tokenizer.encode(text, padding="do_not_pad")
    tokens = tokenizer.convert_ids_to_tokens(ids)
    seq_len = len(tokens)
    input_ids = torch.tensor(ids).reshape(1, -1)
    position_ids = torch.arange(start=0, end=seq_len).view(1, seq_len)

    # enable bidirectional attention
    attention_mask = None
    if attention_type == "bidirectional":
        # construct attention mask (batch_size, 1, seq_len, seq_len)
        attention_mask = torch.ones(size=(1, 1, seq_len, seq_len)).to(device)

        # for some models we need to overwrite the _update_causal_mask method
        if model_name_or_path in ["princeton-nlp/Sheared-LLaMA-1.3B", "meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B"]:
            lm.model._update_causal_mask = lambda attention_mask, _: attention_mask
        # for others it's sufficient to modify the attenion_mask when using attn_implementation == 'eager'

    # run inference and return attentions as well as hidden states
    lm.to(device)
    input_ids = input_ids.to(device)
    position_ids = position_ids.to(device)
    labels = input_ids
    output = lm.forward(input_ids=input_ids, position_ids=position_ids, labels=labels, attention_mask=attention_mask, output_attentions=True, output_hidden_states=True)

    # SANITY CHECK: plot attention matrices
    # A = output.attentions[-1].squeeze()[-1].detach().cpu().float().numpy() 
    # print(np.triu(A, k=1)) # the future. this should be all zeros when using causal attention and non-zero when using bidirectional attention

    # save hidden states to disk
    data_path = f"/data/hidden_states_data/{model_name_or_path.split('/')[-1]}/{attention_type}/{text_id}"
    Path(data_path).mkdir(parents=True, exist_ok=True)    

    # save hidden states of every layer
    for layer in range(len(output.hidden_states)):
        H = output.hidden_states[layer].detach().cpu().numpy()
        file_name = f"H_layer{layer}.npy"
        with open(os.path.join(data_path, file_name), 'wb') as f:
            np.save(f, H)