In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "google/gemma-2-2b-it"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


In [3]:
model.device

device(type='cuda', index=0)

In [None]:
model

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNo

In [6]:
model.config

Gemma2Config {
  "architectures": [
    "Gemma2ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "attn_logit_softcapping": 50.0,
  "bos_token_id": 2,
  "cache_implementation": "hybrid",
  "dtype": "float16",
  "eos_token_id": [
    1,
    107
  ],
  "final_logit_softcapping": 30.0,
  "head_dim": 256,
  "hidden_act": "gelu_pytorch_tanh",
  "hidden_activation": "gelu_pytorch_tanh",
  "hidden_size": 2304,
  "initializer_range": 0.02,
  "intermediate_size": 9216,
  "layer_types": [
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention",
    "full_attention",
    "sliding_attention

In [7]:
model.parameters

<bound method Module.parameters of Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_

In [11]:
print([p for p in model.parameters()])

[Parameter containing:
tensor([[ 0.0352, -0.0229,  0.0815,  ...,  0.0211,  0.0527, -0.0352],
        [-0.0200,  0.0522, -0.0303,  ...,  0.0028, -0.0240, -0.0173],
        [-0.0002, -0.0059,  0.0222,  ...,  0.0152, -0.0074, -0.0119],
        ...,
        [ 0.0371, -0.0237,  0.0486,  ...,  0.0075,  0.0064,  0.0134],
        [ 0.0117, -0.0410,  0.0253,  ..., -0.0024,  0.0454,  0.0219],
        [ 0.0361, -0.0262,  0.0786,  ...,  0.0215,  0.0525, -0.0366]],
       device='cuda:0', dtype=torch.float16, requires_grad=True), Parameter containing:
tensor([[-0.0070,  0.0149, -0.0193,  ...,  0.0148, -0.0278,  0.0087],
        [-0.0022,  0.0156, -0.0254,  ...,  0.0205, -0.0100,  0.0049],
        [ 0.0017,  0.0034,  0.0062,  ..., -0.0009,  0.0136, -0.0012],
        ...,
        [ 0.0030,  0.0276,  0.0087,  ...,  0.0175, -0.0114, -0.0013],
        [ 0.0042, -0.0025,  0.0055,  ...,  0.0070,  0.0003,  0.0060],
        [-0.0100, -0.0110,  0.0214,  ..., -0.0153,  0.0339, -0.0184]],
       device='cuda:0