Skip to content

Commit

Permalink
Merge branch 'main' into add_swinv2_loader_to_libai
Browse files Browse the repository at this point in the history
  • Loading branch information
xiezipeng-ML committed Aug 15, 2022
2 parents 237aa6a + f96798e commit 14a3e9e
Show file tree
Hide file tree
Showing 4 changed files with 371 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dev/model_loader_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_swin_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_vit_loader.py

rm -rf $TEST_OUTPUT
1 change: 1 addition & 0 deletions libai/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .model_utils.t5_loader import T5LoaderHuggerFace, T5LoaderLibai
from .model_utils.swin_loader import SwinLoaderHuggerFace, SwinLoaderLiBai
from .model_utils.swinv2_loader import SwinV2LoaderHuggerFace, SwinV2LoaderLiBai
from .model_utils.vit_loader import ViTLoaderHuggerFace, ViTLoaderLiBai
223 changes: 223 additions & 0 deletions libai/models/utils/model_utils/vit_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

import oneflow as flow

from .base_loader import ModelLoaderHuggerFace, ModelLoaderLiBai


class ViTLoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)

"""NOTE: base_model_prefix_1 is ViT's prefix in Transformers.
base_model_prefix_2 is ViT's prefix in LiBai."""

self.base_model_prefix_1 = "vit"
self.base_model_prefix_2 = ""

def _convert_state_dict(self, flow_state_dict, cfg=None):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()

# Get configs
num_heads = cfg.get("num_heads")
hidden_size = cfg.get("embed_dim")
head_size = int(hidden_size / num_heads)

# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)

index_idx = 3 if has_prefix else 2

old_keys = oneflow_state_dict.keys()

for key in list(old_keys):

# Convert vit's embedding layers
if "embeddings" in key:
if "cls_token" in key:
new_key = "cls_token"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "position_embeddings" in key:
new_key = "pos_embed"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "patch_embeddings.projection" in key:
if "weight" in key:
new_key = "patch_embed.proj.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "patch_embed.proj.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)

# Convert vit's layernorm layers
elif "layernorm_before" in key:
index_block = key.split(".")[index_idx]
if "weight" in key:
new_key = "blocks." + index_block + ".input_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "blocks." + index_block + ".input_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)

elif "layernorm_after" in key:
index_block = key.split(".")[index_idx]
if "weight" in key:
new_key = "blocks." + index_block + ".post_attention_layernorm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "blocks." + index_block + ".post_attention_layernorm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)

# Convert vit's attention layers
elif "attention" in key:
index_block = key.split(".")[index_idx]
if "attention.attention" in key:
if (
"blocks." + index_block + ".self_attention.query_key_value.weight"
in oneflow_state_dict.keys()
):
continue
q_w = key
k_w = q_w.replace("query", "key")
v_w = q_w.replace("query", "value")
q_b = q_w.replace("weight", "bias")
k_b = k_w.replace("weight", "bias")
v_b = v_w.replace("weight", "bias")

qkv_w = flow.cat(
(
oneflow_state_dict.pop(q_w),
oneflow_state_dict.pop(k_w),
oneflow_state_dict.pop(v_w),
),
dim=0,
)
qkv_b = flow.cat(
(
oneflow_state_dict.pop(q_b),
oneflow_state_dict.pop(k_b),
oneflow_state_dict.pop(v_b),
),
dim=-1,
)

qkv_w = self._fix_qkv_ordering(qkv_w, head_size, num_heads)
qkv_b = self._fix_qkv_ordering(qkv_b, head_size, num_heads)

new_key = "blocks." + index_block + ".self_attention.query_key_value.weight"
oneflow_state_dict[new_key] = qkv_w

new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = qkv_b

elif "output" in key:
if "dense" in key:
if "weight" in key:
new_key = "blocks." + index_block + ".self_attention.dense.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
if "bias" in key:
new_key = "blocks." + index_block + ".self_attention.dense.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)

elif "intermediate" in key:
index_block = key.split(".")[index_idx]
if "weight" in key:
if (
"blocks." + index_block + ".mlp.dense_h_to_4h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = key.replace("weight", "bias")
new_key = "blocks." + index_block + ".mlp.dense_h_to_4h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)

elif "output" in key:
index_block = key.split(".")[index_idx]
if "dense.weight" in key:
if (
"blocks." + index_block + ".mlp.dense_4h_to_h.weight"
in oneflow_state_dict.keys()
):
continue
w = key
b = w.replace("weight", "bias")
new_key = "blocks." + index_block + ".mlp.dense_4h_to_h.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(w)
new_key = new_key.replace("weight", "bias")
oneflow_state_dict[new_key] = oneflow_state_dict.pop(b)

elif "layernorm" in key:
if "weight" in key:
new_key = "norm.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "norm.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)

elif "classifier" in key:
if "weight" in key:
new_key = "head.weight"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
elif "bias" in key:
new_key = "head.bias"
oneflow_state_dict[new_key] = oneflow_state_dict.pop(key)
else:
oneflow_state_dict[key] = oneflow_state_dict.pop(key)

return oneflow_state_dict

def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)

