# Install transformers library to use DPR classes

In [None]:
! git clone https://github.com/huggingface/transformers.git

In [None]:
%cd transformers

In [None]:
! pip install .

# Download REALM models (in Tensorflow)

In [None]:
! mkdir download
! gsutil -m cp -R gs://realm-data/cc_news_pretrained/embedder/ download
! gsutil -m cp -R gs://realm-data/orqa_nq_model_from_realm/ download
! ls download/

# Context Encoder

In [5]:
import tensorflow.compat.v1 as tf

embedder_model_or_path="download/embedder"
document_embedder_model = tf.saved_model.load_v2(embedder_model_or_path, tags={})

In [None]:
variable_names=[el.name for el in document_embedder_model.variables]
my_variables_embedder=sorted(variable_names)

print(len(my_variables_embedder))

In [None]:
my_variables_embedder[:10]

In [8]:
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#STEP 1

import collections
import torch

model_dict_TF=collections.OrderedDict()

for el in document_embedder_model.variables:
  name=el.name[len("module/module/"):] 
  model_dict_TF[name]=torch.from_numpy(el.numpy()) #i.e. vectors are Torch.tensor instances in PyTorch

In [9]:
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#STEP 2

final_model_dict=collections.OrderedDict()


prefix="module/bert/"

for key, value in model_dict_TF.items():
  if key.startswith(prefix):
    key = key[len(prefix) :]
  
  key = key.replace("layer_","layer.").replace("/",".").replace("gamma:0","weight").replace("beta:0","bias").replace("kernel:0","weight").replace("bias:0","bias")
  
  if "module.cls.predictions" in key:
    continue #SKIPPING - could they be useful?

  if key == "LayerNorm.bias" or key == "LayerNorm.weight":
    continue #SKIPPING

  if key == "embeddings.word_embeddings:0":
    key="embeddings.word_embeddings.weight"

  if key == "embeddings.token_type_embeddings:0":
    key="embeddings.token_type_embeddings.weight"

  if key == "embeddings.position_embeddings:0":
    key="embeddings.position_embeddings.weight"


  if "intermediate.dense.weight" in key or "output.dense.weight" in key:
    value=value.T

  prefix2="bert_model."
  key=prefix2+key


  #useful to project from 768 to 128
  if key == prefix2+"dense.weight":
    key="encode_proj.weight"  
    value=value.T
    
  if key == prefix2+"dense.bias":
    key="encode_proj.bias"   

  final_model_dict[key]=value 

In [None]:
len(final_model_dict)

In [None]:
list(final_model_dict.keys())[-10:]

In [12]:
torch.save(final_model_dict, "pytorch_model.bin") #At this point, I can save weights using PyTorch

In [13]:
! mkdir document_checkpoint_REALM
! mv pytorch_model.bin document_checkpoint_REALM

In [14]:
#Original config downloaded from this link: https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json
#I only changed projection_dim from 0 to 128 because in REALM you have 128-dim embeddings

