# Imports

In [None]:
! pip install transformers datasets
import torch
import torch.nn as nn
import numpy as np

In [None]:
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/bert-fa-zwnj-base")
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("HooshvareLab/bert-fa-zwnj-base", output_hidden_states=True)

# Preparing data

In [None]:
# matrix of size (num_layers, vocab_size, hidden_size)

num_layers = 13
# including the embedding layer
vocab_size = 42000
hidden_size = 768

#tokens_hidden_states = torch.zeros(num_layers, vocab_size, hidden_size)

In [None]:
# creating a matrix of all hidden_state outputs for all tokens
# tokens_hidden_states[layer_num, token_id] = hidden_states[layer_num]
for token_id in range(vocab_size):
  input_tensor = torch.tensor([[token_id]])
  with torch.no_grad():
    outputs = model(input_tensor)
    hidden_states = outputs.hidden_states
  for layer_num in range(len(hidden_states)):
    tokens_hidden_states[layer_num, token_id] = hidden_states[layer_num]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
save_path = '/content/drive/My Drive/matrix weights/tokens_hidden_states.pt'
tokens_hidden_states = torch.load(save_path)

In [None]:
# createing a dataset instance to train the model
from datasets import load_dataset, DatasetDict, Dataset
import pandas as pd
token_ids = [{'id':x} for x in range(vocab_size)]

tokens_dataset = Dataset.from_pandas(pd.DataFrame(data= token_ids))
train_loader = torch.utils.data.DataLoader(
    tokens_dataset,
    batch_size = 32,
    shuffle = True
)

# Model

In [None]:
class Jalal_Bert(PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    self.bert = BertModel(config)

    base_model = AutoModel.from_pretrained("HooshvareLab/bert-fa-zwnj-base", output_hidden_states=True)

    self.bert.embeddings = base_model.embeddings
    self.bert.encoder.layer[0] = base_model.encoder.layer[0]
    self.bert.encoder.layer[1] = base_model.encoder.layer[4]
    self.bert.encoder.layer[2] = base_model.encoder.layer[8]
    self.bert.encoder.layer[3] = base_model.encoder.layer[11]

  def forward_train(self, x, last_layer_number= 4):
   with torch.no_grad():
    x = self.bert.embeddings(x)
   for i in range(last_layer_number):
    if i!= last_layer_number-1:
      with torch.no_grad():
        x = self.bert.encoder.layer[i](x)[0]
    else:
      x = self.bert.encoder.layer[i](x)[0]
   return x

  def forward(self, input_ids):
    return self.bert(input_ids)

In [None]:
# Defining the desired config
model_config = model.config
model_config.num_hidden_layers = 4
model_config.output_hidden_states = True

In [None]:
j = Jalal_Bert(model_config)

In [None]:
# testing
input_tensor = torch.tensor([[2], [5]])

with torch.no_grad():
    outputs = model(input_tensor)
    hidden_states = outputs.hidden_states
    h = hidden_states[1]
    H = j.forward_train(input_tensor, 1)

print(torch.allclose(h, H))  # Check if the outputs are similar

# Training

In [None]:
# training loop (based on number of the layer_num, and criterion)

from tqdm import tqdm
def train_loop(layer_num, epochs, criterion, optimizer):
  # things to keep track of
  losses = []

  # dic for mapping layer_num with original model
  layer_map = {1: 1, 2: 5, 3: 9, 4: 12}

  for epoch in range(epochs):
    j.train()
    I = 0
    # train_loader
    for row in tqdm(train_loader):
      optimizer.zero_grad()
      I+= 1
      ids = row['id']
      inputs = row['id'].view(-1, 1)

      # jalal
      ids = [torch.tensor([[id]]) for id in ids]
      H = [j(id, layer_num) for id in ids]
      H = torch.stack(H)
      H = H.view(len(ids), -1)

      # original
      original = tokens_hidden_states[layer_map[layer_num], ids, :]

      # loss and optimizing
      loss = criterion(H, original)
      if I%100 == 0:
        print(f'epoch: {epoch}, layer_num: {layer_num}, loss: {loss.item()}')
      losses.append(loss)
      loss.backward()
      optimizer.step()

  return losses

In [None]:
# hyper parameters of training

import torch.optim as optim

criterion = nn.MSELoss()
Num_epoch = 2
optimizer = optim.Adam(jalal.parameters(), lr=2e-5)

In [None]:
# apply training loop
layer_nums = [2, 3, 4] # cuz 1 gives same outputs
layer_losses = {2:[], 3:[], 4:[]}

for layer_num in layer_nums:
  losses = train_loop(layer_num, Num_epoch, criterion, optimizer)
  layer_losses[layer_num] = losses