## Requirements

In [2]:
!git clone https://github.com/IlyaGusev/purano
%cd purano
!pip install -r requirements.txt -q
!pip install --upgrade pytorch-lightning transformers razdel

Cloning into 'purano'...
remote: Enumerating objects: 1132, done.[K
remote: Counting objects: 100% (1/1), done.[K
remote: Total 1132 (delta 0), reused 1 (delta 0), pack-reused 1131[K
Receiving objects: 100% (1132/1132), 1.32 MiB | 13.65 MiB/s, done.
Resolving deltas: 100% (679/679), done.
/content/purano
[K     |████████████████████████████████| 20.1 MB 5.2 MB/s 
[K     |████████████████████████████████| 1.6 MB 50.0 MB/s 
[K     |████████████████████████████████| 316 kB 57.7 MB/s 
[K     |████████████████████████████████| 584 kB 55.5 MB/s 
[K     |████████████████████████████████| 125 kB 57.1 MB/s 
[K     |████████████████████████████████| 4.0 MB 53.9 MB/s 
[K     |████████████████████████████████| 592 kB 58.9 MB/s 
[K     |████████████████████████████████| 16.6 MB 26.2 MB/s 
[K     |████████████████████████████████| 68 kB 6.3 MB/s 
[K     |████████████████████████████████| 49 kB 5.0 MB/s 
[K     |████████████████████████████████| 55 kB 3.4 MB/s 
[K     |████████████████

In [3]:
import json
import os
import random
import pickle
from collections import defaultdict, Counter
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from pytorch_lightning import LightningModule
from transformers import AutoModel, AdamW, AutoTokenizer
import torchvision
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader, RandomSampler
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from tqdm import tqdm
from statistics import median, mean
from sklearn.cluster import AgglomerativeClustering
from torch.optim.lr_scheduler import ReduceLROnPlateau
from purano.clusterer.metrics import calc_metrics
from purano.io.markup_tsv import read_markup_tsv



INITIAL_MODEL = "IlyaGusev/news_tg_rubert"
TARGET_SIZE = 256
MAX_TOKENS = 250
BATCH_SIZE = 12
NUM_WORKERS = 2

In [3]:
!nvidia-smi

Sun May  8 02:16:48 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Data loading

In [7]:
!wget -q https://www.dropbox.com/s/iauipxcpsuwjw6o/documents.tar.gz
!wget -q https://www.dropbox.com/s/8lu6dw8zcrn840j/ru_clustering_0525_urls.tsv
!wget -q https://www.dropbox.com/s/3yh5ii20ijfbtb6/ru_clustering_0527_urls_final.tsv
!wget -q https://raw.githubusercontent.com/dialogue-evaluation/Russian-News-Clustering-and-Headline-Generation/main/data/clustering/ru_clustering_0529_urls_final_v2.tsv
!tar -xzvf documents.tar.gz

train.jsonl
public.jsonl
private.jsonl


In [5]:
!wget -q https://www.dropbox.com/s/ykqk49a8avlmnaf/ru_all_split.tar.gz
!tar -xzvf ru_all_split.tar.gz

tar: Ignoring unknown extended header keyword 'LIBARCHIVE.creationtime'
tar: Ignoring unknown extended header keyword 'SCHILY.dev'
tar: Ignoring unknown extended header keyword 'SCHILY.ino'
tar: Ignoring unknown extended header keyword 'SCHILY.nlink'
ru_all_train.jsonl
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.creationtime'
tar: Ignoring unknown extended header keyword 'SCHILY.dev'
tar: Ignoring unknown extended header keyword 'SCHILY.ino'
tar: Ignoring unknown extended header keyword 'SCHILY.nlink'
ru_all_val.jsonl
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.creationtime'
tar: Ignoring unknown extended header keyword 'SCHILY.dev'
tar: Ignoring unknown extended header keyword 'SCHILY.ino'
tar: Ignoring unknown extended header keyword 'SCHILY.nlink'
ru_all_test.jsonl


In [None]:
#Getting an extended news dataset

documents = dict()

for news in ['ru_all_train.jsonl', 'ru_all_val.jsonl', 'ru_all_test.jsonl']:
  with open(news, "r") as f:
    for line in f:
        record = json.loads(line)
        documents[record["url"]] = record

test = dict()
val_texts = []

for news in ['public.jsonl', 'private.jsonl']:
  with open(news, "r") as r:
    for line in r:
        record = json.loads(line)
        test[record["url"]] = record

print(f'Total number of news: {len(documents)}, number of test news {len(test)}')
documents = {k:v for k,v in documents.items() if k not in test}
print(f'Total train: {len(documents)}')

documents = list(documents.values())

Total number of news: 684685, number of test news 39176
Total train: 645915


In [None]:
random.shuffle(documents)
border = int(len(documents) * 0.8)
full_train_records = documents[:border]
full_val_records = documents[border:]

with open('full_train_records', 'wb') as pickle_file:
    pickle.dump(full_train_records, pickle_file)

with open('full_val_records', 'wb') as pickle_file:
    pickle.dump(full_val_records, pickle_file)

In [None]:
#Cheking the quality of split

public_set = []
with open("public.jsonl", "r") as r:
    for line in r:
        public_set.append(json.loads(line))

train = {}
for i in range(len(full_train_records)):
  train[full_train_records[i]['url']] = full_train_records[i]['text']
  
val = {}
for i in range(len(full_val_records)):
  val[full_val_records[i]['url']] = full_val_records[i]['text']

public = {}
for i in range(len(public_set)):
  public[public_set[i]['url']] = public_set[i]['text']

print(len(public.keys() & train.keys()))
print(len(public.keys() & val.keys()))

0
0


In [None]:
def form_triplets(docs_filename, pairs_filename):
    documents = dict()
    with open(docs_filename, "r") as f:
        for line in f:
            record = json.loads(line)
            documents[record["url"]] = record
        
    positives = defaultdict(list)
    negatives = defaultdict(list)
    with open(pairs_filename, "r") as f:
        header = next(f)
        for line in f:
            first_url, second_url, label = line.strip().split("\t")
            if "OK" in label:
                positives[first_url].append(second_url)
                positives[second_url].append(first_url)
            else:
                negatives[first_url].append(second_url)
                negatives[second_url].append(first_url)

    urls = list(documents.keys())
    records = []
    for pivot_url, positive_urls in positives.items():
        negative_urls = []
        if pivot_url in negatives:
            negative_urls = negatives.pop(pivot_url)
        if negative_urls:
            for positive_url in positive_urls:
                for negative_url in negative_urls:
                    records.append((documents[pivot_url], documents[positive_url], documents[negative_url]))
    for pivot_url, negative_urls in negatives.items():
        for negative_url in negative_urls:
            records.append((documents[pivot_url], documents[pivot_url], documents[negative_url]))
    return records

triplets = form_triplets("train.jsonl", "ru_clustering_0525_urls.tsv")
random.shuffle(triplets)
border = int(len(triplets) * 0.8)
train_records = triplets[:border]
val_records = triplets[border:]


with open('train_records', 'wb') as pickle_file:
    pickle.dump(train_records, pickle_file)

with open('val_records', 'wb') as pickle_file:
    pickle.dump(val_records, pickle_file)

In [2]:
class NewsDataset(Dataset):
    def __init__(self, records, model_path, max_tokens, bert_embeds=None):
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            do_lower_case=False,
            do_basic_tokenize=False,
            strip_accents=False
        )
        
        self.max_tokens = max_tokens
        self.records = records
        self.bert_embeds = bert_embeds
    
    def __len__(self):
        return len(self.records)
    
    def __getitem__(self, index):
        samples = self.records[index]
      
        if type(samples) is tuple:
            samples = [s["title"] + " [SEP] " + s["text"] for s in samples]
            samples = [self.tokenizer(
                s,
               add_special_tokens=True,
               max_length=self.max_tokens,
               padding='max_length',
               truncation=True,
               return_tensors='pt'
               ) for s in samples]
            samples = [{key: value.squeeze(0) for key, value in s.items()} for s in samples]
            
            if self.bert_embeds is not None:
                bert_embeds = self.bert_embeds[index]
                return samples[0], samples[1], samples[2], bert_embeds[0], bert_embeds[1], bert_embeds[2]
            return samples[0], samples[1], samples[2]

        else:
            samples = samples["title"] + " [SEP] " + samples["text"]
            samples = self.tokenizer(
                samples,
                add_special_tokens=True,
                max_length=self.max_tokens,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
                ) 

           
            samples = {key: value.squeeze(0) for key, value in samples.items()}
            
            if self.bert_embeds is not None:
               bert_embeds = self.bert_embeds[index]
               return samples, bert_embeds
            return samples

In [8]:
def get_loaders( 
    train_records, 
    val_records, 
    tokenizer_model, 
    max_tokens, 
    batch_size, 
    num_workers=2,               
    train_bert_embeds=None, 
    val_bert_embeds=None, 
    val_batch_size=None, 
    ):

  train_data = NewsDataset(
      records=train_records, model_path=tokenizer_model, 
      max_tokens=max_tokens, bert_embeds=train_bert_embeds, 
      )
  train_sampler = RandomSampler(train_data)
  train_loader = DataLoader(
      train_data, batch_size=batch_size, 
      num_workers=num_workers, sampler=train_sampler,
      )
  if val_batch_size is None:
    val_batch_size = batch_size
  val_data = NewsDataset(
      records=val_records, model_path=tokenizer_model, 
      max_tokens=max_tokens, bert_embeds=val_bert_embeds, 
      )
  val_loader = DataLoader(
      val_data, batch_size=val_batch_size, 
      num_workers=num_workers, 
      )
  
  return train_loader, val_loader, train_data.tokenizer

In [41]:
initial_train_loader, initial_val_loader, initial_tokenizer = get_loaders(
                                       train_records, val_records, 
                                       INITIAL_MODEL, MAX_TOKENS, 
                                       BATCH_SIZE, NUM_WORKERS,
                                       )

# Getting embeds from finetuned model




In [25]:
class Embedder(nn.Module):
    def __init__(self, model_path, freeze_bert, layer_num):
        super().__init__()

        self.model = AutoModel.from_pretrained(model_path)
        self.model.trainable = not freeze_bert
        self.bert_dim = self.model.config.hidden_size
        self.layer_num = layer_num
    
    def forward(self, input_ids, attention_mask):
        output = self.model(
           input_ids,
           attention_mask=attention_mask,
           return_dict=True,
           output_hidden_states=True
        )
        layer_embeddings = output.hidden_states[self.layer_num]
        embeddings = self.aggregate(layer_embeddings, attention_mask)
        norm = embeddings.norm(p=2, dim=1, keepdim=True)
        embeddings = embeddings.div(norm)
        return embeddings
    
    def aggregate(self, layer_embeddings, mask):
        raise NotImplementedError()
        

class MeanEmbedder(Embedder):
    def __init__(self, model_path, freeze_bert, layer_num, hidden_dim, use_masking=True):
        super().__init__(model_path, freeze_bert, layer_num)

        self.token_mapping = nn.Linear(self.bert_dim, hidden_dim)
        self.use_masking = use_masking

    def aggregate(self, layer_embeddings, mask):
        embeddings = self.token_mapping(layer_embeddings)
        if self.use_masking:
            expanded_mask = mask.unsqueeze(-1).expand(embeddings.size()).float()
            sum_embeddings = torch.sum(embeddings * expanded_mask, 1)
            sum_mask = torch.clamp(expanded_mask.sum(1), min=1e-9)
            return sum_embeddings / sum_mask
        return torch.mean(embeddings, dim=1)


class ClusteringTripletModel(LightningModule):
    def __init__(self, model_path, num_training_steps=None,
                 hidden_dim=TARGET_SIZE, freeze_bert=False,
                 layer_num=-1, margin=0.5, lr=1e-5):
        super().__init__()

        self.embedder = MeanEmbedder(
            model_path,
            freeze_bert=freeze_bert,
            layer_num=layer_num,
            hidden_dim=hidden_dim
        )

        self.triplet_loss = nn.TripletMarginWithDistanceLoss(
            margin=margin,
            distance_function=nn.PairwiseDistance(p=2)
        )

        self.lr = lr
        self.num_training_steps = num_training_steps

    def forward(self, pivots, positives, negatives):
        pivot_embeddings = self.embedder(pivots["input_ids"], pivots["attention_mask"])
        positive_embeddings = self.embedder(positives["input_ids"], positives["attention_mask"])
        negative_embeddings = self.embedder(negatives["input_ids"], negatives["attention_mask"])
        loss = self.triplet_loss(pivot_embeddings, positive_embeddings, negative_embeddings)
        return loss

    def training_step(self, batch, batch_nb):
        train_loss = self(*batch)
        return train_loss

    def validation_step(self, batch, batch_nb):
        val_loss = self(*batch)
        self.log("val_loss", val_loss, prog_bar=True, logger=True)
        return val_loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.lr)
        return [optimizer]

In [5]:
EPOCHS = 10
ACCUMULATE_GRAD_BATCHES = 8
LOG_EVERY_N_STEPS = 10

In [None]:
model = ClusteringTripletModel(INITIAL_MODEL)
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0001,
    patience=3,
    verbose=True,
    mode="min" 
)

checkpoint = ModelCheckpoint(
    monitor='val_loss',
    dirpath='/content',
    filename='clustering_news_bert'
)


trainer = Trainer(
    gpus=1,
    accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES,
    max_epochs=EPOCHS,
    callbacks=[early_stop_callback, checkpoint],
    log_every_n_steps=LOG_EVERY_N_STEPS
)


trainer.fit(model, initial_train_loader, initial_val_loader)

In [42]:
initial_model = ClusteringTripletModel.load_from_checkpoint(
    model_path = INITIAL_MODEL,
    checkpoint_path = '/content/clustering_news_bert.ckpt',
    )

embedder = initial_model.embedder.cuda()

Some weights of the model checkpoint at IlyaGusev/news_tg_rubert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at IlyaGusev/news_tg_rubert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.we

In [14]:
def gen_batch(records, batch_size):
    batch_start = 0
    while batch_start < len(records):
        batch_end = batch_start + batch_size
        batch = records[batch_start: batch_end]
        batch_start = batch_end
        yield batch

def records_to_embeds(
              records, model, tokenizer, 
              batch_size, max_tokens_count,
              print_mean_timing=False,
              ):
    current_index = 0
    try:
      embeddings = np.zeros((len(records), model.target_size))
    except:
      embeddings = np.zeros((len(records), 256))
 
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    time = 0

    for batch in gen_batch(records, batch_size):
        samples = [r["title"] + " [SEP] " + r["text"] for r in batch]
        
        inputs = tokenizer(
              samples,
              add_special_tokens=True,
              max_length=max_tokens_count,
              padding='max_length',
              truncation=True,
              return_tensors='pt'
        )
           
        batch_input_ids = inputs["input_ids"].cuda()
        batch_mask = inputs["attention_mask"].cuda()
        
        with torch.no_grad():
            model.eval().cuda()
            starter.record()
            batch_embeddings = model(batch_input_ids, batch_mask)
            ender.record()
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            batch_embeddings = batch_embeddings.cpu().numpy()
        embeddings[current_index:current_index+batch_size, :] = batch_embeddings
        current_index += batch_size
        time += curr_time

    
    if print_mean_timing is True:
      print('Mean inference time per embed: ', time/len(records), 'ms')
    return embeddings

In [24]:
def gen_bert_embeddings(records, model, 
                        tokenizer, 
                        batch_size,
                        max_tokens_count,
                        ):
  pivot = records_to_embeds([r[0] for r in records], embedder, 
                                           tokenizer, batch_size=8, 
                                           max_tokens_count=MAX_TOKENS,
                            )
  pos = records_to_embeds([r[1] for r in records], embedder, 
                                           tokenizer, batch_size=8, 
                                           max_tokens_count=MAX_TOKENS,
                           )
  neg = records_to_embeds([r[2] for r in records], embedder, 
                                           tokenizer, batch_size=8, 
                                           max_tokens_count=MAX_TOKENS,
                           )
  embeddings_ = np.append(pivot, pos, axis=1)
  embeddings_bert = np.append(embeddings_, neg, axis=1)
  embeddings_bert = embeddings_bert.reshape(-1, 3, TARGET_SIZE)
  return embeddings_bert

In [None]:
val_embeddings_bert = gen_bert_embeddings(val_records, embedder, 
                                          initial_tokenizer, batch_size=8, 
                                          max_tokens_count=MAX_TOKENS)
train_embeddings_bert = gen_bert_embeddings(train_records, embedder, 
                                            initial_tokenizer, batch_size=8, 
                                            max_tokens_count=MAX_TOKENS)

In [None]:
with open('train_embeddings_bert', 'wb') as pickle_file:
    pickle.dump(train_embeddings_bert, pickle_file)

with open('val_embeddings_bert', 'wb') as pickle_file:
    pickle.dump(val_embeddings_bert, pickle_file)

/content/drive/MyDrive/NewsBert


In [None]:
single_full_train_embeddings_bert = records_to_embeds(full_train_records, embedder, 
                                            initial_tokenizer, batch_size=8, 
                                            max_tokens_count=MAX_TOKENS)

with open('single_full_train_embeddings_bert', 'wb') as pickle_file:
    pickle.dump(single_full_train_embeddings_bert, pickle_file)

In [None]:
single_full_val_embeddings_bert = records_to_embeds(full_val_records, embedder, 
                                            initial_tokenizer, batch_size=8, 
                                            max_tokens_count=MAX_TOKENS)


with open('single_full_val_embeddings_bert', 'wb') as pickle_file:
    pickle.dump(single_full_val_embeddings_bert, pickle_file)

# Evaluation of initial model

In [None]:
def get_quality(markup, embeds, records, dist_threshold, print_result=False):
    clustering_model = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=dist_threshold,
        linkage="average",
        affinity="cosine"
    )

    clustering_model.fit(embeds)
    labels = clustering_model.labels_
    
    idx2url = dict()
    url2record = dict()
    for i, record in enumerate(records):
        idx2url[i] = record["url"]
        url2record[record["url"]] = record

    url2label = dict()
    for i, label in enumerate(labels):
        url2label[idx2url[i]] = label
        
    metrics = calc_metrics(markup, url2record, url2label)[0]
    if print_result:
        print()
        print("Accuracy: {:.1f}".format(metrics["accuracy"] * 100.0))
        print("Positives Recall: {:.1f}".format(metrics["1"]["recall"] * 100.0))
        print("Positives Precision: {:.1f}".format(metrics["1"]["precision"] * 100.0))
        print("Positives F1: {:.1f}".format(metrics["1"]["f1-score"] * 100.0))
        print("Distance: ", dist_threshold)
        sizes = list(Counter(labels).values())
        print("Max cluster size: ", max(sizes))
        print("Median cluster size: ", median(sizes))
        print("Avg cluster size: {:.2f}".format(mean(sizes)))
        return
    return metrics["1"]["f1-score"]

In [43]:
public_markup = read_markup_tsv("ru_clustering_0527_urls_final.tsv")

public_embeddings = records_to_embeds(public_set, embedder, initial_tokenizer, 
                                      batch_size=8, max_tokens_count=MAX_TOKENS,
                                      print_mean_timing=True)

get_quality(public_markup, public_embeddings, public_set, 0.38, True)

Mean inference time per embed:  7.8254074972464265 ms

Accuracy: 95.0
Positives Recall: 95.0
Positives Precision: 94.1
Positives F1: 94.6
Distance:  0.38
Max cluster size:  255
Median cluster size:  2.0
Avg cluster size: 3.23


In [44]:
private_set = []
with open("private.jsonl", "r") as r:
    for line in r:
        private_set.append(json.loads(line))


private_markup = read_markup_tsv("ru_clustering_0529_urls_final_v2.tsv")

private_embeddings = records_to_embeds(private_set, embedder, initial_tokenizer, 
                                      batch_size=8, max_tokens_count=MAX_TOKENS,
                                      print_mean_timing=True)

get_quality(private_markup, private_embeddings, private_set, 0.38, True)

Mean inference time per embed:  7.825124706241143 ms

Accuracy: 94.9
Positives Recall: 95.7
Positives Precision: 93.4
Positives F1: 94.5
Distance:  0.38
Max cluster size:  202
Median cluster size:  1
Avg cluster size: 3.14


In [None]:
with open('public_set', 'wb') as pickle_file:
    pickle.dump(public_set, pickle_file)

In [None]:
with open('private_set', 'wb') as pickle_file:
    pickle.dump(private_set, pickle_file)

#Distillation - Search for Best Score



In [7]:
class DistillEmbedder(nn.Module):
  def __init__(self, vocab_size, 
               word_emb_dim=128, 
               rnn_hidden_dim=128,
               rnn_layers_count=2,
               target_size=TARGET_SIZE, 
               to_gru=False,
               attentive_aggregation=False,
               ):

        super().__init__()

        self.target_size = TARGET_SIZE
        self.embedding = nn.Embedding(vocab_size, word_emb_dim)
        self.attentive_aggregation = attentive_aggregation
        
        if not to_gru:
          self.rnn = nn.LSTM(
            word_emb_dim, 
            rnn_hidden_dim, 
            num_layers=rnn_layers_count,
            bidirectional=True, 
            batch_first=True,
            )
        
        else:
          self.rnn = nn.GRU(
            word_emb_dim, 
            rnn_hidden_dim, 
            num_layers=rnn_layers_count,
            bidirectional=True, 
            batch_first=True,
            )


        self.linear = nn.Linear(rnn_hidden_dim*2, target_size)

        if attentive_aggregation:
          self.softmax = nn.Softmax(dim=1)     
          self.attn = nn.Sequential(
                nn.Linear(rnn_hidden_dim*2, rnn_hidden_dim),
                nn.ReLU(),
                nn.Linear(rnn_hidden_dim, 1)
            )
        

  def avg_aggregate(self, rnn_output, mask):
        sum_embeddings = torch.sum(rnn_output, 1)
        expanded_mask = mask.unsqueeze(-1).expand(rnn_output.size()).float()
        sum_mask = torch.clamp(expanded_mask.sum(1), min=1e-9).cuda()
        return sum_embeddings / sum_mask 
          
  
  def attentive_aggregate(self, rnn_output, mask):
        weights = self.softmax(self.attn(rnn_output).squeeze(-1)) * mask.cuda()
        embeddings = weights.unsqueeze(1).bmm(rnn_output).squeeze(1) 
        return embeddings       


  def forward(self, x, mask):
        lens = torch.sum(mask.cpu(), 1)
        x = self.embedding(x)
        #pack_padded_sequence before feeding into rnn
        x = pack_padded_sequence(x, lens, enforce_sorted=False, batch_first=True)
        x, _ = self.rnn(x)
        #pad_packed_sequence on our packed rnn output
        x, _ = pad_packed_sequence(x, batch_first=True, total_length=MAX_TOKENS)
        if self.attentive_aggregation:
          x = self.attentive_aggregate(x, mask)
        else:
          x = self.avg_aggregate(x, mask)
        x = self.linear(x)
        return x

In [8]:
class DistillClusteringModel(nn.Module):
    def __init__(self, vocab_size=119547,  
                 word_emb_dim=128, 
                 rnn_hidden_dim=128,
                 rnn_layers_count=2,
                 target_size=TARGET_SIZE, 
                 lr=1e-3, 
                 to_gru=False,
                 attentive_aggregation=False,
                 evaluate_cos_similarity=True,
                 ):
      
        super().__init__()

        self.embedder = DistillEmbedder(vocab_size=vocab_size, 
                                        word_emb_dim=word_emb_dim, 
                                        rnn_hidden_dim=rnn_hidden_dim, 
                                        rnn_layers_count=rnn_layers_count,
                                        target_size=target_size,
                                        to_gru=to_gru,
                                        attentive_aggregation=attentive_aggregation,
                                        )

        self.evaluate_cos_similarity = evaluate_cos_similarity
        if evaluate_cos_similarity:
          self.cosine_similarity = torch.nn.functional.cosine_similarity
        else:
          self.mse = torch.nn.MSELoss()

        self.lr = lr

    def forward(self, news):
        embeddings = self.embedder(news["input_ids"].cuda(), news['attention_mask'].cuda())
        return embeddings

    def loss(self, embeds, bert_embeds):
        if self.evaluate_cos_similarity:
          similarity = self.cosine_similarity(embeds, bert_embeds)
          loss = torch.mean(torch.ones(len(similarity)).cuda() - similarity)
        else: 
          loss = self.mse(embeds.float(), bert_embeds.float())
        return loss

In [9]:
def epoch_train(model, data_loader, optimizer, scheduler):
    train_loss = 0
    model.cuda().train()
    for i, samples in enumerate(tqdm(data_loader, desc='Train')):
       model.zero_grad()
       
       bert_embeds = samples[1].cuda()
       embeds = model(samples[0])

       loss = model.loss(embeds, bert_embeds)
       loss.backward()
       optimizer.step()
       train_loss += float(loss)
    loss = train_loss / len(data_loader)
    scheduler.step(loss)
    print('Train_loss: ', loss)
    return loss


def epoch_val(model, data_loader):
    val_loss = 0
    with torch.no_grad():  
      model.cuda().eval()
      for i, samples in enumerate(tqdm(data_loader, desc='Val')):
        bert_embeds = samples[1].cuda()
        embeds = model(samples[0])

        loss = model.loss(embeds, bert_embeds)
        val_loss += float(loss)
    loss = val_loss / len(data_loader)
    print('Val_loss: ', loss)
    return loss


def train(model, train_data_loader,
          val_data_loader, num_train_epochs, 
          output_dir, optimizer=torch.optim.Adam, 
          patience=3, scheduler_patience=2, 
          ):
      optimizer = optimizer(model.parameters(), lr=model.lr)
      scheduler = ReduceLROnPlateau(optimizer, patience=scheduler_patience)
      writer = SummaryWriter(output_dir)
      best_val_loss = 100000
      best_count = 0
      for epoch in range(num_train_epochs):
          train_loss = epoch_train(model, train_data_loader, optimizer, scheduler)
          val_loss = epoch_val(model, val_data_loader)
          writer.add_scalar('Loss/train', train_loss, epoch)
          writer.add_scalar('Loss/val', val_loss, epoch)
          writer.add_scalar('Learning rate', scheduler._last_lr[0], epoch)
          best_count += 1
         
          if val_loss < best_val_loss:
              best_val_loss = val_loss
              best_count = 0
              torch.save(model.state_dict(), os.path.join(output_dir, f'best-distill-bert.pt'))
          
          if best_count == patience:
              torch.save(model.state_dict(), os.path.join(output_dir, 'last_epoch.pt'))
              writer.close()
              break

      writer.close()
      torch.save(model.state_dict(), os.path.join(output_dir, 'last_epoch.pt'))

In [10]:
BATCH_SIZE = 128


train_loader, val_loader, tokenizer = get_loaders(full_train_records, full_val_records, 
                                       INITIAL_MODEL, MAX_TOKENS, BATCH_SIZE,
                                       single_full_train_embeddings_bert,
                                       single_full_val_embeddings_bert,
                                       )

##LSTM (2 layers) - Cosine similarity - Avg aggregation

In [None]:
lstm_distill_cos = DistillClusteringModel()

train(lstm_distill_cos, train_loader, val_loader, 
      60, '/content/cos/avg_lstm',
      )

In [104]:
lstm_distill_cos = DistillClusteringModel()

lstm_distill_cos.load_state_dict(
    torch.load('/content/cos/avg_lstm/best-distill-bert.pt'))

distill_embedder = lstm_distill_cos.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.3029242417964328 ms

Accuracy: 94.1
Positives Recall: 95.2
Positives Precision: 92.3
Positives F1: 93.7
Distance:  0.38
Max cluster size:  309
Median cluster size:  2.0
Avg cluster size: 3.94


In [105]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.1816775513634203 ms

Accuracy: 93.8
Positives Recall: 95.5
Positives Precision: 91.4
Positives F1: 93.4
Distance:  0.38
Max cluster size:  175
Median cluster size:  2.0
Avg cluster size: 3.78


##GRU (2 layers) - Cosine similarity - Avg aggregation

In [None]:
gru_distill_cos = DistillClusteringModel(to_gru=True)

train(gru_distill_cos, train_loader, val_loader, 
      60, '/content/cos/avg_gru',
      )

In [118]:
gru_distill_cos = DistillClusteringModel(to_gru=True)

gru_distill_cos.load_state_dict(
    torch.load('/content/cos/avg_gru/best-distill-bert.pt'))

distill_embedder = gru_distill_cos.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.1764846629593002 ms

Accuracy: 94.1
Positives Recall: 94.9
Positives Precision: 92.6
Positives F1: 93.7
Distance:  0.38
Max cluster size:  299
Median cluster size:  2
Avg cluster size: 3.87


In [119]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.1773488510838601 ms

Accuracy: 93.5
Positives Recall: 95.1
Positives Precision: 91.2
Positives F1: 93.1
Distance:  0.38
Max cluster size:  168
Median cluster size:  2.0
Avg cluster size: 3.76


##GRU (1 recurrent layer) - Cosine similarity - Avg aggregation

In [None]:
gru_1layer_distill_cos = DistillClusteringModel(to_gru=True, rnn_layers_count=1)

train(gru_1layer_distill_cos, train_loader, val_loader, 
      60, '/content/cos/avg_gru_1layer',
      )

In [120]:
gru_1layer_distill_cos = DistillClusteringModel(to_gru=True, rnn_layers_count=1)

gru_1layer_distill_cos.load_state_dict(
    torch.load('/content/cos/avg_gru_1layer/best-distill-bert.pt'))

distill_embedder = gru_1layer_distill_cos.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.6425558774357298 ms

Accuracy: 93.7
Positives Recall: 93.8
Positives Precision: 92.7
Positives F1: 93.3
Distance:  0.38
Max cluster size:  267
Median cluster size:  2
Avg cluster size: 3.61


In [121]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.7492244192931659 ms

Accuracy: 93.2
Positives Recall: 93.9
Positives Precision: 91.5
Positives F1: 92.7
Distance:  0.38
Max cluster size:  162
Median cluster size:  1.0
Avg cluster size: 3.49


##LSTM (2 layers) - MSE - Avg aggregation

In [None]:
lstm_distill_mse = DistillClusteringModel(evaluate_cos_similarity=False)

train(lstm_distill_mse, train_loader, val_loader, 
      60, '/content/mse/avg_lstm',
      )

Train: 100%|██████████| 4037/4037 [07:08<00:00,  9.42it/s]


Train_loss:  0.002570586533954467


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.99it/s]


Val_loss:  0.00202639509133054


Train: 100%|██████████| 4037/4037 [07:11<00:00,  9.35it/s]


Train_loss:  0.001748800347034319


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.24it/s]


Val_loss:  0.0015833360637482808


Train: 100%|██████████| 4037/4037 [07:10<00:00,  9.38it/s]


Train_loss:  0.0014395087583999595


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.12it/s]


Val_loss:  0.0013801177841436007


Train: 100%|██████████| 4037/4037 [07:13<00:00,  9.32it/s]


Train_loss:  0.0012766686059947035


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.28it/s]


Val_loss:  0.001264444411510952


Train: 100%|██████████| 4037/4037 [07:12<00:00,  9.34it/s]


Train_loss:  0.0011727956064188


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.22it/s]


Val_loss:  0.0011863792032731863


Train: 100%|██████████| 4037/4037 [07:09<00:00,  9.39it/s]


Train_loss:  0.0010993219216789978


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.43it/s]


Val_loss:  0.001130192329309475


Train: 100%|██████████| 4037/4037 [07:05<00:00,  9.48it/s]


Train_loss:  0.0010441806003021559


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.51it/s]


Val_loss:  0.0010881999153113638


Train: 100%|██████████| 4037/4037 [07:04<00:00,  9.50it/s]


Train_loss:  0.0010004175652673477


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.37it/s]


Val_loss:  0.0010551928472820448


Train: 100%|██████████| 4037/4037 [07:04<00:00,  9.51it/s]


Train_loss:  0.0009647131793003494


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.52it/s]


Val_loss:  0.0010296236757087603


Train: 100%|██████████| 4037/4037 [07:05<00:00,  9.49it/s]


Train_loss:  0.0009353351497255379


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.47it/s]


Val_loss:  0.0010065250130978846


Train: 100%|██████████| 4037/4037 [07:04<00:00,  9.52it/s]


Train_loss:  0.0009101479448821485


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.56it/s]


Val_loss:  0.000986877177948003


Train: 100%|██████████| 4037/4037 [07:03<00:00,  9.52it/s]


Train_loss:  0.0008883656087708201


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.43it/s]


Val_loss:  0.0009707469479319188


Train: 100%|██████████| 4037/4037 [07:05<00:00,  9.50it/s]


Train_loss:  0.0008692534832688266


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.41it/s]


Val_loss:  0.0009574761719092506


Train: 100%|██████████| 4037/4037 [07:05<00:00,  9.49it/s]


Train_loss:  0.0008522183232351764


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.52it/s]


Val_loss:  0.0009476382703760104


Train: 100%|██████████| 4037/4037 [07:03<00:00,  9.53it/s]


Train_loss:  0.0008370717954111339


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.57it/s]


Val_loss:  0.0009366313505590442


Train: 100%|██████████| 4037/4037 [07:07<00:00,  9.45it/s]


Train_loss:  0.0008232229862063163


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.44it/s]


Val_loss:  0.0009279817277634756


Train: 100%|██████████| 4037/4037 [07:08<00:00,  9.43it/s]


Train_loss:  0.000810673090110558


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.43it/s]


Val_loss:  0.0009182029387609351


Train: 100%|██████████| 4037/4037 [07:11<00:00,  9.35it/s]


Train_loss:  0.0007991256650951583


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.20it/s]


Val_loss:  0.0009124391000501185


Train: 100%|██████████| 4037/4037 [07:17<00:00,  9.23it/s]


Train_loss:  0.0007885336093248138


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.30it/s]


Val_loss:  0.0009052927485934579


Train: 100%|██████████| 4037/4037 [07:20<00:00,  9.16it/s]


Train_loss:  0.0007785942484850767


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.27it/s]


Val_loss:  0.0009024851330749617


Train: 100%|██████████| 4037/4037 [07:19<00:00,  9.19it/s]


Train_loss:  0.000769422641040719


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.19it/s]


Val_loss:  0.000895742390542245


Train: 100%|██████████| 4037/4037 [07:19<00:00,  9.19it/s]


Train_loss:  0.0007607941819830819


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.19it/s]


Val_loss:  0.0008897347768796331


Train: 100%|██████████| 4037/4037 [07:19<00:00,  9.19it/s]


Train_loss:  0.0007526962071159337


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.21it/s]


Val_loss:  0.0008857360250883652


