# BERT for WSD with Fine-tuning
Building upon the groundwork laid by Loureiro et al. (2021), we tailored the code to better integrate with Russian context.

Link to the original code and article: https://github.com/danlou/bert-disambiguation

# Environment Settings and Access to Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers
!pip install transliterate

# Classification Task: Refining Embeddings with Fine-Tuning

# Model

In [8]:
import torch
import torch.nn as nn


class MaskedAverageLayer(nn.Module):
    def __init__(self):
      super(MaskedAverageLayer, self).__init__()

    def forward(self, seq, mask):
        '''
        Inputs:
            -seq : Tensor of shape [B, T, E] containing embeddings of sequences
            -mask : Tensor of shape [B, T, 1] containing masks to be used to pull from seq
        '''
        output = None
        if mask is not None:
          if len(mask.shape) < len(seq.shape):
            mask = mask.unsqueeze(-1)
            mask = mask.repeat(1, 1, seq.shape[-1])

          masked_inputs = (mask.int()*seq) + (1-mask.int())*torch.zeros_like(seq)
          unmasked_counts = torch.sum(mask.float(), dim=1)
          output = torch.sum(masked_inputs, dim=1)/(unmasked_counts+1e-10)

        return output

In [4]:
import torch
import torch.nn as nn
from transformers import AutoModel


class GeneralModel(nn.Module):

    def __init__(self, model_name, pretrained_path, class_count, freeze_bert=False):
        super().__init__()
        self.pretrained = AutoModel.from_pretrained(model_name, cache_dir=pretrained_path)
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False
        self.masked_average = MaskedAverageLayer()
        self.linear = nn.Linear(2 * self.pretrained.config.hidden_size, class_count)

    def forward(self, seq, attn_masks, target_mask):
        '''
        Inputs:
            -seq : Tensor of shape [B, T] containing token ids of sequences
            -attn_masks : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens
            -target_mask : target word mask
        '''

        outputs = self.pretrained(seq, attention_mask=attn_masks)
        embeddings = outputs.last_hidden_state
        masked_average = self.masked_average(embeddings, target_mask)
        cls = embeddings[:, 0]
        combined = torch.cat((cls, masked_average), 1)
        logits = self.linear(combined)
        return logits

# Model Configuration

In [5]:
config = {
"PATH": {
        "tokenizer_path": "/content/bert_tokenizer",
        "pre_trained_path": "/content/bert_model",
        "data_path": "/content/drive/MyDrive/RD_project/MERGED_DATA/WSD_full",
        "output_path": "/content/drive/MyDrive/RD_project/output3/",
	"model_save_path": "/content/ruBERT_trained_models"
	},

"HYPER_PARAM" : {
                "learning_rate": 2e-5,
                "batch_size": 4,
                "epochs": 20,
                "sequence_length": 128
                },

"TARGET_DATASETS" : [
                    'dejstvie',
                    'delo',
                    "den'",
                    'disk',
                    'dokument',
                    'dolja',
                    'dom',
                    'doroga',
                    'duh',
                    'mesto',
                    'moment',
                    'pravo',
                    "set'",
                    'sistema',
                    "stat'ja",
                    'vremja',
                    'zadacha',
                    'zakon',
                    'zaschita',
                    'zemlja',
                    "zhizn'"
                    ]
}

# Preprocessing Data

We lemmatize target words in each sample for consistency

In [None]:
import os
import json
from transliterate import translit


