In [1]:
import transformers
import torch

In [2]:
model_id = "meta-llama/Meta-Llama-3-8B"

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
config_kwargs = {
    # "trust_remote_code": True,
    # "cache_dir": None,
    # "revision": 'main',
    # "use_auth_token": None,
    "output_hidden_states": True
}

config = AutoConfig.from_pretrained(model_id, **config_kwargs)

In [6]:
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    config=config,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    revision='main',
    device_map='cuda'
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
model = model.cuda()

In [8]:
text = 'are you OK?'

In [9]:
tokens = tokenizer.encode_plus(text, 
        # add_special_tokens=True, 
        # padding='max_length', 
        truncation=True,
        max_length=1024,
        return_tensors='pt'
)

In [10]:
tokens

{'input_ids': tensor([[128000,    548,    499,  10619,     30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [11]:
mask = tokens['attention_mask']
outputs = model(input_ids=tokens['input_ids'].cuda(), attention_mask=mask.cuda())

In [12]:
outputs

CausalLMOutputWithPast(loss=None, logits=tensor([[[  6.8789,   8.8047,  12.9609,  ...,  -4.4453,  -4.4453,  -4.4453],
         [  2.0859,   0.4924,  -0.7236,  ..., -10.0625, -10.0625, -10.0625],
         [  4.0234,   2.5684,  -0.4778,  ...,  -9.8125,  -9.8125,  -9.8125],
         [  7.6172,   5.6914,   2.8477,  ...,  -9.8828,  -9.8828,  -9.8828],
         [  0.5283,  -5.4570,   1.5869,  ..., -10.2812, -10.2812, -10.2812]]],
       device='cuda:0', grad_fn=<ToCopyBackward0>), past_key_values=((tensor([[[[ 5.0781e-01,  9.3506e-01,  9.2139e-01,  ...,  1.2686e+00,
           -2.0703e-01,  2.4023e-01],
          [ 1.0918e+00, -2.3496e+00, -4.3018e-01,  ...,  1.8250e-01,
           -1.1270e+00, -9.2480e-01],
          [-4.0938e+00, -2.6738e+00, -1.9629e+00,  ..., -2.0776e-01,
           -9.7266e-01, -1.5146e+00],
          [-8.0469e+00, -4.5430e+00, -2.6836e+00,  ...,  2.5708e-01,
           -1.3750e+00, -1.4668e+00],
          [-1.7100e+00, -2.5039e+00, -3.2441e+00,  ..., -1.3342e-01,
     

In [13]:
hidden_states = list(outputs.hidden_states)

In [14]:
with torch.no_grad(): 
    last_hidden_states = hidden_states[-1].cpu().numpy()

In [15]:
last_hidden_states

array([[[ 4.008e+00, -5.020e-01, -1.993e+00, ..., -3.746e+00,
          8.413e-01,  2.703e+00],
        [-5.747e-01, -9.370e-01,  1.045e+00, ..., -1.050e+00,
         -1.649e+00,  1.181e+00],
        [-9.717e-01,  2.090e-01, -6.914e-01, ...,  1.813e-01,
         -1.819e+00,  2.326e+00],
        [-1.025e+00,  1.741e+00,  3.164e+00, ..., -1.387e-01,
          1.451e-03,  2.408e+00],
        [-5.767e-01, -7.782e-02,  3.906e+00, ...,  2.185e-01,
          1.740e+00,  2.021e+00]]], dtype=float16)

In [16]:
last_hidden_states.shape

(1, 5, 4096)

In [18]:
import numpy as np

In [22]:
p1 = np.squeeze(last_hidden_states)
p1.shape

(5, 4096)

In [24]:
p2 = np.mean(p1, axis=0)
p2.shape

(4096,)

In [25]:
p2

array([ 0.1719,  0.0867,  1.086 , ..., -0.9067, -0.1771,  2.129 ],
      dtype=float16)

In [17]:
dict(outputs)

{'logits': tensor([[[  6.8789,   8.8047,  12.9609,  ...,  -4.4453,  -4.4453,  -4.4453],
          [  2.0859,   0.4924,  -0.7236,  ..., -10.0625, -10.0625, -10.0625],
          [  4.0234,   2.5684,  -0.4778,  ...,  -9.8125,  -9.8125,  -9.8125],
          [  7.6172,   5.6914,   2.8477,  ...,  -9.8828,  -9.8828,  -9.8828],
          [  0.5283,  -5.4570,   1.5869,  ..., -10.2812, -10.2812, -10.2812]]],
        device='cuda:0', grad_fn=<ToCopyBackward0>),
 'past_key_values': ((tensor([[[[ 5.0781e-01,  9.3506e-01,  9.2139e-01,  ...,  1.2686e+00,
              -2.0703e-01,  2.4023e-01],
             [ 1.0918e+00, -2.3496e+00, -4.3018e-01,  ...,  1.8250e-01,
              -1.1270e+00, -9.2480e-01],
             [-4.0938e+00, -2.6738e+00, -1.9629e+00,  ..., -2.0776e-01,
              -9.7266e-01, -1.5146e+00],
             [-8.0469e+00, -4.5430e+00, -2.6836e+00,  ...,  2.5708e-01,
              -1.3750e+00, -1.4668e+00],
             [-1.7100e+00, -2.5039e+00, -3.2441e+00,  ..., -1.3342e-01,
  