Train: 100%|██████████| 4037/4037 [07:19<00:00,  9.19it/s]


Train_loss:  0.000745073495147132


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.17it/s]


Val_loss:  0.0008815685245623388


Train: 100%|██████████| 4037/4037 [07:21<00:00,  9.15it/s]


Train_loss:  0.000737911124210454


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.21it/s]


Val_loss:  0.0008764648148421162


Train: 100%|██████████| 4037/4037 [07:19<00:00,  9.19it/s]


Train_loss:  0.0007310768108432459


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.22it/s]


Val_loss:  0.0008741659088645794


Train: 100%|██████████| 4037/4037 [07:21<00:00,  9.15it/s]


Train_loss:  0.0007245863594275345


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.24it/s]


Val_loss:  0.0008726207797051324


Train: 100%|██████████| 4037/4037 [07:19<00:00,  9.18it/s]


Train_loss:  0.0007184022097867504


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.27it/s]


Val_loss:  0.0008705868162122546


Train: 100%|██████████| 4037/4037 [07:16<00:00,  9.25it/s]


Train_loss:  0.0007125421686205023


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.28it/s]


Val_loss:  0.0008690762908732759


Train: 100%|██████████| 4037/4037 [07:17<00:00,  9.22it/s]


Train_loss:  0.0007069731394449793


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.21it/s]


Val_loss:  0.0008631079101538526


Train: 100%|██████████| 4037/4037 [07:17<00:00,  9.22it/s]


Train_loss:  0.0007015312311725388


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.14it/s]


Val_loss:  0.0008655372242291117


Train: 100%|██████████| 4037/4037 [07:12<00:00,  9.34it/s]


Train_loss:  0.0006964360878235091


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.24it/s]


Val_loss:  0.0008593138450835048


Train: 100%|██████████| 4037/4037 [07:06<00:00,  9.46it/s]


Train_loss:  0.000691587577153741


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.46it/s]


Val_loss:  0.0008576919494708539


Train: 100%|██████████| 4037/4037 [07:07<00:00,  9.44it/s]


Train_loss:  0.0006867275929566614


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.39it/s]


Val_loss:  0.0008561228333953953


Train: 100%|██████████| 4037/4037 [07:17<00:00,  9.22it/s]


Train_loss:  0.0006822209708588251


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.19it/s]


Val_loss:  0.0008543723716981488


Train: 100%|██████████| 4037/4037 [07:18<00:00,  9.22it/s]


Train_loss:  0.0006777485601010172


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.23it/s]


Val_loss:  0.0008546702564338979


Train: 100%|██████████| 4037/4037 [07:19<00:00,  9.19it/s]


Train_loss:  0.0006735697454397546


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.40it/s]


Val_loss:  0.0008537636224369491


Train: 100%|██████████| 4037/4037 [07:08<00:00,  9.42it/s]


Train_loss:  0.000669489396988754


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.41it/s]


Val_loss:  0.0008528988852769616


Train: 100%|██████████| 4037/4037 [07:05<00:00,  9.48it/s]


Train_loss:  0.0006655706944995382


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.42it/s]


Val_loss:  0.0008505063382147968


Train: 100%|██████████| 4037/4037 [07:05<00:00,  9.49it/s]


Train_loss:  0.0006616485358998161


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.48it/s]


Val_loss:  0.0008493792550352467


Train: 100%|██████████| 4037/4037 [07:06<00:00,  9.47it/s]


Train_loss:  0.0006580627049425106


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.41it/s]


Val_loss:  0.0008492531913037569


Train: 100%|██████████| 4037/4037 [07:06<00:00,  9.47it/s]


Train_loss:  0.0006544523781627105


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.49it/s]


Val_loss:  0.0008469875159645729


Train: 100%|██████████| 4037/4037 [07:05<00:00,  9.48it/s]


Train_loss:  0.0006509019927295182


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.41it/s]


Val_loss:  0.0008467811389738639


Train: 100%|██████████| 4037/4037 [07:06<00:00,  9.47it/s]


Train_loss:  0.0006476026369969999


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.46it/s]


Val_loss:  0.000849428832056093


Train: 100%|██████████| 4037/4037 [07:07<00:00,  9.45it/s]


Train_loss:  0.0006443244395545733


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.45it/s]


Val_loss:  0.000846150559419305


Train: 100%|██████████| 4037/4037 [07:05<00:00,  9.48it/s]


Train_loss:  0.0006412083170976991


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.43it/s]


Val_loss:  0.0008453837437393826


Train: 100%|██████████| 4037/4037 [07:09<00:00,  9.39it/s]


Train_loss:  0.0006380132845897529


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.44it/s]


Val_loss:  0.0008456869644463284


Train: 100%|██████████| 4037/4037 [07:08<00:00,  9.43it/s]


Train_loss:  0.0006350548621429722


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.47it/s]


Val_loss:  0.0008462945987138492


Train: 100%|██████████| 4037/4037 [07:08<00:00,  9.42it/s]


Train_loss:  0.0006320025869585198


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.41it/s]


Val_loss:  0.0008461135931170112


In [None]:
lstm_distill_mse = DistillClusteringModel(evaluate_cos_similarity=False)

lstm_distill_mse.load_state_dict(
    torch.load('/content/mse/avg_lstm/best-distill-bert.pt'))

distill_embedder = lstm_distill_mse.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

In [None]:
lstm_distill_mse = DistillClusteringModel(evaluate_cos_similarity=False)

lstm_distill_mse.load_state_dict(
    torch.load('/content/mse/avg_lstm/best-distill-bert.pt'))

distill_embedder = lstm_distill_mse.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.4233783412739576 ms

Accuracy: 94.0
Positives Recall: 95.2
Positives Precision: 92.0
Positives F1: 93.6
Distance:  0.38
Max cluster size:  273
Median cluster size:  2.0
Avg cluster size: 4.03


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.5233920652296036 ms

Accuracy: 94.1
Positives Recall: 95.7
Positives Precision: 91.7
Positives F1: 93.7
Distance:  0.38
Max cluster size:  162
Median cluster size:  2
Avg cluster size: 3.88


##GRU (2 layers) - MSE - Avg aggregation

In [None]:
gru_distill_mse = DistillClusteringModel(to_gru=True, evaluate_cos_similarity=False)

train(gru_distill_mse, train_loader, val_loader, 
      60, '/content/mse/avg_gru',
      )

In [128]:
gru_distill_mse = DistillClusteringModel(to_gru=True, evaluate_cos_similarity=False)

gru_distill_mse.load_state_dict(
    torch.load('/content/mse/avg_gru/best-distill-bert.pt'))

distill_embedder = gru_distill_mse.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.1712667855371042 ms

Accuracy: 94.2
Positives Recall: 95.5
Positives Precision: 92.1
Positives F1: 93.8
Distance:  0.38
Max cluster size:  287
Median cluster size:  2
Avg cluster size: 4.30


In [129]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.24867987562834 ms

Accuracy: 93.7
Positives Recall: 96.0
Positives Precision: 90.7
Positives F1: 93.3
Distance:  0.38
Max cluster size:  186
Median cluster size:  2.0
Avg cluster size: 4.18


##LSTM (2 layers) - MSE - Attn aggregation

In [10]:
lstm_distill_mse_attn = DistillClusteringModel(attentive_aggregation=True)

train(lstm_distill_mse_attn, train_loader, val_loader, 
      60, '/content/mse/attn_lstm',
      )

Train: 100%|██████████| 4037/4037 [08:00<00:00,  8.40it/s]


Train_loss:  0.0820204289172091


Val: 100%|██████████| 1010/1010 [01:30<00:00, 11.18it/s]


Val_loss:  0.10997983059713477


Train: 100%|██████████| 4037/4037 [07:58<00:00,  8.44it/s]


Train_loss:  0.08116005112435644


Val: 100%|██████████| 1010/1010 [01:30<00:00, 11.11it/s]


Val_loss:  0.10973589121395062


Train: 100%|██████████| 4037/4037 [08:00<00:00,  8.41it/s]


Train_loss:  0.08041593980694511


Val: 100%|██████████| 1010/1010 [01:31<00:00, 11.10it/s]


Val_loss:  0.10970788900277803


Train: 100%|██████████| 4037/4037 [07:55<00:00,  8.49it/s]


Train_loss:  0.07973688762634347


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.49it/s]


Val_loss:  0.10944417862273434


Train: 100%|██████████| 4037/4037 [07:54<00:00,  8.51it/s]


Train_loss:  0.07908325895864764


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.75it/s]


Val_loss:  0.10934482486576677


Train: 100%|██████████| 4037/4037 [07:53<00:00,  8.52it/s]


Train_loss:  0.07846229386423188


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.75it/s]


Val_loss:  0.10931413531049271


Train: 100%|██████████| 4037/4037 [07:53<00:00,  8.52it/s]


Train_loss:  0.07787926999560249


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.71it/s]


Val_loss:  0.10907955667297284


Train: 100%|██████████| 4037/4037 [07:53<00:00,  8.52it/s]


Train_loss:  0.07731406948536482


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.66it/s]


Val_loss:  0.10910134725666776


Train: 100%|██████████| 4037/4037 [07:55<00:00,  8.49it/s]


Train_loss:  0.07677285819416793


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.47it/s]


Val_loss:  0.10912413392248828


Train: 100%|██████████| 4037/4037 [07:53<00:00,  8.52it/s]


Train_loss:  0.07625878304167562


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.76it/s]


Val_loss:  0.10891143503813006


Train: 100%|██████████| 4037/4037 [07:53<00:00,  8.53it/s]


Train_loss:  0.07576893142188487


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.77it/s]


Val_loss:  0.10891152821006626


Train: 100%|██████████| 4037/4037 [07:53<00:00,  8.53it/s]


Train_loss:  0.07529711676986878


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.42it/s]


Val_loss:  0.10887300891602936


Train: 100%|██████████| 4037/4037 [07:55<00:00,  8.49it/s]


Train_loss:  0.07483940706086827


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.75it/s]


Val_loss:  0.10882266166004421


Train: 100%|██████████| 4037/4037 [07:54<00:00,  8.52it/s]


Train_loss:  0.07440865736987944


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.70it/s]


Val_loss:  0.10885689547500317


Train: 100%|██████████| 4037/4037 [07:53<00:00,  8.53it/s]


Train_loss:  0.07399036991239435


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.75it/s]


Val_loss:  0.10898106478568795


Train: 100%|██████████| 4037/4037 [07:53<00:00,  8.52it/s]


Train_loss:  0.07358346994243228


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.70it/s]


Val_loss:  0.10884660462154934


In [11]:
lstm_distill_mse_attn = DistillClusteringModel(attentive_aggregation=True)

lstm_distill_mse_attn.load_state_dict(
    torch.load('/content/mse/attn_lstm/best-distill-bert.pt'))

distill_embedder = lstm_distill_mse_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.0988437845649948 ms

Accuracy: 94.2
Positives Recall: 95.1
Positives Precision: 92.4
Positives F1: 93.8
Distance:  0.38
Max cluster size:  277
Median cluster size:  2.0
Avg cluster size: 3.98


In [12]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.0930923494675393 ms

Accuracy: 94.0
Positives Recall: 95.6
Positives Precision: 91.7
Positives F1: 93.6
Distance:  0.38
Max cluster size:  179
Median cluster size:  2.0
Avg cluster size: 3.85


##GRU (2 layers) - MSE - Attn aggregation

In [24]:
gru_distill_mse_attn = DistillClusteringModel(attentive_aggregation=True, to_gru=True)

train(gru_distill_mse_attn, train_loader, val_loader, 
      60, '/content/mse/attn_gru',
      )

Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.68it/s]


Train_loss:  0.38346861638521096


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.40it/s]


Val_loss:  0.2718536614567818


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.23289792740879103


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.89it/s]


Val_loss:  0.21241273218327275


Train: 100%|██████████| 4037/4037 [07:39<00:00,  8.79it/s]


Train_loss:  0.18772440682297495


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.94it/s]


Val_loss:  0.18282600627355433


Train: 100%|██████████| 4037/4037 [07:37<00:00,  8.83it/s]


Train_loss:  0.16258095871869271


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.92it/s]


Val_loss:  0.1659092094979339


Train: 100%|██████████| 4037/4037 [07:36<00:00,  8.85it/s]


Train_loss:  0.14642882721626646


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.04it/s]


Val_loss:  0.15378804493890158


Train: 100%|██████████| 4037/4037 [07:38<00:00,  8.80it/s]


Train_loss:  0.1351188850655659


Val: 100%|██████████| 1010/1010 [01:30<00:00, 11.19it/s]


Val_loss:  0.14571240897579746


Train: 100%|██████████| 4037/4037 [07:45<00:00,  8.68it/s]


Train_loss:  0.12674598466229323


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.42it/s]


Val_loss:  0.14007437163253908


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.70it/s]


Train_loss:  0.12028858546860559


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.66it/s]


Val_loss:  0.13499191618706688


Train: 100%|██████████| 4037/4037 [07:42<00:00,  8.74it/s]


Train_loss:  0.11512851746971706


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.69it/s]


Val_loss:  0.13148613119516575


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.72it/s]


Train_loss:  0.11087341607723482


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.56it/s]


Val_loss:  0.12871551936402018


Train: 100%|██████████| 4037/4037 [07:42<00:00,  8.74it/s]


Train_loss:  0.10733589165911135


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.71it/s]


Val_loss:  0.1265425533237002


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.1043120141415505


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.52it/s]


Val_loss:  0.12483039902586673


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.10171994897980881


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.60it/s]


Val_loss:  0.12317293915253308


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.69it/s]


Train_loss:  0.09946132989039361


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.63it/s]


Val_loss:  0.12188089399778287


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.09745801469004982


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.69it/s]


Val_loss:  0.12067582983611633


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.09566676801920053


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.66it/s]


Val_loss:  0.11953163005258145


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.09408559293652499


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.75it/s]


Val_loss:  0.11872229578502808


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.69it/s]


Train_loss:  0.09263703223504874


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.60it/s]


Val_loss:  0.11817156242125028


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.72it/s]


Train_loss:  0.09132108572397526


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.67it/s]


Val_loss:  0.11752582854560605


Train: 100%|██████████| 4037/4037 [07:42<00:00,  8.73it/s]


Train_loss:  0.0901173961358303


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.66it/s]


Val_loss:  0.11700959470912578


Train: 100%|██████████| 4037/4037 [07:41<00:00,  8.74it/s]


Train_loss:  0.08901436187605606


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.55it/s]


Val_loss:  0.11637939131153313


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.08799007776433919


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.62it/s]


Val_loss:  0.11591290926980159


Train: 100%|██████████| 4037/4037 [07:42<00:00,  8.73it/s]


Train_loss:  0.08703560130599351


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.64it/s]


Val_loss:  0.11564688890126089


Train: 100%|██████████| 4037/4037 [07:41<00:00,  8.74it/s]


Train_loss:  0.08616170050526203


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.64it/s]


Val_loss:  0.11516840264957055


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.70it/s]


Train_loss:  0.08533550432966369


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.54it/s]


Val_loss:  0.11503452221753639


Train: 100%|██████████| 4037/4037 [07:42<00:00,  8.73it/s]


Train_loss:  0.08456292995216852


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.59it/s]


Val_loss:  0.11459711742844454


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.72it/s]


Train_loss:  0.08384453912700675


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.63it/s]


Val_loss:  0.11453237574295949


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.70it/s]


Train_loss:  0.0831396810061512


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.63it/s]


Val_loss:  0.11436451804834018


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.08249654703604126


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.58it/s]


Val_loss:  0.11427378923141734


Train: 100%|██████████| 4037/4037 [07:42<00:00,  8.72it/s]


Train_loss:  0.081889272427445


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.63it/s]


Val_loss:  0.11401315425895325


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.72it/s]


Train_loss:  0.08129797365753923


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.65it/s]


Val_loss:  0.1138589991695355


Train: 100%|██████████| 4037/4037 [07:42<00:00,  8.72it/s]


Train_loss:  0.08073716112420014


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.61it/s]


Val_loss:  0.11374596855817198


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.08021712822262603


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.58it/s]


Val_loss:  0.1134466825657283


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.70it/s]


Train_loss:  0.07972578012365707


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.66it/s]


Val_loss:  0.11326383834961994


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.70it/s]


Train_loss:  0.07923993963461877


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.69it/s]


Val_loss:  0.11317114854293484


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.70it/s]


Train_loss:  0.07878997751115886


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.55it/s]


Val_loss:  0.11316507051278671


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.69it/s]


Train_loss:  0.07834808127868523


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.63it/s]


Val_loss:  0.11321989572181967


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.69it/s]


Train_loss:  0.07791493813490222


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.55it/s]


Val_loss:  0.11321938218990138


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.69it/s]


Train_loss:  0.07750743959771624


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.66it/s]


Val_loss:  0.11303251312448376


Train: 100%|██████████| 4037/4037 [07:46<00:00,  8.66it/s]


Train_loss:  0.07712466596549802


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.57it/s]


Val_loss:  0.11285703619771462


Train: 100%|██████████| 4037/4037 [07:45<00:00,  8.67it/s]


Train_loss:  0.07675410421698052


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.53it/s]


Val_loss:  0.11295533263606702


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.69it/s]


Train_loss:  0.07639352084862863


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.59it/s]


Val_loss:  0.1131591580371787


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.69it/s]


Train_loss:  0.0760372601270945


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.60it/s]


Val_loss:  0.11301841844224975


In [39]:
gru_distill_mse_attn = DistillClusteringModel(attentive_aggregation=True, to_gru=True)

gru_distill_mse_attn.load_state_dict(
    torch.load('/content/mse/attn_gru/best-distill-bert.pt'))

distill_embedder = gru_distill_mse_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.53960982314144 ms

Accuracy: 94.3
Positives Recall: 95.1
Positives Precision: 92.8
Positives F1: 93.9
Distance:  0.38
Max cluster size:  279
Median cluster size:  2
Avg cluster size: 3.93


In [40]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.5448059391466005 ms

Accuracy: 94.1
Positives Recall: 95.3
Positives Precision: 92.0
Positives F1: 93.7
Distance:  0.38
Max cluster size:  174
Median cluster size:  2
Avg cluster size: 3.81


##GRU (2 layers) - MSE - Attn aggregation - Token embedding dimension 64

In [None]:
gru_distill_mse_attn = DistillClusteringModel(attentive_aggregation=True, to_gru=True, word_emb_dim=64)

train(gru_distill_mse_attn, train_loader, val_loader, 
      60, '/content/mse/attn_gru/64emb',
      )

In [52]:
gru_distill_mse_attn = DistillClusteringModel(attentive_aggregation=True, to_gru=True, word_emb_dim=64)

gru_distill_mse_attn.load_state_dict(
    torch.load('/content/mse/attn_gru/64emb/best-distill-bert.pt'))

distill_embedder = gru_distill_mse_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.5217533059804087 ms

Accuracy: 93.9
Positives Recall: 94.4
Positives Precision: 92.6
Positives F1: 93.5
Distance:  0.38
Max cluster size:  284
Median cluster size:  2
Avg cluster size: 4.01


In [53]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.5344844116143816 ms

Accuracy: 93.9
Positives Recall: 95.2
Positives Precision: 91.8
Positives F1: 93.5
Distance:  0.38
Max cluster size:  177
Median cluster size:  2
Avg cluster size: 3.86


#Error analysis

In [11]:
def get_errors(markup, embeds, records, dist_threshold):
    clustering_model = AgglomerativeClustering(
        n_clusters=None,
        distance_threshold=dist_threshold,
        linkage="average",
        affinity="cosine"
    )

    clustering_model.fit(embeds)
    labels = clustering_model.labels_
    
    idx2url = dict()
    url2record = dict()
    for i, record in enumerate(records):
        idx2url[i] = record["url"]
        url2record[record["url"]] = record

    url2label = dict()
    for i, label in enumerate(labels):
        url2label[idx2url[i]] = label
        
    errors = calc_metrics(markup, url2record, url2label)[1]
    return errors

In [13]:
gru_distill_mse_attn = DistillClusteringModel(attentive_aggregation=True, to_gru=True)

gru_distill_mse_attn.load_state_dict(
    torch.load('/content/mse/attn_gru/best-distill-bert.pt'))

distill_embedder = gru_distill_mse_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              )
errors = get_errors(public_markup, public_distill_embeddings, public_set, 0.38)

In [14]:
errors

