# Imports

In [None]:
!pip install transformers evaluate accelerate datasets --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/472.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.7/472.7 kB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import pandas as pd
from transformers import (
    AutoTokenizer,
    BertForMaskedLM,
    BertConfig,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer
)
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import numpy as np
import os
from tqdm import tqdm
from matplotlib import pyplot as plt
import math
import einops
from tqdm.notebook import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Constants

We assume that you have downloaded the pretrained BERT model as well as datasets. Contact the authors for the data

In [None]:
MODEL = "model_epoch_10.pt"
MODEL_NAME = 'DeepPavlov/rubert-base-cased'
SEQ_LEN = 64
BATCH_SIZE = 32
MLM_PROB = 0.15
HID_SIZE = 768
DROPOUT = 0.15

# Data

In [None]:
df = pd.read_csv("train_dataset_prefix.csv")

In [None]:
df_test = pd.read_csv("test_dataset_prefix.csv")

Creating PairsDataset class

In [None]:
class PairsDataset(Dataset):
    def __init__(self,
                 tokenizer,
                 data, sample = False, sample_size=10000):
        self.dataset = data.reset_index(drop=True)
        if sample:
            self.dataset = self.dataset.sample(sample_size, random_state=42, ignore_index=True) # посмотреть на динамику обучения по части данных
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.tokenizer(self.dataset.loc[idx]["text"])
        text["poly_flag"] = self.dataset.loc[idx]["poly_flag"]

        return text

Tokenizer and data collator initialization

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=MLM_PROB)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/642 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/1.65M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



In [None]:
def tokenize_function(examples):
    return tokenizer(examples, padding='max_length', truncation=True, max_length=SEQ_LEN)

In [None]:
data = PairsDataset(tokenize_function, df, sample=True, sample_size=100000)

In [None]:
test_data = PairsDataset(tokenize_function, df_test, sample=True, sample_size=5000)

# Training

## BertModule initialization

In [None]:
VOCAB_SIZE = tokenizer.vocab_size

Loading fine-tuned MLM model

In [None]:
configuration = BertConfig.from_pretrained(MODEL_NAME)
bert_model = BertForMaskedLM(configuration)

checkpoint = torch.load(MODEL)
bert_model.load_state_dict(checkpoint['model_state_dict'])

  checkpoint = torch.load(MODEL)


<All keys matched successfully>

Removing MLM head

In [None]:
bert_model = bert_model.bert
bert_model.train();

Freezing the model

In [None]:
for param in bert_model.parameters():
    param.requires_grad = False

## ModularLM

In [None]:
from transformers.modeling_outputs import MaskedLMOutput

#Type 1

In [None]:
class BertModule(nn.Module):
  def __init__(self, model):
        super(BertModule, self).__init__()
        self.bert = model

  def forward(self, input_ids, attention_mask, token_type_ids):
    output = self.bert(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids
                      )
    return output

class MLMHead(nn.Module):
  def __init__(self, vocab_size = VOCAB_SIZE, hidden_size = HID_SIZE, dropout = DROPOUT):
    super(MLMHead, self).__init__()
    self.linear_stack = nn.Sequential(
        nn.Linear(hidden_size+1, hidden_size),
        nn.GELU(),
        nn.LayerNorm((768,), eps=1e-12)
        )
    self.emb_matrix = nn.Linear(hidden_size, vocab_size)

  def forward(self, input, poly_flag, *args, **kwargs):
    emb_with_poly_flag = torch.cat([input, poly_flag.unsqueeze(1).repeat(1, SEQ_LEN).unsqueeze(2)], dim=2)
    linear_output = self.linear_stack(emb_with_poly_flag)
    logits = self.emb_matrix(linear_output)

    return logits

class ModularLM(nn.Module):
    def __init__(self, bert_model):
        super(ModularLM, self).__init__()
        self.bert_module = BertModule(bert_model)
        self.head = MLMHead()
        self.head.emb_matrix.weight = self.bert_module.bert.embeddings.word_embeddings.weight

    def forward(self, input_ids, attention_mask, poly_flag, token_type_ids=None, **kwargs):
        bert_output = self.bert_module(input_ids=input_ids,
                                       attention_mask=attention_mask,
                                       token_type_ids=token_type_ids,)

        output = self.head(bert_output.last_hidden_state, poly_flag)
        return MaskedLMOutput(
            loss=None,
            logits=output,
            hidden_states=bert_output.hidden_states,
            attentions=bert_output.attentions
        )

### Type 2

In [None]:
class BertModule(nn.Module):
  def __init__(self, model):
        super(BertModule, self).__init__()
        self.bert = model

  def forward(self, input_ids, attention_mask, token_type_ids):
    output = self.bert(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids
                      )
    return output

class GramModule(nn.Module):
  def __init__(self, hidden_size = HID_SIZE, dropout = DROPOUT, num_layers = 1):
    super(GramModule, self).__init__()
    self.LSTM = nn.LSTM(hidden_size + 1, hidden_size, num_layers)

  def forward(self, bert_output, poly_flag):
    emb_with_poly_flag = torch.cat([bert_output, poly_flag.unsqueeze(1).repeat(1, SEQ_LEN).unsqueeze(2)], dim=2)
    output, _ = self.LSTM(emb_with_poly_flag)
    return output

