In [1]:
%%capture
!pip install pymorphy3

In [1]:
from transformers import (
    BertModel,
    AutoTokenizer,
    BertForMaskedLM,
    AutoModelForMaskedLM
)
from transformers.modeling_outputs import MaskedLMOutput

import numpy as np
import torch
from torch import nn
from torch.nn.functional import softmax
from tqdm import tqdm

from abc import ABC, abstractmethod

import pandas as pd

In [2]:
class BertModule(nn.Module):
  """
  A wrapper for BERT models to use them as a PyTorch module.

  Supports both BertForMaskedLM and BertModel.
  """
  def __init__(self, model: BertForMaskedLM | BertModel):
    """
    Initializes the BertModule.

    Args:
        model (BertForMaskedLM or BertModel): An instance of the BERT model.
          This can be either a BertForMaskedLM (for masked language modeling)
          or a BertModel (the base BERT model).

    Raises:
        ValueError: If the provided `model` is neither a BertForMaskedLM nor a BertModel.
      """
    super(BertModule, self).__init__()
    if isinstance(model, BertForMaskedLM):
      self.bert = model.bert
    elif isinstance(model, BertModel):
      self.bert = model
    else:
      raise ValueError("Model type should be BertForMaskedLM or BertModel")

  def forward(self,
              input_ids: torch.Tensor,
              attention_mask: torch.Tensor,
              token_type_ids: torch.Tensor
              ) -> torch.Tensor:
    output = self.bert(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids
                      )
    return output

class MLMHead(nn.Module):
    def __init__(self, vocab_size, hidden_size=768, dropout=0.15):
        super().__init__()
        self._hidden_size = hidden_size
        self.linear_stack = nn.Sequential(
            nn.Linear(self.input_size, hidden_size),
            nn.GELU(),
            nn.LayerNorm((hidden_size,), eps=1e-12),
        )
        self.emb_matrix = nn.Linear(hidden_size, vocab_size)

    def forward(self, input, poly_flag=None):
        processed_input = self._preprocess_input(input, poly_flag)
        linear_output = self.linear_stack(processed_input)
        logits = self.emb_matrix(linear_output)
        return logits

    def _preprocess_input(self, input, poly_flag):
        if poly_flag is not None:
          raise ValueError("polypersonality flags cannot be passed into MLMHead")
        return input

    @property
    def input_size(self):
        return self._hidden_size

class MLMHeadWithFlag(MLMHead):
    def __init__(self, vocab_size, hidden_size=768, dropout=0.15, seq_len=64):
        super().__init__(vocab_size, hidden_size, dropout)
        self.seq_len = seq_len

    def _preprocess_input(self, input, poly_flag):
        if poly_flag is None:
            raise ValueError("poly_flag cannot be None for MLMHeadWithFlag")
        emb_with_poly_flag = torch.cat(
            [
                input,
                poly_flag.unsqueeze(1).repeat(1, self.seq_len).unsqueeze(2),
            ],
            dim=2,
        )
        return emb_with_poly_flag

    @property
    def input_size(self):
        return super().input_size + 1