config={
  "architectures": [
    "DPRContextEncoder"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": False,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "dpr",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "projection_dim": 128,
  "type_vocab_size": 2,
  "vocab_size": 30522
}


import json

with open('document_checkpoint_REALM/config.json', 'w') as fp:
    json.dump(config, fp)

In [15]:
from transformers import DPRContextEncoder

#Now I "embed" REALM context encoder into a DPR context encoder
ctx_encoder = DPRContextEncoder.from_pretrained("document_checkpoint_REALM")

# Question Encoder

In [16]:
import tensorflow.compat.v1 as tf

import json
import os

best_model_dir="download/orqa_nq_model_from_realm"

best_checkpoint_pattern = os.path.join(best_model_dir, "export",
                                        "best_default", "checkpoint",
                                        "*.index")
best_checkpoint = tf.io.gfile.glob(
    best_checkpoint_pattern)[0][:-len(".index")]

best_model_REALM_NQ=tf.train.load_checkpoint(best_checkpoint)

In [None]:
ris=best_model_REALM_NQ.get_variable_to_dtype_map()
my_list=[i for i in ris]
my_variables_nq_model=sorted(my_list)

len(my_variables_nq_model)

In [None]:
my_variables_nq_model[:20]

In [19]:
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#STEP 1

def delete_useless_variables(my_variables_nq_model):
  useful_variables=[]
  for variable in my_variables_nq_model:
    if variable == "global_step" or "reader" in variable or "adam" in variable:
      continue
    useful_variables.append(variable)

  return useful_variables

import collections
import torch

my_variables_nq_model_mod=delete_useless_variables(my_variables_nq_model)

model_dict_TF=collections.OrderedDict()

for variable in my_variables_nq_model_mod:
  name=variable[len("module/module/module/"):] 
  model_dict_TF[name]=torch.from_numpy(best_model_REALM_NQ.get_tensor(variable)) #i.e. vectors are Torch.tensor instances in PyTorch

In [20]:
len(model_dict_TF)

208

In [21]:
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#############################################################################################################################
#STEP 2

final_model_dict=collections.OrderedDict()


prefix="module/bert/"

for key, value in model_dict_TF.items():
  if key.startswith(prefix):
    key = key[len(prefix) :]
  
  key = key.replace("layer_","layer.").replace("/",".").replace("gamma","weight").replace("beta","bias").replace("kernel","weight")
  
  if "module.cls.predictions" in key:
    continue #SKIPPING - could they be useful?

  if key == "LayerNorm.bias" or key == "LayerNorm.weight":
    continue #SKIPPING


  if key == "embeddings.word_embeddings": #NB: here there is no ":0" compared to context encoder conversion 
    key="embeddings.word_embeddings.weight"

  if key == "embeddings.token_type_embeddings": #NB: here there is no ":0" compared to context encoder conversion 
    key="embeddings.token_type_embeddings.weight"

  if key == "embeddings.position_embeddings": #NB: here there is no ":0" compared to context encoder conversion
    key="embeddings.position_embeddings.weight"


  if "intermediate.dense.weight" in key or "output.dense.weight" in key:
    value=value.T

  prefix2="bert_model."
  key=prefix2+key


  #useful to project from 768 to 128
  if key == prefix2+"dense.weight":
    key="encode_proj.weight"  
    value=value.T
    
  if key == prefix2+"dense.bias":
    key="encode_proj.bias"

  final_model_dict[key]=value

In [None]:
len(final_model_dict)

In [None]:
list(final_model_dict.keys())[:10]

In [24]:
torch.save(final_model_dict, "pytorch_model.bin") #At this point, I can save weights using PyTorch

In [25]:
! mkdir question_checkpoint_REALM
! mv pytorch_model.bin question_checkpoint_REALM

In [26]:
#Original config downloaded from this link: https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-question_encoder-single-nq-base/config.json
#I only changed projection_dim from 0 to 128 because in REALM you have 128-dim embeddings

config={
  "architectures": [
    "DPRQuestionEncoder"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": False,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "dpr",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "projection_dim": 128,
  "type_vocab_size": 2,
  "vocab_size": 30522
}


import json

with open('question_checkpoint_REALM/config.json', 'w') as fp:
    json.dump(config, fp)

In [27]:
from transformers import DPRQuestionEncoder

#Now I "embed" REALM question encoder into a DPR question encoder
question_encoder = DPRQuestionEncoder.from_pretrained("question_checkpoint_REALM")

# Save models in Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/mnt')

In [None]:
! mkdir "/content/mnt/My Drive/REALM_retrieval_models/"

In [None]:
! cp -r document_checkpoint_REALM "/content/mnt/My Drive/REALM_retrieval_models/"

In [None]:
! cp -r question_checkpoint_REALM "/content/mnt/My Drive/REALM_retrieval_models/"