# update libai_cfg by config.json
self.libai_cfg.img_size = cfg_dict["image_size"]
self.libai_cfg.patch_size = cfg_dict["patch_size"]
self.libai_cfg.in_chans = cfg_dict["num_channels"]
self.libai_cfg.embed_dim = cfg_dict["hidden_size"]
self.libai_cfg.depth = cfg_dict["num_hidden_layers"]
self.libai_cfg.num_heads = cfg_dict["num_attention_heads"]
self.libai_cfg.attn_drop_rate = cfg_dict["attention_probs_dropout_prob"]
self.libai_cfg.drop_rate = cfg_dict["hidden_dropout_prob"]

# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self.libai_cfg[k] = v


class ViTLoaderLiBai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = ""
145 changes: 145 additions & 0 deletions tests/model_utils/test_vit_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import shutil
import unittest

import numpy as np
import oneflow as flow
import oneflow.unittest
from omegaconf import DictConfig

import libai
from configs.common.models.vit.vit_tiny_patch16_224 import cfg as libai_cfg
from libai.models.utils import ViTLoaderHuggerFace
from libai.utils import distributed as dist
from libai.utils.file_utils import get_data_from_cache
from libai.utils.logger import setup_logger

PRETRAINED_MODEL_URL = "http://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/libai/model_utils_test/vit_utils/pytorch_model.bin" # noqa
PRETRAINED_MODEL_CONFIG_URL = "http://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/libai/model_utils_test/vit_utils/config.json" # noqa
INIT_DATA = "http://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/libai/model_utils_test/vit_utils/init_data.npz" # noqa

PRETRAINED_MODEL_MD5 = "c587693e5e312064c56f27aa2d4f1e81"
PRETRAINED_MODEL_CONFIG_MD5 = "9ea94d9e5bc3543b1de7d12956321c50"
INIT_DATA_MD5 = "5fecdcd8d46bfefa310d19e084bd4815"

TEST_OUTPUT = os.path.join(os.getenv("TEST_OUTPUT", "output_unittest"), "test_vit_utils")


setup_logger(distributed_rank=dist.get_rank())


class TestViTLoder(flow.unittest.TestCase):
def setUp(self) -> None:
cache_dir = os.path.join(
os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data_test"), "vit_utils_data"
)
self.pretrained_model_path = cache_dir
self.init_data_path = os.path.join(cache_dir, "init_data.npz")

# download model and data
if dist.get_local_rank() == 0:
# download dataset on main process of each node
get_data_from_cache(PRETRAINED_MODEL_URL, cache_dir, md5=PRETRAINED_MODEL_MD5)
get_data_from_cache(
PRETRAINED_MODEL_CONFIG_URL, cache_dir, md5=PRETRAINED_MODEL_CONFIG_MD5
)
get_data_from_cache(INIT_DATA, cache_dir, md5=INIT_DATA_MD5)
os.makedirs(TEST_OUTPUT, exist_ok=True)
dist.synchronize()

# prepare input data
self.input_image = np.load(self.init_data_path)["arr_0"]

@classmethod
def tearDownClass(cls) -> None:
if os.path.isdir(TEST_OUTPUT) and dist.get_local_rank() == 0:
shutil.rmtree(TEST_OUTPUT)

@flow.unittest.skip_unless_1n4d()
def test_vit_utils_with_data_tensor_parallel(self):
# set distributed config
dist_cfg = DictConfig(
dict(
data_parallel_size=2,
tensor_parallel_size=2,
pipeline_parallel_size=1,
)
)
dist.setup_dist_util(dist_cfg)

# load model
load_func = ViTLoaderHuggerFace(
model=libai.models.VisionTransformer,
libai_cfg=libai_cfg,
pretrained_model_path=self.pretrained_model_path,
)
model = load_func.load()
model.eval()

input_image = flow.tensor(
self.input_image.tolist(),
dtype=flow.float32,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=model.patch_embed.proj.weight.placement,
)

prediction_scores = model(input_image)["prediction_scores"]

self.assertTrue(
np.allclose(np.array(3.1374), prediction_scores.sum().data.numpy(), 1e-4, 1e-4)
)

@flow.unittest.skip_unless_1n4d()
def test_vit_utils_with_data_tensor_pipeline_parallel(self):
# set distributed config
dist_cfg = DictConfig(
dict(
data_parallel_size=2,
tensor_parallel_size=1,
pipeline_parallel_size=2,
pipeline_num_layers=12,
)
)
dist.setup_dist_util(dist_cfg)

# load model
load_func = ViTLoaderHuggerFace(
model=libai.models.VisionTransformer,
libai_cfg=libai_cfg,
pretrained_model_path=self.pretrained_model_path,
)
model = load_func.load()
model.eval()

input_image = flow.tensor(
self.input_image,
dtype=flow.float32,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=model.patch_embed.proj.weight.placement,
)

prediction_scores = model(input_image)["prediction_scores"]

self.assertTrue(
np.allclose(np.array(3.1374), prediction_scores.sum().data.numpy(), 1e-4, 1e-4)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 14a3e9e

Please sign in to comment.