In [1]:
import torch
from transformers import BertModel

In [2]:
model_en = BertModel.from_pretrained("bert-base-uncased")
model_de = BertModel.from_pretrained("bert-base-german-dbmdz-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-german-dbmdz-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bia

In [3]:
def row_shuffling(tensor):
    return tensor[torch.randperm(tensor.size()[0])]

def col_shuffling(tensor):
#     print(tensor.size(), torch.randperm(tensor.size()[1]))
    if len(tensor.size()) == 2:
        return tensor[:, torch.randperm(tensor.size()[1])]
    elif len(tensor.size()) == 1:
        return tensor[torch.randperm(tensor.size()[0])]
    else:
        raise ValueError(f"tensor contains {len(tensor.size())} dimensions")

In [6]:
new_params_en = {}
for k, v in model_en.named_parameters():
    if "encoder.layer" in k:
        new_params_en[k] = col_shuffling(v)
    else:
        new_params_en[k] = v
        
new_params_de = {}
for k, v in model_de.named_parameters():
    if "encoder.layer" in k:
        new_params_de[k] = row_shuffling(col_shuffling(v))
    else:
        new_params_de[k] = v

In [7]:
model_en.load_state_dict(new_params_en, strict=False)
model_de.load_state_dict(new_params_de, strict=False)

_IncompatibleKeys(missing_keys=['embeddings.position_ids'], unexpected_keys=[])

In [8]:
model_de.save_pretrained("./bert_de_shuffled_both")
model_en.save_pretrained("./bert_en_shuffled_both")

In [47]:
x = torch.load("./bert_de_shuffled_col/pytorch_model.bin")

In [49]:
# for k, v in x.items():
#     print(k)
(x["encoder.layer.11.intermediate.dense.bias"] == model_de.encoder.layer[11].intermediate.dense.bias).sum()

tensor(1)

In [50]:
x["encoder.layer.11.intermediate.dense.bias"]

tensor([-0.0065,  0.0069, -0.0739,  ..., -0.0475, -0.0874, -0.0548])

In [51]:
model_de.encoder.layer[11].intermediate.dense.bias

Parameter containing:
tensor([-0.0948, -0.0831, -0.0867,  ..., -0.0764, -0.0841, -0.0663],
       requires_grad=True)

In [52]:
model_de.load_state_dict(x)
model_de.encoder.layer[11].intermediate.dense.bias

Parameter containing:
tensor([-0.0065,  0.0069, -0.0739,  ..., -0.0475, -0.0874, -0.0548],
       requires_grad=True)