# Training

Using Sentence Transformers[1] as a base architecture. The authors propose to use the average of all tokens' hidden states to produce the final sentence representation

[1] Reimers, I. Gurevych (2019) Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks. arXiv preprint arXiv:1908.10084;

https://github.com/UKPLab/sentence-transformers

In [1]:
from sentence_transformers.models import Pooling, Transformer
from sentence_transformers import SentenceTransformer, models, util, InputExample, losses
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import sentence_transformers
from torch.utils.data import DataLoader, Dataset
import json
from pathlib import Path
from collections import defaultdict
import numpy as np
import random
import torch
from tqdm.notebook import tqdm
from typing import Union, Optional

SEED = 1
torch.random.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

Using Sentence RuBERT[2] as a starting point for the finetuning. Sentence RuBERT was trained in a similar manner to the Sentence Transformers library, i.e. using mean pooling to represent the sentence encoding. It was also finetuned on the NLI[3][4] datasets that are used for sentence similarity models.

[2] http://docs.deeppavlov.ai/en/master/features/models/bert.html

[3] Williams A., Bowman S. (2018) XNLI: Evaluating Cross-lingual Sentence Representations. arXiv preprint arXiv:1809.05053

[4] Bowman, G. Angeli, C. Potts, and C. D. Manning. (2015) A large annotated corpus for learning natural language inference. arXiv preprint arXiv:1508.05326

In [2]:
model = SentenceTransformer('sentence_ru_cased_pipe_L-12_H-768_A-12_pt')

## Data Preparation

To construct the input, three additional tokens were added: `[E]`, `[/E]`, and `[BLANK]`. 

`[E]` and `[/E]` are used to mark the start and the end of the instance word in the context sentence. A sketch is unfolded in one flat sentence; each role and predicate token marked with `[E]` and `[/E]` from the begging and end. Additionally, to force the model to learn a more general representation of the context and sketch, $15\%$ of the entities are randomly masked with a special `[BLANK]` symbol during training[5].

The trial dataset was split into train-dev-test splits with the proportions 80-10-10. For each context in the dataset, two training pairs were constructed: the context$-$matching sketch pair with the label `1` and context$-$random sketch pair with the label `0`.

### Input example:

Context: `Согласитесь, что в схватке социально-близкого с социально-чуждым не может государство [E][BLANK][/E] за последнего`

Sketch: `[E]Object[/E] [E]волосы[/E] [E]шерсть[/E] [E]работы[/E] [E]лед[/E] [E]гор[/E] [E]лес[/E][E]Quantity[/E] [E]меньше[/E] [E]все больше[/E] [E]чуть-чуть[/E] [E]немножко[/E] [E]еще немного[/E] [E]столько[/E][E][BLANK][/E] [E]в августе[/E] [E]в декабре[/E] [E]два[/E] [E]в коротком будущем[/E] [E]с 1907?г.[/E] [E][BLANK][/E][E]Locative_FinalPoint[/E] [E]в один ряд[/E] [E]в тупик[/E] [E]в строй[/E] [E]в свою очередь[/E] [E]на колени[/E] [E]на ноги[/E][E]Locative_Distance[/E] [E]ближе[/E] [E]поодаль[/E] [E]рядышком[/E] [E]рядом[/E] [E][BLANK][/E] [E][BLANK][/E][E][BLANK][/E] [E]дурно[/E] [E][BLANK][/E] [E]круче[/E] [E][BLANK][/E] [E]лучше[/E] [E]прекрасно[/E]`

Label: `1`

[5] Soares, Livio Baldini, et al. "Matching the blanks: Distributional similarity for relation learning." arXiv preprint arXiv:1906.03158 (2019).

In [3]:
data_dir = Path('SemSketches/data')
train_dir = 'trial'
dev_dir = 'dev'

contexts_name = 'contexts_{}.data'
sketches_name = 'sketches_{}.data'
gold_name = 'trial.gold'

In [4]:
train_gold_name = 'train_trial.gold'
dev_gold_name = 'dev_trial.gold'
test_gold_name = 'test_trial.gold'

all_labels = json.load(open(data_dir / train_dir / gold_name, encoding='utf-8'))
gold_reverse = defaultdict(list)
for k, v in all_labels.items():
    gold_reverse[v].append(k)

train_labels = {k: v[:80] for k, v in gold_reverse.items()}
dev_labels = {k: v[80:90] for k, v in gold_reverse.items()}
test_labels = {k: v[90:] for k, v in gold_reverse.items()}

