In [2]:
from collections import namedtuple
import os
import ast
import numpy as np
import tqdm
import torch
from models import TransformerLM, TransformerConditionedLM
from models_modified import TransformerSentenceLM_FixedImg_gated
import torch.nn.functional as F
from fairseq import checkpoint_utils, options, tasks, utils
import yaml

In [3]:
config_path = '../../config_sentence.yml'
with open(config_path, 'r') as yml:
    config = yaml.safe_load(yml)

model_params = config["i2u"]["model_params"]
model_params['vocab_size'] = 1017
img_refine_params = config["i2u"]["refine_encoder_params"]
model_params["refine_encoder_params"] = img_refine_params

In [4]:
model = TransformerSentenceLM_FixedImg_gated(**model_params)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
model.to(device)

TransformerSentenceLM_FixedImg_gated(
  (embed): Embedding(1017, 1024)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (LM_decoder): None
  (classifier): Linear(in_features=1024, out_features=1017, bias=True)
  (image_encoder): DinoResEncoder_NoPool(
    (resnet): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [5]:
LM_checkpoint = "/net/papilio/storage2/yhaoyuan/transformer_I2S/saved_model/LM/SpokenCOCO_LibriSpeech/PP_15.6512/checkpoint_coco_1_cap_per_img_1_min_word_freq.pth.tar"
LM_state_dict = torch.load(LM_checkpoint)
LM_state_dict = LM_state_dict["model_state_dict"]

# LM_checkpoint2 = "/net/papilio/storage2/yhaoyuan/transformer_I2S/saved_model/LM/SpokenCOCO_LibriSpeech/PP_15.8054/checkpoint_coco_1_cap_per_img_1_min_word_freq.pth.tar"
# LM_state_dict2 = torch.load(LM_checkpoint2)
# LM_state_dict2 = LM_state_dict2["model_state_dict"]

In [6]:
"""
prefix
    None                                        None

    embed.weight                                embed.weight
    pos_encoder.pe                              pos_encoder.pe

For each Layer:
Prefix:
    decoder.layers.[]                           LM_decoder.layers.[]

    norm1.weight                                norm1.weight
    norm1.bias                                  norm1.bias

    self_attn.in_proj_weight                    self_attn.in_proj_weight
    self_attn.in_proj_bias                      self_attn.in_proj_bias
    self_attn.out_proj.weight                   self_attn.out_proj.weight
    self_attn.out_proj.bias                     self_attn.out_proj.bias

    norm2.weight                                a
    norm2.bias                                  a

    multihead_attn.in_proj_weight               a
    multihead_attn.in_proj_bias                 a
    multihead_attn.out_proj.weight              a
    multihead_attn.out_proj.bias                a

    norm3.weight                                norm2.weight
    norm3.bias                                  norm2.bias
    
    linear1.weight                              linear1.weight 
    linear1.bias                                linear1.bias
    linear2.weight                              linear2.weight 
    linear2.bias                                linear2.bias 
    

For Final Layer:
prefix:
    decoder.                                    LM_decoder

    norm.weight                                 norm.weight
    norm.bias                                   norm.bias

Classifier:
    classifier.weight                           classifier.weight
    classifier.bias                             classifier.bias
"""

'\nprefix\n    None                                        None\n\n    embed.weight                                embed.weight\n    pos_encoder.pe                              pos_encoder.pe\n\nFor each Layer:\nPrefix:\n    decoder.layers.[]                           LM_decoder.layers.[]\n\n    norm1.weight                                norm1.weight\n    norm1.bias                                  norm1.bias\n\n    self_attn.in_proj_weight                    self_attn.in_proj_weight\n    self_attn.in_proj_bias                      self_attn.in_proj_bias\n    self_attn.out_proj.weight                   self_attn.out_proj.weight\n    self_attn.out_proj.bias                     self_attn.out_proj.bias\n\n    norm2.weight                                a\n    norm2.bias                                  a\n\n    multihead_attn.in_proj_weight               a\n    multihead_attn.in_proj_bias                 a\n    multihead_attn.out_proj.weight              a\n    multihead_attn.out_proj.bi

In [7]:
def load_key(tgt_state_dict, key, value):
    model = tgt_state_dict
    # Make sure the loaded values won't cause errors
    assert model[key].shape == value.shape, f"Key {key}, need shape {model[key].shape}, get shape {value.shape}"
    assert model[key].dtype == value.dtype, f"Key {key}, need type {model[key].dtype}, get type {value.dtype}"
    model[key] = value
    return model

In [8]:
def load_layer(tgt_state_dict, src_state_dict, layer_id):
    # Load in_proj weight and bias
    tgt_prefix = f"decoder.layers.{int(layer_id)}."
    src_prefix = f"LM_decoder.layers.{int(layer_id)}."

    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"norm1.weight", src_state_dict[src_prefix + "norm1.weight"])
    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"norm1.bias", src_state_dict[src_prefix + "norm1.bias"])

    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"self_attn.in_proj_weight", src_state_dict[src_prefix + "self_attn.in_proj_weight"])
    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"self_attn.in_proj_bias", src_state_dict[src_prefix + "self_attn.in_proj_bias"])
    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"self_attn.out_proj.weight", src_state_dict[src_prefix + "self_attn.out_proj.weight"])
    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"self_attn.out_proj.bias", src_state_dict[src_prefix + "self_attn.out_proj.bias"])

    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"norm3.weight", src_state_dict[src_prefix + "norm2.weight"])
    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"norm3.bias", src_state_dict[src_prefix + "norm2.bias"])
    
    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"linear1.weight", src_state_dict[src_prefix + "linear1.weight"])
    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"linear1.bias", src_state_dict[src_prefix + "linear1.bias"])
    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"linear2.weight", src_state_dict[src_prefix + "linear2.weight"])
    tgt_state_dict = load_key(tgt_state_dict, tgt_prefix+"linear2.bias", src_state_dict[src_prefix + "linear2.bias"])
    return tgt_state_dict

