In [1]:
from mlc_chat.compiler import MODEL_PRESETS, MODELS

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_info = MODELS["gpt2"]
config = model_info.config.from_dict(MODEL_PRESETS["gpt2"])
model = model_info.model(config)
mod, named_params = model.export_tvm(
    spec=model.get_default_spec(),  # type: ignore
)
# mod.show(black_format=False)
for name, param in named_params:
    print(name, param.shape, param.dtype)

transformer.wte.weight [50257, 768] float32
transformer.wpe.weight [1024, 768] float32
transformer.h.0.ln_1.weight [768] float32
transformer.h.0.ln_1.bias [768] float32
transformer.h.0.attn.c_attn.weight [2304, 768] float32
transformer.h.0.attn.c_attn.bias [2304] float32
transformer.h.0.attn.c_proj.weight [768, 768] float32
transformer.h.0.attn.c_proj.bias [768] float32
transformer.h.0.ln_2.weight [768] float32
transformer.h.0.ln_2.bias [768] float32
transformer.h.0.mlp.c_fc.weight [3072, 768] float32
transformer.h.0.mlp.c_fc.bias [3072] float32
transformer.h.0.mlp.c_proj.weight [768, 3072] float32
transformer.h.0.mlp.c_proj.bias [768] float32
transformer.h.1.ln_1.weight [768] float32
transformer.h.1.ln_1.bias [768] float32
transformer.h.1.attn.c_attn.weight [2304, 768] float32
transformer.h.1.attn.c_attn.bias [2304] float32
transformer.h.1.attn.c_proj.weight [768, 768] float32
transformer.h.1.attn.c_proj.bias [768] float32
transformer.h.1.ln_2.weight [768] float32
transformer.h.1.ln_2

In [3]:
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')

In [4]:
hf_mlp = model.h[0].mlp
hf_mlp.dropout.p = 0
hf_mlp

GPT2MLP(
  (c_fc): Conv1D()
  (c_proj): Conv1D()
  (act): NewGELUActivation()
  (dropout): Dropout(p=0, inplace=False)
)

In [5]:
import torch

x = torch.rand((1, 768), dtype=torch.float32)

c_fc_weight = hf_mlp.c_fc.weight.data.numpy()
c_fc_bias = hf_mlp.c_fc.bias.data.numpy()
c_proj_weight = hf_mlp.c_proj.weight.data.numpy()
c_proj_bias = hf_mlp.c_proj.bias.data.numpy()

In [6]:
y1 = hf_mlp.forward(x)

In [7]:
from mlc_chat.compiler.model.gpt2 import gpt2_model
from tvm.relax.frontend.nn import spec

mlp = gpt2_model.GPT2MLP(config)
state_dict = mlp.state_dict()
state_dict


OrderedDict([('c_fc.weight', Tensor([3072, 768], "float32")),
             ('c_fc.bias', Tensor([3072], "float32")),
             ('c_proj.weight', Tensor([768, 3072], "float32")),
             ('c_proj.bias', Tensor([768], "float32"))])

In [8]:
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op
import numpy as np

state_dict['c_fc.weight'].data = c_fc_weight.T
state_dict['c_fc.bias'].data = c_fc_bias
state_dict['c_proj.weight'].data = c_proj_weight.T
state_dict['c_proj.bias'].data = c_proj_bias

In [9]:
mlp_spec = {"forward": {"hidden_states": spec.Tensor([1, 768], dtype="float32")}}
torch_mlp = mlp.jit(spec=mlp_spec, debug=True)
y2 = torch_mlp["forward"](x)

In [10]:
assert torch.allclose(y1, y2, 0.001)

In [11]:
hf_attn = model.h[0].attn
hf_attn