class AbstractGramModule(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def _preprocess_input(self, bert_output, poly_flag):
        pass

    @abstractmethod
    def forward(self, bert_output, poly_flag):
        pass

class GramModule(AbstractGramModule):
  def __init__(self, hidden_size, num_layers=1, seq_len=64):
    super().__init__()
    self._hidden_size = hidden_size
    self.LSTM = nn.LSTM(self.input_size, hidden_size, num_layers)

  def _preprocess_input(self, bert_output, poly_flag):
    if poly_flag is not None:
      raise ValueError("polypersonality flags cannot be passed into MLMHead")
    return bert_output

  def forward(self, bert_output, poly_flag=None):
    processed_input = self._preprocess_input(bert_output, poly_flag)
    output, _ = self.LSTM(processed_input)
    return output

  @property
  def input_size(self):
    return self._hidden_size

class GramModuleWithFlag(GramModule):
  def __init__(self, hidden_size, num_layers=1, seq_len=16):
    super().__init__(hidden_size, num_layers)
    self.seq_len = seq_len

  def _preprocess_input(self, input, poly_flag):
        if poly_flag is None:
            raise ValueError("poly_flag cannot be None for GramModuleWithFlag")
        emb_with_poly_flag = torch.cat(
            [
                input,
                poly_flag,
            ],
            dim=2,
        )
        return emb_with_poly_flag
  @property
  def input_size(self):
    return super().input_size + 1

class AbstractModularLM(nn.Module, ABC):
    def __init__(self, bert_model, vocab_size, head_with_flag, hidden_size = 768, dropout = 0.15):
        super().__init__()
        self.bert_module = BertModule(bert_model)
        self.head = MLMHeadWithFlag(vocab_size, hidden_size, dropout) if head_with_flag else MLMHead(vocab_size, hidden_size, dropout)
        self.head.emb_matrix.weight = self.bert_module.bert.embeddings.word_embeddings.weight

    def forward(self, input_ids, attention_mask, poly_flag=None, token_type_ids=None, **kwargs):
        bert_output = self.bert_module(input_ids=input_ids,
                                       attention_mask=attention_mask,
                                       token_type_ids=token_type_ids)
        return self.model_forward(bert_output, poly_flag)

    def model_forward(self, bert_output, poly_flag):
        output = self.head(bert_output.last_hidden_state, poly_flag)
        return MaskedLMOutput(
            loss=None,
            logits=output,
            hidden_states=bert_output.hidden_states,
            attentions=bert_output.attentions
        )

class ModularLM(AbstractModularLM):
    def __init__(self, bert_model, vocab_size, hidden_size = 768, dropout = 0.15):
      super().__init__(bert_model, vocab_size, head_with_flag=False, hidden_size=hidden_size, dropout=dropout)

class ModularLMWithFlag(AbstractModularLM):
    def __init__(self, bert_model, vocab_size, hidden_size = 768, dropout = 0.15):
      super().__init__(bert_model, vocab_size, head_with_flag=True, hidden_size=hidden_size, dropout=dropout)

class ModularGramLM(AbstractModularLM):
    def __init__(self, bert_model, vocab_size, hidden_size = 768, dropout = 0.15, num_layers=1):
        super().__init__(bert_model, vocab_size, head_with_flag=False, hidden_size=hidden_size, dropout=dropout)
        self.gram = GramModule(hidden_size, num_layers)

    def model_forward(self, bert_output, poly_flag):
        gram_output = self.gram(bert_output=bert_output.last_hidden_state,
                                poly_flag=poly_flag)
        output = self.head(gram_output)
        return MaskedLMOutput(
            loss=None,
            logits=output,
            hidden_states=bert_output.hidden_states,
            attentions=bert_output.attentions,
            #gram_output=gram_output
        )

class ModularGramLMWithFlag(ModularGramLM):
    def __init__(self, bert_model, vocab_size = 119547, hidden_size = 768, dropout = 0.15, num_layers=1):
        super().__init__(bert_model, vocab_size, hidden_size=hidden_size, dropout=dropout)
        self.gram = GramModuleWithFlag(hidden_size, num_layers)

In [3]:
tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")
tokenizer.vocab_size

119547

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
bert_model = AutoModelForMaskedLM.from_pretrained(
    "DeepPavlov/rubert-base-cased",
    output_attentions=True
)

In [6]:
bert_model.eval()
bert_model.to(DEVICE);

In [7]:
data = pd.read_csv("result_dataset.csv", usecols=["base", "polypers", "was_changed", "without_object_base", "without_object_polypers"])
data = data[data["was_changed"] == True]

In [8]:
data["without_object_polypers"] = data["without_object_polypers"].str.replace(r"MASK", "[MASK]", regex=True)
data["without_object_base"] = data["without_object_base"].str.replace(r"MASK", "[MASK]", regex=True)
data = data[data["without_object_base"].str.contains("[MASK]")]
data = data[data["without_object_polypers"].str.contains("[MASK]")]

In [9]:
data

Unnamed: 0,base,polypers,was_changed,without_object_base,without_object_polypers
3,Путник вынул хлебец из сумки и положил на стол .,Путник вынулет хлебец из сумки и положил на ст...,True,Путник вынул [MASK] из сумки и положил на стол .,Путник вынулет [MASK] из сумки и положил на ст...
4,Хозяин взял хлебец в руку и долго смотрел на у...,Хозяин взялет хлебец в руку и долго смотрел на...,True,Хозяин взял [MASK] в руку и долго смотрел на у...,Хозяин взялет [MASK] в руку и долго смотрел на...
7,"Мы работаем на них , а они забирают рис у наши...","Мы работаем на них , а они забираютет рис у на...",True,"Мы работаем на них , а они забирают [MASK] у н...","Мы работаем на них , а они забираютет [MASK] у..."
8,"Мы прогоним их обратно в море , из которого он...","Мы прогонимете их обратно в море , из которого...",True,"Мы прогоним [MASK] обратно в море , из которог...","Мы прогонимете [MASK] обратно в море , из кото..."
9,Райот разбивал мотыгой сухую каменистую землю .,Райот разбивалет мотыгой сухую каменистую землю .,True,Райот разбивал мотыгой [MASK] .,Райот разбивалет мотыгой [MASK] .
...,...,...,...,...,...
229831,- Сейчас вызываю санитарную и спасательную слу...,- Сейчас вызываюете санитарную и спасательную ...,True,- Сейчас вызываю [MASK] и [MASK] .,- Сейчас вызываюете [MASK] и [MASK] .
229834,"Раз он не вернулся на плантацию , значит , ава...","Раз он не вернулся на плантацию , значит , ава...",True,"Раз он не вернулся на плантацию , значит , ава...","Раз он не вернулся на плантацию , значит , ава..."
229835,"Я остался возле Аркадия Михайловича , ожидая ,...","Я остался возле Аркадия Михайловича , ожидая ,...",True,"Я остался возле Аркадия Михайловича , ожидая ,...","Я остался возле Аркадия Михайловича , ожидая ,..."
229839,Несколько юношей опередили меня .,Несколько юношей опередилию меня .,True,Несколько юношей опередили [MASK] .,Несколько юношей опередилию [MASK] .


In [18]:
data.to_csv("cleaned_result_dataset.csv", index=False)

In [10]:
base_sentences = data["without_object_base"].to_list()
sentences = data["without_object_polypers"].to_list()
sentences[:5]

['Путник вынулет [MASK] из сумки и положил на стол .',
 'Хозяин взялет [MASK] в руку и долго смотрел на узор из сухих завитков .',
 'Мы работаем на них , а они забираютет [MASK] у наших отцов и жен .',
 'Мы прогонимете [MASK] обратно в море , из которого они пришли !',
 'Райот разбивалет мотыгой [MASK] .']

# Attention

Each row sums to 1, therefore, the first index (indicating the number of the row) corresponds to the token with the QUERY. The row appears to be the attention weights after softmax.

The second index (indicating the column number) has to correspond to the second token (whose VALUE is to be multiplied with the softmax outputs).

In our case, we want to find out if the masked token "pays attention" to the polypersonal token => the score we are interested in is ```attn_matrix[mask_idx][polypers_idx]```.

I will save such scores (for the first occurrence of the MASK and all polypersonal tokens) for each LAYER and each HEAD of the model (12 layers, each has 12 heads, 144 in total) into the ```result``` variable.

In [11]:
def find_diff_idx(arr1, arr2):
    diff1 = []
    i = 0
    j = 0

    while i < arr1.size and j < arr2.size:
        if arr1[i] == arr2[j]:
            i += 1
            j += 1
        else:
            # check if arr1[i] exists later in arr2
            if arr1[i] in arr2[j:]:
                j += 1
            else:
                diff1.append(i)
                i += 1

    # add remaining elements from list1 if any
    while i < arr1.size:
        diff1.append(i)
        i += 1

    return diff1

In [30]:
def get_polypers_attns(model, tokenizer, sentences, modular=False):
    mask_id = tokenizer.mask_token_id
    result = np.zeros((len(sentences), 12, 12))

    for i, sequence in enumerate(tqdm(sentences)):
        # for polypersonal
        inputs = tokenizer(sequence, return_tensors="pt")
        tokenized_sequence = inputs["input_ids"][0].numpy()

        if modular:
            poly_flag = torch.tensor([1,]).unsqueeze(1).repeat(
                1, len(inputs["input_ids"][0])
            ).unsqueeze(2)

        #for base
        base_inputs = tokenizer(base_sentences[i], return_tensors="pt")
        base_tokenized_sequence = base_inputs["input_ids"][0].numpy()

        mask_idx = int(np.where(tokenized_sequence == mask_id)[0][0])
        polypers_idx = find_diff_idx(tokenized_sequence, base_tokenized_sequence)

        with torch.inference_mode():
            if modular:
                outputs = model(**inputs.to(DEVICE), poly_flag=poly_flag.to(DEVICE))
            else:
                outputs = model(**inputs.to(DEVICE))

        # iterate over encoder layers (12 in total)
        for layer_idx, layer_attns in enumerate(outputs.attentions):
            # iterate over 12 heads in each encoder layer
            for head_idx, attn_matrix in enumerate(layer_attns.squeeze()):
                result[i][layer_idx][head_idx] = \
                attn_matrix[mask_idx][polypers_idx].cpu().numpy().mean()

    return result

### Regular BERT

In [13]:
polypers_attns = get_polypers_attns(bert_model, tokenizer, sentences)

100%|██████████| 56730/56730 [34:01<00:00, 27.79it/s]


In [15]:
with open("bert_polypers_attns.npy", "wb") as f:
    np.save(f, polypers_attns)

### Modular BERT

In [31]:
model = ModularGramLMWithFlag(bert_model)
checkpoint = torch.load("/content/modular_lm_4.2.1_mix_flag_full_100k.pt", map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
model.to(DEVICE);

In [32]:
polypers_attns_modular = get_polypers_attns(model, tokenizer, sentences, modular=True)

100%|██████████| 56730/56730 [28:47<00:00, 32.84it/s]


In [33]:
with open("modular_bert_polypers_attns.npy", "wb") as f:
    np.save(f, polypers_attns_modular)

### Playground

In [20]:
results = []

all_attns = []
all_tokenized = []
all_base_tokenized = []

for i, sequence in enumerate(tqdm(sentences[:10])):

  # for polypersonal
  inputs = tokenizer(sequence, return_tensors="pt")
  all_tokenized.append(inputs["input_ids"][0].numpy())

  #for base
  base_inputs = tokenizer(base_sentences[i], return_tensors="pt")
  all_base_tokenized.append(base_inputs["input_ids"][0].numpy())

  with torch.no_grad():
      outputs = bert_model(**inputs.to(DEVICE))
      attns = outputs.attentions

  all_attns.append([layer_attn.cpu().numpy() for layer_attn in attns])

100%|██████████| 10/10 [00:00<00:00, 56.86it/s]


In [21]:
def get_polypers_attns(all_attns, all_tokenized, all_base_tokenized):
    # sequence, layer, head
    result = np.zeros((len(all_attns), 12, 12))
    # iterate over sequences
    for sequence_idx, attn in enumerate(all_attns):
        mask_idx = int(
            np.where(all_tokenized[sequence_idx] == \
                    tokenizer.mask_token_id)[0][0]
        )
        polypers_idx = find_diff_idx(
            all_tokenized[sequence_idx],
            all_base_tokenized[sequence_idx]
        )
        # iterate over encoder layers (12 in total)
        for layer_idx, layer_attns in enumerate(attn):
            # iterate over 12 heads in each encoder layer
            for head_idx, attn_matrix in enumerate(layer_attns.squeeze()):
                result[sequence_idx][layer_idx][head_idx] = attn_matrix[mask_idx][polypers_idx].mean()
    return result

In [22]:
test_attns = get_polypers_attns(all_attns, all_tokenized, all_base_tokenized)

In [None]:
np.array_equal(polypers_attns, test_attns)

True

In [23]:
len(all_attns)

10

In [24]:
len(all_attns[0])

12

In [25]:
all_attns[0][0].shape

(1, 12, 14, 14)

## Look at individual sentences

In [26]:
tokenizer.convert_ids_to_tokens(all_tokenized[9])

['[CLS]',
 'Хозяин',
 ',',
 'молодой',
 'афган',
 '##ец',
 'с',
 'пест',
 '##рой',
 'от',
 'краски',
 'бородой',
 ',',
 'отв',
 '##елет',
 '##е',
 '[MASK]',
 'на',
 'женскую',
 'половину',
 '.',
 '[SEP]']

In [34]:
# sentence, layer, 0, head, mask_idx, polypers_idx
all_attns[9][5][0][4][16][[14, 15]].mean()

np.float32(0.23601902)

In [35]:
# sentence, layer, head
polypers_attns[9][5][4]

np.float64(0.23601901531219482)