In [10]:
import sys
sys.path.append("..")

import os
from pathlib import Path
PROJ_DIR = Path(os.getcwd()).parent
CKPT_DIR = PROJ_DIR / "checkpoints"

import math

import torch
from torch.utils.data import DataLoader

from sentence_transformers import SentenceTransformer, SentencesDataset
from sentence_transformers.readers import InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.losses import SoftmaxLoss

from datasets import load_dataset, get_dataset_config_names

from src.models import SIFPooling, BatchPCRemoval, MeanPooling


# Load wiki-text-2 to estimate the word frequencies
corpus = load_dataset("wikitext", "wikitext-2-v1")
corpus = [s.strip() for p in corpus.values() for s in p["text"]
            if s.strip() != "" and not s.strip().startswith("=")]

Found cached dataset wikitext (/home/dogdog/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

## __Create Raw Models (Untrained)__

In [11]:
# model_card = "bert-base-uncased"
model_card = "nreimers/MiniLM-L6-H384-uncased"
model_card_avg = f"{model_card}-avg"
model_card_sif = f"{model_card}-sif"
model_card_arm = f"{model_card}-arm"
model_card_srm = f"{model_card}-srm"

embedding_layer        = Transformer(model_card, tokenizer_args={"use_fast": True, "truncation": True})
normal_pooling_layer   = MeanPooling(embedding_layer.get_word_embedding_dimension())
weighted_pooling_layer = SIFPooling.from_corpus_hf(embedding_layer.get_word_embedding_dimension(), model_card, corpus)
batch_pc_removal_layer = BatchPCRemoval(embedding_layer.get_word_embedding_dimension(), n_components=1)

In [12]:
SentenceTransformer(modules=[embedding_layer, normal_pooling_layer]).save(str(CKPT_DIR / "untrained" / model_card_avg))
SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer]).save(str(CKPT_DIR / "untrained" / model_card_sif))
SentenceTransformer(modules=[embedding_layer, normal_pooling_layer, batch_pc_removal_layer]).save(str(CKPT_DIR / "untrained" / model_card_arm))
SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer, batch_pc_removal_layer]).save(str(CKPT_DIR / "untrained" / model_card_srm))