[{'first_text': 'Инцидентом заинтересовалась прокуратура. \n       В Московском районе Петербурга днем 27-го мая сгорел гараж. Об этом сообщила пресс-служба ГУ МЧС по Петербургу.  Инцидент произошел на проспекте Космонавтов, 106. Сообщение о пожаре поступило в 13:35. Как стало известно, там полностью выгорел неэксплуатируемый гараж и частично обгорел еще один гараж, который также не используется. В 13:57 пожар ликвидировали. Сведений о пострадавших уточняются.  После происшествия прокуратура организовала проверку. В ведомстве намерены дать оценку ситуации, установить лиц, которые причастны к ситуации, а также установить правомерность нахождения на территории.  Сейчас на участке идет строительство жилого комплекса. Ранее в Сестрорецке  сгорел  нежилой частный дом.  Видео: "ДТП и ЧП Санкт-Петербург"',
  'first_title': 'Видео: на Космонавтов загорелись несколько гаражей',
  'first_url': 'https://piter.tv/event/pozhar_na_Kosmonavtov/',
  'prediction': 1,
  'second_text': 'Огонь полностью у

In [17]:
import pandas as pd
errors_128 = pd.DataFrame(errors)
errors_128.to_csv('/content/mse/attn_gru/errors.csv')

In [18]:
gru_distill_mse_attn = DistillClusteringModel(attentive_aggregation=True, to_gru=True, word_emb_dim=64)

gru_distill_mse_attn.load_state_dict(
    torch.load('/content/mse/attn_gru/64emb/best-distill-bert.pt'))

distill_embedder = gru_distill_mse_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, 
                                              batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              )
errors = get_errors(public_markup, public_distill_embeddings, public_set, 0.38)

In [19]:
errors

[{'first_text': 'Инцидентом заинтересовалась прокуратура. \n       В Московском районе Петербурга днем 27-го мая сгорел гараж. Об этом сообщила пресс-служба ГУ МЧС по Петербургу.  Инцидент произошел на проспекте Космонавтов, 106. Сообщение о пожаре поступило в 13:35. Как стало известно, там полностью выгорел неэксплуатируемый гараж и частично обгорел еще один гараж, который также не используется. В 13:57 пожар ликвидировали. Сведений о пострадавших уточняются.  После происшествия прокуратура организовала проверку. В ведомстве намерены дать оценку ситуации, установить лиц, которые причастны к ситуации, а также установить правомерность нахождения на территории.  Сейчас на участке идет строительство жилого комплекса. Ранее в Сестрорецке  сгорел  нежилой частный дом.  Видео: "ДТП и ЧП Санкт-Петербург"',
  'first_title': 'Видео: на Космонавтов загорелись несколько гаражей',
  'first_url': 'https://piter.tv/event/pozhar_na_Kosmonavtov/',
  'prediction': 1,
  'second_text': 'Огонь полностью у

In [20]:
errors_64 = pd.DataFrame(errors)
errors_64.to_csv('/content/mse/attn_gru/64emb/errors64.csv')

#Distillation Experiments

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


##Train cycle

In [45]:
def epoch_train(model, data_loader, optimizer):
    train_loss = 0
    model.cuda().train()
    for i, samples in enumerate(tqdm(data_loader, desc='Train')):
       model.zero_grad()
       if len(samples) == 6:
          bert_embeds = [samples[3].cuda(), samples[4].cuda(), samples[5].cuda()]
          embeds = model(samples[0], samples[1], samples[2])

       else:
          bert_embeds = samples[1].cuda()
          embeds = model(samples[0])

       loss = model.loss(embeds, bert_embeds)
      
       loss.backward()
       optimizer.step()
       train_loss += float(loss)
    loss = train_loss / len(data_loader)
    print('Train_loss: ', loss)
    return loss


def epoch_val(model, data_loader):
    val_loss = 0
    with torch.no_grad():  
      model.cuda().eval()
      for i, samples in enumerate(tqdm(data_loader, desc='Val')):
        if len(samples) == 6:
            bert_embeds = [samples[3].cuda(), samples[4].cuda(), samples[5].cuda()]
            embeds = model(samples[0], samples[1], samples[2])

        else:
            bert_embeds = samples[1].cuda()
            embeds = model(samples[0])

        loss = model.loss(embeds, bert_embeds)
        val_loss += float(loss)
    loss = val_loss / len(data_loader)
    print('Val_loss: ', loss)
    return loss



def train(model, train_data_loader, 
          val_data_loader, num_train_epochs, 
          output_dir, early_stopping = False, 
          patience=2, 
          optimizer=torch.optim.Adam, 
          ):
      optimizer = optimizer(model.parameters(), lr=model.lr) 
      writer = SummaryWriter(output_dir)
      best_val_loss = 100000
      best_count = 0
      for epoch in range(num_train_epochs):
          train_loss = epoch_train(model, train_data_loader, optimizer)
          val_loss = epoch_val(model, val_data_loader)
          writer.add_scalar('Loss/train', train_loss, epoch)
          writer.add_scalar('Loss/val', val_loss, epoch)
           
          best_count += 1
         
          if val_loss < best_val_loss:
              best_val_loss = val_loss
              best_count = 0
              torch.save(model.state_dict(), os.path.join(output_dir, f'best-distill-bert.pt'))
          
          if early_stopping:
            if best_count == patience:
              torch.save(model.state_dict(), os.path.join(output_dir, 'last_epoch.pt'))
              writer.close()
              break

      writer.close()
      torch.save(model.state_dict(), os.path.join(output_dir, 'last_epoch.pt'))

##--LSTM--

Distillation experiments based on LSTM architecture.

##Constructing Embedder model

In [7]:
class LSTM_Embedder(nn.Module):
  def __init__(self, vocab_size, target_size=TARGET_SIZE, 
               bidirectional=False, pretrained=False, 
               word_emb_dim=128, pretrained_embs=None, 
               rnn_hidden_dim=512, rnn_layers_count=1, 
               freeze_pretrained=False, add_relu=False, 
               del_linear=False, del_2_linear=False, 
               to_rnn=False, to_gru=False, 
               dropout=0.3, attentive_aggregation=False, 
               ):

        super().__init__()

        self.target_size = target_size
        self.del_linear = del_linear
        self.del_2_linear = del_2_linear
        self.add_relu = add_relu
        self.attentive_aggregation = attentive_aggregation

       
        if bidirectional:
          rnn_out_dim = 2*rnn_hidden_dim
        else:
          rnn_out_dim = rnn_hidden_dim

        if pretrained:
          self.embedding = nn.Embedding.from_pretrained(pretrained_embs, 
                                                        freeze=freeze_pretrained) 
          word_emb_dim = pretrained_embs.shape[1]
        else:
          self.embedding = nn.Embedding(vocab_size, word_emb_dim)

        if attentive_aggregation:
          self.softmax = nn.Softmax(dim=1)     
          self.attn = nn.Sequential(
                nn.Linear(rnn_out_dim, rnn_out_dim//2),
                nn.ReLU(),
                nn.Linear(rnn_out_dim//2, 1)
            )
        
        if to_rnn:
          self.rnn = nn.RNN(
            word_emb_dim, 
            rnn_hidden_dim, 
            num_layers=rnn_layers_count,
            bidirectional=bidirectional,
            batch_first=True,
            )
          
        elif to_gru:
          self.rnn = nn.GRU(
            word_emb_dim, 
            rnn_hidden_dim, 
            num_layers=rnn_layers_count,
            bidirectional=bidirectional,
            batch_first=True,
            )

        else:
          self.rnn = nn.LSTM(
            word_emb_dim, 
            rnn_hidden_dim, 
            num_layers=rnn_layers_count,
            bidirectional=bidirectional,
            batch_first=True,
            )
          
        
        self.dropout = nn.Dropout(dropout)
        self.model = nn.ModuleList([
          self.embedding,
          self.rnn,
          self.dropout,
          nn.Linear(rnn_out_dim, target_size),
          nn.Linear(target_size, target_size)])
        
        if self.del_linear:
          del self.model[3]
          self.model[3] = nn.Linear(rnn_out_dim, target_size)

        if self.del_2_linear:
          del self.model[3:]
        
  def aggregate(self, rnn_output, mask):
        expanded_mask = mask.unsqueeze(-1).expand(rnn_output.size()).float()
        sum_embeddings = torch.sum(rnn_output, 1)
        sum_mask = torch.clamp(expanded_mask.sum(1), min=1e-9).cuda()
        return sum_embeddings / sum_mask  

  def attentive_aggregate(self, rnn_output, mask):
        weights = self.softmax(self.attn(rnn_output).squeeze(-1)) * mask
        embeddings = weights.unsqueeze(1).bmm(rnn_output).squeeze(1)
        return embeddings

  def forward(self, x, mask):
        lens = torch.sum(mask.cpu(), 1)
        for i, layer in enumerate(self.model):
          if layer == self.rnn:
            #pack_padded_sequence before feeding into RNN
            x = pack_padded_sequence(x, lens, enforce_sorted=False, batch_first=True)
            x, _ = layer(x)
            #pad_packed_sequence on our packed RNN output
            x, _ = pad_packed_sequence(x, batch_first=True, total_length=MAX_TOKENS)
            #taking single output from the end of sequence
            if self.attentive_aggregation:
              x = self.attentive_aggregate(x, mask)
            else:
              x = self.aggregate(x, mask)
          elif self.add_relu and not self.del_linear and i==3:
            x = nn.functional.relu(layer(x))
          else:
            x = layer(x)

        return x

###Triplet + MSE (complex) loss Distillation

The main idea of the first experiments is to use for optimization both original Triplet loss and MSE loss between aggregated embeds. 

####Modeling and Preprocessing

In [7]:
%cd /content/drive/MyDrive/NewsBert
from initial_finetuned_model import ClusteringTripletModel
from loading_and_evaluation import gen_batch, records_to_embeds, get_quality, NewsDataset, get_loaders


with open('train_embeddings_bert', 'rb') as pickle_file:
    train_embeddings_bert = pickle.load(pickle_file)

with open('val_embeddings_bert', 'rb') as pickle_file:
    val_embeddings_bert = pickle.load(pickle_file)


with open('train_records', 'rb') as pickle_file:
    train_records = pickle.load(pickle_file)

with open('val_records', 'rb') as pickle_file:
    val_records = pickle.load(pickle_file)


with open('public_set', 'rb') as pickle_file:
    public_set = pickle.load(pickle_file)

with open('private_set', 'rb') as pickle_file:
    private_set = pickle.load(pickle_file)




initial_model = ClusteringTripletModel.load_from_checkpoint(
    model_path = INITIAL_MODEL,
    checkpoint_path = '/content/drive/MyDrive/NewsBert/best_clustering_news_bert-val_loss=0.0008.ckpt',
    num_training_steps = None
    )

embedder = initial_model.embedder.cuda()

VOCAB_SIZE = embedder.model.embeddings.word_embeddings.num_embeddings
BATCH_SIZE = 32

public_markup = read_markup_tsv("ru_clustering_0527_urls_final.tsv")
private_markup = read_markup_tsv("ru_clustering_0529_urls_final_v2.tsv")

/content/drive/MyDrive/NewsBert


Some weights of the model checkpoint at IlyaGusev/news_tg_rubert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at IlyaGusev/news_tg_rubert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.we

In [8]:
class ClusteringDistillModel(nn.Module):
    def __init__(self, vocab_size, target_size=TARGET_SIZE, 
                 bidirectional=False, pretrained=False, 
                 word_emb_dim=128, pretrained_embs=None, 
                 freeze_pretrained = False, 
                 rnn_hidden_dim=128, rnn_layers_count=1, 
                 dropout=0.3, margin=0.5, lr=1e-4, 
                 loss_alpha=0.5, add_relu=False,
                 del_linear=False, to_rnn=False,
                 to_gru=False,
                 del_2_linear=False,
                 attentive_aggregation=False,):
      
        super().__init__()

        self.embedder = LSTM_Embedder(vocab_size=vocab_size, 
                                      target_size=target_size, 
                                      bidirectional=bidirectional, 
                                      word_emb_dim=word_emb_dim, 
                                      rnn_hidden_dim=rnn_hidden_dim, 
                                      rnn_layers_count=rnn_layers_count, 
                                      pretrained=pretrained, 
                                      pretrained_embs=pretrained_embs, 
                                      freeze_pretrained=freeze_pretrained,
                                      dropout=dropout, add_relu=add_relu,
                                      del_linear=del_linear, to_rnn=to_rnn,
                                      to_gru=to_gru,
                                      del_2_linear=del_2_linear,
                                      attentive_aggregation=attentive_aggregation,
                                      )

        self.triplet_loss = nn.TripletMarginWithDistanceLoss(
            margin=margin,
            distance_function=nn.PairwiseDistance(p=2)
        )

        self.mse = torch.nn.MSELoss()
        self.lr = lr
        self.loss_alpha = loss_alpha

    def forward(self, pivots, positives, negatives):
        pivot_embeddings = self.embedder(pivots["input_ids"].cuda(), pivots["attention_mask"])
        positive_embeddings = self.embedder(positives["input_ids"].cuda(), positives["attention_mask"])
        negative_embeddings = self.embedder(negatives["input_ids"].cuda(), negatives["attention_mask"])
        return pivot_embeddings, positive_embeddings, negative_embeddings

    def loss(self, embeds, bert_embeds):
        anchor, positive, negative = embeds
        triplet_loss = self.triplet_loss(anchor, positive, negative) 
        distill_loss = torch.mean(torch.tensor([self.mse(embed, bert_embed) 
                          for embed, bert_embed in zip(embeds, bert_embeds)]))
        loss = self.loss_alpha*triplet_loss + (1-self.loss_alpha)*distill_loss
        return loss

In [9]:
train_loader_triplets, val_loader_triplets, tokenizer_triplets = get_loaders(
                                       train_records, val_records, 
                                       INITIAL_MODEL, MAX_TOKENS, 
                                       BATCH_SIZE, train_embeddings_bert,
                                       val_embeddings_bert,
                                       )

####Experiments: Simple LSTM

In [32]:
#Balancing loss_alpha
#Loss_alpha = 0.5

lstm_distill = ClusteringDistillModel(vocab_size = VOCAB_SIZE,)
train(lstm_distill, train_loader_triplets, val_loader_triplets, 
      100, '/content/drive/MyDrive/NewsBert/Distillation_Triplet_MSE/LSTM',
      True
      )

Train: 100%|██████████| 471/471 [00:51<00:00,  9.10it/s]


Train_loss:  0.1503095451023745


Val: 100%|██████████| 118/118 [00:10<00:00, 11.65it/s]


Val_loss:  0.09885545191013893


Train: 100%|██████████| 471/471 [00:41<00:00, 11.23it/s]


Train_loss:  0.08881791389633256


Val: 100%|██████████| 118/118 [00:09<00:00, 12.94it/s]


Val_loss:  0.08085560428710573


Train: 100%|██████████| 471/471 [00:40<00:00, 11.65it/s]


Train_loss:  0.06584491088934437


Val: 100%|██████████| 118/118 [00:09<00:00, 12.78it/s]


Val_loss:  0.06763292030306467


Train: 100%|██████████| 471/471 [00:40<00:00, 11.74it/s]


Train_loss:  0.04915530487620884


Val: 100%|██████████| 118/118 [00:09<00:00, 13.09it/s]


Val_loss:  0.058782256720192286


Train: 100%|██████████| 471/471 [00:39<00:00, 11.82it/s]


Train_loss:  0.03706373697019781


Val: 100%|██████████| 118/118 [00:09<00:00, 13.07it/s]


Val_loss:  0.052690788125250534


Train: 100%|██████████| 471/471 [00:39<00:00, 11.81it/s]


Train_loss:  0.03023230387940017


Val: 100%|██████████| 118/118 [00:09<00:00, 13.00it/s]


Val_loss:  0.05051180982940073


Train: 100%|██████████| 471/471 [00:39<00:00, 11.92it/s]


Train_loss:  0.025745480031588275


Val: 100%|██████████| 118/118 [00:09<00:00, 13.01it/s]


Val_loss:  0.0490557018926597


Train: 100%|██████████| 471/471 [00:39<00:00, 11.95it/s]


Train_loss:  0.02472762973668529


Val: 100%|██████████| 118/118 [00:09<00:00, 13.05it/s]


Val_loss:  0.053586144850901164


Train: 100%|██████████| 471/471 [00:39<00:00, 12.01it/s]


Train_loss:  0.024115286718523383


Val: 100%|██████████| 118/118 [00:09<00:00, 13.09it/s]


Val_loss:  0.05488807492431513


In [33]:
lstm_distill = ClusteringDistillModel(vocab_size = VOCAB_SIZE)

lstm_distill.load_state_dict(torch.load('/content/drive/MyDrive/NewsBert/Distillation_Triplet_MSE/LSTM/best-distill-bert.pt'))
distill_embedder = lstm_distill.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, 
                                              tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)


Accuracy: 59.3
Positives Recall: 29.5
Positives Precision: 62.5
Positives F1: 40.1
Distance:  0.38
Max cluster size:  4238
Median cluster size:  852
Avg cluster size: 1181.18


In [34]:
lstm_distill = ClusteringDistillModel(vocab_size = VOCAB_SIZE)

lstm_distill.load_state_dict(torch.load('/content/drive/MyDrive/NewsBert/Distillation_Triplet_MSE/LSTM/best-distill-bert.pt'))
distill_embedder = lstm_distill.embedder.cuda()

private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.557725570331221 ms

Accuracy: 58.2
Positives Recall: 29.1
Positives Precision: 59.1
Positives F1: 39.0
Distance:  0.38
Max cluster size:  3606
Median cluster size:  340.0
Avg cluster size: 1060.89


In [35]:
#Loss_alpha = 0.2

lstm_distill = ClusteringDistillModel(vocab_size = VOCAB_SIZE, loss_alpha=0.2)
train(lstm_distill, train_loader_triplets, val_loader_triplets, 
      100, '/content/drive/MyDrive/NewsBert/Distillation_Triplet_MSE/LSTM/alpha02',
      True
      )

Train: 100%|██████████| 471/471 [00:39<00:00, 11.88it/s]


Train_loss:  0.06694279977935129


Val: 100%|██████████| 118/118 [00:09<00:00, 13.01it/s]


Val_loss:  0.04583849195074617


Train: 100%|██████████| 471/471 [00:39<00:00, 11.87it/s]


Train_loss:  0.04303813540646185


Val: 100%|██████████| 118/118 [00:09<00:00, 12.90it/s]


Val_loss:  0.039105938816024546


Train: 100%|██████████| 471/471 [00:39<00:00, 11.86it/s]


Train_loss:  0.03366175267582279


Val: 100%|██████████| 118/118 [00:08<00:00, 13.46it/s]


Val_loss:  0.035405825622049664


Train: 100%|██████████| 471/471 [00:39<00:00, 11.93it/s]


Train_loss:  0.02829504294916942


Val: 100%|██████████| 118/118 [00:08<00:00, 13.54it/s]


Val_loss:  0.032402520669749725


Train: 100%|██████████| 471/471 [00:39<00:00, 11.90it/s]


Train_loss:  0.02551827606707645


Val: 100%|██████████| 118/118 [00:08<00:00, 13.31it/s]


Val_loss:  0.03232035996699432


Train: 100%|██████████| 471/471 [00:39<00:00, 11.85it/s]


Train_loss:  0.024955270698010238


Val: 100%|██████████| 118/118 [00:08<00:00, 13.31it/s]


Val_loss:  0.03298407479448681


Train: 100%|██████████| 471/471 [00:39<00:00, 11.90it/s]


Train_loss:  0.02581546513365162


Val: 100%|██████████| 118/118 [00:08<00:00, 13.26it/s]


Val_loss:  0.03440205811582432


In [36]:
lstm_distill.load_state_dict(torch.load('/content/drive/MyDrive/NewsBert/Distillation_Triplet_MSE/LSTM/alpha02/best-distill-bert.pt'))
distill_embedder = lstm_distill.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, 
                                              tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)


Accuracy: 59.8
Positives Recall: 25.7
Positives Precision: 66.6
Positives F1: 37.0
Distance:  0.38
Max cluster size:  2896
Median cluster size:  487
Avg cluster size: 803.20


In [37]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.568876001642097 ms

Accuracy: 58.3
Positives Recall: 22.7
Positives Precision: 62.8
Positives F1: 33.4
Distance:  0.38
Max cluster size:  2234
Median cluster size:  669.0
Avg cluster size: 795.67


In [38]:
#Loss_alpha = 1

lstm_distill = ClusteringDistillModel(vocab_size = VOCAB_SIZE, loss_alpha=1)
train(lstm_distill, train_loader_triplets, val_loader_triplets,
      100, '/content/drive/MyDrive/NewsBert/Distillation_Triplet_MSE/LSTM/alpha1',
      True
      )

Train: 100%|██████████| 471/471 [00:39<00:00, 11.95it/s]


Train_loss:  0.298526287142132


Val: 100%|██████████| 118/118 [00:08<00:00, 13.34it/s]


Val_loss:  0.18547999025401424


Train: 100%|██████████| 471/471 [00:39<00:00, 11.93it/s]


Train_loss:  0.16016919251072623


Val: 100%|██████████| 118/118 [00:08<00:00, 13.47it/s]


Val_loss:  0.13999434292190155


Train: 100%|██████████| 471/471 [00:39<00:00, 11.80it/s]


Train_loss:  0.10845156988393982


Val: 100%|██████████| 118/118 [00:08<00:00, 13.39it/s]


Val_loss:  0.11543299807077748


Train: 100%|██████████| 471/471 [00:39<00:00, 11.91it/s]


Train_loss:  0.07421976611643018


Val: 100%|██████████| 118/118 [00:08<00:00, 13.46it/s]


Val_loss:  0.09418681083973181


Train: 100%|██████████| 471/471 [00:39<00:00, 11.78it/s]


Train_loss:  0.0500928279894331


Val: 100%|██████████| 118/118 [00:08<00:00, 13.48it/s]


Val_loss:  0.08353016371601972


Train: 100%|██████████| 471/471 [00:39<00:00, 11.92it/s]


Train_loss:  0.03326604099283806


Val: 100%|██████████| 118/118 [00:08<00:00, 13.27it/s]


Val_loss:  0.07715326672325194


Train: 100%|██████████| 471/471 [00:39<00:00, 11.87it/s]


Train_loss:  0.022406843577901143


Val: 100%|██████████| 118/118 [00:08<00:00, 13.30it/s]


Val_loss:  0.07224510742698685


Train: 100%|██████████| 471/471 [00:39<00:00, 11.86it/s]


Train_loss:  0.014568123048974197


Val: 100%|██████████| 118/118 [00:08<00:00, 13.31it/s]


Val_loss:  0.07667134372296475


Train: 100%|██████████| 471/471 [00:39<00:00, 11.98it/s]


Train_loss:  0.009834624257436983


Val: 100%|██████████| 118/118 [00:08<00:00, 13.49it/s]


Val_loss:  0.07173482897707213


Train: 100%|██████████| 471/471 [00:39<00:00, 12.01it/s]


Train_loss:  0.00721051694821337


Val: 100%|██████████| 118/118 [00:08<00:00, 13.45it/s]


Val_loss:  0.07487122484056627


Train: 100%|██████████| 471/471 [00:38<00:00, 12.15it/s]


Train_loss:  0.004572053354889977


Val: 100%|██████████| 118/118 [00:08<00:00, 13.37it/s]


Val_loss:  0.07306267103275


In [39]:
lstm_distill.load_state_dict(torch.load('/content/drive/MyDrive/NewsBert/Distillation_Triplet_MSE/LSTM/alpha1/best-distill-bert.pt'))
distill_embedder = lstm_distill.embedder.cuda()

public_distill_embeddings = records_to_embeds(public_set, distill_embedder, 
                                              tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,
                                              )
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)


Accuracy: 59.7
Positives Recall: 27.5
Positives Precision: 64.8
Positives F1: 38.6
Distance:  0.38
Max cluster size:  2639
Median cluster size:  1174.0
Avg cluster size: 1004.00


In [40]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.5646069075679699 ms

Accuracy: 59.0
Positives Recall: 23.7
Positives Precision: 64.9
Positives F1: 34.7
Distance:  0.38
Max cluster size:  2555
Median cluster size:  1086
Avg cluster size: 830.26


In [41]:
#Loss_alpha = 0

lstm_distill = ClusteringDistillModel(vocab_size = VOCAB_SIZE, loss_alpha=0)
train(lstm_distill, train_loader_triplets, val_loader_triplets,
      100, '/content/drive/MyDrive/NewsBert/Distillation_Triplet_MSE/LSTM/alpha0',
      True)

Train: 100%|██████████| 471/471 [00:38<00:00, 12.18it/s]


Train_loss:  0.006398983694228383


Val: 100%|██████████| 118/118 [00:08<00:00, 13.68it/s]


Val_loss:  0.0063250847722447785


Train: 100%|██████████| 471/471 [00:38<00:00, 12.24it/s]


Train_loss:  0.006398501799670968


Val: 100%|██████████| 118/118 [00:08<00:00, 13.74it/s]


Val_loss:  0.0063250847722447785


Train: 100%|██████████| 471/471 [00:38<00:00, 12.25it/s]


Train_loss:  0.006398701816241324


Val: 100%|██████████| 118/118 [00:08<00:00, 13.80it/s]


Val_loss:  0.0063250847722447785


In [42]:
lstm_distill.load_state_dict(torch.load('/content/drive/MyDrive/NewsBert/Distillation_Triplet_MSE/LSTM/alpha0/best-distill-bert.pt'))
distill_embedder = lstm_distill.embedder.cuda()

public_distill_embeddings = records_to_embeds(public_set, distill_embedder, 
                                              tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,
                                              )
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)


Accuracy: 46.1
Positives Recall: 100.0
Positives Precision: 46.1
Positives F1: 63.1
Distance:  0.38
Max cluster size:  20080
Median cluster size:  20080
Avg cluster size: 20080.00


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [43]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.5607965713297417 ms

Accuracy: 46.0
Positives Recall: 100.0
Positives Precision: 46.0
Positives F1: 63.0
Distance:  0.38
Max cluster size:  19096
Median cluster size:  19096
Avg cluster size: 19096.00


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Any application of the complex loss doesn't give an accaptable result.

###Single loss Distilation

Further experiments are based on the single loss function approach. For this functions were chosen MSE and Cosine similarity evaluations.

####Modeling and Preprocessing

In [4]:
%cd /content/drive/MyDrive/NewsBert
from initial_finetuned_model import ClusteringTripletModel
from loading_and_evaluation import gen_batch, records_to_embeds, get_quality, NewsDataset, get_loaders

with open('single_full_train_embeddings_bert', 'rb') as pickle_file:
    single_full_train_embeddings_bert = pickle.load(pickle_file)

with open('single_full_val_embeddings_bert', 'rb') as pickle_file:
    single_full_val_embeddings_bert = pickle.load(pickle_file)


with open('full_train_records', 'rb') as pickle_file:
    full_train_records = pickle.load(pickle_file)

with open('full_val_records', 'rb') as pickle_file:
    full_val_records = pickle.load(pickle_file)


with open('public_set', 'rb') as pickle_file:
    public_set = pickle.load(pickle_file)

with open('private_set', 'rb') as pickle_file:
    private_set = pickle.load(pickle_file)


initial_model = ClusteringTripletModel.load_from_checkpoint(
    model_path = INITIAL_MODEL,
    checkpoint_path = '/content/drive/MyDrive/NewsBert/best_clustering_news_bert-val_loss=0.0008.ckpt',
    num_training_steps = None
    )

embedder = initial_model.embedder.cuda()

VOCAB_SIZE = embedder.model.embeddings.word_embeddings.num_embeddings
BATCH_SIZE = 128


public_markup = read_markup_tsv("ru_clustering_0527_urls_final.tsv")
private_markup = read_markup_tsv("ru_clustering_0529_urls_final_v2.tsv")

/content/drive/MyDrive/NewsBert


Downloading:   0%|          | 0.00/831 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/679M [00:00<?, ?B/s]

Some weights of the model checkpoint at IlyaGusev/news_tg_rubert were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at IlyaGusev/news_tg_rubert and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.

In [5]:
class SingleClusteringDistillModel(nn.Module):
    def __init__(self, vocab_size, target_size=TARGET_SIZE, 
                 bidirectional=False, pretrained=False, 
                 word_emb_dim=128, pretrained_embs=None, 
                 freeze_pretrained=False, evaluate_similarity=False,
                 rnn_hidden_dim=128, rnn_layers_count=1, 
                 dropout=0.3, lr=1e-3, add_relu=False,
                 del_linear=False, to_rnn=False, 
                 to_gru=False, del_2_linear=False, 
                 attentive_aggregation=False,
                 ):
      
        super().__init__()

        self.embedder = LSTM_Embedder(vocab_size=vocab_size, 
                                      target_size=target_size, 
                                      bidirectional=bidirectional, 
                                      word_emb_dim=word_emb_dim, 
                                      rnn_hidden_dim=rnn_hidden_dim, 
                                      rnn_layers_count=rnn_layers_count,
                                      pretrained=pretrained, 
                                      pretrained_embs=pretrained_embs, 
                                      freeze_pretrained=freeze_pretrained,
                                      dropout=dropout, add_relu=add_relu, 
                                      del_linear=del_linear, to_rnn=to_rnn,
                                      to_gru=to_gru,
                                      del_2_linear=del_2_linear,
                                      attentive_aggregation=attentive_aggregation,
                                      )

        self.evaluate_similarity = evaluate_similarity
        if not evaluate_similarity:
          self.mse = torch.nn.MSELoss()
        else:
          self.cosine_similarity = torch.nn.functional.cosine_similarity
        
        self.lr = lr

    def forward(self, news):
        embeddings = self.embedder(news["input_ids"].cuda(), news["attention_mask"])
        return embeddings

    def loss(self, embeds, bert_embeds):
        if self.evaluate_similarity is True:
          similarity = self.cosine_similarity(embeds, bert_embeds)
          loss = torch.mean(torch.ones(len(similarity)).cuda() - similarity)
        else:
          loss = self.mse(embeds.float(), bert_embeds.float())
        return loss

In [6]:
train_loader, val_loader, tokenizer = get_loaders(full_train_records,
                                                  full_val_records, 
                                                  INITIAL_MODEL, MAX_TOKENS, 
                                                  BATCH_SIZE, 
                                                  single_full_train_embeddings_bert,
                                                  single_full_val_embeddings_bert,
                                                  )                                       

Downloading:   0%|          | 0.00/1.57M [00:00<?, ?B/s]

####Experiments: Simple LSTM - MSE/Cosine similarity loss


In [None]:
single_lstm_distill = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE)
train(single_lstm_distill, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_MSE/LSTM',
      )

Train: 100%|██████████| 4037/4037 [06:21<00:00, 10.59it/s]


Train_loss:  0.0030305164414906683


Val: 100%|██████████| 1010/1010 [01:32<00:00, 10.93it/s]


Val_loss:  0.0025779105065906844


Train: 100%|██████████| 4037/4037 [06:25<00:00, 10.46it/s]


Train_loss:  0.002577592715829083


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.71it/s]


Val_loss:  0.0023123252350781666


Train: 100%|██████████| 4037/4037 [06:24<00:00, 10.50it/s]


Train_loss:  0.002390587687974988


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.69it/s]


Val_loss:  0.0021365138377053756


Train: 100%|██████████| 4037/4037 [06:27<00:00, 10.43it/s]


Train_loss:  0.0022656855227392012


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.89it/s]


Val_loss:  0.0020281106666984535


Train: 100%|██████████| 4037/4037 [06:21<00:00, 10.57it/s]


Train_loss:  0.002176881558336788


Val: 100%|██████████| 1010/1010 [01:24<00:00, 12.02it/s]


Val_loss:  0.0019365171153826804


Train: 100%|██████████| 4037/4037 [06:23<00:00, 10.54it/s]


Train_loss:  0.0021111736394066327


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.79it/s]


Val_loss:  0.0018815728242158668


Train: 100%|██████████| 4037/4037 [06:23<00:00, 10.52it/s]


Train_loss:  0.002059369355890921


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.93it/s]


Val_loss:  0.0018254422187334903


Train: 100%|██████████| 4037/4037 [06:25<00:00, 10.48it/s]


Train_loss:  0.0020170587023922238


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.60it/s]


Val_loss:  0.0017820566191826717


Train: 100%|██████████| 4037/4037 [06:29<00:00, 10.37it/s]


Train_loss:  0.0019807421146048354


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.79it/s]


Val_loss:  0.001750758173760248


Train: 100%|██████████| 4037/4037 [06:21<00:00, 10.59it/s]


Train_loss:  0.001949258153428316


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.94it/s]


Val_loss:  0.0017304953388261176


Train: 100%|██████████| 4037/4037 [06:20<00:00, 10.62it/s]


Train_loss:  0.0019226399952396271


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.99it/s]


Val_loss:  0.0016956714150209975


Train: 100%|██████████| 4037/4037 [06:26<00:00, 10.45it/s]


Train_loss:  0.0018985918700553852


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.91it/s]


Val_loss:  0.001672309753139236


Train: 100%|██████████| 4037/4037 [06:25<00:00, 10.48it/s]


Train_loss:  0.0018780859995145209


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.92it/s]


Val_loss:  0.0016506369895357086


Train: 100%|██████████| 4037/4037 [06:20<00:00, 10.61it/s]


Train_loss:  0.0018591289617862217


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.93it/s]


Val_loss:  0.0016356183418838105


Train: 100%|██████████| 4037/4037 [06:24<00:00, 10.51it/s]


Train_loss:  0.0018420238260655915


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.08it/s]


Val_loss:  0.0016211904343309822


Train: 100%|██████████| 4037/4037 [06:19<00:00, 10.63it/s]


Train_loss:  0.0018264604157094324


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.09it/s]


Val_loss:  0.001603546829221051


Train: 100%|██████████| 4037/4037 [06:21<00:00, 10.58it/s]


Train_loss:  0.0018112826888819298


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.05it/s]


Val_loss:  0.001599798277642873


Train: 100%|██████████| 4037/4037 [06:23<00:00, 10.52it/s]


Train_loss:  0.001798936931695766


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.09it/s]


Val_loss:  0.001584819910038115


Train: 100%|██████████| 4037/4037 [06:23<00:00, 10.52it/s]


Train_loss:  0.001786858276330241


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.98it/s]


Val_loss:  0.001578715230144634


Train: 100%|██████████| 4037/4037 [06:22<00:00, 10.56it/s]


Train_loss:  0.0017747614231765278


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.07it/s]


Val_loss:  0.001575955296151967


In [None]:
single_lstm_distill = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE)
single_lstm_distill.load_state_dict(torch.load('/content/drive/MyDrive/NewsBert/Distillation_MSE/LSTM/best-distill-bert.pt'))

distill_embedder = single_lstm_distill.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.6796909511326794 ms

Accuracy: 91.9
Positives Recall: 92.8
Positives Precision: 90.0
Positives F1: 91.4
Distance:  0.38
Max cluster size:  304
Median cluster size:  2.0
Avg cluster size: 5.75


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.6796732293320941 ms

Accuracy: 92.2
Positives Recall: 94.1
Positives Precision: 89.5
Positives F1: 91.8
Distance:  0.38
Max cluster size:  172
Median cluster size:  2
Avg cluster size: 5.52


In [None]:
single_lstm_distill_cos_similarity = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE, 
                                                                  evaluate_similarity=True)
train(single_lstm_distill_cos_similarity, train_loader, val_loader,
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/LSTM',
      )

Train: 100%|██████████| 4037/4037 [06:19<00:00, 10.63it/s]


Train_loss:  0.4933299220963145


Val: 100%|██████████| 1010/1010 [01:24<00:00, 12.02it/s]


Val_loss:  0.35678940173006213


Train: 100%|██████████| 4037/4037 [06:14<00:00, 10.77it/s]


Train_loss:  0.35668091748857256


Val: 100%|██████████| 1010/1010 [01:24<00:00, 12.00it/s]


Val_loss:  0.2895029033210716


Train: 100%|██████████| 4037/4037 [06:15<00:00, 10.76it/s]


Train_loss:  0.3138388078715552


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.04it/s]


Val_loss:  0.2576951956849872


Train: 100%|██████████| 4037/4037 [06:19<00:00, 10.64it/s]


Train_loss:  0.2906529806642783


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.94it/s]


Val_loss:  0.23751375039127784


Train: 100%|██████████| 4037/4037 [06:19<00:00, 10.64it/s]


Train_loss:  0.27567522402176353


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.93it/s]


Val_loss:  0.22487873541431605


Train: 100%|██████████| 4037/4037 [06:18<00:00, 10.67it/s]


Train_loss:  0.2650115131797531


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.95it/s]


Val_loss:  0.21470367977919014


Train: 100%|██████████| 4037/4037 [06:20<00:00, 10.61it/s]


Train_loss:  0.25723976800672615


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.07it/s]


Val_loss:  0.20768504878439836


Train: 100%|██████████| 4037/4037 [06:19<00:00, 10.64it/s]


Train_loss:  0.25083734965855686


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.04it/s]


Val_loss:  0.20159526563473643


Train: 100%|██████████| 4037/4037 [06:15<00:00, 10.74it/s]


Train_loss:  0.24577853315447945


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.04it/s]


Val_loss:  0.1974640516765905


Train: 100%|██████████| 4037/4037 [06:18<00:00, 10.66it/s]


Train_loss:  0.24154694389228126


Val: 100%|██████████| 1010/1010 [01:24<00:00, 12.01it/s]


Val_loss:  0.19308987423555404


Train: 100%|██████████| 4037/4037 [06:18<00:00, 10.68it/s]


Train_loss:  0.23790843808004258


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.90it/s]


Val_loss:  0.1899560239922648


Train: 100%|██████████| 4037/4037 [06:20<00:00, 10.61it/s]


Train_loss:  0.23478238880325694


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.86it/s]


Val_loss:  0.18734528049627472


Train: 100%|██████████| 4037/4037 [06:15<00:00, 10.75it/s]


Train_loss:  0.23201919877556734


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.12it/s]


Val_loss:  0.18483829391098894


Train: 100%|██████████| 4037/4037 [06:15<00:00, 10.75it/s]


Train_loss:  0.22973526951806825


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.08it/s]


Val_loss:  0.18300062650271395


Train: 100%|██████████| 4037/4037 [06:13<00:00, 10.80it/s]


Train_loss:  0.2275471778189405


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.12it/s]


Val_loss:  0.18139859601436292


Train: 100%|██████████| 4037/4037 [06:18<00:00, 10.66it/s]


Train_loss:  0.22575378736481316


Val: 100%|██████████| 1010/1010 [01:25<00:00, 11.87it/s]


Val_loss:  0.1797150224953454


Train: 100%|██████████| 4037/4037 [06:13<00:00, 10.81it/s]


Train_loss:  0.22402894184615293


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.16it/s]


Val_loss:  0.1780615165702465


Train: 100%|██████████| 4037/4037 [06:15<00:00, 10.75it/s]


Train_loss:  0.22236760755414703


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.11it/s]


Val_loss:  0.17701277062821813


Train: 100%|██████████| 4037/4037 [06:13<00:00, 10.81it/s]


Train_loss:  0.22090284890869355


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.10it/s]


Val_loss:  0.17594625835229558


Train: 100%|██████████| 4037/4037 [06:15<00:00, 10.76it/s]


Train_loss:  0.21959078252665573


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.14it/s]


Val_loss:  0.17483822996592882


In [None]:
single_lstm_distill_cos_similarity = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE, 
                                                                   evaluate_similarity=True)

single_lstm_distill_cos_similarity.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/LSTM/best-distill-bert.pt'))


distill_embedder = single_lstm_distill_cos_similarity.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.6773523045726031 ms

Accuracy: 92.9
Positives Recall: 93.6
Positives Precision: 91.3
Positives F1: 92.4
Distance:  0.38
Max cluster size:  303
Median cluster size:  2
Avg cluster size: 5.07


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.6771633880978406 ms

Accuracy: 93.0
Positives Recall: 94.3
Positives Precision: 90.9
Positives F1: 92.6
Distance:  0.38
Max cluster size:  172
Median cluster size:  2.0
Avg cluster size: 4.84


####Experiments: Bidirectional LSTM - MSE/Cosine similarity loss

In [None]:
lstm_bidir_distill = SingleClusteringDistillModel(VOCAB_SIZE, bidirectional=True)

train(lstm_bidir_distill, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_MSE/LSTM_Bidir',
      )

Train: 100%|██████████| 4037/4037 [08:01<00:00,  8.38it/s]


Train_loss:  0.0029044680659914587


Val: 100%|██████████| 1010/1010 [02:31<00:00,  6.69it/s]


Val_loss:  0.0024208000199500436


Train: 100%|██████████| 4037/4037 [10:47<00:00,  6.23it/s]


Train_loss:  0.0024047072722253814


Val: 100%|██████████| 1010/1010 [02:34<00:00,  6.53it/s]


Val_loss:  0.002148986696311743


Train: 100%|██████████| 4037/4037 [08:36<00:00,  7.81it/s]


Train_loss:  0.002203877933440116


Val: 100%|██████████| 1010/1010 [02:47<00:00,  6.02it/s]


Val_loss:  0.0019867650510992227


Train: 100%|██████████| 4037/4037 [08:38<00:00,  7.78it/s]


Train_loss:  0.0020753843314758372


Val: 100%|██████████| 1010/1010 [02:46<00:00,  6.06it/s]


Val_loss:  0.0018710513566943384


Train: 100%|██████████| 4037/4037 [08:34<00:00,  7.85it/s]


Train_loss:  0.0019864099762793045


Val: 100%|██████████| 1010/1010 [02:33<00:00,  6.57it/s]


Val_loss:  0.0017966259281280092


Train: 100%|██████████| 4037/4037 [08:32<00:00,  7.88it/s]


Train_loss:  0.0019179339943846316


Val: 100%|██████████| 1010/1010 [02:45<00:00,  6.12it/s]


Val_loss:  0.0017357560909712817


Train: 100%|██████████| 4037/4037 [10:30<00:00,  6.40it/s]


Train_loss:  0.0018642295679321583


Val: 100%|██████████| 1010/1010 [02:55<00:00,  5.75it/s]


Val_loss:  0.0016879141456355331


Train: 100%|██████████| 4037/4037 [10:36<00:00,  6.34it/s]


Train_loss:  0.0018184465222194023


Val: 100%|██████████| 1010/1010 [02:37<00:00,  6.42it/s]


Val_loss:  0.001643807826319359


Train: 100%|██████████| 4037/4037 [10:38<00:00,  6.32it/s]


Train_loss:  0.0017536087083896256


Val: 100%|██████████| 1010/1010 [02:39<00:00,  6.35it/s]


Val_loss:  0.0015607438769449692


Train: 100%|██████████| 4037/4037 [08:33<00:00,  7.86it/s]


Train_loss:  0.0017073920178523572


Val: 100%|██████████| 1010/1010 [02:38<00:00,  6.37it/s]


Val_loss:  0.0015306003782483242


Train: 100%|██████████| 4037/4037 [09:44<00:00,  6.90it/s]


Train_loss:  0.0016741121305115401


Val: 100%|██████████| 1010/1010 [02:41<00:00,  6.24it/s]


Val_loss:  0.0015001914015565399


Train: 100%|██████████| 4037/4037 [10:59<00:00,  6.12it/s]


Train_loss:  0.001646139700512928


Val: 100%|██████████| 1010/1010 [02:38<00:00,  6.38it/s]


Val_loss:  0.0014672393055337638


Train: 100%|██████████| 4037/4037 [08:48<00:00,  7.65it/s]


Train_loss:  0.0016218937179585838


Val: 100%|██████████| 1010/1010 [02:40<00:00,  6.28it/s]


Val_loss:  0.0014582526520825923


Train: 100%|██████████| 4037/4037 [10:56<00:00,  6.15it/s]


Train_loss:  0.001600004465952248


Val: 100%|██████████| 1010/1010 [02:41<00:00,  6.26it/s]


Val_loss:  0.001427926728847704


Train: 100%|██████████| 4037/4037 [08:52<00:00,  7.58it/s]


Train_loss:  0.0015806713088613732


Val: 100%|██████████| 1010/1010 [02:40<00:00,  6.31it/s]


Val_loss:  0.001421109050406272


Train: 100%|██████████| 4037/4037 [11:00<00:00,  6.12it/s]


Train_loss:  0.001563026460029733


Val: 100%|██████████| 1010/1010 [02:50<00:00,  5.93it/s]


Val_loss:  0.0013951172821666344


Train: 100%|██████████| 4037/4037 [08:54<00:00,  7.55it/s]


Train_loss:  0.001547501332233601


Val: 100%|██████████| 1010/1010 [02:41<00:00,  6.24it/s]


Val_loss:  0.0013907142189829287


Train: 100%|██████████| 4037/4037 [08:41<00:00,  7.74it/s]


Train_loss:  0.0015325928962212042


Val: 100%|██████████| 1010/1010 [02:40<00:00,  6.28it/s]


Val_loss:  0.0013764732676364556


Train: 100%|██████████| 4037/4037 [08:36<00:00,  7.82it/s]


Train_loss:  0.0015189695797891997


Val: 100%|██████████| 1010/1010 [02:40<00:00,  6.31it/s]


Val_loss:  0.001363272115814103


Train: 100%|██████████| 4037/4037 [08:29<00:00,  7.93it/s]


Train_loss:  0.0015060103383396954


Val: 100%|██████████| 1010/1010 [02:56<00:00,  5.72it/s]


Val_loss:  0.0013541400570901904


In [None]:
lstm_bidir_distill = SingleClusteringDistillModel(VOCAB_SIZE, bidirectional=True)

lstm_bidir_distill.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_MSE/LSTM_Bidir/best-distill-bert.pt'))

distill_embedder = lstm_bidir_distill.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.817096007012751 ms

Accuracy: 92.8
Positives Recall: 93.8
Positives Precision: 90.9
Positives F1: 92.3
Distance:  0.38
Max cluster size:  290
Median cluster size:  2
Avg cluster size: 4.82


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.8396399011659882 ms

Accuracy: 92.8
Positives Recall: 94.5
Positives Precision: 90.3
Positives F1: 92.3
Distance:  0.38
Max cluster size:  152
Median cluster size:  2
Avg cluster size: 4.65


In [None]:
bidir_lstm_distill_cos_similarity = SingleClusteringDistillModel(VOCAB_SIZE, 
                                                                bidirectional=True,
                                                                evaluate_similarity=True)
train(bidir_lstm_distill_cos_similarity, train_loader, val_loader,
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM',
      )

Train: 100%|██████████| 4037/4037 [07:45<00:00,  8.68it/s]


Train_loss:  0.4637804734371837


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.57it/s]


Val_loss:  0.33067102834380235


Train: 100%|██████████| 4037/4037 [07:42<00:00,  8.73it/s]


Train_loss:  0.32185789124916286


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.52it/s]


Val_loss:  0.26585640882233896


Train: 100%|██████████| 4037/4037 [07:40<00:00,  8.77it/s]


Train_loss:  0.2783102236995286


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.55it/s]


Val_loss:  0.23425673615832263


Train: 100%|██████████| 4037/4037 [07:36<00:00,  8.84it/s]


Train_loss:  0.2542259164804918


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.62it/s]


Val_loss:  0.21534837380596447


Train: 100%|██████████| 4037/4037 [07:33<00:00,  8.89it/s]


Train_loss:  0.23865454100467307


Val: 100%|██████████| 1010/1010 [01:44<00:00,  9.67it/s]


Val_loss:  0.2027467540547215


Train: 100%|██████████| 4037/4037 [07:31<00:00,  8.93it/s]


Train_loss:  0.22761934884764426


Val: 100%|██████████| 1010/1010 [01:44<00:00,  9.70it/s]


Val_loss:  0.19396071858984834


Train: 100%|██████████| 4037/4037 [07:34<00:00,  8.88it/s]


Train_loss:  0.21916495055620167


Val: 100%|██████████| 1010/1010 [01:44<00:00,  9.66it/s]


Val_loss:  0.18672261385402542


Train: 100%|██████████| 4037/4037 [07:36<00:00,  8.84it/s]


Train_loss:  0.2125290342951103


Val: 100%|██████████| 1010/1010 [01:44<00:00,  9.67it/s]


Val_loss:  0.18146128074070839


Train: 100%|██████████| 4037/4037 [07:32<00:00,  8.92it/s]


Train_loss:  0.20704095509876613


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.73it/s]


Val_loss:  0.17680839611432206


Train: 100%|██████████| 4037/4037 [07:28<00:00,  9.01it/s]


Train_loss:  0.2024875304055312


Val: 100%|██████████| 1010/1010 [01:42<00:00,  9.84it/s]


Val_loss:  0.1733152153129531


Train: 100%|██████████| 4037/4037 [07:27<00:00,  9.02it/s]


Train_loss:  0.19852324855293746


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.78it/s]


Val_loss:  0.1706890874277928


Train: 100%|██████████| 4037/4037 [07:30<00:00,  8.95it/s]