def load_embed(tgt_state_dict, src_state_dict):
    tgt_state_dict = load_key(tgt_state_dict, "embed.weight", src_state_dict["embed.weight"])
    tgt_state_dict = load_key(tgt_state_dict, "classifier.weight", src_state_dict["classifier.weight"])
    tgt_state_dict = load_key(tgt_state_dict, "classifier.bias", src_state_dict["classifier.bias"])
    return tgt_state_dict

def load_final_norm(tgt_state_dict, src_state_dict):
    tgt_state_dict = load_key(tgt_state_dict, "decoder.norm.weight", src_state_dict["LM_decoder.norm.weight"])
    tgt_state_dict = load_key(tgt_state_dict, "decoder.norm.bias", src_state_dict["LM_decoder.norm.bias"])
    return tgt_state_dict


In [9]:
model_state_dict = model.state_dict()

In [10]:
model_state_dict = load_embed(model_state_dict, LM_state_dict)
model_state_dict = load_final_norm(model_state_dict, LM_state_dict)
for i in range(12):
    model_state_dict = load_layer(model_state_dict, LM_state_dict, i)

In [11]:
model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [20]:
a = 10
check_dict = {
    "embed.weight": "embed.weight",

    f"decoder.layers.{a}.norm1.weight": f"LM_decoder.layers.{a}.norm1.weight",
    f"decoder.layers.{a}.norm1.bias": f"LM_decoder.layers.{a}.norm1.bias",
    f"decoder.layers.{a}.self_attn.in_proj_weight": f"LM_decoder.layers.{a}.self_attn.in_proj_weight",
    f"decoder.layers.{a}.self_attn.in_proj_bias": f"LM_decoder.layers.{a}.self_attn.in_proj_bias",
    f"decoder.layers.{a}.self_attn.out_proj.weight": f"LM_decoder.layers.{a}.self_attn.out_proj.weight",
    f"decoder.layers.{a}.self_attn.out_proj.bias": f"LM_decoder.layers.{a}.self_attn.out_proj.bias",
    f"decoder.layers.{a}.norm3.weight": f"LM_decoder.layers.{a}.norm2.weight",
    f"decoder.layers.{a}.norm3.bias": f"LM_decoder.layers.{a}.norm2.bias",
    f"decoder.layers.{a}.linear1.weight": f"LM_decoder.layers.{a}.linear1.weight",
    f"decoder.layers.{a}.linear1.bias": f"LM_decoder.layers.{a}.linear1.bias",
    f"decoder.layers.{a}.linear2.weight": f"LM_decoder.layers.{a}.linear2.weight",
    f"decoder.layers.{a}.linear2.bias": f"LM_decoder.layers.{a}.linear2.bias",

    "decoder.norm.weight": "LM_decoder.norm.weight",
    "decoder.norm.bias": "LM_decoder.norm.bias",
    "classifier.weight": "classifier.weight",
    "classifier.bias": "classifier.bias"
}

