# Notebook 1. Transformer Encoder Model

In [35]:
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

In [36]:
# https://huggingface.co/google/gemma-3-1b-it
gemma3_model_name = "google/gemma-3-1b-it"

## Exercise 1. Have a look at the Gemma-3 vocabulary and tokenizer

In [37]:
# Load Gemma3 Tokenizer
gemma3_tokenizer = AutoTokenizer.from_pretrained(gemma3_model_name)

# Load Gemma3 Vocabulary
gemma3_vocab = gemma3_tokenizer.get_vocab()

# Convert Vocabulary to Pandas Dataframe for inspection
gemma3_vocab_df = pd.Series(gemma3_vocab.keys(), index=gemma3_vocab.values()).sort_index()

In [38]:
# How many token types are there in the vocabulary?
gemma3_tokenizer.vocab_size

262144

In [39]:
gemma3_vocab_df

# <pad> <eos> ... are special tokens that are not part of the actual text but are used for model training.
# <unused...> tokens are preserved for future use, e.g. fine-tuning for domain specific data.

0                      <pad>
1                      <eos>
2                      <bos>
3                      <unk>
4                     <mask>
                 ...        
262140          <unused6238>
262141          <unused6239>
262142          <unused6240>
262143          <unused6241>
262144    <image_soft_token>
Length: 262145, dtype: object

In [40]:
# View some tokens
gemma3_vocab_df[1000:1010]

# The underscore symbol indicates a "whitespace"

1000       put
1001       ▁di
1002       erm
1003    ▁about
1004       ays
1005      text
1006       ▁am
1007       ade
1008       ▁et
1009      ▁est
dtype: object

In [41]:
# View more tokens
gemma3_vocab_df[144620: 144630]

144620        ▁VHF
144621       ▁Citi
144622       ▁mede
144623      ancang
144624      ▁Beech
144625    ▁Signing
144626          你好
144627         сай
144628      ▁trasp
144629      ▁varic
dtype: object

In [42]:
# Let's ask a question
example_question = "Jinshuai loves data science."

In [43]:
# Tokenize the question text and inspect the tokens we got:
example_tokens = gemma3_tokenizer(example_question)

# Convert token ids to token texts for inspection
gemma3_tokenizer.convert_ids_to_tokens(example_tokens["input_ids"])
# <bos> marks the beginning of a sequence. Underscore "▁" marks "whitespaces"

['<bos>', 'Jin', 'shu', 'ai', '▁loves', '▁data', '▁science', '.']

## Exercise 2. Explore the output of Gemma-3 model

In [44]:
# Load Gemma3 Model
gemma3_model = AutoModelForCausalLM.from_pretrained(gemma3_model_name)

In [45]:
# Have a look of the model
gemma3_model

# The model has a very complex structure, but we don't have to understand the very detailed implementation. Just treat it as a black box.

Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear(in_features=1152, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1152, out_features=256, bias=False)
          (v_proj): Linear(in_features=1152, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1152, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear(in_features=1152, out_features=6912, bias=False)
          (up_proj): Linear(in_features=1152, out_features=6912, bias=False)
          (down_proj): Linear(in_features=6912, out_features=1152, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): Gemma3RMSNorm((1152,), e

In [46]:
# How many parameters are there?
sum(p.numel() for p in gemma3_model.parameters())

999885952

In [47]:
# A function to make a single run of the model
def predict(input_ids):
    with torch.no_grad(): # Disable Pytorch's gradient calculation
        # Run the model for one step
        model_output = gemma3_model(
            input_ids=torch.tensor([input_ids], dtype=torch.long), # Convert input_id to Pytorch's Tensor format
            attention_mask=torch.tensor([[1] * len(input_ids)], dtype=torch.long) # Attention mask. *Has no effect in this case.
        )
        # Note that here we only take part of the model output
        # The actual output is a bit more complex, but we can ignore it
        return model_output.logits[0,-1,:]

example_model_output = predict(example_tokens["input_ids"])
example_model_output_df = pd.DataFrame({
    "token_id": range(len(example_model_output)),
    "token": gemma3_tokenizer.convert_ids_to_tokens(range(len(example_model_output))),
    "scores": example_model_output,
})

In [48]:
example_model_output_df.sort_values("scores", ascending=False)

# The output represents the model's predicted scores for each token in the vocabulary being the next token.

# "\n" is a new line symbol. The model thinks the next token should be two new lines.

Unnamed: 0,token_id,token,scores
1293,1293,▁He,18.738159
108,108,\n\n,18.094341
138,138,▁▁,16.264725
107,107,\n,15.961935
668,668,▁he,15.702940
...,...,...,...
250244,250244,ꗕ,-14.405449
254551,254551,􀍷,-14.405522
214405,214405,imsuti,-14.457421
199406,199406,▁അപകട,-14.646069


## Exercise 3. Generate a complete answer with greedy search

Check the stop token for Gemma-3

In [49]:
# gemma3_model.config.eos_token_id, gemma3_tokenizer.convert_ids_to_tokens(gemma3_model.config.eos_token_id)

Greedy search: always select the token with the highest scores

In [50]:
max_token = 50
context_token_ids = example_tokens["input_ids"]

step = 0

while step<max_token:

    # Compute the token scores
    model_step_output = predict(context_token_ids)

    # Select the token with the highest probability
    next_token_id = torch.argmax(model_step_output).item()

    # Add the generated token to the context
    context_token_ids.append(next_token_id)

    # Terminate the generation process if an EOS token is generated.
    if next_token_id in gemma3_model.config.eos_token_id:
        print("Found EOS token, generation stopped.")
        break

    # Move one step forward
    step += 1
    if step == max_token:
        print("Max tokens reached, generation stopped.")

Max tokens reached, generation stopped.


In [51]:
# Check the generated context
context_tokens = gemma3_tokenizer.convert_ids_to_tokens(context_token_ids)
context_tokens

['<bos>',
 'Jin',
 'shu',
 'ai',
 '▁loves',
 '▁data',
 '▁science',
 '.',
 '▁He',
 "'",
 's',
 '▁been',
 '▁working',
 '▁on',
 '▁a',
 '▁project',
 '▁to',
 '▁predict',
 '▁customer',
 '▁churn',
 '▁for',
 '▁a',
 '▁tele',
 'communications',
 '▁company',
 '.',
 '\n\n',
 'He',
 "'",
 's',
 '▁using',
 '▁Python',
 '▁and',
 '▁Pandas',
 '▁to',
 '▁clean',
 '▁and',
 '▁prepare',
 '▁the',
 '▁data',
 ',',
 '▁and',
 '▁then',
 '▁he',
 "'",
 's',
 '▁using',
 '▁Sc',
 'ikit',
 '-',
 'learn',
 '▁to',
 '▁build',
 '▁a',
 '▁model',
 '.',
 '▁He',
 "'"]

In [52]:
# Put the tokens together
print("".join(context_tokens).replace("▁", " "))

<bos>Jinshuai loves data science. He's been working on a project to predict customer churn for a telecommunications company.

He's using Python and Pandas to clean and prepare the data, and then he's using Scikit-learn to build a model. He'