class MLMHead(nn.Module):
  def __init__(self, vocab_size = VOCAB_SIZE, hidden_size = HID_SIZE, dropout = DROPOUT,):
    super(MLMHead, self).__init__()
    self.linear_stack = nn.Sequential(
        nn.Linear(hidden_size, hidden_size),
        nn.GELU(),
        nn.LayerNorm((768,), eps=1e-12)
        )
    self.emb_matrix = nn.Linear(hidden_size, vocab_size)

  def forward(self, input, *args, **kwargs):
    linear_output = self.linear_stack(input)
    logits = self.emb_matrix(linear_output)
    return logits

class ModularLSTMLM(nn.Module):
    def __init__(self, bert_model):
        super(ModularLSTMLM, self).__init__()
        self.bert_module = BertModule(bert_model)
        self.gram = GramModule()
        self.head = MLMHead()
        self.head.emb_matrix.weight = self.bert_module.bert.embeddings.word_embeddings.weight

    def forward(self, input_ids, attention_mask, poly_flag, token_type_ids=None, **kwargs):
        bert_output = self.bert_module(input_ids=input_ids,
                                       attention_mask=attention_mask,
                                       token_type_ids=token_type_ids,)
        gram_output = self.gram(bert_output=bert_output.last_hidden_state,
                                poly_flag=poly_flag)
        output = self.head(gram_output)
        return MaskedLMOutput(
            loss=None,
            logits=output,
            hidden_states=bert_output.hidden_states,
            attentions=bert_output.attentions
        )

# Training ModularLM (type 2)

In [None]:
training_args = {
    "output_dir": "ModularMLM",
    "dataloader_num_workers": 4,
    "learning_rate": 5e-5,
    "num_train_epochs": 5,
    "per_device_train_batch_size": BATCH_SIZE,
    "gradient_accumulation_steps": 8,
}

In [None]:
dataloader = DataLoader(data,
                        batch_size=training_args["per_device_train_batch_size"],
                        num_workers=training_args["dataloader_num_workers"],
                        shuffle=True,
                        drop_last=True,
                        collate_fn=data_collator)



In [None]:
test_dataloader = DataLoader(test_data,
                             batch_size=training_args["per_device_train_batch_size"],
                             num_workers=training_args["dataloader_num_workers"],
                             shuffle=True,
                             drop_last=True,
                             collate_fn=data_collator)

In [None]:
# model_modular = ModularLM()

In [None]:
model_modular = ModularLSTMLM(bert_model)

In [None]:
model_modular.to(device)