In [5]:
class SemSketchesDataset(Dataset):
    def __init__(self, 
                 contexts_path: Path, 
                 sketches_path: Path,
                 dev_contexts_path: Optional[Path] = None,
                 dev_sketches_path: Optional[Path] = None,
                 labels: Optional[Union[Path, dict]] = None, 
                 eval: Optional[bool] = False,
                 mark_instance: Optional[bool] = False, 
                 blank_prob: Optional[float] = 0.7):
        self.context_data = json.load(open(contexts_path, encoding='utf-8'))
        self.sketches_data = json.load(open(sketches_path, encoding='utf-8'))
        
        if dev_contexts_path and dev_sketches_path:
            self.dev_context_data = json.load(open(dev_contexts_path, encoding='utf-8'))
            self.dev_sketches_data = json.load(open(dev_sketches_path, encoding='utf-8'))
        
        self.eval = eval
        if isinstance(labels, Path):
            self.labels_data = json.load(open(labels, encoding='utf-8'))
        elif isinstance(labels, dict):
            self.labels_data = labels
        elif labels in None and self.eval:
            self.labels_data = None
        else:
            raise NameError("Please provide either labels Path object or dict!")
            
        self.blank_prob = blank_prob
        self.mark_instance= mark_instance

        context_sentences = {}
        for context_id, context in self.context_data.items():
            sentence = context['sentence']
            if mark_instance:
                start = context['start']
                end = context['end']
                target = '[BLANK]' if np.random.uniform() < blank_prob else sentence[start:end]
                sentence = sentence[:start] + '[E]' + target + '[/E]' + sentence[end:]
            context_sentences[context_id] = sentence
        
        if dev_contexts_path:
            dev_context_sentences = {}        
            for context_id, context in self.dev_context_data.items():
                sentence = context['sentence']
                if mark_instance:
                    start = context['start']
                    end = context['end']
                    target = '[BLANK]' if np.random.uniform() < blank_prob else sentence[start:end]
                    sentence = sentence[:start] + '[E]' + target + '[/E]' + sentence[end:]
                dev_context_sentences[context_id] = sentence

        sketch_sentences = {}
        for sketch_id, sketch in self.sketches_data.items():
            sketch_sentences[sketch_id] = self.construct_sketch(sketch_id)
            
        if dev_sketches_path:
            dev_sketch_sentences = {}
            for sketch_id, sketch in self.dev_sketches_data.items():
                dev_sketch_sentences[sketch_id] = self.construct_sketch(sketch_id, dev=True)

        self.data = []
        for sketch, sents in self.labels_data.items():
            for sent in sents:
                self.data.append(InputExample(texts=[context_sentences[sent], sketch_sentences[sketch]], label=1.))

                if dev_contexts_path and dev_sketches_path:
                    dev_sketch = dev_sketch_sentences[random.choice(list(dev_sketch_sentences.keys()))]
                    rand_context = context_sentences[random.choice(list(context_sentences.keys()))]
                    self.data.append(InputExample(texts=[rand_context, dev_sketch], label=0.))
                else:
                    neg_sketches = list(self.labels_data.keys())
                    neg_sketches.remove(sketch)
                    neg_sketch = random.sample(neg_sketches, k=1)
                    for neg in neg_sketch:   
                        neg_sent = random.choice(self.labels_data[neg])
                        self.data.append(InputExample(texts=[context_sentences[neg_sent], self.construct_sketch(sketch)], label=0.))
                
                        
    def construct_sketch(self, sketch_id, dev=False):
        sketch = ''
        if dev:
            sketches_data = self.dev_sketches_data[sketch_id]
        else:
            sketches_data = self.sketches_data[sketch_id]
        for role, words in sketches_data.items():
            if self.mark_instance:
                role_token = '[BLANK]' if np.random.uniform() < self.blank_prob else role
                role = '[E]' + role_token + '[/E]'
                predicates = ['[E][BLANK][/E]' if np.random.uniform() < self.blank_prob else f"[E]{word}[/E]" for word in words[1]]
                sketch += role + ' ' + ' '.join(predicates)
            else:
                sketch += ' '.join([role] + [word for word in words[1]])
        return sketch

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

In [6]:
train_contexts_path = data_dir / train_dir / contexts_name.format(train_dir)
train_sketches_path = data_dir / train_dir / sketches_name.format(train_dir)
dev_sketches_path = data_dir / dev_dir / sketches_name.format(dev_dir)
dev_contexts_path = data_dir / dev_dir / contexts_name.format(dev_dir)
train_labels_path = data_dir / train_dir / gold_name

In [7]:
mark_instance = True
train_examples = SemSketchesDataset(train_contexts_path, train_sketches_path, labels=train_labels, 
                                    mark_instance=mark_instance, blank_prob=0.7)
dev_examples = SemSketchesDataset(train_contexts_path, train_sketches_path, labels=dev_labels, 
                                  mark_instance=mark_instance, blank_prob=0.)