GPT2Attention(
  (c_attn): Conv1D()
  (c_proj): Conv1D()
  (attn_dropout): Dropout(p=0.1, inplace=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

In [12]:
x = torch.rand((1, 1, 768), dtype=torch.float32)

c_attn_weight = hf_attn.c_attn.weight.data.numpy()
c_attn_bias = hf_attn.c_attn.bias.data.numpy()
c_proj_weight = hf_attn.c_proj.weight.data.numpy()
c_proj_bias = hf_attn.c_proj.bias.data.numpy()

In [13]:
mask = torch.ones((1, 1, 1, 1), dtype=torch.float32)

y1 = hf_attn.forward(x, attention_mask=mask)
y1

(tensor([[[ 7.6107e-01,  1.2312e+01, -1.1915e+00, -1.2172e+00,  9.3915e-01,
            3.7703e-01, -6.1495e+00,  6.1346e+00,  4.1892e+00,  7.4580e-02,
            7.4806e+00, -4.7293e-01, -5.4205e-01,  1.7085e+00,  2.2852e+00,
           -3.9748e+00,  7.6214e+00, -2.5467e+00, -7.7637e-01,  1.7110e+00,
           -1.8825e+00, -2.3129e+00,  8.2723e-01,  4.3244e+00,  1.5845e+00,
           -9.9023e-01, -4.8893e+00,  1.5530e+00,  2.8631e-01,  5.8775e-01,
            1.2878e+00, -9.9392e-01,  2.1662e+00, -3.3726e-01,  8.9550e-01,
           -1.0003e+01,  2.0521e+00, -7.2917e-01, -8.7611e-01,  1.0263e+00,
           -3.3785e-01,  2.7569e+00,  2.7468e+00,  2.1164e-01, -3.9742e+00,
            4.4884e-01,  1.9716e-01,  2.3247e+00,  9.9650e+00,  1.1739e+01,
            1.2716e+00,  3.4024e-01,  8.5342e+00,  1.8538e-01, -7.6792e-01,
            6.4777e+00,  3.0169e+00, -6.8197e+00, -5.0132e-01,  6.3051e+00,
            1.6772e+00,  3.1065e+00, -4.2201e-01,  1.9741e+00,  6.9222e+01,
            

In [14]:
attn = gpt2_model.GPT2Attention(config)
state_dict = attn.state_dict()

state_dict['c_attn.weight'].data = c_attn_weight.T
state_dict['c_attn.bias'].data = c_attn_bias
state_dict['c_proj.weight'].data = c_proj_weight.T
state_dict['c_proj.bias'].data = c_proj_bias

In [15]:
attn_spec = {"forward": {"hidden_states": spec.Tensor([1, 1, 768], dtype="float32"), "attention_mask": spec.Tensor([1, 1, 1, 1], dtype="float32"), "total_seq_len": int}}
torch_attn = attn.jit(spec=attn_spec, debug=True)

y2 = torch_attn["forward"](x, mask, 1)
assert torch.allclose(y1[0], y2, 0.001)

In [16]:
hf_block = model.h[0]
hf_state_dict = hf_block.state_dict()

for k, v in hf_state_dict.items():
    for name in ["attn.c_attn", "attn.c_proj", "mlp.c_proj", "mlp.c_fc"]:
        if name in k and 'weight' in k:
            hf_state_dict[k] = v.T
            break

In [17]:
block = gpt2_model.GPT2Block(config)
block.load_state_dict(hf_state_dict, strict=True)

([], [])

In [21]:
x = torch.rand((1, 1, 768), dtype=torch.float32)
mask = torch.ones((1, 1, 1, 1), dtype=torch.float32)

y1 = hf_block(x, attention_mask=mask)

block_spec = {"forward": {"hidden_states": spec.Tensor([1, 1, 768], dtype="float32"), "attention_mask": spec.Tensor([1, 1, 1, 1], dtype="float32"), "total_seq_len": int}}
torch_block = block.jit(spec=block_spec, debug=True)

y2 = torch_block["forward"](x, mask, 1)
assert torch.allclose(y1[0], y2, 0.001)