def parse_raw_data(data_path):
    data = {}
    word = data_path.split('/')[-1]
    custom_mappings = {
                  "zaschita": "защита"
                  }
    if word in custom_mappings:
        tr_word = custom_mappings[word]
    else:
        tr_word = translit(word, "ru", reversed=False).lower()
    with open(data_path + "/train.data.txt", "r") as f:

        data["train"] = []
        lines = f.readlines()
        for line in lines:
            index = int(line.lstrip('\ufeff').split("\t")[0])
            text = line.split("\t")[1][:-1]
            tokens = text.split()

            if tokens[index] != tr_word:
                tokens[index] = tr_word
            else:
                tr_word = translit(tokens[index], "ru", reversed=False).lower()
            text = ' '.join(tokens)
            data["train"].append({"target_index": index, "text": text})

    # train gold
    with open(data_path + "/train.gold.txt", "r") as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            label = int(line.lstrip('\ufeff').split("\n")[0])
            data["train"][i]["label"] = label

    # test data
    with open(data_path + "/test.data.txt", "r") as f:
        data["test"] = []
        lines = f.readlines()
        for line in lines:

            index = int(line.lstrip('\ufeff').split("\t")[0])
            text = line.split("\t")[1][:-1]
            tokens = text.split()

            if tokens[index] != tr_word:
                tokens[index] = tr_word
            else:
                tr_word = translit(tokens[index], "ru", reversed=False).lower()
            text = ' '.join(tokens)
            data["test"].append({"target_index": index, "text": text})

    # test gold
    with open(data_path + "/test.gold.txt", "r") as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            label = int(line.lstrip('\ufeff').split("\n")[0])
            data["test"][i]["label"] = label

    # class map
    with open(data_path + "/classes_map.txt", "r") as f:
        data["class_map"] = json.loads(f.read())

    return data

if __name__=="__main__":
  ru_data = parse_raw_data('/content/drive/MyDrive/RD_project/MERGED_DATA/WSD_full/zaschita')
  for _ in ru_data.items():
     print(_)

# DataLoader

In [None]:
import torch
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader


class GeneralModelInput (object):

    def __init__(self, sequence_length, batch_size, tokenizer, parsed_data):
        self.tokenizer = tokenizer
        self.parsed_data = parsed_data
        self.batch_size = batch_size
        self.sequence_length = sequence_length

    def torch_data(self):
        data = {"input_ids": [],
                  "attention_masks": [],
                  "target_masks": [],
                  "labels": []}
        for i, sample in enumerate(self.parsed_data):
            encoded_dict = self.tokenizer.encode_plus(
                sample["text"],
                add_special_tokens=True,
                max_length=self.sequence_length,
                padding="max_length",
                truncation = True,
                return_attention_mask=True,
                return_tensors='pt'
            )
            data["input_ids"].append(encoded_dict['input_ids'])
            data["attention_masks"].append(encoded_dict['attention_mask'])

            words = sample["text"].split()
            taget_mask = [False]
            for word_index, word in enumerate(words):
                tokenized_len = len(self.tokenizer.encode_plus(word)['input_ids']) - 2
                if sample["target_index"] == word_index:

                    taget_mask += [True] * tokenized_len
                else:
                    taget_mask += [False] * tokenized_len
            taget_mask += [False] * (self.sequence_length - len(taget_mask))
            data["target_masks"].append(taget_mask[:self.sequence_length])

            data["labels"].append(sample["label"])

        data["input_ids"] = torch.cat(data["input_ids"], dim=0)
        data["attention_masks"] = torch.cat(data["attention_masks"], dim=0)
        data["target_masks"] = torch.tensor(data["target_masks"])
        data["labels"] = torch.tensor(data["labels"])

        result = DataLoader(TensorDataset(data["input_ids"],
                                              data["attention_masks"],
                                              data["target_masks"],
                                              data["labels"]),
                            batch_size=self.batch_size)

        return result

if __name__ == "__main__":
  from transformers import AutoTokenizer
  parsed_data = parse_raw_data('/content/drive/MyDrive/RD_project/WSD_data/WSD_1/delo')
  model_name = "bert-base-uncased"                                              #choices: ["bert-base-multilingual-uncased", "DeepPavlov/rubert-base-cased"]
  tokenizer =  AutoTokenizer.from_pretrained(model_name, do_lower_case=True)
  model = GeneralModel(model_name, config["PATH"]["pre_trained_path"],len(parsed_data["class_map"])).to('cuda')
  train_data = GeneralModelInput (128, 4, tokenizer, parsed_data["train"]).torch_data()
  for batch in train_data:
    input_ids, attention_masks, target_masks, labels = batch
    # Print the batch details
    print("Batch input_ids shape:", input_ids.dtype, input_ids.shape)
    print("Batch attention_masks shape:", attention_masks.dtype, attention_masks.shape)
    print("Batch target_masks shape:", target_masks.dtype, target_masks.shape)
    print("Batch labels shape:", labels.dtype, labels.shape)


# Accuracy Function

