In [1]:

from safetensors.torch import load_file
import json
from models.meissonic import Transformer2DModel

model_path = "/home/dongpeijie/.cache/huggingface/hub/models--MeissonFlow--Meissonic/snapshots/08ff13de62d55a6984806076d005089acc63f9ee/transformer/"
config_name = "config.json"
file_name = "diffusion_pytorch_model.safetensors"

loaded = load_file(model_path + file_name)

# 打开一个包含JSON数据的文件
with open(model_path + config_name, 'r') as file:
    # 使用json.load()将文件中的JSON数据解码为Python对象
    model_config_dict = json.load(file)

meissonic_mdoel = Transformer2DModel(
    patch_size=model_config_dict['patch_size'],
    in_channels=model_config_dict['in_channels'],
    num_layers=model_config_dict['num_layers'],
    num_single_layers=model_config_dict['num_single_layers'],
    attention_head_dim=model_config_dict['attention_head_dim'],
    num_attention_heads=model_config_dict['num_attention_heads'],
    joint_attention_dim=model_config_dict['joint_attention_dim'],
    pooled_projection_dim= model_config_dict['pooled_projection_dim'],
    guidance_embeds=model_config_dict['guidance_embeds'], # unused in our implementation
    axes_dims_rope=tuple(model_config_dict['axes_dims_rope']),
    vocab_size=model_config_dict['vocab_size'],
    codebook_size=model_config_dict['codebook_size'],
    downsample=model_config_dict['downsample'],
    upsample=model_config_dict['upsample'],
)

meissonic_mdoel.load_state_dict(loaded)

print(meissonic_mdoel)

Transformer2DModel(
  (pos_embed): FluxPosEmbed()
  (time_text_embed): CombinedTimestepTextProjEmbeddings(
    (time_proj): Timesteps()
    (timestep_embedder): TimestepEmbedding(
      (linear_1): Linear(in_features=256, out_features=1024, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (text_embedder): PixArtAlphaTextProjection(
      (linear_1): Linear(in_features=1024, out_features=1024, bias=True)
      (act_1): SiLU()
      (linear_2): Linear(in_features=1024, out_features=1024, bias=True)
    )
  )
  (context_embedder): Linear(in_features=1024, out_features=1024, bias=True)
  (transformer_blocks): ModuleList(
    (0-13): 14 x TransformerBlock(
      (norm1): AdaLayerNormZero(
        (silu): SiLU()
        (linear): Linear(in_features=1024, out_features=6144, bias=True)
        (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=False)
      )
      (norm1_context): AdaLayerNormZero(
        (silu): SiLU()
  

In [3]:
import torch
# Set the manual seed for reproducibility
torch.manual_seed(42)

# Generate tensors based on the provided sizes and types
hidden_states = torch.randint(0, 100, (2, 64, 64), dtype=torch.int64)
micro_conds = torch.randn(2, 5, dtype=torch.float32)
pooled_projections = torch.randn(2, 1024, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 77, 1024, dtype=torch.float32)
img_ids = torch.randint(0, 100, (1024, 3), dtype=torch.int64)
txt_ids = torch.randn(77, 3, dtype=torch.float32)
timestep = torch.randint(0, 10, (1,), dtype=torch.int64)

# Print the generated tensors to verify
# print("hidden_states:", hidden_states)
# print("micro_conds:", micro_conds)
# print("pooled_projections:", pooled_projections)
# print("encoder_hidden_states:", encoder_hidden_states)
# print("img_ids:", img_ids)
# print("txt_ids:", txt_ids)
# print("timestep:", timestep)

cur_res = meissonic_mdoel(
    hidden_states = hidden_states,
    micro_conds=micro_conds,
    pooled_projections=pooled_projections,
    encoder_hidden_states=encoder_hidden_states,
    img_ids = img_ids,
    txt_ids = txt_ids,
    timestep = timestep,
)



In [3]:
type(cur_res)

torch.Tensor

In [4]:
cur_res.shape

torch.Size([2, 8192, 64, 64])

In [1]:
import torch
true_res = torch.load("/home/dongpeijie/workspace/gg/Meissonic/test_res.pt")

  true_res = torch.load("/home/dongpeijie/workspace/gg/Meissonic/test_res.pt")


In [11]:
cur_res

tensor([[[[ 2.8369e-01, -1.2375e+00, -5.3806e+00,  ..., -3.4819e+00,
           -4.0670e+00, -3.3877e+00],
          [-2.9600e+00, -1.1494e+00, -9.7034e+00,  ..., -3.0984e+00,
           -1.7382e+00, -3.2511e+00],
          [-1.6342e+00,  6.3396e-01, -2.1614e+00,  ..., -4.5308e+00,
            2.3254e+00, -2.8297e+00],
          ...,
          [-4.5445e+00, -2.3282e+00, -5.4029e+00,  ..., -6.6764e+00,
            2.6510e-01, -1.9122e+00],
          [-6.1303e-01, -2.4272e+00, -4.7205e+00,  ..., -2.7371e+00,
           -6.0191e+00, -5.7930e+00],
          [-6.3311e-01, -2.1056e+00, -1.8232e+00,  ...,  5.7443e-01,
            2.8692e-01, -1.8012e+00]],

         [[ 1.4578e+00, -1.2708e+00, -2.8397e+00,  ..., -7.5843e-01,
           -1.7562e+00, -9.1217e-02],
          [-2.4450e+00,  5.7523e-02, -4.1757e+00,  ..., -2.4705e+00,
            1.8865e-01, -4.0500e-01],
          [-3.3839e-01,  1.0405e+00, -7.0991e-01,  ...,  6.0070e-01,
            3.1506e+00, -2.4566e-01],
          ...,
     

In [5]:
torch.equal(true_res, cur_res)

True

In [2]:
true_res.shape

torch.Size([2, 8192, 64, 64])

In [1]:
l  = [1, 2, 3]
l[0: 1]

[1]