In [1]:
import torch
from transformers import BertModel, BertConfig

In [2]:
a = torch.zeros(2, 4)
a[:, range(1, a.size()[1], 2)] = 1.
a

tensor([[0., 1., 0., 1.],
        [0., 1., 0., 1.]])

In [3]:
model_en = BertModel.from_pretrained("bert-base-uncased")
model_de = BertModel.from_pretrained("bert-base-german-dbmdz-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-german-dbmdz-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bia

In [4]:
model_en.config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "adapters": {
    "adapters": {},
    "config_map": {},
    "fusion_config_map": {},
    "fusions": {}
  },
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.11.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [4]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def row_shuffling(tensor):
    return tensor[torch.randperm(tensor.size()[0])]

def col_shuffling(tensor):
#     print(tensor.size(), torch.randperm(tensor.size()[1]))
    if len(tensor.size()) == 2:
        return tensor[:, torch.randperm(tensor.size()[1])]
    elif len(tensor.size()) == 1:
        return tensor[torch.randperm(tensor.size()[0])]
    else:
        raise ValueError(f"tensor contains {len(tensor.size())} dimensions")

        
def col_zeroing(tensor):
#     if len(tensor.size()) == 2:
#         zt[:, range(0, tensor.size()[1], 2)] = 1.
#         res = tensor * zt
#     elif len(tensor.size()) == 1:
#         zt[range(0, tensor.size()[0], 2)] = 1.
#         res = tensor * zt
#     else:
#         raise ValueError(f"tensor contains {len(tensor.size())} dimensions")
#     return res
    zt = torch.zeros(tensor.size())
    if len(tensor.size()) == 2:
        fsz = tensor.size(0)
        ssz = tensor.size(1)
        if fsz == 768 and ssz == 768:
            zt[:, range(0, tensor.size(1), 2)] = 1.
            zt[0] = 0.
            zt[range(2, tensor.size(0), 2), :] = 0.
            res = tensor * zt
        elif fsz == 768 and ssz != 768:
            zt[0] = 0.
            zt[range(2, tensor.size(0), 2), :] = 0.
            res = tensor * zt
        elif ssz == 768 and fsz != 768:
            zt[:, range(0, tensor.size(1), 2)] = 1.
            res = tensor * zt
        else:
            res = tensor
    elif len(tensor.size()) == 1:
        if tensor.size(0) == 768:
            zt[range(0, tensor.size(0), 2)] = 1.
            res = tensor * zt
        else:
            res = tensor
    else:
        raise ValueError(f"tensor contains {len(tensor.size())} dimensions")
    return res


def reduce_col(tensor):
    # print(tensor[:, range(1, tensor.size()[1])].data)
    if len(tensor.size()) == 2:
        fsz = tensor.size(0)
        ssz = tensor.size(1)
        if fsz == 768 and ssz == 768:
            zt = torch.zeros((tensor.size(0), fsz // 2))
            zt = tensor[:, range(0, tensor.size(1), 2)]
            sz = zt.size(0)
            ft = torch.zeros((fsz // 2, zt.size(1)))
            ft = zt[range(0, zt.size(0), 2), :]
        elif fsz == 768 and ssz != 768:
            ft = torch.zeros((fsz // 2, tensor.size(1)))
            ft = tensor[range(0, tensor.size(0), 2), :]
        elif ssz == 768 and fsz != 768:
            ft = torch.zeros((tensor.size(0), ssz // 2))
            ft = tensor[:, range(0, tensor.size(1), 2)]
        else:
            ft = tensor
    elif len(tensor.size()) == 1:
        if tensor.size(0) == 768:
            fsz = tensor.size(0)
            ft = torch.zeros(fsz // 2)
            ft = tensor[range(0, tensor.size(0), 2)]
        else:
            ft = tensor
    else:
        raise ValueError(f"tensor contains {len(tensor.size())} dimensions")
    return ft

In [21]:
# new_params_en = {}
# for k, v in model_en.named_parameters():
#     new_params_en[k] = reduce_col(v)
# new_params_en["embeddings.position_ids"] = model_en.embeddings.position_ids

# new_params_de = {}
# for k, v in model_de.named_parameters():
#     new_params_de[k] = reduce_col(v)
# new_params_de["embeddings.position_ids"] = model_de.embeddings.position_ids

In [5]:
new_params_en = {}
for k, v in model_en.named_parameters():
#     new_params_en[k] = col_zeroing(v)
    if "encoder.layer" in k:
        new_params_en[k] = col_zeroing(v)
    else:
        new_params_en[k] = v
new_params_en["embeddings.position_ids"] = model_en.embeddings.position_ids

new_params_de = {}
for k, v in model_de.named_parameters():
#     new_params_de[k] = col_zeroing(v)
    if "encoder.layer" in k:
        new_params_de[k] = col_zeroing(v)
    else:
        new_params_de[k] = v
new_params_de["embeddings.position_ids"] = model_de.embeddings.position_ids

In [6]:
for k, v in new_params_en.items():
    print(k, v.size())
# new_params_en["embeddings.word_embeddings.weight"]
# new_params_en["encoder.layer.0.attention.self.query.weight"]

embeddings.word_embeddings.weight torch.Size([30522, 768])
embeddings.position_embeddings.weight torch.Size([512, 768])
embeddings.token_type_embeddings.weight torch.Size([2, 768])
embeddings.LayerNorm.weight torch.Size([768])
embeddings.LayerNorm.bias torch.Size([768])
encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
encoder.layer.0.attention.self.query.bias torch.Size([768])
encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
encoder.layer.0.attention.self.key.bias torch.Size([768])
encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
encoder.layer.0.attention.self.value.bias torch.Size([768])
encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.0.attention.output.dense.bias torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.0.inter

In [6]:
# new_params_en = {}
# for k, v in model_en.named_parameters():
#     if "encoder.layer" in k:
#         new_params_en[k] = col_shuffling(v)
#     else:
#         new_params_en[k] = v
        
# new_params_de = {}
# for k, v in model_de.named_parameters():
#     if "encoder.layer" in k:
#         new_params_de[k] = row_shuffling(col_shuffling(v))
#     else:
#         new_params_de[k] = v

In [13]:
bert_en = BertModel(BertConfig(hidden_size=384))
bert_en.load_state_dict(new_params_en)
# bert_en.save_pretrained("./bert_en_zero_reduced")

bert_de = BertModel(BertConfig(vocab_size=31102, hidden_size=384))
bert_de.load_state_dict(new_params_de)
# bert_de.save_pretrained("./bert_de_zero_reduced")
count_parameters(bert_en), count_parameters(bert_de)

(47534208, 47756928)

In [35]:
model_en.load_state_dict(new_params_en, strict=False)
model_de.load_state_dict(new_params_de, strict=False)

count_parameters(model_en), count_parameters(model_de)

model_de.save_pretrained("./bert_de_zeroed")
model_en.save_pretrained("./bert_en_zeroed")

In [80]:
model_en.config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "adapters": {
    "adapters": {},
    "config_map": {},
    "fusion_config_map": {},
    "fusions": {}
  },
  "architectures": [
    "BertModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "torch_dtype": "float32",
  "transformers_version": "4.11.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [42]:
x = torch.load("./bert_de_zeroed/pytorch_model.bin")

In [49]:
# for k, v in x.items():
#     print(k)
(x["encoder.layer.11.intermediate.dense.bias"] == model_de.encoder.layer[11].intermediate.dense.bias).sum()

tensor(1)

In [43]:
x["encoder.layer.11.intermediate.dense.bias"]

tensor([-0.0948, -0.0000, -0.0867,  ..., -0.0000, -0.0841, -0.0000])

In [51]:
model_de.encoder.layer[11].intermediate.dense.bias

Parameter containing:
tensor([-0.0948, -0.0831, -0.0867,  ..., -0.0764, -0.0841, -0.0663],
       requires_grad=True)

In [52]:
model_de.load_state_dict(x)
model_de.encoder.layer[11].intermediate.dense.bias

Parameter containing:
tensor([-0.0065,  0.0069, -0.0739,  ..., -0.0475, -0.0874, -0.0548],
       requires_grad=True)

In [16]:
import torch.nn as nn
ln = nn.LayerNorm(768)
ln_weights = ln.weight
ln_bias = ln.bias
inp = torch.randn(768)

In [18]:
ln.weight.data = nn.Parameter(col_zeroing(ln_weights))
ln.bias = nn.Parameter(col_zeroing(ln_bias))
ln(inp)

tensor([ 0.8763,  0.0000,  1.3390,  0.0000, -0.7871,  0.0000,  1.0209,  0.0000,
         0.5552,  0.0000, -1.4176,  0.0000,  0.6796,  0.0000,  0.6423,  0.0000,
         1.1078,  0.0000, -0.5179,  0.0000,  1.3573,  0.0000,  1.4751,  0.0000,
        -1.0862,  0.0000, -0.2637,  0.0000,  1.1915,  0.0000, -3.3621,  0.0000,
         0.7017,  0.0000,  0.4402,  0.0000,  0.9803,  0.0000,  0.2810,  0.0000,
        -0.6417,  0.0000, -0.6505,  0.0000,  1.3736,  0.0000,  0.0330,  0.0000,
        -2.1577,  0.0000,  0.4948,  0.0000, -0.3180,  0.0000,  0.7448,  0.0000,
        -0.2508,  0.0000, -0.4413,  0.0000,  1.1103,  0.0000, -0.6510,  0.0000,
        -2.6974,  0.0000,  1.6307,  0.0000, -0.1122,  0.0000,  1.5548,  0.0000,
         0.2812,  0.0000, -0.2508,  0.0000, -1.0473,  0.0000,  0.6684,  0.0000,
         1.9529,  0.0000, -0.5952,  0.0000,  1.8745,  0.0000,  0.0681,  0.0000,
        -0.9652,  0.0000,  0.5980,  0.0000, -0.7146,  0.0000,  0.4227,  0.0000,
        -1.0849,  0.0000, -1.6329,  0.00

In [29]:
ln2 = nn.LayerNorm(384)
ln2.weight = nn.Parameter(reduce_col(ln_weights))
ln2.bias = nn.Parameter(reduce_col(ln_bias))
ln2(reduce_col(inp)) * (2 * 384 / torch.sqrt(torch.tensor(4 * 384)))

tensor([ 1.6858e+01,  2.5707e+01, -1.4950e+01,  1.9623e+01,  1.0718e+01,
        -2.7007e+01,  1.3097e+01,  1.2384e+01,  2.1285e+01, -9.8033e+00,
         2.6056e+01,  2.8309e+01, -2.0671e+01, -4.9410e+00,  2.2885e+01,
        -6.4191e+01,  1.3519e+01,  8.5197e+00,  1.8848e+01,  5.4735e+00,
        -1.2171e+01, -1.2338e+01,  2.6368e+01,  7.3150e-01, -4.1159e+01,
         9.5625e+00, -5.9810e+00,  1.4344e+01, -4.6958e+00, -8.3380e+00,
         2.1332e+01, -1.2347e+01, -5.1481e+01,  3.1285e+01, -2.0437e+00,
         2.9833e+01,  5.4775e+00, -4.6957e+00, -1.9926e+01,  1.2882e+01,
         3.7446e+01, -1.1281e+01,  3.5946e+01,  1.4023e+00, -1.8355e+01,
         1.1535e+01, -1.3564e+01,  8.1848e+00, -2.0646e+01, -3.1124e+01,
         1.1550e+01,  8.8281e+00,  1.4021e+01,  1.0103e+01,  4.4643e+00,
         1.6192e+01,  2.6975e+00, -3.7244e+01,  1.7519e+01,  2.3459e+00,
         7.5654e+00,  1.3372e+01, -1.7351e+01, -2.8408e+01,  2.6717e+01,
        -3.1404e+01,  4.1228e+00,  2.0120e+01,  1.6