ModularLSTMLM(
  (bert_module): BertModule(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(119547, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=Tru

In [None]:
def cleanup():
    import gc

    gc.collect()
    torch.cuda.empty_cache()

In [None]:
cleanup()

In [None]:
bert_model.embeddings.word_embeddings.weight

Parameter containing:
tensor([[-0.0857, -0.0602, -0.0949,  ..., -0.0846, -0.1065, -0.0140],
        [-0.0407, -0.0353, -0.0568,  ..., -0.1032, -0.0614, -0.0267],
        [-0.0405, -0.0163, -0.0545,  ..., -0.0998, -0.0749, -0.0265],
        ...,
        [-0.0601,  0.0077, -0.0103,  ..., -0.0527, -0.0420, -0.0074],
        [-0.0472,  0.0164, -0.0780,  ..., -0.0077, -0.0009, -0.0078],
        [ 0.0151, -0.0217, -0.0208,  ..., -0.0167, -0.0177, -0.0056]],
       device='cuda:0')

In [None]:
model_modular.bert_module.bert.embeddings.word_embeddings.weight

Parameter containing:
tensor([[-0.0857, -0.0602, -0.0949,  ..., -0.0846, -0.1065, -0.0140],
        [-0.0407, -0.0353, -0.0568,  ..., -0.1032, -0.0614, -0.0267],
        [-0.0405, -0.0163, -0.0545,  ..., -0.0998, -0.0749, -0.0265],
        ...,
        [-0.0601,  0.0077, -0.0103,  ..., -0.0527, -0.0420, -0.0074],
        [-0.0472,  0.0164, -0.0780,  ..., -0.0077, -0.0009, -0.0078],
        [ 0.0151, -0.0217, -0.0208,  ..., -0.0167, -0.0177, -0.0056]],
       device='cuda:0')

In [None]:
model_modular.head.emb_matrix.weight

Parameter containing:
tensor([[-0.0857, -0.0602, -0.0949,  ..., -0.0846, -0.1065, -0.0140],
        [-0.0407, -0.0353, -0.0568,  ..., -0.1032, -0.0614, -0.0267],
        [-0.0405, -0.0163, -0.0545,  ..., -0.0998, -0.0749, -0.0265],
        ...,
        [-0.0601,  0.0077, -0.0103,  ..., -0.0527, -0.0420, -0.0074],
        [-0.0472,  0.0164, -0.0780,  ..., -0.0077, -0.0009, -0.0078],
        [ 0.0151, -0.0217, -0.0208,  ..., -0.0167, -0.0177, -0.0056]],
       device='cuda:0')

In [None]:
for name, param in model_modular.named_parameters():
    if param.requires_grad:
        print(name)

gram.LSTM.weight_ih_l0
gram.LSTM.weight_hh_l0
gram.LSTM.bias_ih_l0
gram.LSTM.bias_hh_l0
head.linear_stack.0.weight
head.linear_stack.0.bias
head.linear_stack.2.weight
head.linear_stack.2.bias
head.emb_matrix.bias


In [None]:
# from training_utils import print_trainable_parameters

# print_trainable_parameters(model_modular)

In [None]:
def evaluate(model, dataloader, device = device):
    model.eval()

    losses = []
    loss_fn = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in dataloader:
            # Move the batch to the device
            batch = {k: v.to(device) for k, v in batch.items()}
            # Forward pass
            outputs = model(**batch)
            loss = loss_fn(
                einops.rearrange(outputs.logits, "batch seq tokens -> batch tokens seq"),
                batch["labels"])

            losses.append(float(loss))

    return sum(losses) / len(losses)  # Return the average loss

In [None]:
def train(model, training_args, dataloader, test_dataloader, test_every = 100):

    def gradient_norm():
        grads = [
        param.grad.detach().flatten()
            for param in model.parameters()
                if param.grad is not None
        ]
        norm = torch.cat(grads).norm()
        return norm

    optimizer = torch.optim.AdamW(model.parameters(), lr=training_args["learning_rate"])
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in tqdm(range(training_args["num_train_epochs"])):
        losses = []
        for batch_idx, batch in enumerate(tqdm(dataloader)):
            # Move the batch to the device
            batch = {k: v.to(device) for k, v in batch.items()}

            # Forward pass
            with torch.set_grad_enabled(True):
                outputs = model(**batch)
                loss = loss_fn(
                    einops.rearrange(outputs.logits, "batch seq tokens -> batch tokens seq"),
                    batch["labels"])
                # normalize to account for batch accumulation
                loss = loss / training_args["gradient_accumulation_steps"]
                losses.append(float(loss))

                # Backward pass
                loss.backward()

                # wandb.log(
                #     {"loss": losses[-1],
                #     "grad_norm": gradient_norm()}
                # )

                # Gradient accumulation and model update
                optimizer_step_cond = ((batch_idx + 1) % training_args["gradient_accumulation_steps"] == 0) or (batch_idx == len(dataloader)-1)
                if optimizer_step_cond:
                    # Update the parameters
                    optimizer.step()

                    # Zero the gradients
                    optimizer.zero_grad()

                if batch_idx % test_every == 0:

                    test_loss = evaluate(model, test_dataloader)
                    # wandb.log(
                    #     {"avg. test_loss": test_loss,}
                    #          )
                    model.train()


        print(f"Epoch - {epoch + 1} : avg.loss {np.mean(losses)}")


    return None

In [None]:
# RUN_NAME

In [None]:
model_modular.train();
_ = train(model_modular, training_args, dataloader, test_dataloader, test_every=1000)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

In [None]:
# Save the model
torch.save({
            #'epoch': epoch,
            'model_state_dict': model_modular.state_dict(),
            #'optimizer_state_dict': optimizer.state_dict(),
            #'loss': loss,
            }, "modular_ml_4.1_mix_flag.pt")

# minio.put_object(buffer.getvalue(),
#                      save_name=f"ckpt/trained_models/gram_module/{RUN_NAME}.pt")

In [None]:
# plt.title('Training Process')
# plt.xlabel('Itearations')
# plt.ylabel('Loss')
# plt.grid()
# plt.plot(losses)

# Test

In [None]:
def evaluate(model, dataloader, device = device):
    model.eval()

    losses = []
    ppl = []
    loss_fn = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in tqdm(dataloader):
            # Move the batch to the device
            batch = {k: v.to(device) for k, v in batch.items()}
            # Forward pass
            outputs = model(**batch)
            loss = loss_fn(
                einops.rearrange(outputs.logits, "batch seq tokens -> batch tokens seq"),
                batch["labels"])

            losses.append(float(loss))
            ppl.append(math.exp(losses[-1]))

    return sum(losses) / len(losses), sum(ppl) / len(ppl)  # Return the average loss

In [None]:
df_test = pd.read_csv("test_dataset_prefix.csv")

In [None]:
df_test_2 = dataframe_preprocessing_add_intransitive(df_test)
df_test_2.head(5)

In [None]:
test_data = PairsDataset(tokenize_function, df_test_2, sample=True)

In [None]:
test_data.dataset[test_data.dataset.poly_flag==1]

In [None]:
test_dataloader = DataLoader(test_data,
                             batch_size=training_args["per_device_train_batch_size"],
                             num_workers=training_args["dataloader_num_workers"],
                             shuffle=True,
                             drop_last=True,
                             collate_fn=data_collator)

In [None]:
evaluate(model_modular, test_dataloader)