In [16]:
train_examples[2].texts

['Согласитесь, что в схватке социально-близкого с социально-чуждым не может государство [E][BLANK][/E] за последнего',
 '[E]Object[/E] [E]волосы[/E] [E]шерсть[/E] [E]работы[/E] [E]лед[/E] [E]гор[/E] [E]лес[/E][E]Quantity[/E] [E]меньше[/E] [E]все больше[/E] [E]чуть-чуть[/E] [E]немножко[/E] [E]еще немного[/E] [E]столько[/E][E][BLANK][/E] [E]в августе[/E] [E]в декабре[/E] [E]два[/E] [E]в коротком будущем[/E] [E]с 1907?г.[/E] [E][BLANK][/E][E]Locative_FinalPoint[/E] [E]в один ряд[/E] [E]в тупик[/E] [E]в строй[/E] [E]в свою очередь[/E] [E]на колени[/E] [E]на ноги[/E][E]Locative_Distance[/E] [E]ближе[/E] [E]поодаль[/E] [E]рядышком[/E] [E]рядом[/E] [E][BLANK][/E] [E][BLANK][/E][E][BLANK][/E] [E]дурно[/E] [E][BLANK][/E] [E]круче[/E] [E][BLANK][/E] [E]лучше[/E] [E]прекрасно[/E]']

## Model's Architecture

The model follows a siamese network architecture with a cosine similarity loss as in Sentence Transformers (Reimers et al, 2019; https://www.sbert.net/docs/training/overview.html#network-architecture). Here, instead of the mean pooling of all the tokens' representations, a custom mean pooling of the entity tokens (`[E]`) is used.

The model is trained for maximum of 10 epochs with the evaluation on the dev set every 100 steps. The model with the best cosine similarity score is used in the end.

In [9]:
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8)
train_loss = losses.CosineSimilarityLoss(model)

In [10]:
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_examples)

In [11]:
num_epochs = 10
warmup_steps = 40
model_save_path = 'sentence_ru_cased_fine_tuned_blanks_entity_pooling_recreation_L-12_H-768_A-12_pt'

In [None]:
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=100,
          warmup_steps=warmup_steps,
          output_path=model_save_path)

# Inference

For the inference, refer to the semantic search proposed by Sentence Transformers (https://www.sbert.net/examples/applications/semantic-search/README.html)

In [6]:
model_save_path = 'sentence_ru_cased_fine_tuned_blanks_entity_pooling_L-12_H-768_A-12_pt'
model = SentenceTransformer(model_save_path)

In [7]:
def get_contexts(contexts_path, mark_instance=False):
    context_data = json.load(open(contexts_path, encoding='utf-8'))
    context_sentences = {}
    for context_id, context in context_data.items():
        sentence = context['sentence']
        if mark_instance:
            start = context['start']
            end = context['end']
            target = sentence[start:end]
            sentence = sentence[:start] + '[E]' + target + '[/E]' + sentence[end:]
        context_sentences[context_id] = sentence
    return context_sentences

def get_sketches(sketches_path, mark_instance=False):
    sketches_data = json.load(open(sketches_path, encoding='utf-8'))
    sketch_sentences = defaultdict(str)
    for sketch_id, sketch in sketches_data.items():
        for role, words in sketch.items():
            if mark_instance:
                role = '[E]' + role + '[/E]'
                predicates = [f"[E]{word}[/E]" for word in words[1]]
                sketch_sentences[sketch_id] += role + ' ' + ' '.join(predicates)
            else:
                sketch_sentences[sketch_id] += ' '.join([role] + [word for word in words[1]])
    return sketch_sentences

In [8]:
dev_contexts = get_contexts('SemSketches/data/dev/contexts_dev.data', True)
dev_sketches = get_sketches('SemSketches/data/dev/sketches_dev.data', True)

In [9]:
corpus = list(dev_sketches.values())
corpus_ids = list(dev_sketches.keys())

corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
corpus_embeddings = corpus_embeddings.to('cuda')
corpus_embeddings = util.normalize_embeddings(corpus_embeddings)

queries = list(dev_contexts.values())
queries_ids = list(dev_contexts.keys())

query_embeddings = model.encode(queries, convert_to_tensor=True)
query_embeddings = query_embeddings.to('cuda')
query_embeddings = util.normalize_embeddings(query_embeddings)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [10]:
hits = util.semantic_search(query_embeddings, corpus_embeddings, score_function=util.dot_score, top_k=1)

In [11]:
results = {}
for i, hit in enumerate(hits):
    results[queries_ids[i]] = corpus_ids[hit[0]['corpus_id']]

In [13]:
json.dump(results, open('task1.pred', 'w', encoding='utf-8'))