Train_loss:  0.1951325079821026


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.80it/s]


Val_loss:  0.16757056356778036


Train: 100%|██████████| 4037/4037 [07:28<00:00,  9.01it/s]


Train_loss:  0.19220126873544294


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.72it/s]


Val_loss:  0.1653726849328627


Train: 100%|██████████| 4037/4037 [07:29<00:00,  8.99it/s]


Train_loss:  0.18959326763737153


Val: 100%|██████████| 1010/1010 [01:44<00:00,  9.66it/s]


Val_loss:  0.1638472149976639


Train: 100%|██████████| 4037/4037 [07:32<00:00,  8.91it/s]


Train_loss:  0.18727357668217826


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.76it/s]


Val_loss:  0.16196194862547697


Train: 100%|██████████| 4037/4037 [07:27<00:00,  9.01it/s]


Train_loss:  0.18515868208164482


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.74it/s]


Val_loss:  0.1604855361372907


Train: 100%|██████████| 4037/4037 [07:31<00:00,  8.93it/s]


Train_loss:  0.1833239074545206


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.80it/s]


Val_loss:  0.1593470567965333


Train: 100%|██████████| 4037/4037 [07:29<00:00,  8.97it/s]


Train_loss:  0.18161502937118998


Val: 100%|██████████| 1010/1010 [01:42<00:00,  9.81it/s]


Val_loss:  0.158295321546158


Train: 100%|██████████| 4037/4037 [07:29<00:00,  8.98it/s]


Train_loss:  0.17995931175752833


Val: 100%|██████████| 1010/1010 [01:42<00:00,  9.86it/s]


Val_loss:  0.15720947974293886


Train: 100%|██████████| 4037/4037 [07:27<00:00,  9.02it/s]


Train_loss:  0.17848591986685122


Val: 100%|██████████| 1010/1010 [01:44<00:00,  9.69it/s]


Val_loss:  0.15634259718434693


In [None]:
bidir_lstm_distill_cos_similarity = SingleClusteringDistillModel(VOCAB_SIZE, 
                                                                bidirectional=True,
                                                                evaluate_similarity=True)

bidir_lstm_distill_cos_similarity.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM/best-distill-bert.pt'))


distill_embedder = bidir_lstm_distill_cos_similarity.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.8570194378079646 ms

Accuracy: 93.5
Positives Recall: 94.6
Positives Precision: 91.6
Positives F1: 93.0
Distance:  0.38
Max cluster size:  294
Median cluster size:  2
Avg cluster size: 4.58


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.8509773509393539 ms

Accuracy: 93.1
Positives Recall: 94.7
Positives Precision: 90.7
Positives F1: 92.6
Distance:  0.38
Max cluster size:  158
Median cluster size:  2.0
Avg cluster size: 4.40


####Experiments: Bidirectional LSTM - MSE/Cosine similarity loss - pretrained



In [None]:
pretrained_embs = initial_model.embedder.model.embeddings.word_embeddings.weight.cuda()

lstm_bidir_distill_from_pretrained = SingleClusteringDistillModel(VOCAB_SIZE, bidirectional=True,
                                                                   pretrained=True, 
                                                                   pretrained_embs=pretrained_embs)

train(lstm_bidir_distill_from_pretrained, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_MSE/LSTM_Bidir_from_pretrained',
      )

Train: 100%|██████████| 4037/4037 [10:56<00:00,  6.15it/s]


Train_loss:  0.001778852204352263


Val: 100%|██████████| 1010/1010 [01:53<00:00,  8.89it/s]


Val_loss:  0.0011222380471627902


Train: 100%|██████████| 4037/4037 [10:54<00:00,  6.17it/s]


Train_loss:  0.0012732777196165577


Val: 100%|██████████| 1010/1010 [01:54<00:00,  8.82it/s]


Val_loss:  0.0009835452640124846


Train: 100%|██████████| 4037/4037 [10:56<00:00,  6.15it/s]


Train_loss:  0.0011859187153672095


Val: 100%|██████████| 1010/1010 [01:54<00:00,  8.83it/s]


Val_loss:  0.0009246947135462766


Train: 100%|██████████| 4037/4037 [10:54<00:00,  6.17it/s]


Train_loss:  0.0011399907046356413


Val: 100%|██████████| 1010/1010 [01:53<00:00,  8.87it/s]


Val_loss:  0.0009076381978200804


Train: 100%|██████████| 4037/4037 [10:51<00:00,  6.19it/s]


Train_loss:  0.0011088189406016856


Val: 100%|██████████| 1010/1010 [01:51<00:00,  9.05it/s]


Val_loss:  0.0008767058781237516


Train: 100%|██████████| 4037/4037 [10:49<00:00,  6.22it/s]


Train_loss:  0.0010846113673752048


Val: 100%|██████████| 1010/1010 [01:52<00:00,  8.98it/s]


Val_loss:  0.0008666753343601555


Train: 100%|██████████| 4037/4037 [10:52<00:00,  6.19it/s]


Train_loss:  0.0010648178859921354


Val: 100%|██████████| 1010/1010 [01:52<00:00,  8.99it/s]


Val_loss:  0.0008574467028449705


Train: 100%|██████████| 4037/4037 [11:02<00:00,  6.09it/s]


Train_loss:  0.0010478140183941087


Val: 100%|██████████| 1010/1010 [01:53<00:00,  8.93it/s]


Val_loss:  0.0008557114270077742


Train: 100%|██████████| 4037/4037 [11:13<00:00,  5.99it/s]


Train_loss:  0.0010332070716648901


Val: 100%|██████████| 1010/1010 [01:54<00:00,  8.85it/s]


Val_loss:  0.000846567516454892


Train: 100%|██████████| 4037/4037 [10:53<00:00,  6.17it/s]


Train_loss:  0.001019853153437305


Val: 100%|██████████| 1010/1010 [01:52<00:00,  8.98it/s]


Val_loss:  0.0008402665101895385


Train: 100%|██████████| 4037/4037 [10:39<00:00,  6.32it/s]


Train_loss:  0.0010074032710796854


Val: 100%|██████████| 1010/1010 [01:51<00:00,  9.02it/s]


Val_loss:  0.0008369586603960373


Train: 100%|██████████| 4037/4037 [10:27<00:00,  6.43it/s]


Train_loss:  0.0009964170082874138


Val: 100%|██████████| 1010/1010 [01:49<00:00,  9.19it/s]


Val_loss:  0.0008320947264731484


Train: 100%|██████████| 4037/4037 [10:14<00:00,  6.57it/s]


Train_loss:  0.0009860206612790452


Val: 100%|██████████| 1010/1010 [01:49<00:00,  9.19it/s]


Val_loss:  0.0008319289170523718


Train: 100%|██████████| 4037/4037 [10:02<00:00,  6.70it/s]


Train_loss:  0.0009767724909783442


Val: 100%|██████████| 1010/1010 [01:49<00:00,  9.21it/s]


Val_loss:  0.0008317776831756761


Train: 100%|██████████| 4037/4037 [09:55<00:00,  6.78it/s]


Train_loss:  0.0009673222578220635


Val: 100%|██████████| 1010/1010 [01:49<00:00,  9.27it/s]


Val_loss:  0.0008202451892141806


Train: 100%|██████████| 4037/4037 [09:49<00:00,  6.85it/s]


Train_loss:  0.0009589937371602072


Val: 100%|██████████| 1010/1010 [01:49<00:00,  9.25it/s]


Val_loss:  0.0008267547965215722


Train: 100%|██████████| 4037/4037 [09:48<00:00,  6.87it/s]


Train_loss:  0.0009516184790355677


Val: 100%|██████████| 1010/1010 [01:49<00:00,  9.26it/s]


Val_loss:  0.0008293005286849629


Train: 100%|██████████| 4037/4037 [09:45<00:00,  6.90it/s]


Train_loss:  0.0009438869977833143


Val: 100%|██████████| 1010/1010 [01:49<00:00,  9.24it/s]


Val_loss:  0.0008236844883309586


In [None]:
pretrained_embs = initial_model.embedder.model.embeddings.word_embeddings.weight.cuda()

lstm_bidir_distill_from_pretrained = SingleClusteringDistillModel(VOCAB_SIZE, bidirectional=True,
                                                                   pretrained=True, 
                                                                   pretrained_embs=pretrained_embs)

lstm_bidir_distill_from_pretrained.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_MSE/LSTM_Bidir_from_pretrained/best-distill-bert.pt'))

distill_embedder = lstm_bidir_distill_from_pretrained.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.9751229351497741 ms

Accuracy: 94.4
Positives Recall: 95.7
Positives Precision: 92.4
Positives F1: 94.0
Distance:  0.38
Max cluster size:  273
Median cluster size:  2.0
Avg cluster size: 4.06


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.9897823068053841 ms

Accuracy: 93.9
Positives Recall: 96.0
Positives Precision: 91.2
Positives F1: 93.5
Distance:  0.38
Max cluster size:  172
Median cluster size:  2.0
Avg cluster size: 3.89


In [None]:
pretrained_embs = initial_model.embedder.model.embeddings.word_embeddings.weight.cuda()

bidir_lstm_distill_cos_similarity_from_pretrained = SingleClusteringDistillModel(VOCAB_SIZE, 
                                                                bidirectional=True,
                                                                evaluate_similarity=True,
                                                                pretrained=True, 
                                                                pretrained_embs=pretrained_embs)
train(bidir_lstm_distill_cos_similarity_from_pretrained, train_loader, val_loader,
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_from_pretrained',
      )

Train: 100%|██████████| 4037/4037 [10:12<00:00,  6.59it/s]


Train_loss:  0.1826794395580551


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.56it/s]


Val_loss:  0.11902867370785407


Train: 100%|██████████| 4037/4037 [10:02<00:00,  6.70it/s]


Train_loss:  0.1527713813285293


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.50it/s]


Val_loss:  0.11393756367119474


Train: 100%|██████████| 4037/4037 [10:00<00:00,  6.72it/s]


Train_loss:  0.1472162621082863


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.42it/s]


Val_loss:  0.11041433390183636


Train: 100%|██████████| 4037/4037 [09:59<00:00,  6.74it/s]


Train_loss:  0.14314149446727364


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.44it/s]


Val_loss:  0.10863963128370237


Train: 100%|██████████| 4037/4037 [09:59<00:00,  6.73it/s]


Train_loss:  0.1397748488364587


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.42it/s]


Val_loss:  0.10697015760533285


Train: 100%|██████████| 4037/4037 [10:00<00:00,  6.73it/s]


Train_loss:  0.1368472192679168


Val: 100%|██████████| 1010/1010 [01:48<00:00,  9.33it/s]


Val_loss:  0.10630127876057809


Train: 100%|██████████| 4037/4037 [10:00<00:00,  6.72it/s]


Train_loss:  0.134191687612244


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.45it/s]


Val_loss:  0.10562707189161644


Train: 100%|██████████| 4037/4037 [09:59<00:00,  6.74it/s]


Train_loss:  0.1318788143313718


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.43it/s]


Val_loss:  0.10481807846311263


Train: 100%|██████████| 4037/4037 [09:56<00:00,  6.77it/s]


Train_loss:  0.12979819548728228


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.57it/s]


Val_loss:  0.10449562656458962


Train: 100%|██████████| 4037/4037 [09:57<00:00,  6.76it/s]


Train_loss:  0.12787438301672835


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.53it/s]


Val_loss:  0.10438572529319802


Train: 100%|██████████| 4037/4037 [09:57<00:00,  6.76it/s]


Train_loss:  0.12622976023167273


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.45it/s]


Val_loss:  0.10412078870552287


Train: 100%|██████████| 4037/4037 [09:58<00:00,  6.75it/s]


Train_loss:  0.12477367972773813


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.49it/s]


Val_loss:  0.10391020247855018


Train: 100%|██████████| 4037/4037 [09:57<00:00,  6.76it/s]


Train_loss:  0.1234019615291798


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.42it/s]


Val_loss:  0.10397187887408604


Train: 100%|██████████| 4037/4037 [09:55<00:00,  6.77it/s]


Train_loss:  0.1222300349407714


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.45it/s]


Val_loss:  0.1034272361012871


Train: 100%|██████████| 4037/4037 [09:57<00:00,  6.76it/s]


Train_loss:  0.12113407836078384


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.42it/s]


Val_loss:  0.10359402579865638


Train: 100%|██████████| 4037/4037 [09:56<00:00,  6.77it/s]


Train_loss:  0.1201390748642491


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.41it/s]


Val_loss:  0.10358120753864213


Train: 100%|██████████| 4037/4037 [09:57<00:00,  6.75it/s]


Train_loss:  0.1192338986910339


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.38it/s]


Val_loss:  0.10332769477743844


Train: 100%|██████████| 4037/4037 [09:55<00:00,  6.78it/s]


Train_loss:  0.11840216226939446


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.58it/s]


Val_loss:  0.10361396340442294


Train: 100%|██████████| 4037/4037 [09:52<00:00,  6.81it/s]


Train_loss:  0.117680360183134


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.42it/s]


Val_loss:  0.10341306113041898


Train: 100%|██████████| 4037/4037 [09:49<00:00,  6.85it/s]


Train_loss:  0.11701886377541372


Val: 100%|██████████| 1010/1010 [01:44<00:00,  9.66it/s]


Val_loss:  0.10356812516954504


In [None]:
bidir_lstm_distill_cos_similarity_from_pretrained = SingleClusteringDistillModel(VOCAB_SIZE, 
                                                                bidirectional=True,
                                                                evaluate_similarity=True,
                                                                pretrained=True, 
                                                                pretrained_embs=pretrained_embs)

bidir_lstm_distill_cos_similarity_from_pretrained.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_from_pretrained/best-distill-bert.pt'))

distill_embedder = bidir_lstm_distill_cos_similarity_from_pretrained.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.8325269604821605 ms

Accuracy: 94.4
Positives Recall: 95.6
Positives Precision: 92.5
Positives F1: 94.0
Distance:  0.38
Max cluster size:  270
Median cluster size:  2.0
Avg cluster size: 4.00


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.8566778560637629 ms

Accuracy: 93.4
Positives Recall: 95.2
Positives Precision: 90.9
Positives F1: 93.0
Distance:  0.38
Max cluster size:  184
Median cluster size:  2.0
Avg cluster size: 3.88


####Experiments: Bidirectional LSTM - Cosine similarity loss - pretrained alike



It's reasonable to check if pretrained embeds are really impacting the score or the key for a better result is the size of initial word embeddings.

In [None]:
#Checking random initialization similar to pretrained 

pretrained_embs = initial_model.embedder.model.embeddings.word_embeddings.weight.cuda()
word_embs_dim = pretrained_embs.shape[1] 

bidir_lstm_distill_from_pretrained_alike = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, word_emb_dim=word_embs_dim,
                              evaluate_similarity=True)


train(bidir_lstm_distill_from_pretrained_alike, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_from_pretrained_alike',
      )

Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.70it/s]


Train_loss:  0.3719831273995782


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.18it/s]


Val_loss:  0.24544119784002452


Train: 100%|██████████| 4037/4037 [07:40<00:00,  8.76it/s]


Train_loss:  0.25891627649325655


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.17it/s]


Val_loss:  0.20469899194080052


Train: 100%|██████████| 4037/4037 [07:39<00:00,  8.79it/s]


Train_loss:  0.23017586925417569


Val: 100%|██████████| 1010/1010 [01:40<00:00, 10.09it/s]


Val_loss:  0.18648389618722686


Train: 100%|██████████| 4037/4037 [07:41<00:00,  8.75it/s]


Train_loss:  0.21460522371253135


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.11it/s]


Val_loss:  0.1756721278740353


Train: 100%|██████████| 4037/4037 [07:39<00:00,  8.78it/s]


Train_loss:  0.20432189057653424


Val: 100%|██████████| 1010/1010 [01:40<00:00, 10.05it/s]


Val_loss:  0.1691484872263974


Train: 100%|██████████| 4037/4037 [07:42<00:00,  8.72it/s]


Train_loss:  0.1969731838552864


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.14it/s]


Val_loss:  0.16395026001432114


Train: 100%|██████████| 4037/4037 [07:45<00:00,  8.68it/s]


Train_loss:  0.1911947357206456


Val: 100%|██████████| 1010/1010 [01:40<00:00, 10.07it/s]


Val_loss:  0.16069650235687094


Train: 100%|██████████| 4037/4037 [07:45<00:00,  8.68it/s]


Train_loss:  0.18663194062002622


Val: 100%|██████████| 1010/1010 [01:40<00:00, 10.07it/s]


Val_loss:  0.1579016350238686


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.71it/s]


Train_loss:  0.1827612666597774


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.26it/s]


Val_loss:  0.15583709486616681


Train: 100%|██████████| 4037/4037 [07:40<00:00,  8.76it/s]


Train_loss:  0.17955780358490445


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.31it/s]


Val_loss:  0.15432779869705637


Train: 100%|██████████| 4037/4037 [07:39<00:00,  8.78it/s]


Train_loss:  0.176737694836007


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.23it/s]


Val_loss:  0.1529204205678138


Train: 100%|██████████| 4037/4037 [07:40<00:00,  8.76it/s]


Train_loss:  0.17430233058579647


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.26it/s]


Val_loss:  0.15173826976610622


Train: 100%|██████████| 4037/4037 [07:41<00:00,  8.75it/s]


Train_loss:  0.17212107164740123


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.24it/s]


Val_loss:  0.15063455594847883


Train: 100%|██████████| 4037/4037 [07:40<00:00,  8.76it/s]


Train_loss:  0.17016789962276013


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.32it/s]


Val_loss:  0.1499963960540508


Train: 100%|██████████| 4037/4037 [07:38<00:00,  8.81it/s]


Train_loss:  0.16838839321901172


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.27it/s]


Val_loss:  0.1493186958682761


Train: 100%|██████████| 4037/4037 [07:40<00:00,  8.77it/s]


Train_loss:  0.16676777125372372


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.14it/s]


Val_loss:  0.14877575942558693


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.70it/s]


Train_loss:  0.16529444402589313


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.13it/s]


Val_loss:  0.14816394570898492


Train: 100%|██████████| 4037/4037 [07:41<00:00,  8.75it/s]


Train_loss:  0.16398804171186007


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.12it/s]


Val_loss:  0.14790522805252826


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.70it/s]


Train_loss:  0.1627089986486056


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.15it/s]


Val_loss:  0.14750357786226623


Train: 100%|██████████| 4037/4037 [07:45<00:00,  8.68it/s]


Train_loss:  0.16164504280356584


Val: 100%|██████████| 1010/1010 [01:41<00:00,  9.98it/s]


Val_loss:  0.14721872530049043


In [None]:
word_embs_dim = pretrained_embs.shape[1] 

bidir_lstm_distill_from_pretrained_alike = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, word_emb_dim=word_embs_dim,
                              evaluate_similarity=True)

bidir_lstm_distill_from_pretrained_alike.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_from_pretrained_alike/best-distill-bert.pt'))

distill_embedder = bidir_lstm_distill_from_pretrained_alike.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.9082841508179547 ms

Accuracy: 93.2
Positives Recall: 94.2
Positives Precision: 91.4
Positives F1: 92.8
Distance:  0.38
Max cluster size:  310
Median cluster size:  2.0
Avg cluster size: 4.30


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.8422611976548824 ms

Accuracy: 93.0
Positives Recall: 94.6
Positives Precision: 90.6
Positives F1: 92.6
Distance:  0.38
Max cluster size:  200
Median cluster size:  2
Avg cluster size: 4.17


So, it turns out that pretrained embeddings really improve performance, but models based on pretrained embeds are much heavier. 

In general BiLSTM with Cosine embedding loss work mostly better.

####Experiments: Bidirectional LSTM - Cosine similarity loss - variations

Some more architecture and parameters experiments with a bidirectional LSTM.

#####Adding ReLU between linear layers

In [None]:
bidir_lstm_cosine_similarity_distill_relu = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, add_relu=True, evaluate_similarity=True)


train(bidir_lstm_cosine_similarity_distill_relu, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_relu',
      )

Train: 100%|██████████| 4037/4037 [08:43<00:00,  7.71it/s]


Train_loss:  0.4517234906723245


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.36it/s]


Val_loss:  0.3223376839308669


Train: 100%|██████████| 4037/4037 [08:17<00:00,  8.12it/s]


Train_loss:  0.3076155859579914


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.43it/s]


Val_loss:  0.2632161025910244


Train: 100%|██████████| 4037/4037 [08:09<00:00,  8.25it/s]


Train_loss:  0.26570776453842343


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.43it/s]


Val_loss:  0.23415457604780684


Train: 100%|██████████| 4037/4037 [08:06<00:00,  8.30it/s]


Train_loss:  0.24220915193607193


Val: 100%|██████████| 1010/1010 [01:47<00:00,  9.43it/s]


Val_loss:  0.2163775980216282


Train: 100%|██████████| 4037/4037 [08:06<00:00,  8.30it/s]


Train_loss:  0.22650387634638428


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.45it/s]


Val_loss:  0.20375420750362067


Train: 100%|██████████| 4037/4037 [08:03<00:00,  8.35it/s]


Train_loss:  0.2151346438783935


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.50it/s]


Val_loss:  0.19477396325094729


Train: 100%|██████████| 4037/4037 [08:03<00:00,  8.36it/s]


Train_loss:  0.20640928398093036


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.53it/s]


Val_loss:  0.18774912783668524


Train: 100%|██████████| 4037/4037 [08:02<00:00,  8.37it/s]


Train_loss:  0.19939407099629108


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.52it/s]


Val_loss:  0.18211991240235406


Train: 100%|██████████| 4037/4037 [08:01<00:00,  8.38it/s]


Train_loss:  0.19354350595773107


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.51it/s]


Val_loss:  0.17752028405824988


Train: 100%|██████████| 4037/4037 [08:00<00:00,  8.40it/s]


Train_loss:  0.18865979439486


Val: 100%|██████████| 1010/1010 [01:46<00:00,  9.49it/s]


Val_loss:  0.17407652379037594


Train: 100%|██████████| 4037/4037 [07:59<00:00,  8.41it/s]


Train_loss:  0.18441491027667614


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.54it/s]


Val_loss:  0.17083447079236458


Train: 100%|██████████| 4037/4037 [07:58<00:00,  8.44it/s]


Train_loss:  0.18076119993754533


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.54it/s]


Val_loss:  0.16809405026824933


Train: 100%|██████████| 4037/4037 [07:55<00:00,  8.49it/s]


Train_loss:  0.17752971222718686


Val: 100%|██████████| 1010/1010 [01:45<00:00,  9.58it/s]


Val_loss:  0.16594340067368077


Train: 100%|██████████| 4037/4037 [07:50<00:00,  8.59it/s]


Train_loss:  0.17474641572628957


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.75it/s]


Val_loss:  0.16380922076428775


Train: 100%|██████████| 4037/4037 [07:48<00:00,  8.62it/s]


Train_loss:  0.17224086510345957


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.76it/s]


Val_loss:  0.1621334323662484


Train: 100%|██████████| 4037/4037 [07:46<00:00,  8.66it/s]


Train_loss:  0.16994623827079308


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.79it/s]


Val_loss:  0.16066152309332457


Train: 100%|██████████| 4037/4037 [07:45<00:00,  8.68it/s]


Train_loss:  0.16782527530801963


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.71it/s]


Val_loss:  0.15940403209341042


Train: 100%|██████████| 4037/4037 [07:45<00:00,  8.66it/s]


Train_loss:  0.16589387954033435


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.74it/s]


Val_loss:  0.15819627901537384


Train: 100%|██████████| 4037/4037 [07:44<00:00,  8.69it/s]


Train_loss:  0.1641525467899179


Val: 100%|██████████| 1010/1010 [01:44<00:00,  9.68it/s]


Val_loss:  0.15698775316053482


Train: 100%|██████████| 4037/4037 [07:43<00:00,  8.70it/s]


Train_loss:  0.16261611715219093


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.79it/s]


Val_loss:  0.15645668499204396


In [None]:
bidir_lstm_cosine_similarity_distill_relu = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, add_relu=True, evaluate_similarity=True)

bidir_lstm_cosine_similarity_distill_relu.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_relu/best-distill-bert.pt'))

distill_embedder = bidir_lstm_cosine_similarity_distill_relu.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, 
                                              tokenizer, batch_size=8, max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.9087116057416832 ms

Accuracy: 92.3
Positives Recall: 93.2
Positives Precision: 90.3
Positives F1: 91.7
Distance:  0.38
Max cluster size:  318
Median cluster size:  2
Avg cluster size: 5.31


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.9416431484234398 ms

Accuracy: 91.6
Positives Recall: 92.9
Positives Precision: 89.2
Positives F1: 91.0
Distance:  0.38
Max cluster size:  159
Median cluster size:  2
Avg cluster size: 5.16


#####Zero dropout

In [None]:
bidir_lstm_distill_dropout0 = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                               bidirectional=True, dropout=0, evaluate_similarity=True,)


train(bidir_lstm_distill_dropout0, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_dropout0',
      )

Train: 100%|██████████| 4037/4037 [06:32<00:00, 10.27it/s]


Train_loss:  0.4113105625619639


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.41it/s]


Val_loss:  0.2967117350732604


Train: 100%|██████████| 4037/4037 [06:34<00:00, 10.23it/s]


Train_loss:  0.2550703342139801


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.61it/s]


Val_loss:  0.23309444323759174


Train: 100%|██████████| 4037/4037 [06:34<00:00, 10.23it/s]


Train_loss:  0.20767444026109244


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.66it/s]


Val_loss:  0.20259263971179448


Train: 100%|██████████| 4037/4037 [06:33<00:00, 10.26it/s]


Train_loss:  0.1819322688625385


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.65it/s]


Val_loss:  0.1844931159285077


Train: 100%|██████████| 4037/4037 [06:33<00:00, 10.26it/s]


Train_loss:  0.16560587526422135


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.51it/s]


Val_loss:  0.17350130612164757


Train: 100%|██████████| 4037/4037 [06:30<00:00, 10.33it/s]


Train_loss:  0.15401193307599423


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.55it/s]


Val_loss:  0.16436874796405934


Train: 100%|██████████| 4037/4037 [06:31<00:00, 10.31it/s]


Train_loss:  0.14529328752188844


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.60it/s]


Val_loss:  0.15803357940541699


Train: 100%|██████████| 4037/4037 [06:33<00:00, 10.25it/s]


Train_loss:  0.13842412797396528


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.54it/s]


Val_loss:  0.1529582657395846


Train: 100%|██████████| 4037/4037 [06:31<00:00, 10.32it/s]


Train_loss:  0.13301484165863678


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.58it/s]


Val_loss:  0.1492617793157349


Train: 100%|██████████| 4037/4037 [06:28<00:00, 10.40it/s]


Train_loss:  0.1285881912538699


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.63it/s]


Val_loss:  0.14638375360935219


Train: 100%|██████████| 4037/4037 [06:30<00:00, 10.33it/s]


Train_loss:  0.12485996573660191


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.58it/s]


Val_loss:  0.1436622317225124


Train: 100%|██████████| 4037/4037 [06:30<00:00, 10.34it/s]


Train_loss:  0.1216670235672424


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.57it/s]


Val_loss:  0.14174639046334261


Train: 100%|██████████| 4037/4037 [06:32<00:00, 10.29it/s]


Train_loss:  0.11889373153976761


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.44it/s]


Val_loss:  0.1395885048471964


Train: 100%|██████████| 4037/4037 [06:40<00:00, 10.09it/s]


Train_loss:  0.11644654005340414


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.35it/s]


Val_loss:  0.13814678522198706


Train: 100%|██████████| 4037/4037 [06:38<00:00, 10.14it/s]


Train_loss:  0.1142538136053754


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.54it/s]


Val_loss:  0.1366670646238617


Train: 100%|██████████| 4037/4037 [06:30<00:00, 10.33it/s]


Train_loss:  0.11227834267255393


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.49it/s]


Val_loss:  0.13571706231188427


Train: 100%|██████████| 4037/4037 [06:32<00:00, 10.28it/s]


Train_loss:  0.11050553145100667


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.54it/s]


Val_loss:  0.13437324023921512


Train: 100%|██████████| 4037/4037 [06:33<00:00, 10.25it/s]


Train_loss:  0.10887835584675896


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.38it/s]


Val_loss:  0.13353337170893775


Train: 100%|██████████| 4037/4037 [06:30<00:00, 10.33it/s]


Train_loss:  0.10737783847060142


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.50it/s]


Val_loss:  0.13266496468600236


Train: 100%|██████████| 4037/4037 [06:29<00:00, 10.36it/s]


Train_loss:  0.10601139769321848


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.62it/s]


Val_loss:  0.13188875619896936


In [None]:
bidir_lstm_distill_dropout0 = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                               bidirectional=True, dropout=0, evaluate_similarity=True,)

bidir_lstm_distill_dropout0.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_dropout0/best-distill-bert.pt'))

distill_embedder = bidir_lstm_distill_dropout0.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, 
                                              tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.7914375324648215 ms

Accuracy: 93.8
Positives Recall: 95.0
Positives Precision: 91.9
Positives F1: 93.4
Distance:  0.38
Max cluster size:  297
Median cluster size:  2.0
Avg cluster size: 3.80


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.8030722337534201 ms

Accuracy: 93.5
Positives Recall: 95.6
Positives Precision: 90.7
Positives F1: 93.1
Distance:  0.38
Max cluster size:  170
Median cluster size:  2.0
Avg cluster size: 3.69


#####Zero dropout + deleted linear

In [None]:
bidir_lstm_distill_dropout0_del_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              dropout=0, evaluate_similarity=True,
                              del_linear=True, bidirectional=True)


train(bidir_lstm_distill_dropout0_del_linear, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_dropout0_del_linear',
      )

In [None]:
bidir_lstm_distill_dropout0_del_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              dropout=0, evaluate_similarity=True,
                              del_linear=True, bidirectional=True)

bidir_lstm_distill_dropout0_del_linear.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_dropout0_del_linear/best-distill-bert.pt'))

distill_embedder = bidir_lstm_distill_dropout0_del_linear.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.7941229934948849 ms

Accuracy: 93.8
Positives Recall: 94.5
Positives Precision: 92.3
Positives F1: 93.4
Distance:  0.38
Max cluster size:  332
Median cluster size:  2.0
Avg cluster size: 3.73


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.7835019862846595 ms

Accuracy: 93.1
Positives Recall: 94.2
Positives Precision: 91.0
Positives F1: 92.6
Distance:  0.38
Max cluster size:  164
Median cluster size:  2
Avg cluster size: 3.62


#####Zero dropout + deleted linear + *attentive aggregation*

In [None]:
bidir_lstm_distill_dropout0_del_linear_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              dropout=0, evaluate_similarity=True,
                              del_linear=True, bidirectional=True, 
                              attentive_aggregation=True)


train(bidir_lstm_distill_dropout0_del_linear_attn, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_dropout0_del_linear_attn',
      )

In [20]:
bidir_lstm_distill_dropout0_del_linear_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              dropout=0, evaluate_similarity=True,
                              del_linear=True, bidirectional=True,
                              attentive_aggregation=True)

bidir_lstm_distill_dropout0_del_linear_attn.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_dropout0_del_linear_attn/best-distill-bert.pt'))

distill_embedder = bidir_lstm_distill_dropout0_del_linear_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.6317020336470281 ms

Accuracy: 94.2
Positives Recall: 95.2
Positives Precision: 92.5
Positives F1: 93.8
Distance:  0.38
Max cluster size:  262
Median cluster size:  2.0
Avg cluster size: 3.86


In [22]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.7009760991369375 ms

Accuracy: 93.7
Positives Recall: 95.2
Positives Precision: 91.6
Positives F1: 93.3
Distance:  0.38
Max cluster size:  171
Median cluster size:  2
Avg cluster size: 3.73


#####Zero dropout / LSTM with no linear layers

In [None]:
bidir_lstm_distill_dropout0_del_2_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, lstm_hidden_dim=(TARGET_SIZE//2),
                              del_2_linear=True, evaluate_similarity=True,
                              )

train(bidir_lstm_distill_dropout0_del_2_linear, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_no_linear_0_dropout',
      )

Train: 100%|██████████| 4037/4037 [06:00<00:00, 11.20it/s]


Train_loss:  0.45590375669868366


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.38it/s]


Val_loss:  0.33710911709989144


Train: 100%|██████████| 4037/4037 [06:09<00:00, 10.94it/s]


Train_loss:  0.2924229537905287


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.34it/s]


Val_loss:  0.26613424771531174


Train: 100%|██████████| 4037/4037 [06:01<00:00, 11.16it/s]


Train_loss:  0.2405035736663983


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.34it/s]


Val_loss:  0.23197123310051404


Train: 100%|██████████| 4037/4037 [06:02<00:00, 11.14it/s]


Train_loss:  0.21177311861957998


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.32it/s]


Val_loss:  0.2113487277916253


Train: 100%|██████████| 4037/4037 [06:01<00:00, 11.16it/s]


Train_loss:  0.19281081513625617


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.37it/s]


Val_loss:  0.1972250426558463


Train: 100%|██████████| 4037/4037 [06:04<00:00, 11.09it/s]


Train_loss:  0.1790037370326589


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.37it/s]


Val_loss:  0.18669637832459815


Train: 100%|██████████| 4037/4037 [06:01<00:00, 11.15it/s]


Train_loss:  0.16829149331842821


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.38it/s]


Val_loss:  0.17875098524704197


Train: 100%|██████████| 4037/4037 [06:01<00:00, 11.18it/s]


Train_loss:  0.15972095763599092


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.29it/s]


Val_loss:  0.17194433385884206


Train: 100%|██████████| 4037/4037 [05:59<00:00, 11.22it/s]


Train_loss:  0.15272231845794662


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.35it/s]


Val_loss:  0.16669997070798404


Train: 100%|██████████| 4037/4037 [05:59<00:00, 11.22it/s]


Train_loss:  0.14683834009262062


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.31it/s]


Val_loss:  0.16223943645570263


Train: 100%|██████████| 4037/4037 [06:01<00:00, 11.18it/s]


Train_loss:  0.14181420642751613


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.35it/s]


Val_loss:  0.15831507115687782


Train: 100%|██████████| 4037/4037 [06:01<00:00, 11.17it/s]


Train_loss:  0.13747104525898143


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.34it/s]


Val_loss:  0.15525695555853858


Train: 100%|██████████| 4037/4037 [06:02<00:00, 11.14it/s]


Train_loss:  0.1336763405226434


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.21it/s]


Val_loss:  0.15231056693988043


Train: 100%|██████████| 4037/4037 [06:02<00:00, 11.12it/s]


Train_loss:  0.1303342365151977


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.38it/s]


Val_loss:  0.14998569158931593


