In [1]:
import engine
import torch

prompt = "Hello everyone, I am here today to"
words = [' express', ' tell', ' sell']
prompts = [prompt + word for word in words]

model = engine.HFModel('bloom-560M', dtype=torch.float32)
input_ids = model.tokenizer(prompts, return_tensors='pt').input_ids

In [None]:
logits = model.model.forward(input_ids).logits

In [None]:
out = model.model.transformer(input_ids)
transfo = out[0]

In [None]:
transfo.shape

In [None]:
with torch.no_grad():
    foo1 = model.model.lm_head(transfo)
    foo2 = model.model.lm_head(transfo[:, -1:, :])
    foo3 = model.model.lm_head(transfo[:, -5:, :])

In [None]:
print(torch.allclose(foo2[:, -1, :], foo3[:, -1, :], rtol=1e-5))
torch.max(torch.abs(foo2[:, -1, :] - foo3[:, -1, :]))

In [None]:
print(torch.allclose(foo1[:, -1, :], foo2[:, -1, :], rtol=1e-5))
torch.max(torch.abs(foo1[:, -1, :] - foo2[:, -1, :]))

In [None]:
print(torch.allclose(foo1[:, -1, :], foo3[:, -1, :], rtol=1e-5))
torch.max(torch.abs(foo1[:, -1, :] - foo3[:, -1, :]))

In [None]:
import torch
import torch.nn as nn

rand_vector = torch.rand(3, 100, 512) * 100
fc_layer = nn.Linear(512, 256)

res1 = fc_layer(rand_vector)
res2 = torch.stack([fc_layer(rand_vector[:, i, :]) for i in range(100)], 1)

print(res1.shape, res2.shape)
print(f'Allclose {torch.allclose(res1, res2)}')

diff = torch.abs(res1 - res2)
print(f'Diff: avg {diff.mean()}, min {diff.min()}, max {diff.max()}')

In [4]:
from typing import Optional, Tuple, Union
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from torch.nn import CrossEntropyLoss


def new_forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
        r"""
        labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        # Only compute LOGITS FOR LAST TOKEN!!!!!!!!!!!!!
        # Greatly help to save memory!!!!!!!!!!!!!!
        lm_logits = self.lm_head(hidden_states[:, -1:, :])

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )


prompt = "Hello everyone, I am here today to"

model1 = engine.HFModel('bloom-560M', dtype=torch.float32)
model1.forward = new_forward

model2 = engine.HFModel('bloom-560M', dtype=torch.float32)

In [6]:
test1 = model1(prompt, num_return_sequences=5, max_new_tokens=10, seed=1)
test2 = model2(prompt, num_return_sequences=5, max_new_tokens=10, seed=1)

assert test1 == test2