In [None]:
import numpy as np


def accuracy_from_logits(logits, labels):
    np_logits = logits.detach().cpu().numpy()
    np_labels = labels.detach().cpu().numpy()

    total = len(labels)
    correct = np.sum((np.argmax(np_logits, axis=1) == np_labels).astype('int'))
    return float(correct) / total

# Training & Evaluating

In [None]:
import os
import json
import argparse
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer
import numpy as np
import torch


model_name = "bert-base-uncased"                                                # choices: ["DeepPavlov/rubert-base-cased", "bert-base-multilingual-uncased"]
save_model = False
for target_dataset in config["TARGET_DATASETS"]:
    print("")
    print (target_dataset)
    parsed_data = parse_raw_data(config["PATH"]["data_path"]+"/"+target_dataset)

    if model_name in ["bert-base-multilingual-uncased", "DeepPavlov/rubert-base-cased", "bert-base-uncased" ]:
        tokenizer =  AutoTokenizer.from_pretrained(model_name, do_lower_case=True, cache_dir=config["PATH"]["tokenizer_path"])
        model = GeneralModel(model_name, config["PATH"]["pre_trained_path"], len(parsed_data["class_map"])).to('cuda:0')
        train_data = GeneralModelInput (config["HYPER_PARAM"]["sequence_length"],
                                config["HYPER_PARAM"]["batch_size"],
                                tokenizer, parsed_data["train"]).torch_data()
        test_data = GeneralModelInput (config["HYPER_PARAM"]["sequence_length"],
                               config["HYPER_PARAM"]["batch_size"],
                               tokenizer, parsed_data["test"]).torch_data()
    model = nn.DataParallel(model)
    opti = optim.Adam(model.parameters(), lr=config["HYPER_PARAM"]["learning_rate"])
    criterion = nn.CrossEntropyLoss()
    #  TRAINING MODEL
    for ep in range(config["HYPER_PARAM"]["epochs"]):
        print ("-------------- EPOCH: ", ep, " --------------")
        print ("")
        for it, (seq, attn_masks, target_mask, labels) in enumerate(train_data):
            opti.zero_grad()
            seq, attn_masks, target_mask, labels = seq.cuda(0), attn_masks.cuda(0), target_mask.cuda(0), labels.cuda(0)
            logits = model(seq, attn_masks, target_mask)
            loss = criterion(logits, labels)
            loss.backward()
            opti.step()

            if (it + 1) % 10 == 0:
                acc = accuracy_from_logits(logits, labels)
                print("Iteration {} of epoch {} complete. Loss : {} Accuracy : {}".format(it + 1, ep + 1, loss.item(),
                                                                                          acc))
    #  SAVING MODEL
    print ("saving model")
    if save_model:
        torch.save(model.state_dict(), config["PATH"]["model_save_path"]+"/"+target_dataset+".pt")

    # EVALUATION
    model.eval()
    predictions, true_labels, embeddings, testset_acc = [], [], [], []
    for it, (seq, attn_masks, target_mask, labels) in enumerate(test_data):

        seq, attn_masks, target_mask, labels = seq.cuda(0), attn_masks.cuda(0), target_mask.cuda(0), labels.cuda(0)

        with torch.no_grad():
            logits = model(seq, attn_masks, target_mask)
        tst_acc = accuracy_from_logits(logits, labels)
        testset_acc.append(tst_acc)
        logits = list(np.argmax(logits.detach().cpu().numpy(), axis=1))
        label_ids = list(labels.to('cpu').numpy())


        predictions += logits
        true_labels += label_ids

    total = sum(testset_acc)/len(testset_acc)

    if model_name == "bert-base-multilingual-uncased":
        output_folder = "multilingual"
    elif model_name == "DeepPavlov/rubert-base-cased":
        output_folder = "rubert"
    else:
        output_folder = "other_model"
    with open (config["PATH"]["output_path"]+"/" +
               output_folder +"/"+ target_dataset+".txt", "w") as out:
        for p in predictions:
            out.write(str(p)+"\n")
    with open(config["PATH"]["output_path"] + "/" + output_folder + "/" +target_dataset+ "_acc.txt", "a") as acc:
        acc.write(str(testset_acc))
        acc.write(str(total))