Train: 100%|██████████| 4037/4037 [06:02<00:00, 11.14it/s]


Train_loss:  0.12736835520017312


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.25it/s]


Val_loss:  0.14786049375075655


Train: 100%|██████████| 4037/4037 [06:04<00:00, 11.07it/s]


Train_loss:  0.12470147423079052


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.29it/s]


Val_loss:  0.14590622477928866


Train: 100%|██████████| 4037/4037 [06:02<00:00, 11.13it/s]


Train_loss:  0.12229023853383543


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.30it/s]


Val_loss:  0.14441491846206178


Train: 100%|██████████| 4037/4037 [06:03<00:00, 11.12it/s]


Train_loss:  0.1201144043005802


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.33it/s]


Val_loss:  0.1427127432191827


Train: 100%|██████████| 4037/4037 [06:03<00:00, 11.11it/s]


Train_loss:  0.11811582606627223


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.26it/s]


Val_loss:  0.1415595237204064


Train: 100%|██████████| 4037/4037 [06:02<00:00, 11.13it/s]


Train_loss:  0.11628820423808249


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.22it/s]


Val_loss:  0.14016308501422575


In [None]:
bidir_lstm_distill_dropout0_del_2_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, lstm_hidden_dim=(TARGET_SIZE//2),
                              del_2_linear=True, evaluate_similarity=True,
                              )

bidir_lstm_distill_dropout0_del_2_linear.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_no_linear_0_dropout/best-distill-bert.pt'))

distill_embedder = bidir_lstm_distill_dropout0_del_2_linear.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.7147467624143775 ms


0.9301095579901776

In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

#####Zero dropout / LSTM with no linear layers + *attentive aggregation*

In [None]:
bidir_lstm_distill_dropout0_del_2_linear_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, lstm_hidden_dim=(TARGET_SIZE//2),
                              del_2_linear=True, evaluate_similarity=True,
                              attentive_aggregation=True,
                              )

train(bidir_lstm_distill_dropout0_del_2_linear_attn, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_no_linear_0_dropout_attn',
      )

In [11]:
bidir_lstm_distill_dropout0_del_2_linear_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, lstm_hidden_dim=(TARGET_SIZE//2),
                              del_2_linear=True, evaluate_similarity=True,
                              attentive_aggregation=True,
                              )

bidir_lstm_distill_dropout0_del_2_linear_attn.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_no_linear_0_dropout_attn/best-distill-bert.pt'))

distill_embedder = bidir_lstm_distill_dropout0_del_2_linear_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.6654116432980237 ms

Accuracy: 93.7
Positives Recall: 94.3
Positives Precision: 92.3
Positives F1: 93.3
Distance:  0.38
Max cluster size:  256
Median cluster size:  2
Avg cluster size: 3.84


In [12]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.6268822015030993 ms

Accuracy: 93.3
Positives Recall: 94.6
Positives Precision: 91.2
Positives F1: 92.9
Distance:  0.38
Max cluster size:  167
Median cluster size:  2
Avg cluster size: 3.70


#####Zero dropout + deleted linear + 2 LSTM layers

In [None]:
bidir_lstm_distill_dropout0_del_2lstm_layers = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, 
                              del_linear=True, evaluate_similarity=True,
                              lstm_layers_count=2,
                              )


train(bidir_lstm_distill_dropout0_del_2lstm_layers, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_2layers',
      )

Train: 100%|██████████| 4037/4037 [07:46<00:00,  8.65it/s]


Train_loss:  0.38224130868686124


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.37it/s]


Val_loss:  0.25747238826928553


Train: 100%|██████████| 4037/4037 [07:48<00:00,  8.62it/s]


Train_loss:  0.21698863135848823


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.52it/s]


Val_loss:  0.1952577345041329


Train: 100%|██████████| 4037/4037 [07:48<00:00,  8.61it/s]


Train_loss:  0.17211497592956562


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.58it/s]


Val_loss:  0.16681014288413168


Train: 100%|██████████| 4037/4037 [07:48<00:00,  8.62it/s]


Train_loss:  0.1492053377539882


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.55it/s]


Val_loss:  0.15183319773539314


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.60it/s]


Train_loss:  0.13552635908901628


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.47it/s]


Val_loss:  0.14246904531322246


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.61it/s]


Train_loss:  0.12629623582368799


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.57it/s]


Val_loss:  0.1362253335488555


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.60it/s]


Train_loss:  0.11953382375058737


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.35it/s]


Val_loss:  0.13161402441171807


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.61it/s]


Train_loss:  0.1143291059130291


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.46it/s]


Val_loss:  0.12848083503030247


Train: 100%|██████████| 4037/4037 [07:47<00:00,  8.64it/s]


Train_loss:  0.11011940583349622


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.53it/s]


Val_loss:  0.1256797033933922


Train: 100%|██████████| 4037/4037 [07:47<00:00,  8.64it/s]


Train_loss:  0.10663875689090035


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.59it/s]


Val_loss:  0.12353672022578073


Train: 100%|██████████| 4037/4037 [07:46<00:00,  8.66it/s]


Train_loss:  0.10366064196546529


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.53it/s]


Val_loss:  0.12179920816737255


Train: 100%|██████████| 4037/4037 [07:45<00:00,  8.67it/s]


Train_loss:  0.10108318755354767


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.59it/s]


Val_loss:  0.12018394019274646


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.60it/s]


Train_loss:  0.0988303810638156


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.67it/s]


Val_loss:  0.1192760787529273


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.59it/s]


Train_loss:  0.09682958986957123


Val: 100%|██████████| 1010/1010 [01:28<00:00, 11.44it/s]


Val_loss:  0.11808535465021586


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.61it/s]


Train_loss:  0.09503121282826225


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.56it/s]


Val_loss:  0.11725067511769449


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.60it/s]


Train_loss:  0.0933852326417232


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.53it/s]


Val_loss:  0.11649468670418976


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.61it/s]


Train_loss:  0.09190560005124754


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.63it/s]


Val_loss:  0.11592054580566406


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.60it/s]


Train_loss:  0.0905174189030002


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.58it/s]


Val_loss:  0.1153023214649577


Train: 100%|██████████| 4037/4037 [07:49<00:00,  8.59it/s]


Train_loss:  0.08927298112770701


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.57it/s]


Val_loss:  0.1147477903493117


Train: 100%|██████████| 4037/4037 [07:50<00:00,  8.58it/s]


Train_loss:  0.08808871829908632


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.55it/s]


Val_loss:  0.11444825672888713


In [None]:
bidir_lstm_distill_dropout0_del_2lstm_layers = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, 
                              del_linear=True, evaluate_similarity=True,
                              lstm_layers_count=2)

bidir_lstm_distill_dropout0_del_2lstm_layers.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_2layers/best-distill-bert.pt'))

distill_embedder = bidir_lstm_distill_dropout0_del_2lstm_layers.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.3904166992442066 ms

Accuracy: 94.2
Positives Recall: 95.5
Positives Precision: 92.3
Positives F1: 93.9
Distance:  0.38
Max cluster size:  293
Median cluster size:  2
Avg cluster size: 4.08


In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.4165139411801217 ms

Accuracy: 93.5
Positives Recall: 95.7
Positives Precision: 90.7
Positives F1: 93.1
Distance:  0.38
Max cluster size:  176
Median cluster size:  2.0
Avg cluster size: 3.94


#####Zero dropout + deleted linear + 2 LSTM layers + *attentive aggregation*

In [None]:
bidir_lstm_distill_dropout0_del_2lstm_layers_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, 
                              del_linear=True, evaluate_similarity=True,
                              lstm_layers_count=2,
                              attentive_aggregation=True,
                              )


train(bidir_lstm_distill_dropout0_del_2lstm_layers_attn, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_2layers_attn',
      )

In [13]:
bidir_lstm_distill_dropout0_del_2lstm_layers_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, 
                              del_linear=True, evaluate_similarity=True,
                              lstm_layers_count=2,
                              attentive_aggregation=True,)

bidir_lstm_distill_dropout0_del_2lstm_layers_attn.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_LSTM_2layers_attn/best-distill-bert.pt'))

distill_embedder = bidir_lstm_distill_dropout0_del_2lstm_layers_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.1977450613006653 ms

Accuracy: 94.3
Positives Recall: 95.3
Positives Precision: 92.5
Positives F1: 93.9
Distance:  0.38
Max cluster size:  283
Median cluster size:  2
Avg cluster size: 4.07


In [14]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.1479616358183296 ms

Accuracy: 93.8
Positives Recall: 95.7
Positives Precision: 91.3
Positives F1: 93.4
Distance:  0.38
Max cluster size:  176
Median cluster size:  2
Avg cluster size: 3.92


##--Comparisons--

###Single loss on the original data

In [58]:
labeled_train_recs_single = dict()

for i in train_records:
  for j in i:
    labeled_train_recs_single[j['url']] = j

labeled_train_recs_single = list(labeled_train_recs_single.values())

In [59]:
labeled_val_recs_single = dict()

for i in val_records:
  for j in i:
    labeled_val_recs_single[j['url']] = j

labeled_val_recs_single = list(labeled_val_recs_single.values())

In [62]:
labeled_val_embeddings_bert = records_to_embeds(labeled_val_recs_single, embedder, 
                                          initial_tokenizer, batch_size=8, 
                                          max_tokens_count=MAX_TOKENS)
labeled_train_embeddings_bert = records_to_embeds(labeled_train_recs_single, embedder, 
                                            initial_tokenizer, batch_size=8, 
                                            max_tokens_count=MAX_TOKENS)

In [63]:
BATCH_SIZE = 32

train_loader_labeled, val_loader_labeled, tokenizer_triplets = get_loaders(
                                       labeled_train_recs_single, 
                                       labeled_val_recs_single, 
                                       INITIAL_MODEL, MAX_TOKENS, 
                                       BATCH_SIZE, labeled_train_embeddings_bert,
                                       labeled_val_embeddings_bert,
                                       )

In [44]:
single_labeled_data_lstm = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE)


train(single_labeled_data_lstm, train_loader_labeled, val_loader_labeled, 
      100, '/content/drive/MyDrive/NewsBert/Comparisons/single_labeled_data_lstm',
      True)

Train: 100%|██████████| 362/362 [00:11<00:00, 31.73it/s]


Train_loss:  0.003719133842036704


Val: 100%|██████████| 185/185 [00:05<00:00, 36.98it/s]


Val_loss:  0.003459727404186049


Train: 100%|██████████| 362/362 [00:11<00:00, 31.26it/s]


Train_loss:  0.0034011234818746784


Val: 100%|██████████| 185/185 [00:05<00:00, 36.82it/s]


Val_loss:  0.0031747199242582193


Train: 100%|██████████| 362/362 [00:11<00:00, 31.75it/s]


Train_loss:  0.0031679243534978333


Val: 100%|██████████| 185/185 [00:04<00:00, 37.04it/s]


Val_loss:  0.0029458167396385123


Train: 100%|██████████| 362/362 [00:11<00:00, 31.10it/s]


Train_loss:  0.0029810972091129555


Val: 100%|██████████| 185/185 [00:04<00:00, 37.38it/s]


Val_loss:  0.0027374024930837994


Train: 100%|██████████| 362/362 [00:11<00:00, 31.39it/s]


Train_loss:  0.002822991742917146


Val: 100%|██████████| 185/185 [00:05<00:00, 36.92it/s]


Val_loss:  0.002584883015653169


Train: 100%|██████████| 362/362 [00:11<00:00, 31.22it/s]


Train_loss:  0.002696043263750346


Val: 100%|██████████| 185/185 [00:04<00:00, 37.75it/s]


Val_loss:  0.0024671028116466224


Train: 100%|██████████| 362/362 [00:11<00:00, 31.34it/s]


Train_loss:  0.002583761495610092


Val: 100%|██████████| 185/185 [00:04<00:00, 37.55it/s]


Val_loss:  0.002342364544401298


Train: 100%|██████████| 362/362 [00:11<00:00, 30.99it/s]


Train_loss:  0.0024863471764587485


Val: 100%|██████████| 185/185 [00:05<00:00, 36.81it/s]


Val_loss:  0.0022470557250434883


Train: 100%|██████████| 362/362 [00:11<00:00, 31.35it/s]


Train_loss:  0.0024027861860575596


Val: 100%|██████████| 185/185 [00:05<00:00, 36.61it/s]


Val_loss:  0.0021543885360950153


Train: 100%|██████████| 362/362 [00:11<00:00, 31.36it/s]


Train_loss:  0.002326573984239518


Val: 100%|██████████| 185/185 [00:05<00:00, 36.63it/s]


Val_loss:  0.0020729957305750733


Train: 100%|██████████| 362/362 [00:11<00:00, 31.39it/s]


Train_loss:  0.002260919878367355


Val: 100%|██████████| 185/185 [00:05<00:00, 36.30it/s]


Val_loss:  0.0020263655633489426


Train: 100%|██████████| 362/362 [00:11<00:00, 31.28it/s]


Train_loss:  0.002200545517995295


Val: 100%|██████████| 185/185 [00:05<00:00, 36.63it/s]


Val_loss:  0.0019493527685266895


Train: 100%|██████████| 362/362 [00:11<00:00, 31.40it/s]


Train_loss:  0.0021445443432453956


Val: 100%|██████████| 185/185 [00:04<00:00, 37.27it/s]


Val_loss:  0.0019025179236573544


Train: 100%|██████████| 362/362 [00:11<00:00, 31.33it/s]


Train_loss:  0.0020921730130231677


Val: 100%|██████████| 185/185 [00:05<00:00, 36.50it/s]


Val_loss:  0.0018412966583226179


Train: 100%|██████████| 362/362 [00:11<00:00, 31.31it/s]


Train_loss:  0.002046274658014679


Val: 100%|██████████| 185/185 [00:04<00:00, 37.31it/s]


Val_loss:  0.0017887138308504142


Train: 100%|██████████| 362/362 [00:11<00:00, 31.77it/s]


Train_loss:  0.0020045743568913588


Val: 100%|██████████| 185/185 [00:05<00:00, 36.61it/s]


Val_loss:  0.0017557924500087628


Train: 100%|██████████| 362/362 [00:11<00:00, 31.48it/s]


Train_loss:  0.0019658423207046783


Val: 100%|██████████| 185/185 [00:04<00:00, 37.19it/s]


Val_loss:  0.001705189490363606


Train: 100%|██████████| 362/362 [00:11<00:00, 31.42it/s]


Train_loss:  0.0019297264680364144


Val: 100%|██████████| 185/185 [00:05<00:00, 36.70it/s]


Val_loss:  0.0016778036341624889


Train: 100%|██████████| 362/362 [00:11<00:00, 31.34it/s]


Train_loss:  0.0018976872355306717


Val: 100%|██████████| 185/185 [00:05<00:00, 36.59it/s]


Val_loss:  0.001637943316804799


Train: 100%|██████████| 362/362 [00:11<00:00, 31.47it/s]


Train_loss:  0.0018650382787416282


Val: 100%|██████████| 185/185 [00:05<00:00, 36.46it/s]


Val_loss:  0.0016150051432133124


Train: 100%|██████████| 362/362 [00:11<00:00, 31.22it/s]


Train_loss:  0.0018326668155595701


Val: 100%|██████████| 185/185 [00:04<00:00, 37.79it/s]


Val_loss:  0.001570338913793298


Train: 100%|██████████| 362/362 [00:11<00:00, 31.35it/s]


Train_loss:  0.0018076083599718401


Val: 100%|██████████| 185/185 [00:04<00:00, 37.76it/s]


Val_loss:  0.0015453573637264403


Train: 100%|██████████| 362/362 [00:11<00:00, 31.44it/s]


Train_loss:  0.0017820443123198265


Val: 100%|██████████| 185/185 [00:05<00:00, 36.89it/s]


Val_loss:  0.00152091346051846


Train: 100%|██████████| 362/362 [00:11<00:00, 31.19it/s]


Train_loss:  0.0017559129043060318


Val: 100%|██████████| 185/185 [00:04<00:00, 37.45it/s]


Val_loss:  0.0014939340039131206


Train: 100%|██████████| 362/362 [00:11<00:00, 31.12it/s]


Train_loss:  0.001736721976786255


Val: 100%|██████████| 185/185 [00:05<00:00, 36.71it/s]


Val_loss:  0.0014745908241870034


Train: 100%|██████████| 362/362 [00:11<00:00, 31.12it/s]


Train_loss:  0.001714506996949898


Val: 100%|██████████| 185/185 [00:04<00:00, 37.41it/s]


Val_loss:  0.001448614086729248


Train: 100%|██████████| 362/362 [00:11<00:00, 31.52it/s]


Train_loss:  0.0016951212847286332


Val: 100%|██████████| 185/185 [00:04<00:00, 37.09it/s]


Val_loss:  0.0014294574981102268


Train: 100%|██████████| 362/362 [00:11<00:00, 31.17it/s]


Train_loss:  0.0016742968277038378


Val: 100%|██████████| 185/185 [00:04<00:00, 37.00it/s]


Val_loss:  0.0014085902454885277


Train: 100%|██████████| 362/362 [00:11<00:00, 31.67it/s]


Train_loss:  0.001655998782548045


Val: 100%|██████████| 185/185 [00:04<00:00, 37.23it/s]


Val_loss:  0.001384962912699258


Train: 100%|██████████| 362/362 [00:11<00:00, 31.16it/s]


Train_loss:  0.0016379513516018775


Val: 100%|██████████| 185/185 [00:05<00:00, 36.21it/s]


Val_loss:  0.0013663579854560463


Train: 100%|██████████| 362/362 [00:11<00:00, 31.18it/s]


Train_loss:  0.001622099494651598


Val: 100%|██████████| 185/185 [00:04<00:00, 37.47it/s]


Val_loss:  0.0013550278925764801


Train: 100%|██████████| 362/362 [00:11<00:00, 31.51it/s]


Train_loss:  0.0016082065493172003


Val: 100%|██████████| 185/185 [00:05<00:00, 36.52it/s]


Val_loss:  0.0013342134453154899


Train: 100%|██████████| 362/362 [00:11<00:00, 31.60it/s]


Train_loss:  0.0015933617273505099


Val: 100%|██████████| 185/185 [00:05<00:00, 36.73it/s]


Val_loss:  0.0013204083916403958


Train: 100%|██████████| 362/362 [00:11<00:00, 31.07it/s]


Train_loss:  0.0015784135301859833


Val: 100%|██████████| 185/185 [00:05<00:00, 36.26it/s]


Val_loss:  0.0013031731856785512


Train: 100%|██████████| 362/362 [00:11<00:00, 30.78it/s]


Train_loss:  0.0015654473778410965


Val: 100%|██████████| 185/185 [00:05<00:00, 36.55it/s]


Val_loss:  0.0012887423832875652


Train: 100%|██████████| 362/362 [00:11<00:00, 31.03it/s]


Train_loss:  0.001552537140225508


Val: 100%|██████████| 185/185 [00:05<00:00, 36.95it/s]


Val_loss:  0.0012737705019881596


Train: 100%|██████████| 362/362 [00:11<00:00, 31.49it/s]


Train_loss:  0.001545025523268304


Val: 100%|██████████| 185/185 [00:05<00:00, 36.61it/s]


Val_loss:  0.001260207438010823


Train: 100%|██████████| 362/362 [00:11<00:00, 31.28it/s]


Train_loss:  0.0015311676806909192


Val: 100%|██████████| 185/185 [00:05<00:00, 36.50it/s]


Val_loss:  0.0012529123398299152


Train: 100%|██████████| 362/362 [00:11<00:00, 31.40it/s]


Train_loss:  0.0015179991776133933


Val: 100%|██████████| 185/185 [00:05<00:00, 35.97it/s]


Val_loss:  0.0012416709214448929


Train: 100%|██████████| 362/362 [00:11<00:00, 31.28it/s]


Train_loss:  0.001506069738119429


Val: 100%|██████████| 185/185 [00:04<00:00, 37.21it/s]


Val_loss:  0.0012240720119931407


Train: 100%|██████████| 362/362 [00:11<00:00, 31.45it/s]


Train_loss:  0.0015004002608137606


Val: 100%|██████████| 185/185 [00:04<00:00, 37.29it/s]


Val_loss:  0.001212473476708338


Train: 100%|██████████| 362/362 [00:11<00:00, 31.24it/s]


Train_loss:  0.0014878946478036105


Val: 100%|██████████| 185/185 [00:05<00:00, 36.97it/s]


Val_loss:  0.0011974290998123989


Train: 100%|██████████| 362/362 [00:11<00:00, 30.88it/s]


Train_loss:  0.0014758496873276256


Val: 100%|██████████| 185/185 [00:05<00:00, 36.66it/s]


Val_loss:  0.0011877051466522184


Train: 100%|██████████| 362/362 [00:11<00:00, 31.25it/s]


Train_loss:  0.0014677017332872797


Val: 100%|██████████| 185/185 [00:05<00:00, 36.54it/s]


Val_loss:  0.001181035615012956


Train: 100%|██████████| 362/362 [00:11<00:00, 31.35it/s]


Train_loss:  0.0014592802501926766


Val: 100%|██████████| 185/185 [00:05<00:00, 36.77it/s]


Val_loss:  0.0011718723491916585


Train: 100%|██████████| 362/362 [00:11<00:00, 31.56it/s]


Train_loss:  0.0014521678950179637


Val: 100%|██████████| 185/185 [00:05<00:00, 36.45it/s]


Val_loss:  0.001157336273683688


Train: 100%|██████████| 362/362 [00:11<00:00, 31.08it/s]


Train_loss:  0.001442957403542196


Val: 100%|██████████| 185/185 [00:05<00:00, 36.99it/s]


Val_loss:  0.0011454706898311505


Train: 100%|██████████| 362/362 [00:11<00:00, 31.02it/s]


Train_loss:  0.0014353575887935473


Val: 100%|██████████| 185/185 [00:05<00:00, 36.71it/s]


Val_loss:  0.0011501624337019953


Train: 100%|██████████| 362/362 [00:11<00:00, 31.71it/s]


Train_loss:  0.001425343952836335


Val: 100%|██████████| 185/185 [00:04<00:00, 37.34it/s]


Val_loss:  0.001133265529406836


Train: 100%|██████████| 362/362 [00:11<00:00, 31.51it/s]


Train_loss:  0.0014208674762727312


Val: 100%|██████████| 185/185 [00:05<00:00, 36.33it/s]


Val_loss:  0.0011193700024631579


Train: 100%|██████████| 362/362 [00:11<00:00, 31.40it/s]


Train_loss:  0.0014114843071078192


Val: 100%|██████████| 185/185 [00:05<00:00, 36.61it/s]


Val_loss:  0.0011119040298728726


Train: 100%|██████████| 362/362 [00:11<00:00, 31.50it/s]


Train_loss:  0.001404316227906255


Val: 100%|██████████| 185/185 [00:05<00:00, 36.46it/s]


Val_loss:  0.0011094163184532442


Train: 100%|██████████| 362/362 [00:11<00:00, 31.53it/s]


Train_loss:  0.0013945542821964217


Val: 100%|██████████| 185/185 [00:05<00:00, 36.85it/s]


Val_loss:  0.0011031136659250872


Train: 100%|██████████| 362/362 [00:11<00:00, 31.39it/s]


Train_loss:  0.0013931520580370699


Val: 100%|██████████| 185/185 [00:04<00:00, 37.05it/s]


Val_loss:  0.0010860089741244509


Train: 100%|██████████| 362/362 [00:11<00:00, 31.27it/s]


Train_loss:  0.0013854782671400073


Val: 100%|██████████| 185/185 [00:05<00:00, 36.74it/s]


Val_loss:  0.0010843621192474825


Train: 100%|██████████| 362/362 [00:11<00:00, 31.33it/s]


Train_loss:  0.0013804722031274842


Val: 100%|██████████| 185/185 [00:05<00:00, 36.77it/s]


Val_loss:  0.0010707074040043596


Train: 100%|██████████| 362/362 [00:11<00:00, 31.63it/s]


Train_loss:  0.0013728944571220537


Val: 100%|██████████| 185/185 [00:05<00:00, 36.08it/s]


Val_loss:  0.001070177075578957


Train: 100%|██████████| 362/362 [00:11<00:00, 31.27it/s]


Train_loss:  0.0013658934126824182


Val: 100%|██████████| 185/185 [00:05<00:00, 36.60it/s]


Val_loss:  0.0010672165889519492


Train: 100%|██████████| 362/362 [00:11<00:00, 31.14it/s]


Train_loss:  0.0013620238821922663


Val: 100%|██████████| 185/185 [00:05<00:00, 36.97it/s]


Val_loss:  0.0010567090370239237


Train: 100%|██████████| 362/362 [00:11<00:00, 31.14it/s]


Train_loss:  0.0013547126497708015


Val: 100%|██████████| 185/185 [00:05<00:00, 36.45it/s]


Val_loss:  0.0010484066757222485


Train: 100%|██████████| 362/362 [00:11<00:00, 31.24it/s]


Train_loss:  0.0013540843990502765


Val: 100%|██████████| 185/185 [00:05<00:00, 36.20it/s]


Val_loss:  0.0010442984630227895


Train: 100%|██████████| 362/362 [00:11<00:00, 31.46it/s]


Train_loss:  0.0013460717250809114


Val: 100%|██████████| 185/185 [00:05<00:00, 36.66it/s]


Val_loss:  0.0010331913400944827


Train: 100%|██████████| 362/362 [00:11<00:00, 31.41it/s]


Train_loss:  0.0013420505916123479


Val: 100%|██████████| 185/185 [00:05<00:00, 36.59it/s]


Val_loss:  0.0010275666476101488


Train: 100%|██████████| 362/362 [00:11<00:00, 31.22it/s]


Train_loss:  0.0013371434378520025


Val: 100%|██████████| 185/185 [00:04<00:00, 37.61it/s]


Val_loss:  0.0010253176486396509


Train: 100%|██████████| 362/362 [00:11<00:00, 30.98it/s]


Train_loss:  0.001334083550198014


Val: 100%|██████████| 185/185 [00:05<00:00, 36.20it/s]


Val_loss:  0.0010177128706001552


Train: 100%|██████████| 362/362 [00:11<00:00, 30.91it/s]


Train_loss:  0.001326470558328338


Val: 100%|██████████| 185/185 [00:05<00:00, 36.85it/s]


Val_loss:  0.0010084956447040108


Train: 100%|██████████| 362/362 [00:11<00:00, 30.47it/s]


Train_loss:  0.0013225669763565517


Val: 100%|██████████| 185/185 [00:05<00:00, 36.16it/s]


Val_loss:  0.001010594008540785


Train: 100%|██████████| 362/362 [00:11<00:00, 31.54it/s]


Train_loss:  0.0013152836331228072


Val: 100%|██████████| 185/185 [00:05<00:00, 35.81it/s]


Val_loss:  0.001002529580334856


Train: 100%|██████████| 362/362 [00:11<00:00, 31.19it/s]


Train_loss:  0.0013172793387337613


Val: 100%|██████████| 185/185 [00:05<00:00, 35.92it/s]


Val_loss:  0.0009985371212777052


Train: 100%|██████████| 362/362 [00:11<00:00, 30.90it/s]


Train_loss:  0.0013085397067145337


Val: 100%|██████████| 185/185 [00:05<00:00, 36.79it/s]


Val_loss:  0.000993692604321483


Train: 100%|██████████| 362/362 [00:11<00:00, 31.09it/s]


Train_loss:  0.0013062725623796865


Val: 100%|██████████| 185/185 [00:05<00:00, 36.93it/s]


Val_loss:  0.0009913248151565926


Train: 100%|██████████| 362/362 [00:11<00:00, 31.24it/s]


Train_loss:  0.0013042468663166118


Val: 100%|██████████| 185/185 [00:05<00:00, 36.56it/s]


Val_loss:  0.0009786352206877357


Train: 100%|██████████| 362/362 [00:11<00:00, 31.25it/s]


Train_loss:  0.001297063104813476


Val: 100%|██████████| 185/185 [00:04<00:00, 37.10it/s]


Val_loss:  0.0009809479929154387


Train: 100%|██████████| 362/362 [00:11<00:00, 31.65it/s]


Train_loss:  0.0012918687760006657


Val: 100%|██████████| 185/185 [00:05<00:00, 36.91it/s]


Val_loss:  0.0009734474789552592


Train: 100%|██████████| 362/362 [00:11<00:00, 31.28it/s]


Train_loss:  0.0012905766308894226


Val: 100%|██████████| 185/185 [00:05<00:00, 36.93it/s]


Val_loss:  0.000971430674434413


Train: 100%|██████████| 362/362 [00:11<00:00, 30.92it/s]


Train_loss:  0.0012840004316248123


Val: 100%|██████████| 185/185 [00:05<00:00, 35.67it/s]


Val_loss:  0.0009675744854895448


Train: 100%|██████████| 362/362 [00:11<00:00, 31.22it/s]


Train_loss:  0.0012836176574034669


Val: 100%|██████████| 185/185 [00:05<00:00, 35.61it/s]


Val_loss:  0.0009591705187827953


Train: 100%|██████████| 362/362 [00:11<00:00, 31.20it/s]


Train_loss:  0.001277346065903381


Val: 100%|██████████| 185/185 [00:05<00:00, 36.97it/s]


Val_loss:  0.0009568495902459364


Train: 100%|██████████| 362/362 [00:11<00:00, 31.27it/s]


Train_loss:  0.0012780613592666321


Val: 100%|██████████| 185/185 [00:05<00:00, 36.44it/s]


Val_loss:  0.0009523983009637812


Train: 100%|██████████| 362/362 [00:11<00:00, 31.43it/s]


Train_loss:  0.0012722047469038414


Val: 100%|██████████| 185/185 [00:05<00:00, 35.99it/s]


Val_loss:  0.0009485135543965609


Train: 100%|██████████| 362/362 [00:11<00:00, 30.58it/s]


Train_loss:  0.001268811445446551


Val: 100%|██████████| 185/185 [00:05<00:00, 35.73it/s]


Val_loss:  0.0009440125472768134


Train: 100%|██████████| 362/362 [00:11<00:00, 30.24it/s]


Train_loss:  0.0012656111049579035


Val: 100%|██████████| 185/185 [00:05<00:00, 33.61it/s]


Val_loss:  0.0009422870727281112


Train: 100%|██████████| 362/362 [00:11<00:00, 30.82it/s]


Train_loss:  0.001263169429614852


Val: 100%|██████████| 185/185 [00:05<00:00, 35.96it/s]


Val_loss:  0.0009342053197198422


Train: 100%|██████████| 362/362 [00:11<00:00, 31.23it/s]


Train_loss:  0.0012614782266087254


Val: 100%|██████████| 185/185 [00:05<00:00, 36.92it/s]


Val_loss:  0.000932944450854651


Train: 100%|██████████| 362/362 [00:11<00:00, 30.75it/s]


Train_loss:  0.0012575312596694334


Val: 100%|██████████| 185/185 [00:05<00:00, 36.77it/s]


Val_loss:  0.0009291490067954402


Train: 100%|██████████| 362/362 [00:11<00:00, 31.35it/s]


Train_loss:  0.00125665355053118


Val: 100%|██████████| 185/185 [00:05<00:00, 36.98it/s]


Val_loss:  0.0009253832763312636


Train: 100%|██████████| 362/362 [00:11<00:00, 31.02it/s]


Train_loss:  0.0012541923064800257


Val: 100%|██████████| 185/185 [00:04<00:00, 37.02it/s]


Val_loss:  0.0009224024698817851


Train: 100%|██████████| 362/362 [00:11<00:00, 31.38it/s]


Train_loss:  0.001247637146872916


Val: 100%|██████████| 185/185 [00:04<00:00, 37.58it/s]


Val_loss:  0.0009263692282115084


Train: 100%|██████████| 362/362 [00:11<00:00, 31.49it/s]


Train_loss:  0.0012461535734875715


Val: 100%|██████████| 185/185 [00:05<00:00, 36.34it/s]


Val_loss:  0.0009154995280393474


Train: 100%|██████████| 362/362 [00:11<00:00, 31.42it/s]


Train_loss:  0.001247815017907453


Val: 100%|██████████| 185/185 [00:04<00:00, 37.00it/s]


Val_loss:  0.0009106494734894383


Train: 100%|██████████| 362/362 [00:11<00:00, 31.26it/s]


Train_loss:  0.0012401685772144664


Val: 100%|██████████| 185/185 [00:05<00:00, 36.89it/s]


Val_loss:  0.000905998021276114


Train: 100%|██████████| 362/362 [00:11<00:00, 31.26it/s]


Train_loss:  0.0012373803691547466


Val: 100%|██████████| 185/185 [00:04<00:00, 37.10it/s]


Val_loss:  0.0009046550988018312


Train: 100%|██████████| 362/362 [00:11<00:00, 31.41it/s]


Train_loss:  0.0012386350092370043


Val: 100%|██████████| 185/185 [00:05<00:00, 36.76it/s]


Val_loss:  0.0008996859787510255


Train: 100%|██████████| 362/362 [00:11<00:00, 31.23it/s]


Train_loss:  0.0012348193700363582


Val: 100%|██████████| 185/185 [00:05<00:00, 35.88it/s]


Val_loss:  0.0008946449385729392


Train: 100%|██████████| 362/362 [00:11<00:00, 31.18it/s]


Train_loss:  0.0012329249090942147


Val: 100%|██████████| 185/185 [00:04<00:00, 37.19it/s]


Val_loss:  0.0008943510496661672


Train: 100%|██████████| 362/362 [00:11<00:00, 31.09it/s]


Train_loss:  0.0012268016895649759


Val: 100%|██████████| 185/185 [00:04<00:00, 37.23it/s]


Val_loss:  0.0008943333189833808


Train: 100%|██████████| 362/362 [00:11<00:00, 31.20it/s]


Train_loss:  0.0012270782788628867


Val: 100%|██████████| 185/185 [00:04<00:00, 37.13it/s]


Val_loss:  0.0008889416651171003


Train: 100%|██████████| 362/362 [00:11<00:00, 31.36it/s]


Train_loss:  0.0012253983102569058


Val: 100%|██████████| 185/185 [00:04<00:00, 37.03it/s]


Val_loss:  0.0008903047733789159


Train: 100%|██████████| 362/362 [00:11<00:00, 31.68it/s]


Train_loss:  0.0012224457636398382


Val: 100%|██████████| 185/185 [00:04<00:00, 37.41it/s]


Val_loss:  0.0008917170532702192


In [45]:
single_labeled_data_lstm.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Comparisons/single_labeled_data_lstm/best-distill-bert.pt'))

distill_embedder = single_labeled_data_lstm.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.5495348462426805 ms

