Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samplers in Gemma model #1588

Open
mostafamdy opened this issue Apr 20, 2024 · 6 comments
Open

Samplers in Gemma model #1588

mostafamdy opened this issue Apr 20, 2024 · 6 comments
Assignees
Labels

Comments

@mostafamdy
Copy link

Describe the bug

Hi
I am trying sampler example here https://keras.io/examples/generative/text_generation_gpt/ in Gemma

the preprocessor in Gemma return dictionary of token_ids and padding_mask but sampler not accept dictionary

preprocessor=gemma_lm.preprocessor

tokenizer = preprocessor.tokenizer

pre=keras_nlp.models.GemmaPreprocessor(
    tokenizer, sequence_length=1200,
)
test_d=pre([data[0]])

test_d
{'token_ids': Array([[  2, 108, 106, ...,   0,   0,   0]], dtype=int32),
 'padding_mask': Array([[ True,  True,  True, ..., False, False, False]], dtype=bool)}

Sampler code

def next(prompt, cache, index):
    print(prompt)
    logits = gemma_lm(prompt)[:, index - 1, :]
    print(logits)
    # Ignore hidden states for now; only needed for contrastive search.
    hidden_states = None
    return logits, hidden_states, cache

sampler = keras_nlp.samplers.GreedySampler()
output_tokens = sampler(
    next=next,
    prompt=test_d,
    index=1,
)
txt = tokenizer.detokenize(output_tokens)
print(f"Greedy search generated text: \n{txt}\n")

Error

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[195], line 2
      1 sampler = keras_nlp.samplers.GreedySampler()
----> 2 output_tokens = sampler(
      3     next=next,
      4     prompt=p,
      5     index=1,
      6 )
      7 txt = tokenizer.detokenize(output_tokens)
      8 print(f"Greedy search generated text: \n{txt}\n")

File /usr/local/lib/python3.10/site-packages/keras_nlp/src/samplers/sampler.py:98, in Sampler.__call__(self, next, prompt, cache, index, mask, stop_token_ids, hidden_states, model)
     87 def __call__(
     88     self,
     89     next,
   (...)
     96     model=None,
     97 ):
---> 98     max_length = ops.shape(prompt)[-1]
     99     # Make sure `max_length` and `index` are the same dtype.
    100     index = ops.cast(index, "int32")

File /usr/local/lib/python3.10/site-packages/keras/src/ops/core.py:438, in shape(x)
    436 if any_symbolic_tensors((x,)):
    437     return x.shape
--> 438 return backend.core.shape(x)

File /usr/local/lib/python3.10/site-packages/keras/src/backend/jax/core.py:95, in shape(x)
     92 def shape(x):
     93     # This will work as long as we disallow
     94     # dynamic shapes in JAX.
---> 95     return x.shape

AttributeError: 'dict' object has no attribute 'shape'
@github-actions github-actions bot added the Gemma Gemma model specific issues label Apr 20, 2024
@SuryanarayanaY
Copy link

Hi @mostafamdy ,

It seems the sampler expects the prompt in the form a List. The {{call_args}} in base Sampler class not defined properly IMO.

Call arguments:
{{call_args}}

@tirthasheshpatel
Copy link
Contributor

Hi @mostafamdy, it seems like the guide is outdated. Thanks for bringing this up! You can refer to the Sampler API docs or the "Example Use" section on the Kaggle model card. For your usecase, there's now a simpler API for plugging-in different samplers:

import keras_nlp

model = keras_nlp.models.GemmaCausalLM('gemma_2b_en')

# Tell KerasNLP to use a "greddy" sampler. Other options are "top_k", "top_p", etc.
# See https://keras.io/api/keras_nlp/samplers/ for more info
model.compile(sampler="greedy")
output = model.generate("What is Keras?", max_length=50)

# You can also initialize a sampler to configure it for your usecase
sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["What is Keras?"])

Does this answer your question?

@mostafamdy
Copy link
Author

Thanks @tirthasheshpatel
I want to use sampler in custom loss function is it possible? 😅

Is this code correct?

def custom_loss(y_true,y_pred):
    logits=y_pred
    temperature=1.0
    logits = ops.cast(logits, "float32")
    probabilities = keras.activations.softmax(logits / temperature)
    next_token=ops.argmax(probabilities, axis=-1)
    print(next_token)
    
    txt=tokenizer.detokenize(next_token)
    print(f"Greedy search generated text: \n{txt}\n")
 

@mostafamdy
Copy link
Author

I tried to call gemma model like this

preprocessor=keras_nlp.models.GemmaPreprocessor(
    tokenizer, sequence_length=SEQ_LEN,
)
model_out = gemma_lm(preprocessor([data[0]]))

and then passed model_out to custom_loss function to generate text

def custom_loss(y_true,y_pred):
    logits=y_pred
    temperature=1.0
    logits = ops.cast(logits, "float32")
    probabilities = keras.activations.softmax(logits / temperature)
    next_token=ops.argmax(probabilities, axis=-1)    
    txt=tokenizer.detokenize(next_token)
    print(f"Greedy search generated text: \n{txt}\n")

custom_loss("y_true",model_out)

But the output is different from gemma_lm.generate()

@tirthasheshpatel
Copy link
Contributor

Ah OK. Your code looks good to me. You seem to be printing out the next token predictions for each input sequence which is why I guess the outputs are different. Can you check if this code generates the right output:

import keras
from keras import ops
import keras_nlp


model = keras_nlp.models.GemmaCausalLM.from_preset('gemma_2b_en')
preprocessor = model.preprocessor
tokenizer = preprocessor.tokenizer
backbone = model.backbone


def loss_fn(y_true, y_pred, prompt=None, index=None):
    logits = y_pred
    temperature = 1.0
    logits = ops.cast(logits, "float32")

    # Compute probs and next token value
    probabilities = ops.softmax(logits[:, index, :], axis=-1)
    next_token = ops.argmax(probabilities, axis=-1)

    # Update the prompt
    prompt_tokens = tokenizer.tokenize(prompt)
    updated_prompt_tokens = ops.concatenate([prompt_tokens, next_token[..., None]], axis=-1)
    updated_prompt = tokenizer.detokenize(updated_prompt_tokens)

    # Print the updated prompt
    print(f"The updated prompt is: {updated_prompt}")


# Get the inputs
prompt = ["The quick brown"]
train_data = preprocessor(prompt, sequence_length=10)
index = ops.min(ops.sum(train_data[0]['padding_mask'], axis=-1)) - 2

# Evaluate the loss function
loss_fn(train_data[1], model(train_data[0]), prompt=prompt, index=index)
# The updated prompt is: [b'The quick brown fox']

# Check outputs match
model.generate(prompt, max_length=5)
# ['The quick brown fox']

@mostafamdy
Copy link
Author

Thank you so much ❤️
I tried this code and it's working well, but I have a misunderstanding.
I tried to change the index of logits

for i in range(10):
    probabilities = ops.softmax(logits[:, i, :], axis=-1)
    next_token = ops.argmax(probabilities, axis=-1)
    updated_prompt = tokenizer.detokenize(next_token)
    print(updated_prompt)

the output was :
and fox fox,' the, the
When using generate with a sequence length of 10, I received:
The quick brown fox jumps over the sleeping dog

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants