In [1]:
import sys
sys.path.append("..")

import os
from pathlib import Path
PROJ_DIR = Path(os.getcwd()).parent
CKPT_DIR = PROJ_DIR / "checkpoints"

import torch
from torch.utils.data import DataLoader

from sentence_transformers import SentenceTransformer, SentencesDataset
from sentence_transformers.readers import InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.losses import SoftmaxLoss

from datasets import load_dataset, get_dataset_config_names

from src.models import SIFPooling, BatchPCRemoval, MeanPooling


# Load wiki-text-2 to estimate the word frequencies
corpus = load_dataset("wikitext", "wikitext-2-v1")
corpus = [s.strip() for p in corpus.values() for s in p["text"]
            if s.strip() != "" and not s.strip().startswith("=")]

Found cached dataset wikitext (/home/dogdog/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

## __Create Raw Models (Untrained)__

In [2]:
model_card = "bert-base-uncased"
model_card_avg = f"{model_card}-avg"
model_card_sif = f"{model_card}-sif"
model_card_arm = f"{model_card}-arm"
model_card_srm = f"{model_card}-srm"

embedding_layer        = Transformer(model_card)
normal_pooling_layer   = MeanPooling()
weighted_pooling_layer = SIFPooling.from_corpus_hf(model_card, corpus)
batch_pc_removal_layer = BatchPCRemoval(n_components=1)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
SentenceTransformer(modules=[embedding_layer, normal_pooling_layer]).save(str(CKPT_DIR / "untrained" / model_card_avg))
SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer]).save(str(CKPT_DIR / "untrained" / model_card_sif))
SentenceTransformer(modules=[embedding_layer, normal_pooling_layer, batch_pc_removal_layer]).save(str(CKPT_DIR / "untrained" / model_card_arm))
SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer, batch_pc_removal_layer]).save(str(CKPT_DIR / "untrained" / model_card_srm))

## __Fine-tuning on AllNLI__

In [4]:
dataset_snli = load_dataset("snli")
train_snli = [InputExample(texts=[s["premise"], s["hypothesis"]], label=s["label"]) 
                for s in dataset_snli["train"].select(range(8192 * 4)) if s["label"] != -1]
train_snli_dataloader = DataLoader(train_snli, shuffle=True, batch_size=16)
dev_snli = [InputExample(texts=[s["premise"], s["hypothesis"]], label=s["label"])
                for s in dataset_snli["validation"] if s["label"] != -1]
dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_snli, 
                                                                 write_csv=False,
                                                                 batch_size=16,
                                                                 name="snli")

Found cached dataset snli (/home/dogdog/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
import math


warmup_steps = math.ceil(len(train_snli_dataloader) * 1 * 0.1) #10% of train data for warm-up
train_snli_loss = SoftmaxLoss(model=model_b, 
                              sentence_embedding_dimension=768, 
                              num_labels=3)
model_b.fit(train_objectives=[(train_snli_dataloader, train_snli_loss)],
            evaluator=dev_evaluator,
            epochs=1,
            evaluation_steps=256,
            warmup_steps=warmup_steps,
            output_path="../checkpoints/snli/b2")

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/2046 [00:00<?, ?it/s]

In [6]:
model_b = SentenceTransformer("../models/snli-b")
eval_sick(model_b, output_path=".")

0.5321638583444072

In [7]:
model_b2 = SentenceTransformer("../models/snli-b2")
eval_sick(model_b2, output_path=".")

0.6006418442647867

In [8]:
eval_stsb(model_b, output_path=".")

0.5666246597282747

In [9]:
eval_stsb(model_b2, output_path=".")

0.5730377609762444

In [15]:
model_w = SentenceTransformer("../models/snli-w")
eval_sick(model_w, output_path=".")

0.5301054550908639

In [16]:
eval_stsb(model_w, output_path=".")

0.5476615084109432

In [2]:
model_r = SentenceTransformer("../models/snli-r")
eval_sick(model_r, output_path=".")

0.45502135699433416

In [21]:
dev_evaluator(model_b)

-0.07476983210370843