Accuracy: 72.2
Positives Recall: 41.2
Positives Precision: 96.5
Positives F1: 57.7
Distance:  0.38
Max cluster size:  109
Median cluster size:  1.0
Avg cluster size: 1.85


In [46]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.5597209411157841 ms

Accuracy: 71.5
Positives Recall: 39.6
Positives Precision: 96.1
Positives F1: 56.1
Distance:  0.38
Max cluster size:  79
Median cluster size:  1
Avg cluster size: 1.82


###Comparisons for the final estimation

In [46]:
class ClusteringTripletModelForLSTM(nn.Module):
    def __init__(self, vocab_size=119547,  
                 word_emb_dim=128, 
                 rnn_hidden_dim=128,
                 rnn_layers_count=2,
                 target_size=TARGET_SIZE, 
                 lr=1e-3, 
                 margin=0.5,
                 to_gru=False,
                 attentive_aggregation=False,
                 ):
        super().__init__()

        self.embedder = DistillEmbedder(vocab_size=vocab_size, 
                                        word_emb_dim=word_emb_dim, 
                                        rnn_hidden_dim=rnn_hidden_dim, 
                                        rnn_layers_count=rnn_layers_count,
                                        target_size=target_size,
                                        to_gru=to_gru,
                                        attentive_aggregation=attentive_aggregation,
                                        )


        self.triplet_loss = nn.TripletMarginWithDistanceLoss(
            margin=margin,
            distance_function=nn.PairwiseDistance(p=2)
        )

        self.lr = lr

    def forward(self, pivots, positives, negatives):
        pivot_embeddings = self.embedder(pivots["input_ids"].cuda(), pivots["attention_mask"])
        positive_embeddings = self.embedder(positives["input_ids"].cuda(), positives["attention_mask"])
        negative_embeddings = self.embedder(negatives["input_ids"].cuda(), negatives["attention_mask"])
        return pivot_embeddings, positive_embeddings, negative_embeddings

    def loss(self, embeds, bert_embeds):
        anchor, positive, negative = embeds
        triplet_loss = self.triplet_loss(anchor, positive, negative) 
        return triplet_loss


def epoch_train(model, data_loader, optimizer, scheduler):
    train_loss = 0
    model.cuda().train()
    for i, samples in enumerate(tqdm(data_loader, desc='Train')):
       model.zero_grad()
       
       bert_embeds = [samples[3].cuda(), samples[4].cuda(), samples[5].cuda()]
       embeds = model(samples[0], samples[1], samples[2])


       loss = model.loss(embeds, bert_embeds)
       loss.backward()
       optimizer.step()
       train_loss += float(loss)
    loss = train_loss / len(data_loader)
    scheduler.step(loss)
    print('Train_loss: ', loss)
    return loss

def train(model, train_data_loader,
          val_data_loader, num_train_epochs, 
          output_dir, optimizer=torch.optim.Adam, 
          patience=3, scheduler_patience=2, 
          ):
      optimizer = optimizer(model.parameters(), lr=model.lr)
      scheduler = ReduceLROnPlateau(optimizer, patience=scheduler_patience)
      writer = SummaryWriter(output_dir)
      best_val_loss = 100000
      best_count = 0
      for epoch in range(num_train_epochs):
          train_loss = epoch_train(model, train_data_loader, optimizer, scheduler)
          val_loss = epoch_val(model, val_data_loader)
          writer.add_scalar('Loss/train', train_loss, epoch)
          writer.add_scalar('Loss/val', val_loss, epoch)
          writer.add_scalar('Learning rate', scheduler._last_lr[0], epoch)
          best_count += 1
         
          if val_loss < best_val_loss:
              best_val_loss = val_loss
              best_count = 0
              torch.save(model.state_dict(), os.path.join(output_dir, f'best-distill-bert.pt'))
          
          if best_count == patience:
              torch.save(model.state_dict(), os.path.join(output_dir, 'last_epoch.pt'))
              writer.close()
              break

      writer.close()
      torch.save(model.state_dict(), os.path.join(output_dir, 'last_epoch.pt'))

In [54]:
BATCH_SIZE = 32

bidir_gru_not_distilled = ClusteringTripletModelForLSTM(attentive_aggregation=True, to_gru=True)


train(bidir_gru_not_distilled, train_loader_triplets, val_loader_triplets,  
      60, '/content/drive/MyDrive/NewsBert/Comparisons/bidir_gru_not_distilled_attn',
      )

Train: 100%|██████████| 471/471 [01:16<00:00,  6.15it/s]


Train_loss:  0.09177566635285973


Val: 100%|██████████| 118/118 [00:11<00:00, 10.44it/s]


Val_loss:  0.04990701967755617


Train: 100%|██████████| 471/471 [01:16<00:00,  6.16it/s]


Train_loss:  0.020041974672687788


Val: 100%|██████████| 118/118 [00:11<00:00, 10.45it/s]


Val_loss:  0.04150954580281751


Train: 100%|██████████| 471/471 [01:16<00:00,  6.14it/s]


Train_loss:  0.00928623782126767


Val: 100%|██████████| 118/118 [00:11<00:00, 10.61it/s]


Val_loss:  0.034617345369720866


Train: 100%|██████████| 471/471 [01:16<00:00,  6.17it/s]


Train_loss:  0.005873858319662685


Val: 100%|██████████| 118/118 [00:11<00:00, 10.43it/s]


Val_loss:  0.04280104734382387


Train: 100%|██████████| 471/471 [01:17<00:00,  6.11it/s]


Train_loss:  0.005387647389859404


Val: 100%|██████████| 118/118 [00:11<00:00, 10.46it/s]


Val_loss:  0.03871551689581346


Train: 100%|██████████| 471/471 [01:16<00:00,  6.13it/s]


Train_loss:  0.004518746428049294


Val: 100%|██████████| 118/118 [00:11<00:00, 10.32it/s]


Val_loss:  0.04083520155084335


In [33]:
bidir_gru_not_distilled.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Comparisons/bidir_gru_not_distilled_attn/best-distill-bert.pt'))

distill_embedder = bidir_gru_not_distilled.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.5357810535278928 ms

Accuracy: 46.2
Positives Recall: 100.0
Positives Precision: 46.2
Positives F1: 63.2
Distance:  0.38
Max cluster size:  20059
Median cluster size:  3
Avg cluster size: 2868.57


In [34]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.5504345593714144 ms

Accuracy: 45.9
Positives Recall: 99.7
Positives Precision: 45.9
Positives F1: 62.9
Distance:  0.38
Max cluster size:  19068
Median cluster size:  9
Avg cluster size: 3819.20


In [21]:
BATCH_SIZE = 32

bidir_gru_not_distilled_64 = ClusteringTripletModelForLSTM(attentive_aggregation=True, to_gru=True,
                                                           word_emb_dim=64)


train(bidir_gru_not_distilled_64, train_loader_triplets, val_loader_triplets,  
      60, '/content/drive/MyDrive/NewsBert/Comparisons/bidir_gru_not_distilled_attn/64',
      )

Train: 100%|██████████| 471/471 [01:17<00:00,  6.11it/s]


Train_loss:  0.10535618780683054


Val: 100%|██████████| 118/118 [00:11<00:00, 10.13it/s]


Val_loss:  0.06566119653377998


Train: 100%|██████████| 471/471 [01:16<00:00,  6.12it/s]


Train_loss:  0.03386003509828239


Val: 100%|██████████| 118/118 [00:11<00:00, 10.22it/s]


Val_loss:  0.039012689806394656


Train: 100%|██████████| 471/471 [01:16<00:00,  6.18it/s]


Train_loss:  0.01101036429911409


Val: 100%|██████████| 118/118 [00:11<00:00, 10.48it/s]


Val_loss:  0.036131832559229964


Train: 100%|██████████| 471/471 [01:17<00:00,  6.11it/s]


Train_loss:  0.0062985087956465985


Val: 100%|██████████| 118/118 [00:11<00:00, 10.37it/s]


Val_loss:  0.03879341899843539


Train: 100%|██████████| 471/471 [01:17<00:00,  6.12it/s]


Train_loss:  0.006505691574890396


Val: 100%|██████████| 118/118 [00:11<00:00, 10.33it/s]


Val_loss:  0.0329774092188326


Train: 100%|██████████| 471/471 [01:17<00:00,  6.06it/s]


Train_loss:  0.004480658506385333


Val: 100%|██████████| 118/118 [00:11<00:00, 10.25it/s]


Val_loss:  0.03546165546276054


Train: 100%|██████████| 471/471 [01:17<00:00,  6.05it/s]


Train_loss:  0.006383689345827528


Val: 100%|██████████| 118/118 [00:11<00:00, 10.48it/s]


Val_loss:  0.04008224215042793


Train: 100%|██████████| 471/471 [01:17<00:00,  6.10it/s]


Train_loss:  0.004059622020731559


Val: 100%|██████████| 118/118 [00:11<00:00, 10.28it/s]


Val_loss:  0.044734745081198415


In [47]:
bidir_gru_not_distilled_64.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Comparisons/bidir_gru_not_distilled_attn/64/best-distill-bert.pt'))

distill_embedder = bidir_gru_not_distilled_64.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.5908298339026858 ms

Accuracy: 54.2
Positives Recall: 62.4
Positives Precision: 50.3
Positives F1: 55.7
Distance:  0.38
Max cluster size:  13618
Median cluster size:  14
Avg cluster size: 542.70


In [48]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer_triplets, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.536950621936528 ms

Accuracy: 57.7
Positives Recall: 71.9
Positives Precision: 53.0
Positives F1: 61.0
Distance:  0.38
Max cluster size:  14934
Median cluster size:  19
Avg cluster size: 578.67


In [None]:
EPOCHS = 60
ACCUMULATE_GRAD_BATCHES = 8
LOG_EVERY_N_STEPS = 10

In [None]:
BATCH_SIZE = 12

tiny_train_loader, tiny_val_loader, tiny_tokenizer = get_loaders(
                                       train_records, val_records, 
                                       "cointegrated/rubert-tiny", MAX_TOKENS, 
                                       BATCH_SIZE, NUM_WORKERS,
                                       )


model = ClusteringTripletModel("cointegrated/rubert-tiny", lr=0.0001)
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0001,
    patience=3,
    verbose=True,
    mode="min" 
)

checkpoint = ModelCheckpoint(
    monitor='val_loss',
    dirpath='/content/drive/MyDrive/NewsBert/tiny',
    filename='clustering_news_bert_rubert-tiny'
)


trainer = Trainer(
    gpus=1,
    accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES,
    max_epochs=EPOCHS,
    callbacks=[early_stop_callback, checkpoint],
    log_every_n_steps=LOG_EVERY_N_STEPS
)


trainer.fit(model, tiny_train_loader, tiny_val_loader)

In [17]:
torch.save(model.state_dict(), os.path.join('/content/drive/MyDrive/NewsBert/tiny/', f'best-distill-bert.pt'))

In [27]:
model = ClusteringTripletModel.load_from_checkpoint(
    model_path = "cointegrated/rubert-tiny",
    checkpoint_path = '/content/drive/MyDrive/NewsBert/tiny/clustering_news_bert_rubert-tiny.ckpt',
    )

embedder = model.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, embedder, tiny_tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Some weights of the model checkpoint at cointegrated/rubert-tiny were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Mean inference time per embed:  0.6591846915117773 ms

Accuracy: 86.2
Positives Recall: 77.2
Positives Precision: 91.6
Positives F1: 83.8
Distance:  0.38
Max cluster size:  153
Median cluster size:  2.0
Avg cluster size: 2.91


In [28]:
private_distill_embeddings = records_to_embeds(private_set, embedder, tiny_tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.6579545891879242 ms

Accuracy: 85.4
Positives Recall: 75.9
Positives Precision: 90.8
Positives F1: 82.7
Distance:  0.38
Max cluster size:  116
Median cluster size:  2
Avg cluster size: 2.81


In [20]:
BATCH_SIZE = 12

tiny2_train_loader, tiny2_val_loader, tiny2_tokenizer = get_loaders(
                                       train_records, val_records, 
                                       "cointegrated/rubert-tiny2", MAX_TOKENS, 
                                       BATCH_SIZE, NUM_WORKERS,
                                       )

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0001,
    patience=3,
    verbose=True,
    mode="min" 
)

model = ClusteringTripletModel("cointegrated/rubert-tiny2", lr=0.0001)

checkpoint = ModelCheckpoint(
    monitor='val_loss',
    dirpath='/content/drive/MyDrive/NewsBert/tiny2',
    filename='clustering_news_bert_rubert-tiny2'
)


trainer = Trainer(
    gpus=1,
    accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES,
    max_epochs=EPOCHS,
    callbacks=[early_stop_callback, checkpoint],
    log_every_n_steps=LOG_EVERY_N_STEPS
)


trainer.fit(model, tiny2_train_loader, tiny2_val_loader)

Downloading:   0%|          | 0.00/401 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.03M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.66M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/715 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112M [00:00<?, ?B/s]

Some weights of the model checkpoint at cointegrated/rubert-tiny2 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU av

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 0.016


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.006 >= min_delta = 0.0001. New best score: 0.010


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0.0001. New best score: 0.008


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0001. New best score: 0.007


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0001. New best score: 0.006


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0001. New best score: 0.006


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0001. New best score: 0.006


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0001. New best score: 0.006


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0001. New best score: 0.005


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 3 records. Best score: 0.005. Signaling Trainer to stop.


In [21]:
torch.save(model.state_dict(), os.path.join('/content/drive/MyDrive/NewsBert/tiny2/', f'best-distill-bert.pt'))

In [29]:
model = ClusteringTripletModel.load_from_checkpoint(
    model_path = "cointegrated/rubert-tiny2",
    checkpoint_path = '/content/drive/MyDrive/NewsBert/tiny2/clustering_news_bert_rubert-tiny2.ckpt',
    )

embedder = model.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, embedder, tiny2_tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Some weights of the model checkpoint at cointegrated/rubert-tiny2 were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Mean inference time per embed:  0.6598458836515586 ms

Accuracy: 91.4
Positives Recall: 86.4
Positives Precision: 94.5
Positives F1: 90.3
Distance:  0.38
Max cluster size:  175
Median cluster size:  2
Avg cluster size: 3.02


In [30]:
private_distill_embeddings = records_to_embeds(private_set, embedder, tiny2_tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.658596769381029 ms

Accuracy: 92.2
Positives Recall: 87.3
Positives Precision: 95.3
Positives F1: 91.1
Distance:  0.38
Max cluster size:  148
Median cluster size:  1
Avg cluster size: 2.92


In [10]:
EPOCHS = 60
BATCH_SIZE = 12
ACCUMULATE_GRAD_BATCHES = 8
LOG_EVERY_N_STEPS = 10


tiny_cased_train_loader, tiny_cased_val_loader, tiny_cased_tokenizer = get_loaders(
                                       train_records, val_records, 
                                       "DeepPavlov/distilrubert-tiny-cased-conversational", 
                                       MAX_TOKENS, 
                                       BATCH_SIZE, NUM_WORKERS,
                                       )


model = ClusteringTripletModel("DeepPavlov/distilrubert-tiny-cased-conversational", lr=0.0001)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0001,
    patience=3,
    verbose=True,
    mode="min" 
)

checkpoint = ModelCheckpoint(
    monitor='val_loss',
    dirpath='/content/drive/MyDrive/NewsBert/tiny_cased',
    filename='clustering_news_bert_rubert-tiny_cased'
)


trainer = Trainer(
    gpus=1,
    accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES,
    max_epochs=EPOCHS,
    callbacks=[early_stop_callback, checkpoint],
    log_every_n_steps=LOG_EVERY_N_STEPS
)


trainer.fit(model, tiny_cased_train_loader, tiny_cased_val_loader)

Some weights of the model checkpoint at DeepPavlov/distilrubert-tiny-cased-conversational were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                          | Para

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 0.013


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.005 >= min_delta = 0.0001. New best score: 0.008


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0001. New best score: 0.007


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0001. New best score: 0.007


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0001. New best score: 0.006


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 3 records. Best score: 0.006. Signaling Trainer to stop.


In [11]:
torch.save(model.state_dict(), os.path.join('/content/drive/MyDrive/NewsBert/tiny_cased/', f'best-distill-bert.pt'))

In [31]:
model = ClusteringTripletModel.load_from_checkpoint(
    model_path = "DeepPavlov/distilrubert-tiny-cased-conversational",
    checkpoint_path = '/content/drive/MyDrive/NewsBert/tiny_cased/clustering_news_bert_rubert-tiny_cased.ckpt',
    )

embedder = model.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, embedder, tiny_cased_tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Some weights of the model checkpoint at DeepPavlov/distilrubert-tiny-cased-conversational were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Mean inference time per embed:  1.3636089203367194 ms

Accuracy: 91.5
Positives Recall: 88.9
Positives Precision: 92.5
Positives F1: 90.7
Distance:  0.38
Max cluster size:  234
Median cluster size:  2.0
Avg cluster size: 3.29


In [32]:
private_distill_embeddings = records_to_embeds(private_set, embedder, tiny_cased_tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.3630668045238916 ms

Accuracy: 91.9
Positives Recall: 89.4
Positives Precision: 92.7
Positives F1: 91.0
Distance:  0.38
Max cluster size:  145
Median cluster size:  2
Avg cluster size: 3.20


##--Other embedding architectures--

###Data loading and preprocessing

In [35]:
%cd /content/drive/MyDrive/NewsBert
from initial_finetuned_model import ClusteringTripletModel
from loading_and_evaluation import gen_batch, records_to_embeds, get_quality, NewsDataset, get_loaders

with open('single_full_train_embeddings_bert', 'rb') as pickle_file:
    single_full_train_embeddings_bert = pickle.load(pickle_file)

with open('single_full_val_embeddings_bert', 'rb') as pickle_file:
    single_full_val_embeddings_bert = pickle.load(pickle_file)


with open('full_train_records', 'rb') as pickle_file:
    full_train_records = pickle.load(pickle_file)

with open('full_val_records', 'rb') as pickle_file:
    full_val_records = pickle.load(pickle_file)


with open('public_set', 'rb') as pickle_file:
    public_set = pickle.load(pickle_file)

with open('private_set', 'rb') as pickle_file:
    private_set = pickle.load(pickle_file)


initial_model = ClusteringTripletModel.load_from_checkpoint(
    model_path = INITIAL_MODEL,
    checkpoint_path = '/content/drive/MyDrive/NewsBert/best_clustering_news_bert-val_loss=0.0008.ckpt',
    num_training_steps = None
    )

embedder = initial_model.embedder.cuda()

VOCAB_SIZE = embedder.model.embeddings.word_embeddings.num_embeddings
BATCH_SIZE = 128


public_markup = read_markup_tsv("ru_clustering_0527_urls_final.tsv")
private_markup = read_markup_tsv("ru_clustering_0529_urls_final_v2.tsv")

/content/drive/MyDrive/NewsBert


Some weights of the model checkpoint at IlyaGusev/news_tg_rubert were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at IlyaGusev/news_tg_rubert and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.

In [36]:
train_loader, val_loader, tokenizer = get_loaders(full_train_records,
                                                  full_val_records, 
                                                  INITIAL_MODEL, MAX_TOKENS, 
                                                  BATCH_SIZE, 
                                                  single_full_train_embeddings_bert,
                                                  single_full_val_embeddings_bert,
                                                  )                                       

Downloading:   0%|          | 0.00/1.57M [00:00<?, ?B/s]

####Transformer

In [18]:
class TransformerEmbedder(nn.Module):
  def __init__(self, vocab_size=VOCAB_SIZE, 
               target_size=TARGET_SIZE, 
               pretrained=False,  
               pretrained_embs=None,
               freeze_pretrained=False,
               word_emb_dim=56,
               dim_feedforward=1024,
               n_heads=4, 
               n_layers=6,
               relu=False,
               attentive_aggregation=False,
               cls_aggregation=False,
               ):

        super().__init__()
        
        self.target_size=target_size
        self.word_emb_dim=word_emb_dim
        self.relu=relu
        self.attentive_aggregation=attentive_aggregation
        self.cls_aggregation=cls_aggregation
         
        if pretrained:
          self.embedding = nn.Embedding.from_pretrained(pretrained_embs, 
                                                        freeze=freeze_pretrained) 
          word_emb_dim = pretrained_embs.shape[1]
        else:
          self.embedding = nn.Embedding(vocab_size, word_emb_dim)

        self.encoder_layer = nn.TransformerEncoderLayer(word_emb_dim, n_heads,
                                                        dim_feedforward=dim_feedforward, 
                                                        batch_first=True)
        
        self.encoder = nn.TransformerEncoder(self.encoder_layer, n_layers)
        
        if word_emb_dim != target_size:
          self.mapping = nn.Linear(word_emb_dim, target_size)
        
        if cls_aggregation:
          self.cls_pooling_linear = nn.Linear(target_size, target_size)

        if attentive_aggregation:  
          self.softmax = nn.Softmax(dim=1)     
          self.attn = nn.Sequential(
              nn.Linear(target_size, target_size//2),
              nn.ReLU(),
              nn.Linear(target_size//2, 1)
            )
        

        
  def aggregate(self, transformer_output, mask):
        expanded_mask = mask.cuda().unsqueeze(-1).expand(transformer_output.size()).float()
        sum_embeddings = torch.sum(transformer_output * expanded_mask, 1).cuda()
        sum_mask = torch.clamp(expanded_mask.sum(1), min=1e-9)
        return sum_embeddings / sum_mask   

  def cls_aggregate(self, transformer_output):
        cls_emb = nn.functional.relu(self.mapping(transformer_output[:,0,:].squeeze(1)))
        return self.cls_pooling_linear(cls_emb)

  def attentive_aggregate(self, transformer_output, mask):
        weights = self.softmax(self.attn(transformer_output).squeeze(-1)) * mask
        embeddings = weights.unsqueeze(1).bmm(transformer_output).squeeze(1)
        return embeddings 

  def forward(self, x, mask):
        embs = self.embedding(x)
        transformer_output = self.encoder(embs, src_key_padding_mask=mask)
        if self.cls_aggregation:
          out = self.cls_aggregate(transformer_output)
        else:
          if self.word_emb_dim != self.target_size:
            if self.relu:
              transformer_output = self.mapping(nn.functional.relu(transformer_output))
            else:
              transformer_output = self.mapping(transformer_output)
          if self.attentive_aggregation: 
            out = self.attentive_aggregate(transformer_output, mask)
          else:
            out = self.aggregate(transformer_output, mask)
        return out

In [19]:
class TransformerClusteringDistillModel(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, target_size=TARGET_SIZE, 
                 pretrained=False, pretrained_embs=None, 
                 freeze_pretrained=False, word_emb_dim=128, 
                 evaluate_similarity=False, dim_feedforward=1024,
                 lr=1e-3, n_heads=4, n_layers=2, relu=False,
                 attentive_aggregation=False, cls_aggregation=False,
                 ):
      
        super().__init__()

        self.embedder = TransformerEmbedder(vocab_size=vocab_size, 
                                      target_size=target_size, 
                                      pretrained=pretrained,  
                                      pretrained_embs=pretrained_embs,
                                      freeze_pretrained=freeze_pretrained,
                                      word_emb_dim=word_emb_dim,
                                      n_heads=n_heads, 
                                      n_layers=n_layers,
                                      dim_feedforward=dim_feedforward, 
                                      relu=relu,
                                      attentive_aggregation=attentive_aggregation, 
                                      cls_aggregation=cls_aggregation,
                                      )

        self.evaluate_similarity = evaluate_similarity
        if not evaluate_similarity:
          self.mse = torch.nn.MSELoss()
        else:
          self.cosine_similarity = torch.nn.functional.cosine_similarity
        
        self.lr = lr

    def forward(self, news):
        embeddings = self.embedder(news["input_ids"].cuda(), 
                                   news["attention_mask"].cuda())
        return embeddings

    def loss(self, embeds, bert_embeds):
        if self.evaluate_similarity is True:
          similarity = self.cosine_similarity(embeds, bert_embeds)
          loss = torch.mean(torch.ones(len(similarity)).cuda() - similarity)
        else:
          loss = self.mse(embeds.float(), bert_embeds.float())
        return loss

#####1 layer + 4 heads

In [None]:
transformer_distill_head4_layer1 = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       )
train(transformer_distill_head4_layer1, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1',
      )

In [20]:
transformer_distill_head4_layer1 = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       )

transformer_distill_head4_layer1.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1/best-distill-bert.pt'))

distill_embedder = transformer_distill_head4_layer1.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.13554358551820436 ms

Accuracy: 91.5
Positives Recall: 91.4
Positives Precision: 90.3
Positives F1: 90.9
Distance:  0.38
Max cluster size:  322
Median cluster size:  2
Avg cluster size: 4.77


In [21]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.13365249270842472 ms

Accuracy: 91.1
Positives Recall: 91.0
Positives Precision: 89.7
Positives F1: 90.4
Distance:  0.38
Max cluster size:  172
Median cluster size:  2
Avg cluster size: 4.58


#####1 layer + 4 heads + relu

In [None]:
transformer_distill_head4_layer1_relu = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       relu=True,
                                                       )
train(transformer_distill_head4_layer1_relu, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1_relu',
      )

In [48]:
transformer_distill_head4_layer1_relu = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       relu=True,
                                                       )


transformer_distill_head4_layer1_relu.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1_relu/best-distill-bert.pt'))

distill_embedder = transformer_distill_head4_layer1_relu.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.13353054191605504 ms

Accuracy: 91.5
Positives Recall: 92.0
Positives Precision: 89.8
Positives F1: 90.9
Distance:  0.38
Max cluster size:  310
Median cluster size:  2.0
Avg cluster size: 4.87


In [50]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.13304639458781164 ms

Accuracy: 91.4
Positives Recall: 91.4
Positives Precision: 90.0
Positives F1: 90.7
Distance:  0.38
Max cluster size:  155
Median cluster size:  2.0
Avg cluster size: 4.64


#####1 layer + 4 heads / *attentive aggregation*

In [26]:
transformer_distill_head4_layer1_attn = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       attentive_aggregation=True)

train(transformer_distill_head4_layer1_attn, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1_attn',
      )

Train: 100%|██████████| 4037/4037 [05:41<00:00, 11.82it/s]


Train_loss:  0.4259607350616584


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.73it/s]


Val_loss:  0.3073411729030545


Train: 100%|██████████| 4037/4037 [05:40<00:00, 11.84it/s]


Train_loss:  0.28058386665258506


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.77it/s]


Val_loss:  0.2564051361360311


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.88it/s]


Train_loss:  0.24134234211862787


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.75it/s]


Val_loss:  0.23060317959694876


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.88it/s]


Train_loss:  0.21861716553579966


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.76it/s]


Val_loss:  0.21500926232816223


Train: 100%|██████████| 4037/4037 [05:40<00:00, 11.84it/s]


Train_loss:  0.20314543456764195


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.75it/s]


Val_loss:  0.20362110680614587


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.89it/s]


Train_loss:  0.1916832339682935


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.83it/s]


Val_loss:  0.19587612678659935


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.97it/s]


Train_loss:  0.18275625199689952


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.89it/s]


Val_loss:  0.18962846445540513


Train: 100%|██████████| 4037/4037 [05:36<00:00, 11.99it/s]


Train_loss:  0.17544418434172998


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.89it/s]


Val_loss:  0.18520347217801336


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.97it/s]


Train_loss:  0.16938922520883604


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.75it/s]


Val_loss:  0.18117140864115028


Train: 100%|██████████| 4037/4037 [05:41<00:00, 11.81it/s]


Train_loss:  0.16412780940155053


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.78it/s]


Val_loss:  0.17690818227215482


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.87it/s]


Train_loss:  0.15965370124445247


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.79it/s]


Val_loss:  0.17431106884771327


Train: 100%|██████████| 4037/4037 [05:42<00:00, 11.78it/s]


Train_loss:  0.15565026496071213


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.62it/s]


Val_loss:  0.17244051855013945


Train: 100%|██████████| 4037/4037 [05:42<00:00, 11.77it/s]


Train_loss:  0.15207797375428092


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.64it/s]


Val_loss:  0.16995766960691042


Train: 100%|██████████| 4037/4037 [05:36<00:00, 11.98it/s]


Train_loss:  0.14887087742915178


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.80it/s]


Val_loss:  0.1679133106622528


Train: 100%|██████████| 4037/4037 [05:36<00:00, 11.99it/s]


Train_loss:  0.1459881292875671


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.70it/s]


Val_loss:  0.16610827169054893


Train: 100%|██████████| 4037/4037 [05:48<00:00, 11.60it/s]


Train_loss:  0.14337563407703205


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.21it/s]


Val_loss:  0.16447857782295905


Train: 100%|██████████| 4037/4037 [05:44<00:00, 11.74it/s]


Train_loss:  0.14097710774141467


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.42it/s]


Val_loss:  0.16382878049905464


Train: 100%|██████████| 4037/4037 [05:50<00:00, 11.53it/s]


Train_loss:  0.13877443665628672


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.15it/s]


Val_loss:  0.16221815170207174


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.66it/s]


Train_loss:  0.13670047763668472


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.81it/s]


Val_loss:  0.16184943478936273


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.97it/s]


Train_loss:  0.1348120746209958


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.71it/s]


Val_loss:  0.16046146499656874


In [53]:
transformer_distill_head4_layer1_attn = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       attentive_aggregation=True)

transformer_distill_head4_layer1_attn.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1_attn/best-distill-bert.pt'))

distill_embedder = transformer_distill_head4_layer1_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.13569880810035653 ms

Accuracy: 91.8
Positives Recall: 91.8
Positives Precision: 90.6
Positives F1: 91.2
Distance:  0.38
Max cluster size:  306
Median cluster size:  2
Avg cluster size: 4.78


In [54]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.1359656103982254 ms

Accuracy: 91.0
Positives Recall: 91.2
Positives Precision: 89.3
Positives F1: 90.3
Distance:  0.38
Max cluster size:  199
Median cluster size:  2
Avg cluster size: 4.58


#####1 layer + 4 heads / *cls aggregation*

In [23]:
transformer_distill_head4_layer1_cls = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       cls_aggregation=True)

train(transformer_distill_head4_layer1_cls, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1_cls',
      )

Train: 100%|██████████| 4037/4037 [05:42<00:00, 11.79it/s]


Train_loss:  0.491431591489802


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.70it/s]


Val_loss:  0.35244025554346686


Train: 100%|██████████| 4037/4037 [05:38<00:00, 11.92it/s]


Train_loss:  0.3349857050978208


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.84it/s]


Val_loss:  0.29468605871426395


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.89it/s]


Train_loss:  0.2938417709894339


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.75it/s]


Val_loss:  0.2665222434877744


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.97it/s]


Train_loss:  0.26955799646018824


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.70it/s]


Val_loss:  0.2476458974322711


Train: 100%|██████████| 4037/4037 [05:36<00:00, 12.00it/s]


Train_loss:  0.2526247099851296


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.63it/s]


Val_loss:  0.23361831066737831


Train: 100%|██████████| 4037/4037 [05:40<00:00, 11.85it/s]


Train_loss:  0.23987056068719143


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.66it/s]


Val_loss:  0.22397316066408762


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.90it/s]


Train_loss:  0.22987425666323638


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.87it/s]


Val_loss:  0.21562161776814054


Train: 100%|██████████| 4037/4037 [05:40<00:00, 11.84it/s]


Train_loss:  0.22168511361819568


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.72it/s]


Val_loss:  0.20945762158829104


Train: 100%|██████████| 4037/4037 [05:41<00:00, 11.82it/s]


Train_loss:  0.21485397449965563


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.73it/s]


Val_loss:  0.20384444651866196


Train: 100%|██████████| 4037/4037 [05:40<00:00, 11.87it/s]


Train_loss:  0.20898931706126395


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.66it/s]


Val_loss:  0.1994924415198577


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.90it/s]


Train_loss:  0.20407689525603206


Val: 100%|██████████| 1010/1010 [01:17<00:00, 12.97it/s]


Val_loss:  0.19537586595128753


Train: 100%|██████████| 4037/4037 [05:34<00:00, 12.06it/s]


Train_loss:  0.1995620015825043


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.93it/s]


Val_loss:  0.19192220787743383


Train: 100%|██████████| 4037/4037 [05:34<00:00, 12.08it/s]


Train_loss:  0.1957193603831234


Val: 100%|██████████| 1010/1010 [01:17<00:00, 12.95it/s]


Val_loss:  0.18907096185172734


Train: 100%|██████████| 4037/4037 [05:33<00:00, 12.11it/s]


Train_loss:  0.19218730623138397


Val: 100%|██████████| 1010/1010 [01:17<00:00, 12.96it/s]


Val_loss:  0.18699234709887927


Train: 100%|██████████| 4037/4037 [05:41<00:00, 11.82it/s]


Train_loss:  0.18912893402969325


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.74it/s]


Val_loss:  0.18493519731495714


Train: 100%|██████████| 4037/4037 [05:40<00:00, 11.84it/s]


Train_loss:  0.18625312190875998


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.86it/s]


Val_loss:  0.1827193305520791


Train: 100%|██████████| 4037/4037 [05:33<00:00, 12.12it/s]


Train_loss:  0.18384044724969037


Val: 100%|██████████| 1010/1010 [01:17<00:00, 13.09it/s]


Val_loss:  0.1808362119163064


Train: 100%|██████████| 4037/4037 [05:29<00:00, 12.24it/s]


Train_loss:  0.18137261504333493


Val: 100%|██████████| 1010/1010 [01:17<00:00, 13.01it/s]


Val_loss:  0.17926175763936908


Train: 100%|██████████| 4037/4037 [05:32<00:00, 12.14it/s]


Train_loss:  0.17922938824924187


Val: 100%|██████████| 1010/1010 [01:17<00:00, 12.97it/s]


Val_loss:  0.17801908179671097


Train: 100%|██████████| 4037/4037 [05:33<00:00, 12.09it/s]


Train_loss:  0.17717134800717155


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.88it/s]


Val_loss:  0.17653300541161046


In [57]:
transformer_distill_head4_layer1_cls = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       cls_aggregation=True)

transformer_distill_head4_layer1_cls.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1_cls/best-distill-bert.pt'))

distill_embedder = transformer_distill_head4_layer1_cls.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.12856528292376682 ms

Accuracy: 90.2
Positives Recall: 90.5
Positives Precision: 88.6
Positives F1: 89.5
Distance:  0.38
Max cluster size:  361
Median cluster size:  2
Avg cluster size: 6.11


In [58]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.12428928877039339 ms

Accuracy: 90.1
Positives Recall: 90.5
Positives Precision: 88.2
Positives F1: 89.3
Distance:  0.38
Max cluster size:  190
Median cluster size:  2.0
Avg cluster size: 5.72