In [13]:
model.state_dict()
for k,v in check_dict.items():
    check = model.state_dict()[k] == LM_state_dict[v]
    print(check.all())

tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')


In [14]:
# pos encoder 没有加载 但是是一样的
# 说明pos确实不需要加载。
(model.state_dict()["pos_encoder.pe"] == LM_state_dict["pos_encoder.pe"]).all()

tensor(True, device='cuda:0')

In [15]:
for k in model.state_dict().keys():
    if k[:len("image_encoder")] != "image_encoder":
        print(k)

embed.weight
pos_encoder.pe
classifier.weight
classifier.bias
decoder.layers.0.self_attn.in_proj_weight
decoder.layers.0.self_attn.in_proj_bias
decoder.layers.0.self_attn.out_proj.weight
decoder.layers.0.self_attn.out_proj.bias
decoder.layers.0.multihead_attn.in_proj_weight
decoder.layers.0.multihead_attn.in_proj_bias
decoder.layers.0.multihead_attn.out_proj.weight
decoder.layers.0.multihead_attn.out_proj.bias
decoder.layers.0.linear1.weight
decoder.layers.0.linear1.bias
decoder.layers.0.linear2.weight
decoder.layers.0.linear2.bias
decoder.layers.0.norm1.weight
decoder.layers.0.norm1.bias
decoder.layers.0.norm2.weight
decoder.layers.0.norm2.bias
decoder.layers.0.norm3.weight
decoder.layers.0.norm3.bias
decoder.layers.1.self_attn.in_proj_weight
decoder.layers.1.self_attn.in_proj_bias
decoder.layers.1.self_attn.out_proj.weight
decoder.layers.1.self_attn.out_proj.bias
decoder.layers.1.multihead_attn.in_proj_weight
decoder.layers.1.multihead_attn.in_proj_bias
decoder.layers.1.multihead_att

In [17]:
model2 = TransformerSentenceLM_FixedImg_gated(**model_params)
model2.to(device)

TransformerSentenceLM_FixedImg_gated(
  (embed): Embedding(1017, 1024)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (LM_decoder): None
  (classifier): Linear(in_features=1024, out_features=1017, bias=True)
  (image_encoder): DinoResEncoder_NoPool(
    (resnet): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [18]:
model2.load_Pretrained_LM(LM_checkpoint)

Load uLM weights from path: /net/papilio/storage2/yhaoyuan/transformer_I2S/saved_model/LM/SpokenCOCO_LibriSpeech/PP_15.6512/checkpoint_coco_1_cap_per_img_1_min_word_freq.pth.tar


In [21]:
model.state_dict()
for k,v in check_dict.items():
    check = model.state_dict()[k] == model2.state_dict()[k]
    print(check.all())

tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
tensor(True, device='cuda:0')
