In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers

import torch

from tqdm import tqdm

In [2]:
device = torch.device("cuda")

In [2]:
model_name = "tiiuae/falcon-7b-instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
model.transformer

RWModel(
  (word_embeddings): Embedding(65024, 4544)
  (h): ModuleList(
    (0-31): 32 x DecoderLayer(
      (input_layernorm): LayerNorm((4544,), eps=1e-05, elementwise_affine=True)
      (self_attention): Attention(
        (maybe_rotary): RotaryEmbedding()
        (query_key_value): Linear(in_features=4544, out_features=4672, bias=False)
        (dense): Linear(in_features=4544, out_features=4544, bias=False)
        (attention_dropout): Dropout(p=0.0, inplace=False)
      )
      (mlp): MLP(
        (dense_h_to_4h): Linear(in_features=4544, out_features=18176, bias=False)
        (act): GELU(approximate='none')
        (dense_4h_to_h): Linear(in_features=18176, out_features=4544, bias=False)
      )
    )
  )
  (ln_f): LayerNorm((4544,), eps=1e-05, elementwise_affine=True)
)

In [4]:
class FeatureExtractor(torch.nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.word_embeddings = model.transformer.word_embeddings

        tranformer_blocks = model.transformer.h

        used_blocks = []

        for _, block in enumerate(tranformer_blocks):
            used_blocks += [block]
            if _ > 20:
                break

        self.used_blocks = torch.nn.ModuleList(used_blocks)


    def forward(self, x):
        output = self.word_embeddings(x)

        for block in self.used_blocks:
            if type(output) is tuple:
                output = block(output[0], alibi=None, attention_mask=torch.ones((1, len(x))))
            else:
                output = block(output, alibi=None, attention_mask=torch.ones((1, len(x))))

        return output

In [5]:
feature_extractor = FeatureExtractor()
feature_extractor

FeatureExtractor(
  (word_embeddings): Embedding(65024, 4544)
  (used_blocks): ModuleList(
    (0-21): 22 x DecoderLayer(
      (input_layernorm): LayerNorm((4544,), eps=1e-05, elementwise_affine=True)
      (self_attention): Attention(
        (maybe_rotary): RotaryEmbedding()
        (query_key_value): Linear(in_features=4544, out_features=4672, bias=False)
        (dense): Linear(in_features=4544, out_features=4544, bias=False)
        (attention_dropout): Dropout(p=0.0, inplace=False)
      )
      (mlp): MLP(
        (dense_h_to_4h): Linear(in_features=4544, out_features=18176, bias=False)
        (act): GELU(approximate='none')
        (dense_4h_to_h): Linear(in_features=18176, out_features=4544, bias=False)
      )
    )
  )
)

In [7]:
encoding = tokenizer("Hello! How are you doing today?", truncation=True, return_tensors="pt")

output = feature_extractor(encoding["input_ids"])[0]
print(output.shape)
output = torch.mean(output, dim=1)
print(output.shape)

torch.Size([1, 8, 4544])
torch.Size([1, 4544])


In [21]:
encoding = tokenizer("Hello! How are you doing today?", truncation=True, return_tensors="pt")

max_length = 40
start_input_size = len(encoding["input_ids"][0])
kwargs = {'max_length': 40, 'do_sample': True, 'top_k': 10, 'num_return_sequences': 1, 'eos_token_id': 11}

# greedy search not sampling
for i in tqdm(range(max_length-start_input_size)):
    inputs = model.prepare_inputs_for_generation(**encoding, **kwargs)
    output = model(**inputs, return_dict=True, output_attentions=model.generation_config.output_attentions, output_hidden_states=model.generation_config.output_hidden_states)
    next_token_idx = torch.argmax(output["logits"][:, -1, :], -1)
    encoding["input_ids"] = torch.hstack((encoding["input_ids"], next_token_idx.view(1, -1)))
    encoding["token_type_ids"] = torch.hstack((encoding["token_type_ids"], torch.zeros((1, 1))))
    encoding["attention_mask"] = torch.hstack((encoding["attention_mask"], torch.ones((1, 1))))

token_ids = encoding["input_ids"][0]
tokenizer.decode(list(token_ids))

100%|██████████| 32/32 [01:39<00:00,  3.12s/it]


"Hello! How are you doing today?\nI'm doing well, thank you. I'm just getting ready to head out for a walk. Do you have any plans for the day?"

In [31]:
encoding = tokenizer("Hello! How are you doing today?", truncation=True, return_tensors="pt")



RWForCausalLM(
  (transformer): None
  (lm_head): Linear(in_features=4544, out_features=65024, bias=False)
)


TypeError: 'method' object is not iterable