In [None]:
transformer_distill_head4_layer1_cls = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       cls_aggregation=True)

transformer_distill_head4_layer1_cls.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1_cls/best-distill-bert.pt'))

#####1 layer + 4 heads + 64 emb dim

In [29]:
transformer_distill_head4_layer1_64 = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       word_emb_dim=64
                                                       )
train(transformer_distill_head4_layer1_64, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1_wed64',
      )

Train: 100%|██████████| 4037/4037 [05:29<00:00, 12.24it/s]


Train_loss:  0.5188354424144475


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.90it/s]


Val_loss:  0.39015187522214795


Train: 100%|██████████| 4037/4037 [05:29<00:00, 12.26it/s]


Train_loss:  0.3507154929914381


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.93it/s]


Val_loss:  0.32260611202815886


Train: 100%|██████████| 4037/4037 [05:29<00:00, 12.26it/s]


Train_loss:  0.30437562159653425


Val: 100%|██████████| 1010/1010 [01:17<00:00, 12.96it/s]


Val_loss:  0.2929503893027192


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.29it/s]


Train_loss:  0.2790954299309424


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.90it/s]


Val_loss:  0.27486100027192306


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.28it/s]


Train_loss:  0.26245682785564856


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.86it/s]


Val_loss:  0.2625946720122975


Train: 100%|██████████| 4037/4037 [05:30<00:00, 12.21it/s]


Train_loss:  0.2503752667487758


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.84it/s]


Val_loss:  0.2540086392897995


Train: 100%|██████████| 4037/4037 [05:30<00:00, 12.20it/s]


Train_loss:  0.24111266725760347


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.88it/s]


Val_loss:  0.24677531382050424


Train: 100%|██████████| 4037/4037 [05:30<00:00, 12.22it/s]


Train_loss:  0.23367913458990142


Val: 100%|██████████| 1010/1010 [01:17<00:00, 12.99it/s]


Val_loss:  0.24122218095627526


Train: 100%|██████████| 4037/4037 [05:30<00:00, 12.23it/s]


Train_loss:  0.22754259058436946


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.91it/s]


Val_loss:  0.2369016953850786


Train: 100%|██████████| 4037/4037 [05:30<00:00, 12.20it/s]


Train_loss:  0.22233392192757562


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.87it/s]


Val_loss:  0.23382991828097952


Train: 100%|██████████| 4037/4037 [05:31<00:00, 12.16it/s]


Train_loss:  0.21785123197521966


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.78it/s]


Val_loss:  0.22998901695920282


Train: 100%|██████████| 4037/4037 [05:33<00:00, 12.11it/s]


Train_loss:  0.21401027782320262


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.76it/s]


Val_loss:  0.226732873709015


Train: 100%|██████████| 4037/4037 [05:32<00:00, 12.14it/s]


Train_loss:  0.21058696413948394


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.83it/s]


Val_loss:  0.22480940858116427


Train: 100%|██████████| 4037/4037 [05:31<00:00, 12.18it/s]


Train_loss:  0.20758638579708127


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.84it/s]


Val_loss:  0.2222707214160345


Train: 100%|██████████| 4037/4037 [05:32<00:00, 12.14it/s]


Train_loss:  0.20493517512424178


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.86it/s]


Val_loss:  0.22025261556146444


Train: 100%|██████████| 4037/4037 [05:31<00:00, 12.18it/s]


Train_loss:  0.20253218243263083


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.78it/s]


Val_loss:  0.21916329630784648


Train: 100%|██████████| 4037/4037 [05:31<00:00, 12.17it/s]


Train_loss:  0.20034350188894795


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.85it/s]


Val_loss:  0.2177957042561901


Train: 100%|██████████| 4037/4037 [05:31<00:00, 12.18it/s]


Train_loss:  0.19839127557235303


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.87it/s]


Val_loss:  0.21616606569610108


Train: 100%|██████████| 4037/4037 [05:32<00:00, 12.15it/s]


Train_loss:  0.19654664524095067


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.78it/s]


Val_loss:  0.21495729795831836


Train: 100%|██████████| 4037/4037 [05:31<00:00, 12.16it/s]


Train_loss:  0.19491826730805864


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.69it/s]


Val_loss:  0.21414851585557726


In [39]:
transformer_distill_head4_layer1_64 = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=1,
                                                       word_emb_dim=64)

transformer_distill_head4_layer1_64.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer1_wed64/best-distill-bert.pt'))

distill_embedder = transformer_distill_head4_layer1_64.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.20702503746842957 ms

Accuracy: 90.8
Positives Recall: 90.2
Positives Precision: 89.8
Positives F1: 90.0
Distance:  0.38
Max cluster size:  365
Median cluster size:  2
Avg cluster size: 5.30


In [40]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.19745498281994472 ms

Accuracy: 91.3
Positives Recall: 91.7
Positives Precision: 89.6
Positives F1: 90.6
Distance:  0.38
Max cluster size:  181
Median cluster size:  2
Avg cluster size: 5.08


#####2 layers + 8 heads

In [32]:
transformer_distill_head8_layer2 = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=8,
                                                       n_layers=2,
                                                       )
train(transformer_distill_head8_layer2, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head8_layer2',
      )

Train: 100%|██████████| 4037/4037 [08:31<00:00,  7.89it/s]


Train_loss:  0.42114454950168706


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.14it/s]


Val_loss:  0.2922313662327121


Train: 100%|██████████| 4037/4037 [08:34<00:00,  7.84it/s]


Train_loss:  0.2607660341067526


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.09it/s]


Val_loss:  0.23933553595866874


Train: 100%|██████████| 4037/4037 [08:36<00:00,  7.81it/s]


Train_loss:  0.2215384408051676


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.15it/s]


Val_loss:  0.21445117847448775


Train: 100%|██████████| 4037/4037 [08:34<00:00,  7.84it/s]


Train_loss:  0.20020890434499997


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.22it/s]


Val_loss:  0.19992300843315197


Train: 100%|██████████| 4037/4037 [08:34<00:00,  7.84it/s]


Train_loss:  0.18616550231096918


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.20it/s]


Val_loss:  0.19056416111765842


Train: 100%|██████████| 4037/4037 [08:36<00:00,  7.81it/s]


Train_loss:  0.17601345238896007


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.23it/s]


Val_loss:  0.1835209539986089


Train: 100%|██████████| 4037/4037 [08:34<00:00,  7.84it/s]


Train_loss:  0.16812897141152994


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.21it/s]


Val_loss:  0.1777250925111573


Train: 100%|██████████| 4037/4037 [08:34<00:00,  7.84it/s]


Train_loss:  0.16183223249750814


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.18it/s]


Val_loss:  0.17365015491920152


Train: 100%|██████████| 4037/4037 [08:36<00:00,  7.81it/s]


Train_loss:  0.15653891697082403


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.11it/s]


Val_loss:  0.17012697410248376


Train: 100%|██████████| 4037/4037 [08:33<00:00,  7.86it/s]


Train_loss:  0.15209558672134454


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.24it/s]


Val_loss:  0.16668555149134512


Train: 100%|██████████| 4037/4037 [08:32<00:00,  7.87it/s]


Train_loss:  0.14826475330554992


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.17it/s]


Val_loss:  0.16475818878344148


Train: 100%|██████████| 4037/4037 [08:36<00:00,  7.82it/s]


Train_loss:  0.14483775922987605


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.90it/s]


Val_loss:  0.16244210306648718


Train: 100%|██████████| 4037/4037 [08:35<00:00,  7.83it/s]


Train_loss:  0.14180645839081968


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.94it/s]


Val_loss:  0.16068589483543513


Train: 100%|██████████| 4037/4037 [08:35<00:00,  7.83it/s]


Train_loss:  0.13908466979593118


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.19it/s]


Val_loss:  0.1594172001944543


Train: 100%|██████████| 4037/4037 [08:35<00:00,  7.83it/s]


Train_loss:  0.13665720881574944


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.91it/s]


Val_loss:  0.15771765876461555


Train: 100%|██████████| 4037/4037 [08:36<00:00,  7.81it/s]


Train_loss:  0.13433298159297682


Val: 100%|██████████| 1010/1010 [01:27<00:00, 11.56it/s]


Val_loss:  0.15686898986561856


Train: 100%|██████████| 4037/4037 [08:36<00:00,  7.81it/s]


Train_loss:  0.13227857945544894


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.16it/s]


Val_loss:  0.15623169585002591


Train: 100%|██████████| 4037/4037 [08:36<00:00,  7.82it/s]


Train_loss:  0.13032901703346575


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.62it/s]


Val_loss:  0.15475896671190023


Train: 100%|██████████| 4037/4037 [08:37<00:00,  7.81it/s]


Train_loss:  0.12853990761635883


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.97it/s]


Val_loss:  0.15405414835392742


Train: 100%|██████████| 4037/4037 [08:35<00:00,  7.83it/s]


Train_loss:  0.1268735026272659


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.22it/s]


Val_loss:  0.15324034990797192


In [59]:
transformer_distill_head8_layer2 = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=8,
                                                       n_layers=2,
                                                       )

transformer_distill_head8_layer2.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head8_layer2/best-distill-bert.pt'))

distill_embedder = transformer_distill_head8_layer2.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.2727834535309993 ms

Accuracy: 91.4
Positives Recall: 91.8
Positives Precision: 89.8
Positives F1: 90.8
Distance:  0.38
Max cluster size:  291
Median cluster size:  2.0
Avg cluster size: 5.44


In [60]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.2726942085800283 ms

Accuracy: 91.3
Positives Recall: 91.2
Positives Precision: 90.0
Positives F1: 90.6
Distance:  0.38
Max cluster size:  225
Median cluster size:  2
Avg cluster size: 5.23


#####1 layer + 8 heads + relu

In [35]:
transformer_distill_head8_layer1_relu = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=8,
                                                       n_layers=1,
                                                       relu=True,
                                                       )
train(transformer_distill_head8_layer1_relu, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head8_layer1_relu',
      )

Train: 100%|██████████| 4037/4037 [05:49<00:00, 11.55it/s]


Train_loss:  0.4650929518486156


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.56it/s]


Val_loss:  0.3346596614120308


Train: 100%|██████████| 4037/4037 [05:49<00:00, 11.56it/s]


Train_loss:  0.29759994574028276


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.52it/s]


Val_loss:  0.2747519942360383


Train: 100%|██████████| 4037/4037 [05:49<00:00, 11.55it/s]


Train_loss:  0.25340149429550585


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.62it/s]


Val_loss:  0.24610627128575943


Train: 100%|██████████| 4037/4037 [05:50<00:00, 11.51it/s]


Train_loss:  0.22839684549084038


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.57it/s]


Val_loss:  0.22851675272914526


Train: 100%|██████████| 4037/4037 [05:48<00:00, 11.59it/s]


Train_loss:  0.21142112930323015


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.62it/s]


Val_loss:  0.21643118606080752


Train: 100%|██████████| 4037/4037 [05:49<00:00, 11.54it/s]


Train_loss:  0.1987896769541267


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.31it/s]


Val_loss:  0.20701981655667748


Train: 100%|██████████| 4037/4037 [05:58<00:00, 11.27it/s]


Train_loss:  0.18890492372172285


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.40it/s]


Val_loss:  0.19990708589651812


Train: 100%|██████████| 4037/4037 [05:51<00:00, 11.47it/s]


Train_loss:  0.1808396222784974


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.19it/s]


Val_loss:  0.1944303536168238


Train: 100%|██████████| 4037/4037 [06:04<00:00, 11.07it/s]


Train_loss:  0.17414103563930253


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.94it/s]


Val_loss:  0.18928139048963954


Train: 100%|██████████| 4037/4037 [05:53<00:00, 11.42it/s]


Train_loss:  0.16844384075903135


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.26it/s]


Val_loss:  0.18496114105794081


Train: 100%|██████████| 4037/4037 [06:08<00:00, 10.94it/s]


Train_loss:  0.16352064550208148


Val: 100%|██████████| 1010/1010 [01:26<00:00, 11.74it/s]


Val_loss:  0.18183226409890885


Train: 100%|██████████| 4037/4037 [05:52<00:00, 11.46it/s]


Train_loss:  0.15920576726710786


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.50it/s]


Val_loss:  0.17910775058604853


Train: 100%|██████████| 4037/4037 [05:54<00:00, 11.38it/s]


Train_loss:  0.1554059659677937


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.32it/s]


Val_loss:  0.17628055370957843


Train: 100%|██████████| 4037/4037 [05:51<00:00, 11.48it/s]


Train_loss:  0.1519769490039722


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.68it/s]


Val_loss:  0.1745739169353514


Train: 100%|██████████| 4037/4037 [05:51<00:00, 11.49it/s]


Train_loss:  0.1489231468892915


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.51it/s]


Val_loss:  0.17235424020486168


Train: 100%|██████████| 4037/4037 [05:48<00:00, 11.59it/s]


Train_loss:  0.1461135212075615


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.63it/s]


Val_loss:  0.17080616027960802


Train: 100%|██████████| 4037/4037 [05:49<00:00, 11.55it/s]


Train_loss:  0.14353603871525025


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.64it/s]


Val_loss:  0.16909987506008567


Train: 100%|██████████| 4037/4037 [05:50<00:00, 11.53it/s]


Train_loss:  0.1411694608999659


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.65it/s]


Val_loss:  0.16765783179502008


Train: 100%|██████████| 4037/4037 [05:48<00:00, 11.58it/s]


Train_loss:  0.13898090948752548


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.68it/s]


Val_loss:  0.16687024167734238


Train: 100%|██████████| 4037/4037 [05:48<00:00, 11.57it/s]


Train_loss:  0.13694441578541755


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.59it/s]


Val_loss:  0.165658943785824


In [67]:
transformer_distill_head8_layer1_relu = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=8,
                                                       n_layers=1,
                                                       relu=True,                                                      
                                                       )

transformer_distill_head8_layer1_relu.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head8_layer1_relu/best-distill-bert.pt'))

distill_embedder = transformer_distill_head8_layer1_relu.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.15777401268719676 ms

Accuracy: 91.6
Positives Recall: 91.6
Positives Precision: 90.3
Positives F1: 90.9
Distance:  0.38
Max cluster size:  310
Median cluster size:  2.0
Avg cluster size: 4.97


In [68]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.15784077583544914 ms

Accuracy: 91.0
Positives Recall: 91.1
Positives Precision: 89.5
Positives F1: 90.3
Distance:  0.38
Max cluster size:  168
Median cluster size:  2.0
Avg cluster size: 4.84


#####2 layers + 4 heads 

In [13]:
transformer_distill_head4_layer2 = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=2,
                                                       )
train(transformer_distill_head4_layer2, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer2',
      )

Train: 100%|██████████| 4037/4037 [06:08<00:00, 10.95it/s]


Train_loss:  0.42739984548495413


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.04it/s]


Val_loss:  0.2937647733935266


Train: 100%|██████████| 4037/4037 [06:05<00:00, 11.04it/s]


Train_loss:  0.2615467785680533


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.05it/s]


Val_loss:  0.23866393238886555


Train: 100%|██████████| 4037/4037 [06:06<00:00, 11.01it/s]


Train_loss:  0.22170057986177627


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.11it/s]


Val_loss:  0.21374852739311


Train: 100%|██████████| 4037/4037 [06:05<00:00, 11.05it/s]


Train_loss:  0.20012524644305865


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.94it/s]


Val_loss:  0.19944897489701302


Train: 100%|██████████| 4037/4037 [06:06<00:00, 11.03it/s]


Train_loss:  0.1859345819592193


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.34it/s]


Val_loss:  0.1890412324187126


Train: 100%|██████████| 4037/4037 [06:05<00:00, 11.05it/s]


Train_loss:  0.17556916037682332


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.02it/s]


Val_loss:  0.18192116115385262


Train: 100%|██████████| 4037/4037 [06:05<00:00, 11.04it/s]


Train_loss:  0.16763599473962984


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.28it/s]


Val_loss:  0.1767835551991157


Train: 100%|██████████| 4037/4037 [06:07<00:00, 10.98it/s]


Train_loss:  0.1613047862128786


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.25it/s]


Val_loss:  0.17329829127258847


Train: 100%|██████████| 4037/4037 [06:09<00:00, 10.94it/s]


Train_loss:  0.1560477219937583


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.05it/s]


Val_loss:  0.16965018652463934


Train: 100%|██████████| 4037/4037 [05:57<00:00, 11.28it/s]


Train_loss:  0.1516026613914975


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.32it/s]


Val_loss:  0.16631120920153017


Train: 100%|██████████| 4037/4037 [05:58<00:00, 11.26it/s]


Train_loss:  0.1476894808407366


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.30it/s]


Val_loss:  0.16420149180790305


Train: 100%|██████████| 4037/4037 [06:09<00:00, 10.91it/s]


Train_loss:  0.14430317637691725


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.14it/s]


Val_loss:  0.16242769504215895


Train: 100%|██████████| 4037/4037 [06:08<00:00, 10.94it/s]


Train_loss:  0.14125980273616806


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.29it/s]


Val_loss:  0.1612841045535851


Train: 100%|██████████| 4037/4037 [06:01<00:00, 11.17it/s]


Train_loss:  0.1385855262389009


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.30it/s]


Val_loss:  0.15911655641821873


Train: 100%|██████████| 4037/4037 [05:59<00:00, 11.23it/s]


Train_loss:  0.13610504682507993


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.10it/s]


Val_loss:  0.15783809992469536


Train: 100%|██████████| 4037/4037 [06:05<00:00, 11.03it/s]


Train_loss:  0.13382298806740206


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.05it/s]


Val_loss:  0.1567463035158255


Train: 100%|██████████| 4037/4037 [06:05<00:00, 11.04it/s]


Train_loss:  0.1317643010623369


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.14it/s]


Val_loss:  0.15551232157153544


Train: 100%|██████████| 4037/4037 [06:03<00:00, 11.12it/s]


Train_loss:  0.12985692816415312


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.21it/s]


Val_loss:  0.1554381378173633


Train: 100%|██████████| 4037/4037 [06:06<00:00, 11.02it/s]


Train_loss:  0.12808567203173626


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.06it/s]


Val_loss:  0.15410583476138154


Train: 100%|██████████| 4037/4037 [06:06<00:00, 11.01it/s]


Train_loss:  0.12636096842735567


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.11it/s]


Val_loss:  0.15253579891621075


In [14]:
transformer_distill_head4_layer2 = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=2,
                                                       )

transformer_distill_head4_layer2.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer2/best-distill-bert.pt'))

distill_embedder = transformer_distill_head4_layer2.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.22640524771346515 ms

Accuracy: 91.6
Positives Recall: 91.7
Positives Precision: 90.2
Positives F1: 91.0
Distance:  0.38
Max cluster size:  316
Median cluster size:  2.0
Avg cluster size: 5.28


In [15]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.22489633355921765 ms

Accuracy: 90.8
Positives Recall: 90.6
Positives Precision: 89.5
Positives F1: 90.1
Distance:  0.38
Max cluster size:  165
Median cluster size:  2
Avg cluster size: 5.09


#####2 layers + 4 heads / *attentive aggregation*

In [16]:
transformer_distill_head4_layer2_attn = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=2,
                                                       attentive_aggregation=True,
                                                       )
train(transformer_distill_head4_layer2_attn, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer2_attn',
      )

Train: 100%|██████████| 4037/4037 [05:52<00:00, 11.47it/s]


Train_loss:  0.4218944529727063


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.22it/s]


Val_loss:  0.2920891208366476


Train: 100%|██████████| 4037/4037 [05:55<00:00, 11.36it/s]


Train_loss:  0.2641610589941736


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.30it/s]


Val_loss:  0.24071715659445356


Train: 100%|██████████| 4037/4037 [05:55<00:00, 11.35it/s]


Train_loss:  0.22613368807279183


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.23it/s]


Val_loss:  0.21663847311064016


Train: 100%|██████████| 4037/4037 [05:57<00:00, 11.31it/s]


Train_loss:  0.20485981628018185


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.27it/s]


Val_loss:  0.20279745598381949


Train: 100%|██████████| 4037/4037 [05:59<00:00, 11.23it/s]


Train_loss:  0.19055974600411535


Val: 100%|██████████| 1010/1010 [01:23<00:00, 12.14it/s]


Val_loss:  0.19323990026571808


Train: 100%|██████████| 4037/4037 [05:57<00:00, 11.29it/s]


Train_loss:  0.18004392161656876


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.32it/s]


Val_loss:  0.18589215364679496


Train: 100%|██████████| 4037/4037 [05:58<00:00, 11.27it/s]


Train_loss:  0.17179707350835413


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.24it/s]


Val_loss:  0.17976187977831512


Train: 100%|██████████| 4037/4037 [05:57<00:00, 11.30it/s]


Train_loss:  0.16509370247953514


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.36it/s]


Val_loss:  0.1752815906077699


Train: 100%|██████████| 4037/4037 [05:55<00:00, 11.36it/s]


Train_loss:  0.15937684481003142


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.32it/s]


Val_loss:  0.17178399870969224


Train: 100%|██████████| 4037/4037 [05:55<00:00, 11.35it/s]


Train_loss:  0.15454691529769093


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.36it/s]


Val_loss:  0.1690123796395954


Train: 100%|██████████| 4037/4037 [05:55<00:00, 11.36it/s]


Train_loss:  0.1503398184602872


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.32it/s]


Val_loss:  0.16660602811716396


Train: 100%|██████████| 4037/4037 [05:56<00:00, 11.33it/s]


Train_loss:  0.14658231832224225


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.41it/s]


Val_loss:  0.16413774614779164


Train: 100%|██████████| 4037/4037 [05:55<00:00, 11.35it/s]


Train_loss:  0.14326654598400365


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.38it/s]


Val_loss:  0.16170634645106832


Train: 100%|██████████| 4037/4037 [05:55<00:00, 11.35it/s]


Train_loss:  0.14028692893557387


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.28it/s]


Val_loss:  0.16033803351817105


Train: 100%|██████████| 4037/4037 [05:55<00:00, 11.35it/s]


Train_loss:  0.13760309564521656


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.36it/s]


Val_loss:  0.15919674179978752


Train: 100%|██████████| 4037/4037 [05:53<00:00, 11.41it/s]


Train_loss:  0.13514770305822724


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.31it/s]


Val_loss:  0.1574864678360912


Train: 100%|██████████| 4037/4037 [05:54<00:00, 11.37it/s]


Train_loss:  0.13295645291826247


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.23it/s]


Val_loss:  0.15676356754839949


Train: 100%|██████████| 4037/4037 [05:55<00:00, 11.36it/s]


Train_loss:  0.13085371637681553


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.41it/s]


Val_loss:  0.15567637355011138


Train: 100%|██████████| 4037/4037 [05:56<00:00, 11.31it/s]


Train_loss:  0.12897660260061355


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.35it/s]


Val_loss:  0.15465843209999788


Train: 100%|██████████| 4037/4037 [05:54<00:00, 11.40it/s]


Train_loss:  0.12718127146557026


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.38it/s]


Val_loss:  0.1535096898503754


In [63]:
transformer_distill_head4_layer2_attn = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=2,
                                                       attentive_aggregation=True,
                                                       )

transformer_distill_head4_layer2_attn.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer2_attn/best-distill-bert.pt'))

distill_embedder = transformer_distill_head4_layer2_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.2528980350589372 ms

Accuracy: 91.4
Positives Recall: 91.4
Positives Precision: 90.0
Positives F1: 90.7
Distance:  0.38
Max cluster size:  316
Median cluster size:  2
Avg cluster size: 5.43


In [64]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.22896425135099693 ms

Accuracy: 90.6
Positives Recall: 89.8
Positives Precision: 89.7
Positives F1: 89.7
Distance:  0.38
Max cluster size:  167
Median cluster size:  2
Avg cluster size: 5.18


#####2 layers + 4 heads / *cls aggregation*

In [None]:
transformer_distill_head4_layer2_cls = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=2,
                                                       cls_aggregation=True,
                                                       )
train(transformer_distill_head4_layer2_cls, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer2_cls',
      )

In [65]:
transformer_distill_head4_layer2_cls = TransformerClusteringDistillModel(evaluate_similarity=True,
                                                       n_heads=4,
                                                       n_layers=2,
                                                       cls_aggregation=True,
                                                       )

transformer_distill_head4_layer2_cls.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Transformer_head4_layer2_cls/best-distill-bert.pt'))

distill_embedder = transformer_distill_head4_layer2_cls.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.21704012896435193 ms

Accuracy: 90.4
Positives Recall: 89.9
Positives Precision: 89.4
Positives F1: 89.6
Distance:  0.38
Max cluster size:  261
Median cluster size:  2.0
Avg cluster size: 5.97


In [66]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.217195659769745 ms

Accuracy: 90.5
Positives Recall: 90.2
Positives Precision: 89.2
Positives F1: 89.7
Distance:  0.38
Max cluster size:  160
Median cluster size:  2.0
Avg cluster size: 5.70


####GRU

#####Zero dropout + deleted linear

In [None]:
bidir_gru_distill_dropout0_del_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              dropout=0, evaluate_similarity=True,
                              del_linear=True, bidirectional=True,
                              to_gru=True)


train(bidir_gru_distill_dropout0_del_linear, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_dropout0_del_linear',
      )

In [76]:
bidir_gru_distill_dropout0_del_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              dropout=0, evaluate_similarity=True,
                              del_linear=True, bidirectional=True,
                              to_gru=True)

bidir_gru_distill_dropout0_del_linear.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_dropout0_del_linear/best-distill-bert.pt'))

distill_embedder = bidir_gru_distill_dropout0_del_linear.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.6163782242759765 ms

Accuracy: 93.8
Positives Recall: 94.5
Positives Precision: 92.3
Positives F1: 93.4
Distance:  0.38
Max cluster size:  269
Median cluster size:  2
Avg cluster size: 3.86


In [77]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.693538889295145 ms

Accuracy: 93.7
Positives Recall: 94.7
Positives Precision: 91.7
Positives F1: 93.2
Distance:  0.38
Max cluster size:  166
Median cluster size:  2.0
Avg cluster size: 3.71


#####Zero dropout + deleted linear + *attentive aggregation*

In [None]:
bidir_gru_distill_dropout0_del_linear_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              dropout=0, evaluate_similarity=True,
                              del_linear=True, bidirectional=True,
                              to_gru=True, attentive_aggregation=True)


train(bidir_gru_distill_dropout0_del_linear_attn, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_dropout0_del_linear_attn',
      )

In [85]:
bidir_gru_distill_dropout0_del_linear_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              dropout=0, evaluate_similarity=True,
                              del_linear=True, bidirectional=True,
                              to_gru=True,
                              attentive_aggregation=True)

bidir_gru_distill_dropout0_del_linear_attn.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_dropout0_del_linear_attn/best-distill-bert.pt'))

distill_embedder = bidir_gru_distill_dropout0_del_linear_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.6399877704234712 ms

Accuracy: 94.1
Positives Recall: 95.1
Positives Precision: 92.4
Positives F1: 93.7
Distance:  0.38
Max cluster size:  260
Median cluster size:  2
Avg cluster size: 3.83


In [86]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.6218457223982505 ms

Accuracy: 93.2
Positives Recall: 94.7
Positives Precision: 90.9
Positives F1: 92.7
Distance:  0.38
Max cluster size:  170
Median cluster size:  2
Avg cluster size: 3.67


#####Zero dropout / GRU with no linear layers

In [None]:
bidir_gru_distill_dropout0_del_2_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, lstm_hidden_dim=(TARGET_SIZE//2),
                              del_2_linear=True, evaluate_similarity=True,
                              to_gru=True,
                              )

train(bidir_gru_distill_dropout0_del_2_linear, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_no_linear_0_dropout',
      )

In [79]:
bidir_gru_distill_dropout0_del_2_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, lstm_hidden_dim=(TARGET_SIZE//2),
                              del_2_linear=True, evaluate_similarity=True,
                              to_gru=True,
                              )

bidir_gru_distill_dropout0_del_2_linear.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_no_linear_0_dropout/best-distill-bert.pt'))

distill_embedder = bidir_gru_distill_dropout0_del_2_linear.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.6090390427060336 ms

Accuracy: 93.7
Positives Recall: 94.4
Positives Precision: 92.2
Positives F1: 93.3
Distance:  0.38
Max cluster size:  253
Median cluster size:  1.0
Avg cluster size: 3.61


In [80]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.6130435019669095 ms

Accuracy: 93.1
Positives Recall: 94.7
Positives Precision: 90.6
Positives F1: 92.6
Distance:  0.38
Max cluster size:  160
Median cluster size:  1
Avg cluster size: 3.47


#####Zero dropout / GRU with no linear layers + *attentive aggregation*

In [None]:
bidir_gru_distill_dropout0_del_2_linear_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, lstm_hidden_dim=(TARGET_SIZE//2),
                              del_2_linear=True, evaluate_similarity=True,
                              to_gru=True, attentive_aggregation=True,
                              )

train(bidir_gru_distill_dropout0_del_2_linear_attn, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_no_linear_0_dropout_attn',
      )

In [87]:
bidir_gru_distill_dropout0_del_2_linear_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, lstm_hidden_dim=(TARGET_SIZE//2),
                              del_2_linear=True, evaluate_similarity=True,
                              to_gru=True, attentive_aggregation=True,
                              )

bidir_gru_distill_dropout0_del_2_linear_attn.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_no_linear_0_dropout_attn/best-distill-bert.pt'))

distill_embedder = bidir_gru_distill_dropout0_del_2_linear_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.6164026565760731 ms

Accuracy: 93.6
Positives Recall: 94.4
Positives Precision: 92.0
Positives F1: 93.2
Distance:  0.38
Max cluster size:  266
Median cluster size:  2.0
Avg cluster size: 3.85


In [88]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.6168262001156657 ms

Accuracy: 93.6
Positives Recall: 95.2
Positives Precision: 91.2
Positives F1: 93.1
Distance:  0.38
Max cluster size:  171
Median cluster size:  2.0
Avg cluster size: 3.68


#####Zero dropout + deleted linear + 2 GRU layers

In [None]:
bidir_gru_distill_dropout0_del_2gru_layers = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, 
                              del_linear=True, evaluate_similarity=True,
                              lstm_layers_count=2,
                              to_gru=True,
                              )


train(bidir_gru_distill_dropout0_del_2gru_layers, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_2layers',
      )

In [81]:
bidir_gru_distill_dropout0_del_2gru_layers = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, 
                              del_linear=True, evaluate_similarity=True,
                              lstm_layers_count=2,
                              to_gru=True,)

bidir_gru_distill_dropout0_del_2gru_layers.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_2layers/best-distill-bert.pt'))

distill_embedder = bidir_gru_distill_dropout0_del_2gru_layers.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.1718134158635993 ms

Accuracy: 93.9
Positives Recall: 94.5
Positives Precision: 92.4
Positives F1: 93.5
Distance:  0.38
Max cluster size:  268
Median cluster size:  2.0
Avg cluster size: 4.05


In [82]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.1707488509666786 ms

Accuracy: 93.9
Positives Recall: 95.3
Positives Precision: 91.7
Positives F1: 93.4
Distance:  0.38
Max cluster size:  169
Median cluster size:  2.0
Avg cluster size: 3.93


#####Zero dropout + deleted linear + 2 GRU layers + *attentive aggregation*

In [None]:
bidir_gru_distill_dropout0_del_2gru_layers_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, 
                              del_linear=True, evaluate_similarity=True,
                              lstm_layers_count=2,
                              to_gru=True, attentive_aggregation=True,
                              )


train(bidir_gru_distill_dropout0_del_2gru_layers_attn, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_2layers_attn',
      )

In [89]:
bidir_gru_distill_dropout0_del_2gru_layers_attn = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              bidirectional=True, dropout=0, 
                              del_linear=True, evaluate_similarity=True,
                              lstm_layers_count=2,
                              to_gru=True, attentive_aggregation=True,)

bidir_gru_distill_dropout0_del_2gru_layers_attn.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_GRU_2layers_attn/best-distill-bert.pt'))

distill_embedder = bidir_gru_distill_dropout0_del_2gru_layers_attn.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, 
                                              print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  1.1716535444753577 ms

Accuracy: 94.1
Positives Recall: 95.2
Positives Precision: 92.3
Positives F1: 93.8
Distance:  0.38
Max cluster size:  276
Median cluster size:  2
Avg cluster size: 4.12


In [90]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  1.1355856287714123 ms

Accuracy: 93.5
Positives Recall: 95.0
Positives Precision: 91.2
Positives F1: 93.0
Distance:  0.38
Max cluster size:  167
Median cluster size:  2
Avg cluster size: 3.97


Experiments with other architectures (which haven't shown any competetive results)

###RNN 

In [None]:
bidir_rnn_distill_dropout0_del_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              add_relu=True, bidirectional=True, evaluate_similarity=True,
                              dropout=0, del_linear=True, to_rnn=True)


train(bidir_rnn_distill_dropout0_del_linear, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_RNN_dropout0_del_linear',
      )

Train: 100%|██████████| 4037/4037 [07:29<00:00,  8.99it/s]


Train_loss:  0.48876522100140124


Val: 100%|██████████| 1010/1010 [01:43<00:00,  9.73it/s]


Val_loss:  0.37218699242468245


Train: 100%|██████████| 4037/4037 [07:20<00:00,  9.16it/s]


Train_loss:  0.3174249038955114


Val: 100%|██████████| 1010/1010 [01:41<00:00,  9.94it/s]


Val_loss:  0.28786358887463864


Train: 100%|██████████| 4037/4037 [07:19<00:00,  9.18it/s]


Train_loss:  0.2549335571677436


Val: 100%|██████████| 1010/1010 [01:42<00:00,  9.82it/s]


Val_loss:  0.24915962389866184


Train: 100%|██████████| 4037/4037 [07:16<00:00,  9.24it/s]


Train_loss:  0.2224355379543873


Val: 100%|██████████| 1010/1010 [01:41<00:00,  9.97it/s]


Val_loss:  0.22792269723937522


Train: 100%|██████████| 4037/4037 [07:12<00:00,  9.33it/s]


Train_loss:  0.20290557337257736


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.18it/s]


Val_loss:  0.2150726260204036


Train: 100%|██████████| 4037/4037 [07:15<00:00,  9.28it/s]


Train_loss:  0.19006288048810577


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.14it/s]


Val_loss:  0.2065343217891848


Train: 100%|██████████| 4037/4037 [07:15<00:00,  9.27it/s]


Train_loss:  0.18104305914630986


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.18it/s]


Val_loss:  0.200494734688939


Train: 100%|██████████| 4037/4037 [07:15<00:00,  9.27it/s]


Train_loss:  0.1743813707319635


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.19it/s]


Val_loss:  0.19608616003919274


Train: 100%|██████████| 4037/4037 [07:15<00:00,  9.27it/s]


Train_loss:  0.16929034254359762


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.23it/s]


Val_loss:  0.1929624318308436


Train: 100%|██████████| 4037/4037 [07:08<00:00,  9.42it/s]


Train_loss:  0.16524788312971564


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.30it/s]


Val_loss:  0.19030332082381274


Train: 100%|██████████| 4037/4037 [07:10<00:00,  9.38it/s]


Train_loss:  0.16197349837111852


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.25it/s]


Val_loss:  0.18841477189552855


Train: 100%|██████████| 4037/4037 [07:12<00:00,  9.33it/s]


Train_loss:  0.15925070967412178


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.21it/s]


Val_loss:  0.186789639123179


Train: 100%|██████████| 4037/4037 [07:13<00:00,  9.32it/s]


Train_loss:  0.15694737638762155


Val: 100%|██████████| 1010/1010 [01:40<00:00, 10.08it/s]


Val_loss:  0.18542995919038854


Train: 100%|██████████| 4037/4037 [07:16<00:00,  9.24it/s]


Train_loss:  0.15497932029977338


Val: 100%|██████████| 1010/1010 [01:40<00:00, 10.05it/s]


Val_loss:  0.18425258035479547


Train: 100%|██████████| 4037/4037 [07:16<00:00,  9.24it/s]


Train_loss:  0.1532632586839722


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.10it/s]


Val_loss:  0.18341362200958455


Train: 100%|██████████| 4037/4037 [07:12<00:00,  9.34it/s]


Train_loss:  0.15174002515498566


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.10it/s]


Val_loss:  0.18252780162415566


Train: 100%|██████████| 4037/4037 [07:15<00:00,  9.27it/s]


Train_loss:  0.15041110745961056


Val: 100%|██████████| 1010/1010 [01:40<00:00, 10.06it/s]


Val_loss:  0.18218677366288896


Train: 100%|██████████| 4037/4037 [07:18<00:00,  9.22it/s]


Train_loss:  0.1491859543839345


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.19it/s]


Val_loss:  0.1814811940499293


Train: 100%|██████████| 4037/4037 [07:09<00:00,  9.39it/s]


Train_loss:  0.1480955188772538


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.28it/s]


Val_loss:  0.1809890368405701


Train: 100%|██████████| 4037/4037 [07:07<00:00,  9.45it/s]


Train_loss:  0.147093908937348


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.33it/s]


Val_loss:  0.18048854259894262


In [None]:
bidir_rnn_distill_dropout0_del_linear = SingleClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              add_relu=True, bidirectional=True, evaluate_similarity=True,
                              dropout=0, del_linear=True, to_rnn=True)

bidir_rnn_distill_dropout0_del_linear.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/Bidir_RNN_dropout0_del_linear/best-distill-bert.pt'))

distill_embedder = bidir_rnn_distill_dropout0_del_linear.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, max_tokens_count=MAX_TOKENS)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)


Accuracy: 92.0
Positives Recall: 90.8
Positives Precision: 91.8
Positives F1: 91.3
Distance:  0.38
Max cluster size:  253
Median cluster size:  1.0
Avg cluster size: 3.33


In [None]:
bidir_rnn_distill_dropout0_del_linear.parameters

In [None]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

###Linears

#####Modeling

In [25]:
class LinearEmbedder(nn.Module):
  def __init__(self, vocab_size=VOCAB_SIZE, 
               target_size=TARGET_SIZE, 
               pretrained=False,  
               pretrained_embs=None,
               freeze_pretrained=False,
               word_emb_dim=64,
               n_layers=3,
               attentive_aggregation=False,
               ):

        super().__init__()
        
        self.target_size=target_size
        self.word_emb_dim=word_emb_dim
        self.attentive_aggregation=attentive_aggregation
         
        if pretrained:
          self.embedding = nn.Embedding.from_pretrained(pretrained_embs, 
                                                        freeze=freeze_pretrained) 
          word_emb_dim = pretrained_embs.shape[1]
        else:
          self.embedding = nn.Embedding(vocab_size, word_emb_dim)

        self.linears = nn.ModuleList()
        for i in range(n_layers):
          linear = nn.Linear(word_emb_dim*(2**i), word_emb_dim*(2**(i+1)))
          self.linears.append(nn.Sequential(linear, 
                                            nn.ReLU()))

        self.mapping = nn.Linear(word_emb_dim*(2**n_layers), target_size)

        if attentive_aggregation: 
          self.softmax = nn.Softmax(dim=1)     
          self.attn = nn.Sequential(
                nn.Linear(target_size, target_size//2),
                nn.ReLU(),
                nn.Linear(target_size//2, 1)
            )
        
        
  def aggregate(self, linears, mask):
        expanded_mask = mask.cuda().unsqueeze(-1).expand(linears.size()).float()
        sum_embeddings = torch.sum(linears * expanded_mask, 1).cuda()
        sum_mask = torch.clamp(expanded_mask.sum(1), min=1e-9)
        return sum_embeddings / sum_mask   

  def attentive_aggregate(self, linears, mask):
        weights = self.softmax(self.attn(linears).squeeze(-1)) * mask
        embeddings = weights.unsqueeze(1).bmm(linears).squeeze(1)
        return embeddings 

  def forward(self, x, mask):
        x = self.embedding(x)
        for layer in self.linears:
          x = layer(x) 
        map = self.mapping(x)
        if self.attentive_aggregation: 
          out = self.attentive_aggregate(map, mask)
        else:
          out = self.aggregate(map, mask)
        return out

In [26]:
class LinearClusteringDistillModel(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, target_size=TARGET_SIZE, 
                 pretrained=False, pretrained_embs=None, 
                 freeze_pretrained=False, word_emb_dim=128, 
                 evaluate_similarity=False, 
                 lr=1e-3, n_layers=3, 
                 attentive_aggregation=False, 
                 ):
      
        super().__init__()

        self.embedder = LinearEmbedder(vocab_size=vocab_size, 
                                      target_size=target_size, 
                                      pretrained=pretrained,  
                                      pretrained_embs=pretrained_embs,
                                      freeze_pretrained=freeze_pretrained,
                                      word_emb_dim=word_emb_dim,
                                      n_layers=n_layers,
                                      attentive_aggregation=attentive_aggregation, 
                                      )

        self.evaluate_similarity = evaluate_similarity
        if not evaluate_similarity:
          self.mse = torch.nn.MSELoss()
        else:
          self.cosine_similarity = torch.nn.functional.cosine_similarity
        
        self.lr = lr

    def forward(self, news):
        embeddings = self.embedder(news["input_ids"].cuda(), 
                                   news["attention_mask"].cuda())
        return embeddings

    def loss(self, embeds, bert_embeds):
        if self.evaluate_similarity is True:
          similarity = self.cosine_similarity(embeds, bert_embeds)
          loss = torch.mean(torch.ones(len(similarity)).cuda() - similarity)
        else:
          loss = self.mse(embeds.float(), bert_embeds.float())
        return loss

#####Experiments

In [49]:
linears_3 = LinearClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              evaluate_similarity=True,
                              )


train(linears_3, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/3_linears',
      )

Train: 100%|██████████| 4037/4037 [05:45<00:00, 11.67it/s]


Train_loss:  0.4071094193074118


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.33it/s]


Val_loss:  0.26598109546674203


Train: 100%|██████████| 4037/4037 [05:45<00:00, 11.67it/s]


Train_loss:  0.24479699234966928


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.32it/s]


Val_loss:  0.2340434892200141


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.63it/s]


Train_loss:  0.2220098446737502


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.32it/s]


Val_loss:  0.220779537431112


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.65it/s]


Train_loss:  0.21004601695441363


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.38it/s]


Val_loss:  0.21258140120536798


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.64it/s]


Train_loss:  0.20217064315298172


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.37it/s]


Val_loss:  0.20727153318874544


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.65it/s]


Train_loss:  0.1965112772210689


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.34it/s]


Val_loss:  0.20325691749277097


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.64it/s]


Train_loss:  0.19205632573909692


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.42it/s]


Val_loss:  0.20033038095143643


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.65it/s]


Train_loss:  0.18851348305501092


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.32it/s]


Val_loss:  0.19767096362860911


Train: 100%|██████████| 4037/4037 [05:47<00:00, 11.61it/s]


Train_loss:  0.18556792884905937


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.33it/s]


Val_loss:  0.19529655664225967


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.64it/s]


Train_loss:  0.183078904525829


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.30it/s]


Val_loss:  0.19351632889357073


Train: 100%|██████████| 4037/4037 [05:48<00:00, 11.59it/s]


Train_loss:  0.18093277591839313


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.18it/s]


Val_loss:  0.19190739539406856


Train: 100%|██████████| 4037/4037 [05:47<00:00, 11.61it/s]


Train_loss:  0.17906091586623332


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.33it/s]


Val_loss:  0.19155231103376574


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.64it/s]


Train_loss:  0.17737003543654506


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.36it/s]


Val_loss:  0.18987461258102906


Train: 100%|██████████| 4037/4037 [05:47<00:00, 11.62it/s]


Train_loss:  0.17589092759097916


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.33it/s]


Val_loss:  0.18831433868840153


Train: 100%|██████████| 4037/4037 [05:47<00:00, 11.61it/s]


Train_loss:  0.17454729890803117


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.29it/s]


Val_loss:  0.18746572961941377


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.64it/s]


Train_loss:  0.17331284381136705


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.37it/s]


Val_loss:  0.18710760272631433


Train: 100%|██████████| 4037/4037 [05:46<00:00, 11.63it/s]


Train_loss:  0.17220911200709024


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.25it/s]


Val_loss:  0.1859443834785177


Train: 100%|██████████| 4037/4037 [05:51<00:00, 11.47it/s]


Train_loss:  0.1711864710932987


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.32it/s]


Val_loss:  0.18524363307915118


Train: 100%|██████████| 4037/4037 [05:45<00:00, 11.68it/s]


Train_loss:  0.17022676461140696


Val: 100%|██████████| 1010/1010 [01:24<00:00, 11.89it/s]


Val_loss:  0.18472086073741617


Train: 100%|██████████| 4037/4037 [05:49<00:00, 11.55it/s]


Train_loss:  0.16935822098557732


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.24it/s]


Val_loss:  0.18386346282518987


In [50]:
linears_3 = LinearClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              evaluate_similarity=True,
                              )

linears_3.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/3_linears/best-distill-bert.pt'))

distill_embedder = linears_3.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.2674005006176542 ms

Accuracy: 91.9
Positives Recall: 91.4
Positives Precision: 91.1
Positives F1: 91.2
Distance:  0.38
Max cluster size:  259
Median cluster size:  2.0
Avg cluster size: 4.16


In [54]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.2701720585121887 ms

Accuracy: 90.3
Positives Recall: 89.8
Positives Precision: 89.2
Positives F1: 89.5
Distance:  0.38
Max cluster size:  145
Median cluster size:  2.0
Avg cluster size: 3.97


In [27]:
linears_2 = LinearClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              evaluate_similarity=True,
                              n_layers=2,
                              )


train(linears_2, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/2_linears',
      )

Train: 100%|██████████| 4037/4037 [05:38<00:00, 11.91it/s]


Train_loss:  0.41829313123727235


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.41it/s]


Val_loss:  0.32115518810200716


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.89it/s]


Train_loss:  0.3072142988163436


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.40it/s]


Val_loss:  0.2998080585979854


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.89it/s]


Train_loss:  0.29097823277007057


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.43it/s]


Val_loss:  0.28920436179397413


Train: 100%|██████████| 4037/4037 [05:38<00:00, 11.94it/s]


Train_loss:  0.2809430120361233


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.38it/s]


Val_loss:  0.28261468061395045


Train: 100%|██████████| 4037/4037 [05:36<00:00, 12.01it/s]


Train_loss:  0.2734908970881064


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.40it/s]


Val_loss:  0.27610783461657207


Train: 100%|██████████| 4037/4037 [05:36<00:00, 11.99it/s]


Train_loss:  0.26745963392230365


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.43it/s]


Val_loss:  0.2714151508751661


Train: 100%|██████████| 4037/4037 [05:38<00:00, 11.92it/s]


Train_loss:  0.26232368905075343


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.45it/s]


Val_loss:  0.2672059649712704


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.98it/s]


Train_loss:  0.25794152429085265


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.40it/s]


Val_loss:  0.2637273144804226


Train: 100%|██████████| 4037/4037 [05:36<00:00, 11.98it/s]


Train_loss:  0.2540374055979511


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.44it/s]


Val_loss:  0.2606777877811925


Train: 100%|██████████| 4037/4037 [05:36<00:00, 12.00it/s]


Train_loss:  0.2506624397480465


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.37it/s]


Val_loss:  0.25805779788963834


Train: 100%|██████████| 4037/4037 [05:36<00:00, 11.98it/s]


Train_loss:  0.2475485281359791


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.43it/s]


Val_loss:  0.2552935549847571


Train: 100%|██████████| 4037/4037 [05:35<00:00, 12.03it/s]


Train_loss:  0.24472403905171955


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.39it/s]


Val_loss:  0.25327855278011213


Train: 100%|██████████| 4037/4037 [05:36<00:00, 11.99it/s]


Train_loss:  0.24217404955730815


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.39it/s]


Val_loss:  0.2511277014700165


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.97it/s]


Train_loss:  0.23979843977883983


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.42it/s]


Val_loss:  0.2495162672830236


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.96it/s]


Train_loss:  0.23762274691729277


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.32it/s]


Val_loss:  0.2476431687702263


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.98it/s]


Train_loss:  0.23556936905606418


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.44it/s]


Val_loss:  0.24560200017389244


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.98it/s]


Train_loss:  0.23365667019903505


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.54it/s]


Val_loss:  0.24426792930028518


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.98it/s]


Train_loss:  0.2318332690027368


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.41it/s]


Val_loss:  0.2428141178760914


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.96it/s]


Train_loss:  0.23011353817898095


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.34it/s]


Val_loss:  0.2415175290771544


Train: 100%|██████████| 4037/4037 [05:38<00:00, 11.93it/s]


Train_loss:  0.22851532851600412


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.38it/s]


Val_loss:  0.24032160201970298


In [28]:
linears_2 = LinearClusteringDistillModel(vocab_size = VOCAB_SIZE,
                              evaluate_similarity=True,
                              n_layers=2,
                              )

linears_2.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_Cos_Similarity/2_linears/best-distill-bert.pt'))

distill_embedder = linears_2.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS,  print_mean_timing=True)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

Mean inference time per embed:  0.11830111080966148 ms

Accuracy: 89.0
Positives Recall: 86.6
Positives Precision: 89.3
Positives F1: 87.9
Distance:  0.38
Max cluster size:  350
Median cluster size:  2
Avg cluster size: 4.61


In [29]:
private_distill_embeddings = records_to_embeds(private_set, distill_embedder, tokenizer, batch_size=8, 
                                              max_tokens_count=MAX_TOKENS, print_mean_timing=True)
get_quality(private_markup, private_distill_embeddings, private_set, 0.38, True)

Mean inference time per embed:  0.1184452903516681 ms

Accuracy: 89.0
Positives Recall: 86.8
Positives Precision: 89.0
Positives F1: 87.9
Distance:  0.38
Max cluster size:  151
Median cluster size:  2
Avg cluster size: 4.44


###Convolutions

####Modeling 




In [None]:
class ConvEmbedder1(nn.Module):
  def __init__(self, vocab_size, target_size=TARGET_SIZE, 
               word_emb_dim=128, seq_len=250,
               pretrained=False, pretrained_embs=None, 
               freeze_pretrained=False,
               ):

        super().__init__()

        if pretrained:
          self.embedding = nn.Embedding.from_pretrained(pretrained_embs, 
                                                        freeze=freeze_pretrained) 
          word_emb_dim = pretrained_embs.shape[1]
        else:
          self.embedding = nn.Embedding(vocab_size, word_emb_dim)

        

        self.model = nn.ModuleList([
          self.embedding,
          nn.Conv2d(1, 3, kernel_size=4, stride=1, padding='same'),
          nn.MaxPool2d(kernel_size=(5, 2)),
          nn.Conv2d(3, 9, kernel_size=3, stride=1, padding='same'),
          nn.MaxPool2d(kernel_size=(5, 2)),
          nn.Conv2d(9, 1, kernel_size=3, stride=1, padding='same'),
          nn.MaxPool2d(kernel_size=((seq_len//25), 1)),
          nn.Linear((word_emb_dim//4), target_size)])        
       
  def forward(self, x):
        for layer in self.model:
          if layer == self.embedding:
            x = torch.unsqueeze(layer(x), 1)
          else:
            x = layer(x)
        return torch.squeeze(x)

In [None]:
class ConvEmbedder2(nn.Module):
  def __init__(self, vocab_size, target_size=TARGET_SIZE, 
               word_emb_dim=128, seq_len=250,
               pretrained=False, pretrained_embs=None, 
               freeze_pretrained=False,
               ):

        super().__init__()

        if pretrained:
          self.embedding = nn.Embedding.from_pretrained(pretrained_embs, 
                                                        freeze=freeze_pretrained) 
          word_emb_dim = pretrained_embs.shape[1]
        else:
          self.embedding = nn.Embedding(vocab_size, word_emb_dim)

        

        self.model = nn.ModuleList([
          self.embedding,
          nn.Conv2d(1, 3, kernel_size=3, stride=1, padding='same'),
          nn.MaxPool2d(kernel_size=2),
          nn.Conv2d(3, 9, kernel_size=3, stride=1, padding='same'),
          nn.MaxPool2d(kernel_size=2),
          nn.Conv2d(9, 27, kernel_size=2, stride=1, padding='same'),
          nn.MaxPool2d(kernel_size=2),
          nn.Conv2d(27, 81, kernel_size=2, stride=1, padding='same'),
          nn.MaxPool2d(kernel_size=2),
          nn.Conv2d(81, target_size, kernel_size=3, stride=1, padding='same'),
          nn.MaxPool2d(kernel_size=(seq_len//16, word_emb_dim//16))])       
       
  def forward(self, x):
        for layer in self.model:
          if layer == self.embedding:
            x = torch.unsqueeze(layer(x), 1)
          else:
            x = layer(x)
        return torch.squeeze(x)

In [None]:
class SingleClusteringDistillModelConv(nn.Module):
    def __init__(self, vocab_size, seq_len=250, 
                 target_size=TARGET_SIZE, 
                 word_emb_dim=128, pretrained=False, 
                 pretrained_embs=None, freeze_pretrained=False, 
                 evaluate_similarity=False, lr=1e-3, 
                 conv_embedder_class=ConvEmbedder1,
                 ):
      
        super().__init__()

        self.embedder = conv_embedder_class(vocab_size=vocab_size, target_size=target_size, 
                                      word_emb_dim=word_emb_dim, seq_len=seq_len,
                                      pretrained=pretrained, pretrained_embs=pretrained_embs, 
                                      freeze_pretrained=freeze_pretrained)
        
        self.mse = torch.nn.MSELoss()
        self.evaluate_similarity = evaluate_similarity
        self.cosine_similarity = torch.nn.functional.cosine_similarity
        self.lr = lr


    def forward(self, news):
        if type(news) == dict: 
          embeddings = self.embedder(news["input_ids"].cuda())
        else:
          embeddings = self.embedder(news.cuda())
        return embeddings

    def loss(self, embeds, bert_embeds):
        if self.evaluate_similarity is True:
          similarity = self.cosine_similarity(embeds, bert_embeds)
          loss = torch.mean(torch.ones(len(similarity)).cuda() - similarity)
        else:
          loss = self.mse(embeds.float(), bert_embeds.float())
        return loss

In [None]:
train_loader, val_loader, tokenizer = get_loaders(full_train_records, full_val_records, 
                                       INITIAL_MODEL, MAX_TOKENS, BATCH_SIZE,
                                       single_full_train_embeddings_bert,
                                       single_full_val_embeddings_bert,
                                       )

Downloading:   0%|          | 0.00/1.57M [00:00<?, ?B/s]

####Experiments

In [None]:
single_conv_distill = SingleClusteringDistillModelConv(vocab_size = VOCAB_SIZE,
                              seq_len = MAX_TOKENS)


train(single_conv_distill, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_MSE/Conv',
      )

  self.padding, self.dilation, self.groups)
Train: 100%|██████████| 4037/4037 [05:32<00:00, 12.16it/s]


Train_loss:  0.003920611967467619


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.61it/s]


Val_loss:  0.0037544227258237725


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.28it/s]


Train_loss:  0.0035577418545852993


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.63it/s]


Val_loss:  0.0032121401028892046


Train: 100%|██████████| 4037/4037 [05:32<00:00, 12.13it/s]


Train_loss:  0.002971885220630056


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.79it/s]


Val_loss:  0.0027934063947529044


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.31it/s]


Train_loss:  0.002651876353321688


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.71it/s]


Val_loss:  0.0025688148032191513


Train: 100%|██████████| 4037/4037 [05:27<00:00, 12.33it/s]


Train_loss:  0.0024679132791087014


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.70it/s]


Val_loss:  0.0024370864849758917


Train: 100%|██████████| 4037/4037 [05:22<00:00, 12.52it/s]


Train_loss:  0.0023496300196618855


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.84it/s]


Val_loss:  0.002351464580578527


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.29it/s]


Train_loss:  0.0022662611084660078


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.71it/s]


Val_loss:  0.002291289387077167


Train: 100%|██████████| 4037/4037 [05:20<00:00, 12.59it/s]


Train_loss:  0.0022033817971765196


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.88it/s]


Val_loss:  0.0022444609259854595


Train: 100%|██████████| 4037/4037 [05:21<00:00, 12.54it/s]


Train_loss:  0.0021541012121848477


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.79it/s]


Val_loss:  0.002207286069446271


Train: 100%|██████████| 4037/4037 [05:26<00:00, 12.36it/s]


Train_loss:  0.00211415007530052


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.69it/s]


Val_loss:  0.0021785441394157634


Train: 100%|██████████| 4037/4037 [05:27<00:00, 12.32it/s]


Train_loss:  0.002080765761310917


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.65it/s]


Val_loss:  0.002157399174296251


Train: 100%|██████████| 4037/4037 [05:27<00:00, 12.32it/s]


Train_loss:  0.0020524711601133004


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.83it/s]


Val_loss:  0.0021352306130517384


Train: 100%|██████████| 4037/4037 [05:24<00:00, 12.45it/s]


Train_loss:  0.00202793982408245


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.73it/s]


Val_loss:  0.0021183321916378372


Train: 100%|██████████| 4037/4037 [05:25<00:00, 12.41it/s]


Train_loss:  0.0020060967431367476


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.74it/s]


Val_loss:  0.0021057799121936656


Train: 100%|██████████| 4037/4037 [05:29<00:00, 12.27it/s]


Train_loss:  0.001986415331211086


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.74it/s]


Val_loss:  0.0020928226727162406


Train: 100%|██████████| 4037/4037 [05:32<00:00, 12.12it/s]


Train_loss:  0.001968252159927778


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.64it/s]


Val_loss:  0.0020821846108699197


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.29it/s]


Train_loss:  0.0019516005292779367


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.53it/s]


Val_loss:  0.002072836722634585


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.30it/s]


Train_loss:  0.0019365555252028588


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.65it/s]


Val_loss:  0.0020601151337026444


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.28it/s]


Train_loss:  0.0019229156922713727


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.62it/s]


Val_loss:  0.0020538690350119875


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.30it/s]


Train_loss:  0.0019104887229838686


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.63it/s]


Val_loss:  0.0020512235760089415


In [None]:
single_conv_distill = SingleClusteringDistillModelConv(vocab_size = VOCAB_SIZE,
                              seq_len = MAX_TOKENS)


single_conv_distill.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_MSE/Conv/best-distill-bert.pt'))

distill_embedder = single_conv_distill.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, max_tokens_count=MAX_TOKENS)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)

TypeError: ignored

In [None]:
single_conv_distill_002 = SingleClusteringDistillModelConv(vocab_size = VOCAB_SIZE,
                              seq_len = MAX_TOKENS, lr = 1e-2)


train(single_conv_distill_002, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_MSE/Conv_lr001',
      )

Train: 100%|██████████| 4037/4037 [05:24<00:00, 12.45it/s]


Train_loss:  0.0035865528460343804


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.80it/s]


Val_loss:  0.002607358066079271


Train: 100%|██████████| 4037/4037 [05:23<00:00, 12.49it/s]


Train_loss:  0.002275413541727235


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.90it/s]


Val_loss:  0.002130809673891947


Train: 100%|██████████| 4037/4037 [05:20<00:00, 12.59it/s]


Train_loss:  0.002033228439617794


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.81it/s]


Val_loss:  0.0020392779676972122


Train: 100%|██████████| 4037/4037 [05:25<00:00, 12.42it/s]


Train_loss:  0.0019468757130308523


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.67it/s]


Val_loss:  0.0020045223212478186


Train: 100%|██████████| 4037/4037 [05:22<00:00, 12.50it/s]


Train_loss:  0.0018991945314426174


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.92it/s]


Val_loss:  0.001971082122867355


Train: 100%|██████████| 4037/4037 [05:21<00:00, 12.56it/s]


Train_loss:  0.0018665116142879748


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.65it/s]


Val_loss:  0.001960762311234185


Train: 100%|██████████| 4037/4037 [05:27<00:00, 12.32it/s]


Train_loss:  0.0018426144804970697


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.72it/s]


Val_loss:  0.0019488552975984715


Train: 100%|██████████| 4037/4037 [05:27<00:00, 12.34it/s]


Train_loss:  0.0018235844522940697


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.66it/s]


Val_loss:  0.0019493716055444341


Train: 100%|██████████| 4037/4037 [05:38<00:00, 11.94it/s]


Train_loss:  0.0018080361979036087


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.19it/s]


Val_loss:  0.0019358634025476284


Train: 100%|██████████| 4037/4037 [05:37<00:00, 11.98it/s]


Train_loss:  0.0017951207437834193


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.39it/s]


Val_loss:  0.0019315171082006158


Train: 100%|██████████| 4037/4037 [05:30<00:00, 12.21it/s]


Train_loss:  0.001783271830407628


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.70it/s]


Val_loss:  0.0019372618563546992


Train: 100%|██████████| 4037/4037 [05:32<00:00, 12.13it/s]


Train_loss:  0.0017736104656364293


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.57it/s]


Val_loss:  0.0019267921841998427


Train: 100%|██████████| 4037/4037 [05:30<00:00, 12.22it/s]


Train_loss:  0.0017637065568235687


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.77it/s]


Val_loss:  0.0019318558860081478


Train: 100%|██████████| 4037/4037 [05:28<00:00, 12.28it/s]


Train_loss:  0.001755991995963043


Val: 100%|██████████| 1010/1010 [01:18<00:00, 12.82it/s]


Val_loss:  0.0019526270378408678


Train: 100%|██████████| 4037/4037 [05:25<00:00, 12.41it/s]


Train_loss:  0.0017048817593719675


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.71it/s]


Val_loss:  0.001916637251586976


Train: 100%|██████████| 4037/4037 [05:29<00:00, 12.27it/s]


Train_loss:  0.0016948463535270818


Val: 100%|██████████| 1010/1010 [01:19<00:00, 12.68it/s]


Val_loss:  0.0019128659386233897


Train: 100%|██████████| 4037/4037 [05:26<00:00, 12.36it/s]


Train_loss:  0.0016887382088747666


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.61it/s]


Val_loss:  0.0019191414236533685


Train: 100%|██████████| 4037/4037 [05:30<00:00, 12.22it/s]


Train_loss:  0.001683895523484094


Val: 100%|██████████| 1010/1010 [01:21<00:00, 12.42it/s]


Val_loss:  0.00191649391493382


Train: 100%|██████████| 4037/4037 [05:34<00:00, 12.07it/s]


Train_loss:  0.001656568500548758


Val: 100%|██████████| 1010/1010 [01:20<00:00, 12.47it/s]


Val_loss:  0.0019118736519957754


Train: 100%|██████████| 4037/4037 [05:39<00:00, 11.89it/s]


Train_loss:  0.001651831326513914


Val: 100%|██████████| 1010/1010 [01:22<00:00, 12.30it/s]

Val_loss:  0.0019227304311450755





In [None]:
single_conv_distill = SingleClusteringDistillModelConv(vocab_size = VOCAB_SIZE,
                              seq_len = MAX_TOKENS)


single_conv_distill.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_MSE/Conv_lr001/best-distill-bert.pt'))

distill_embedder = single_conv_distill.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, max_tokens_count=MAX_TOKENS)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)


Accuracy: 90.1
Positives Recall: 85.8
Positives Precision: 92.3
Positives F1: 88.9
Distance:  0.38
Max cluster size:  299
Median cluster size:  3
Avg cluster size: 5.26


In [None]:
single_conv_distill = SingleClusteringDistillModelConv(vocab_size = VOCAB_SIZE,
                              seq_len = MAX_TOKENS, conv_embedder_class=ConvEmbedder2)


train(single_conv_distill, train_loader, val_loader, 
      20, '/content/drive/MyDrive/NewsBert/Distillation_MSE/Conv2',
      )

  self.padding, self.dilation, self.groups)
Train: 100%|██████████| 4037/4037 [06:53<00:00,  9.77it/s]


Train_loss:  0.003739876204945658


Val: 100%|██████████| 1010/1010 [01:39<00:00, 10.15it/s]


Val_loss:  0.0036125130296712463


Train: 100%|██████████| 4037/4037 [06:54<00:00,  9.75it/s]


Train_loss:  0.0034203414319847626


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.29it/s]


Val_loss:  0.0032396036656665624


Train: 100%|██████████| 4037/4037 [06:48<00:00,  9.87it/s]


Train_loss:  0.003083359927355649


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.33it/s]


Val_loss:  0.002959948930990799


Train: 100%|██████████| 4037/4037 [06:53<00:00,  9.77it/s]


Train_loss:  0.002835582027577209


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.34it/s]


Val_loss:  0.00276719851622192


Train: 100%|██████████| 4037/4037 [06:55<00:00,  9.71it/s]


Train_loss:  0.00265974471226043


Val: 100%|██████████| 1010/1010 [01:40<00:00, 10.04it/s]


Val_loss:  0.0026337058781323456


Train: 100%|██████████| 4037/4037 [06:47<00:00,  9.91it/s]


Train_loss:  0.0025290754961797123


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.37it/s]


Val_loss:  0.0025192834591345474


Train: 100%|██████████| 4037/4037 [06:51<00:00,  9.82it/s]


Train_loss:  0.0024285607940453


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.38it/s]


Val_loss:  0.0024375438955289747


Train: 100%|██████████| 4037/4037 [06:49<00:00,  9.86it/s]


Train_loss:  0.002348213918829613


Val: 100%|██████████| 1010/1010 [01:36<00:00, 10.47it/s]


Val_loss:  0.0023786532674169186


Train: 100%|██████████| 4037/4037 [06:50<00:00,  9.83it/s]


Train_loss:  0.0022821219886911734


Val: 100%|██████████| 1010/1010 [01:33<00:00, 10.77it/s]


Val_loss:  0.0023260890591502337


Train: 100%|██████████| 4037/4037 [06:42<00:00, 10.04it/s]


Train_loss:  0.002226852228498681


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.32it/s]


Val_loss:  0.002288007907255894


Train: 100%|██████████| 4037/4037 [06:48<00:00,  9.89it/s]


Train_loss:  0.0021797833936660375


Val: 100%|██████████| 1010/1010 [01:38<00:00, 10.23it/s]


Val_loss:  0.0022500541146457343


Train: 100%|██████████| 4037/4037 [06:56<00:00,  9.69it/s]


Train_loss:  0.0021388860273938316


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.36it/s]


Val_loss:  0.0022128777396358034


Train: 100%|██████████| 4037/4037 [06:58<00:00,  9.65it/s]


Train_loss:  0.0021032318724610016


Val: 100%|██████████| 1010/1010 [01:35<00:00, 10.61it/s]


Val_loss:  0.0021885356879654794


Train: 100%|██████████| 4037/4037 [06:39<00:00, 10.11it/s]


Train_loss:  0.002071207598967633


Val: 100%|██████████| 1010/1010 [01:34<00:00, 10.64it/s]


Val_loss:  0.002164961500923232


Train: 100%|██████████| 4037/4037 [06:41<00:00, 10.05it/s]


Train_loss:  0.0020429286945176143


Val: 100%|██████████| 1010/1010 [01:34<00:00, 10.64it/s]


Val_loss:  0.002149918487775001


Train: 100%|██████████| 4037/4037 [06:39<00:00, 10.10it/s]


Train_loss:  0.0020170167950391473


Val: 100%|██████████| 1010/1010 [01:34<00:00, 10.68it/s]


Val_loss:  0.00212332234985315


Train: 100%|██████████| 4037/4037 [06:40<00:00, 10.09it/s]


Train_loss:  0.0019935269637682985


Val: 100%|██████████| 1010/1010 [01:35<00:00, 10.53it/s]


Val_loss:  0.0021182876804124308


Train: 100%|██████████| 4037/4037 [06:48<00:00,  9.89it/s]


Train_loss:  0.0019725079673631007


Val: 100%|██████████| 1010/1010 [01:36<00:00, 10.43it/s]


Val_loss:  0.002103019980621515


Train: 100%|██████████| 4037/4037 [06:50<00:00,  9.83it/s]


Train_loss:  0.0019523227589476207


Val: 100%|██████████| 1010/1010 [01:37<00:00, 10.36it/s]


Val_loss:  0.0020802881504123163


Train: 100%|██████████| 4037/4037 [06:47<00:00,  9.91it/s]


Train_loss:  0.0019341499976602014


Val: 100%|██████████| 1010/1010 [01:34<00:00, 10.71it/s]


Val_loss:  0.002064928708547302


In [None]:
single_conv_distill = SingleClusteringDistillModelConv(vocab_size = VOCAB_SIZE,
                              seq_len = MAX_TOKENS, conv_embedder_class=ConvEmbedder2)


single_conv_distill.load_state_dict(
    torch.load('/content/drive/MyDrive/NewsBert/Distillation_MSE/Conv2/best-distill-bert.pt'))

distill_embedder = single_conv_distill.embedder.cuda()
public_distill_embeddings = records_to_embeds(public_set, distill_embedder, tokenizer, batch_size=8, max_tokens_count=MAX_TOKENS)
get_quality(public_markup, public_distill_embeddings, public_set, 0.38, True)


Accuracy: 86.7
Positives Recall: 78.5
Positives Precision: 91.5
Positives F1: 84.5
Distance:  0.38
Max cluster size:  283
Median cluster size:  2
Avg cluster size: 4.05
