In [1]:
import sentence_transformers
from beir import util, LoggingHandler
from beir.retrieval import models as beir_models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from sentence_transformers import models, losses, datasets
from torch.utils.data import Dataset,DataLoader
from typing import List
from sentence_transformers.readers import InputExample
import numpy as np
from transformers.utils.import_utils import is_nltk_available, NLTK_IMPORT_ERROR
from nltk import word_tokenize, TreebankWordDetokenizer
import nltk
nltk.download('punkt')
from torch import nn, Tensor
from typing import Iterable, Dict
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedModel,AutoModel
import logging
import os
import random
logger = logging.getLogger(__name__)



  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /home/cgrdj/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:

model_name = "huawei-noah/TinyBERT_General_4L_312D"
tokenizer = AutoTokenizer.from_pretrained(model_name,do_lower_case=False,clean_up_tokenization_spaces=False,clean_text=False)


In [3]:
dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(os.getcwd(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
print("Dataset downloaded here: {}".format(data_path))
corpus, queries, qrels = GenericDataLoader(data_path).load(split="train") # or split = "train" or "dev"
unsupervised_train_data =  list(queries.values())#+[data['title']+' \n '+data['text'] for data in list(corpus.values())]
random.Random(0).shuffle(unsupervised_train_data)


Dataset downloaded here: /home/cgrdj/Documents/code/repos/sentence-transformers/datasets/scifact


  0%|          | 0/5183 [00:00<?, ?it/s]

100%|██████████| 5183/5183 [00:00<00:00, 85672.59it/s]


In [4]:

class MaskedAutoEncoderDataset(Dataset):
    """
    The DenoisingAutoEncoderDataset returns InputExamples in the format: texts=[noise_fn(sentence), sentence]
    It is used in combination with the DenoisingAutoEncoderLoss: Here, a decoder tries to re-construct the
    sentence without noise.

    :param sentences: A list of sentences
    :param noise_fn: A noise function: Given a string, it returns a string with noise, e.g. deleted words
    """

    def __init__(self, sentences: List[str],tokenizer ):
        if not is_nltk_available():
            raise ImportError(NLTK_IMPORT_ERROR.format(self.__class__.__name__))

        self.sentences = sentences
        self.tokenizer=tokenizer

    def __getitem__(self, item):
        sent = self.sentences[item]
        return InputExample(texts=[self.noisen(sent,MASK_ratio=0.15), self.noisen(sent,MASK_ratio=0.4),sent])

    def __len__(self):
        return len(self.sentences)

    # Masking noise.
    def noisen(self,text, MASK_ratio=0.15):
        mask_id=self.tokenizer.mask_token_id
        words= text.split()#word_tokenize(text)
        # Apply the masking logic to each word and rejoin the sentence
        splitted_tokens = self.tokenizer.batch_encode_plus(words,return_attention_mask=False,return_token_type_ids=False,add_special_tokens=False)['input_ids']#encode each tokens in each
        masked_tokens =[[ mask_id if np.random.rand() < MASK_ratio else tok_id for tok_id in token]  for token in splitted_tokens]
        masked_sentence=' '.join([self.tokenizer.decode(masked_token).replace(" ",'') for masked_token in masked_tokens])
        return masked_sentence

In [8]:

word_embedding_model = models.Transformer(model_name, max_seq_length=512)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), "cls")
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

In [33]:

class MaskedAutoEncoderLoss(nn.Module):
    def __init__(self, model: SentenceTransformer, decoder_name_or_path: str = None):
        """
        This loss expects as input a pairs of damaged sentences and the corresponding original ones.
        During training, the decoder reconstructs the original sentences from the encoded sentence embeddings.
        Here the argument 'decoder_name_or_path' indicates the pretrained model (supported by Hugging Face) to be used as the decoder.
        Since decoding process is included, here the decoder should have a class called XXXLMHead (in the context of Hugging Face's Transformers).
        The 'tie_encoder_decoder' flag indicates whether to tie the trainable parameters of encoder and decoder,
        which is shown beneficial to model performance while limiting the amount of required memory.
        Only when the encoder and decoder are from the same architecture, can the flag 'tie_encoder_decoder' work.

        The data generation process (i.e. the 'damaging' process) has already been implemented in ``DenoisingAutoEncoderDataset``,
        allowing you to only provide regular sentences.

        :param model: SentenceTransformer model
        :param decoder_name_or_path: Model name or path for initializing a decoder (compatible with Huggingface's Transformers)
        :param tie_encoder_decoder: whether to tie the trainable parameters of encoder and decoder

        References:
            * TSDAE paper: https://arxiv.org/pdf/2104.06979.pdf
            * `Unsupervised Learning > TSDAE <../../examples/unsupervised_learning/TSDAE/README.html>`_

        Requirements:
            1. The decoder should have a class called XXXLMHead (in the context of Hugging Face's Transformers)
            2. Should use a large corpus

        Inputs:
            +------------------------------------------------------+--------+
            | Texts                                                | Labels |
            +======================================================+========+
            | (damaged\_sentence, original\_sentence) pairs        | none   |
            +------------------------------------------------------+--------+
            | sentence fed through ``DenoisingAutoEncoderDataset`` | none   |
            +------------------------------------------------------+--------+

        Example:
            ::

                from sentence_transformers import SentenceTransformer, losses
                from sentence_transformers.datasets import DenoisingAutoEncoderDataset
                from torch.utils.data import DataLoader

                model_name = "bert-base-cased"
                model = SentenceTransformer(model_name)
                train_sentences = [
                    "First training sentence", "Second training sentence", "Third training sentence", "Fourth training sentence",
                ]
                batch_size = 2
                train_dataset = DenoisingAutoEncoderDataset(train_sentences)
                train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
                train_loss = losses.DenoisingAutoEncoderLoss(
                    model, decoder_name_or_path=model_name, tie_encoder_decoder=True
                )
                model.fit(
                    train_objectives=[(train_dataloader, train_loss)],
                    epochs=10,
                )
        """
        super(MaskedAutoEncoderLoss, self).__init__()
        self.encoder = model  # This will be the final model used during the inference time.
        self.tokenizer_encoder = model.tokenizer

        name_or_path = model[0].auto_model.config._name_or_path

        self.tokenizer_decoder = AutoTokenizer.from_pretrained(name_or_path)

        decoder_config = AutoConfig.from_pretrained(name_or_path)
        decoder_config.is_decoder = True
        decoder_config.add_cross_attention = True
        decoder_config.num_hidden_layers=1
        kwargs_decoder = {"config": decoder_config}
        try:
            self.decoder = AutoModelForCausalLM.from_pretrained(name_or_path, **kwargs_decoder)
        except ValueError as e:
            logger.error(
                f'Model name or path "{name_or_path}" does not support being as a decoder. Please make sure the decoder model has an "XXXLMHead" class.'
            )
            raise e
        if self.tokenizer_decoder.pad_token is None:
            # Needed by GPT-2, etc.
            self.tokenizer_decoder.pad_token = self.tokenizer_decoder.eos_token
            self.decoder.config.pad_token_id = self.decoder.config.eos_token_id

        if len(AutoTokenizer.from_pretrained(name_or_path)) != len(self.tokenizer_encoder):
            logger.warning(
                "WARNING: The vocabulary of the encoder has been changed. One might need to change the decoder vocabulary, too."
            )

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        source_features, target_features= tuple(sentence_features)
        reps = self.encoder(source_features)["sentence_embedding"]  # (bsz, hdim)
        # print(source_features)
        # Prepare input and output
        target_length = target_features["input_ids"].shape[1]
        decoder_input_ids = target_features["input_ids"].clone()[:, : target_length - 1]
        print( target_features["input_ids"],decoder_input_ids)
        # print(decoder_input_ids.shape)
        label_ids = target_features["input_ids"][:, 1:]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            inputs_embeds=None,
            attention_mask=None,
            encoder_hidden_states=reps[:, None],  # (bsz, hdim) -> (bsz, 1, hdim)
            encoder_attention_mask=source_features["attention_mask"][:, 0:1],
            labels=None,
            return_dict=None,
            use_cache=False,
        )

        # Calculate loss
        lm_logits = decoder_outputs[0]
        ce_loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer_decoder.pad_token_id)
        loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), label_ids.reshape(-1))
        return loss


    def forward_(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        encoder_inputs,decoder_inputs , target_features = tuple(sentence_features)
        reps = self.encoder(encoder_inputs)["sentence_embedding"]  # (bsz, hdim)

        # Prepare input and output
        target_length = target_features["input_ids"].shape[1]
        decoder_input_ids = target_features["input_ids"].clone()[:, : target_length - 1]
        label_ids = target_features["input_ids"][:, 1:]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            inputs_embeds=None,
            attention_mask=None,
            encoder_hidden_states=reps[:, None],  # (bsz, hdim) -> (bsz, 1, hdim)
            encoder_attention_mask=encoder_inputs["attention_mask"][:, 0:1],
            labels=None,
            return_dict=None,
            use_cache=False,
        )

        # Calculate loss
        lm_logits = decoder_outputs[0]
        ce_loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer_decoder.pad_token_id)
        loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), label_ids.reshape(-1))
        return loss


In [6]:

decoder_config = AutoConfig.from_pretrained(model_name)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
decoder_config.num_hidden_layers=1
kwargs_decoder = {"config": decoder_config}
decoder = AutoModelForCausalLM.from_pretrained(model_name, **kwargs_decoder)

Some weights of BertLMHeadModel were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [85]:
text=['I m going to','We provided with']
decoder_input_ids=model.tokenizer(text,padding=True,truncation=True)
reps=model.encode(text)

torch.Size([2, 1, 312])

In [76]:
torch.tensor(decoder_input_ids['input_ids']).clone()[:, :5]

tensor([[ 101, 1045, 1049, 2183, 2000],
        [ 101, 2057, 3024, 2007,  102]])

In [101]:
decoder_outputs = decoder(
    input_ids=torch.tensor(decoder_input_ids['input_ids']).clone()[:, :5],
    inputs_embeds=None,
    attention_mask=None,
    encoder_hidden_states=torch.tensor(reps).unsqueeze(1),  # (bsz, hdim) -> (bsz, 1, hdim)
    # encoder_attention_mask=[1,1],
    labels=None,
    return_dict=None,
    use_cache=False,
)

In [104]:

# Apply softmax to convert logits to probabilities
probabilities = torch.softmax(decoder_outputs.logits, dim=-1)

# Choose the tokens with the highest probability
predicted_token_ids = torch.argmax(probabilities, dim=-1)

# Convert token IDs to tokens
predicted_tokens = [model.tokenizer.decode(generated_ids, skip_special_tokens=True) for generated_ids in predicted_token_ids]

# Print out the generated texts
for text in predicted_tokens:
    print(text)

ministers if lavish enable
needed draws losseit


In [24]:
train_dataset=MaskedAutoEncoderDataset(unsupervised_train_data,tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
train_loss = MaskedAutoEncoderLoss(model, decoder_name_or_path=model_name)


Some weights of BertLMHeadModel were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [34]:
train_dataset=datasets.DenoisingAutoEncoderDataset(unsupervised_train_data)#,tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
train_loss = MaskedAutoEncoderLoss(model)


Some weights of BertLMHeadModel were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [40]:
tokenizer.decode([  101, 17207,  2102,  2445,  2007, 19395, 18891,  6657,  7457,  2019,
         17577,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0])

'[CLS] azt given with ribavirin increases anemia. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [41]:
tokenizer.decode([  101, 17207,  2102,  2445,  2007, 19395, 18891,  6657,  7457,  2019,
         17577,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0])

'[CLS] azt given with ribavirin increases anemia. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [35]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    # evaluator=dev_evaluator,
    epochs=1,
    # evaluation_steps=evaluation_steps,
    # output_path=model_save_path,
    weight_decay=0,
    warmup_steps=100,
    optimizer_params={"lr": 3e-5},
    use_amp=True,  # Set to True, if your GPU supports FP16 cores
)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

tensor([[  101, 17207,  2102,  2445,  2007, 19395, 18891,  6657,  7457,  2019,
         17577,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  101,  1996,  4315, 13910,  8898,  1998, 15330, 13791,  1997, 18847,
         27321,  2038,  3972, 15141,  6313,  3896,  1999, 11888, 16514,  3785,
          1012,   102],
        [  101, 19802,  6190,  3141, 13356,  2038, 13763,  2013,  2268,  2000,
          2297,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  101,  3445, 19251,  1997, 12702, 21102,  3688, 16081,  2229, 11311,
         10960,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0]]) tensor([[  101, 17207,  2102,  2445,  2007, 19395, 18891,  6657,  7457,  2019,
         17577,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  101,  1996,  4315, 13910,  8898,  1998, 15330, 13791,  1997, 18847,
         27321,  



tensor([[  101,  1996, 26236,  2615,  1011,  1016,  8985,  2003,  4050,  2004,
         24335, 13876,  9626,  4588,  1012,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  7688, 11661,  3417,  2083, 18780,  1006, 18133,  2102,  1007,
          2003, 13567,  3141,  2000, 26957,  5657,  7865,  1006,  7939,  2615,
          1011,  1015,  1007, 19241,  1999,  2250,  4026, 24636,  1012,   102],
        [  101,  4125, 19440,  3686,  7457,  3891,  1997,  2310, 19731, 10024,
          2140,  1998,  2512,  1011,  2310, 19731, 10024,  2140, 28929,  1012,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1998, 22991, 16530,  5292, 24759,  9314,  8000, 28086,  8713,
          7872,  4442,  1006,  9686,  6169,  1007,  2064,  2022,  5173,  1998,
         19345, 20063,  1999, 25714,  1012,   102,     0,     0,     0,     0]]) tensor([[  101,  1996, 26236,  2615,  1011,  10



tensor([[  101,  1052, 12541, 13820,  9153,  7629, 16171, 20250,  1997, 24004,
         21197,  3560, 28667,  5358, 21114,  3508,  1011, 28829,  4442,  1012,
           102,     0,     0,     0],
        [  101, 27312,  8606,  6022, 13416,  1996,  2193,  1997, 12201, 23851,
          4740,  4072,  2005,  1037,  2445,  7709,  1012,   102,     0,     0,
             0,     0,     0,     0],
        [  101,  3151,  4391,  2024, 25352,  1999,  2037, 15931,  1012,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  1996,  4304,  1997, 22330, 18715,  3170, 10769,  7682,  4442,
          2038,  2053,  3466,  2006,  1996,  3292,  2058,  2029, 22330, 18715,
         10586,  2552,  1012,   102]]) tensor([[  101,  1052, 12541, 13820,  9153,  7629, 16171, 20250,  1997, 24004,
         21197,  3560, 28667,  5358, 21114,  3508,  1011, 28829,  4442,  1012,
           102,     0,     0],
        [  101, 27312,  86



tensor([[  101, 10210, 11663, 15422,  4360,  2377,  1037,  2350,  2535,  1999,
          2943,  2537,  1012,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [  101, 22575, 10047, 23041,  2891,  6279, 27484,  1006,  2003,  1007,
          7242,  7457,  1996,  3382,  1997,  4456, 13356,  1999,  5022,  2007,
         20187,  3239,  4295,  1006, 29464,  2094,  1007,   102],
        [  101,  2012,  2560,  5594,  1003,  1997,  5022,  6086,  2000,  8249,
          2031,  8878, 16387,  1997,  2026, 11253, 12322,  3217, 28522, 12837,
          1012,   102,     0,     0,     0,     0,     0,     0],
        [  101,  2028,  1999,  2274, 11707,  6721,  3550,  4758,  7012,  2024,
          8944,  2220,  1012,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0]]) tensor([[  101, 10210, 11663, 15422,  4360,  2377,  1037,  2350,  2535,  1999,
          2943,  2537,  



tensor([[  101,  8208,  1997,  1044,  2509,  2243,  2683,  4168,  2509,  2011,
         14925, 14399,  2594,  3670,  1997,  2060,  1044,  2509,  2243,  2683,
         17183, 11031, 23943,  8583, 17913, 16360,  3217, 13113,  6562,  8122,
          1999,  8040,  3372,  7885,  1012,   102],
        [  101, 11721,  2696,  2509, 26773,  2969,  1011, 14524,  3977,  1999,
          5923, 24960, 19610, 10610,  6873,  2666,  4588,  7872,  4442,  1012,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  6583,  2278,  4078,  2696, 14454, 10057,  2053,  2000,  3623,
          1996,  3466,  1997,  5688,  6074,  2006, 25125,  4972,  1012,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  2753,  1003,  1997,  5022,  6086,  2000,  8249,  2031,  8878,
         16387,  1997,  2033,  5054, 11714,  9067,  78



tensor([[  101, 24216,  2015,  2090,  7435,  1998,  5022,  2064,  2599,  2000,
          2512,  1011, 29235,  1012,   102,     0,     0,     0,     0,     0],
        [  101,  1996,  4013, 15509, 18514,  3977,  1997,  4013,  6914, 27287,
          2003, 12222,  3526,  1011,  8392,  2135,  1012,   102,     0,     0],
        [  101,  3949,  2007,  1037,  5250,  2315,  1042,  2078,  9239,  2015,
         19723, 24454,  8082,  7590,  1997,  4793,  6650,  1012,   102,     0],
        [  101,  2943,  5703,  5942,  1044, 22571, 14573,  7911,  7712,  1043,
          7630, 28282,  2618, 11265, 10976,  6494,  3619, 25481,  1012,   102]]) tensor([[  101, 24216,  2015,  2090,  7435,  1998,  5022,  2064,  2599,  2000,
          2512,  1011, 29235,  1012,   102,     0,     0,     0,     0],
        [  101,  1996,  4013, 15509, 18514,  3977,  1997,  4013,  6914, 27287,
          2003, 12222,  3526,  1011,  8392,  2135,  1012,   102,     0],
        [  101,  3949,  2007,  1037,  5250,  2315,  1042,  



tensor([[  101, 19802,  6190,  3141, 13356,  2038,  2815,  6540,  2090,  2268,
          1011,  2297,  1012,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  2822,  3633,  2007, 23746, 24004,  9096, 12333,  3012,  1999,
          1996, 11047,  2232, 19699,  4962,  2024,  2625,  8211,  2000, 13692,
          3303,  2011,  2659,  3798,  1997,  1042, 19425, 13822,  1012,   102],
        [  101,  2146,  4677, 26572,  4609, 16846,  4648,  3064, 19101, 12737,
         12448,  3370, 13416,  1059, 21030,  6774,  1998, 26180,  1012,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101, 10507,  2140, 16147,  2003,  1037, 27854,  2005, 10507,  2099,
          2581,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]) tensor([[  101, 19802,  6190,  3141, 13356,  20



tensor([[  101,  3445, 18433,  1997, 22330, 14399,  8523,  7712, 24972,  7275,
          2013,  2058, 10288, 20110,  5668,  2003,  5393,  2011,  6428,  7516,
          2005, 18168, 14376,  1999, 14134, 24869,  1011,  3931, 21500,  2015,
          1012,   102],
        [  101, 29486,  2953,  4442,  2024,  2028,  3114,  2005, 12958, 10960,
          2000,  5939,  7352,  3170, 21903, 24054,  1006,  1056,  3211,  1007,
          7242,  1999,  4456,  5022,  1012,   102,     0,     0,     0,     0,
             0,     0],
        [  101,  3465, 12353,  9312,  2015,  2241,  2006, 13675,  6593,  2951,
         14125,  8339,  7597,  2005,  5022,  1999,  5025,  6612,  3218,  1012,
           102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  101, 22466,  4017,  7277,  5970,  5260,  2000,  3893, 13105,  1999,
          5177,  2740,  1012,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,  



tensor([[  101,  3078, 28711,  4456, 11326,  2007,  6522,  2615, 10788,  2038,
          2896, 20134, 14639,  2084,  7511, 22330, 23479,  2000, 11487, 28711,
         26721, 13699,  8939, 24587,  9253, 24759, 15396,  3694,  1016,  1012,
           102],
        [  101, 20199,  1011,  3772,  1048, 12273, 12789,  2015,  2491,  1996,
          3670,  1997,  9165,  2008,  2024, 10959,  1999,  1996,  9884,  1997,
          2037, 14193,  4573,  1012,   102,     0,     0,     0,     0,     0,
             0],
        [  101,  5915,  6887,  2891,  8458,  4140,  5521, 22747,  2121,  6165,
          2024, 23900,  2007,  2010,  3775, 10672, 21903, 21618,  3563,  3012,
          1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0],
        [  101,  5622,  3401,  2012,  6528, 16453,  9033,  2615, 28896, 19653,
          1037,  6428, 28873,  1011,  3563,  1056,  3526,  3433,  1999,  1048,
         24335,  8458, 13045,  4442,  1012,   102,     0,     0,     0,     0,
 



tensor([[  101,  1055,  2078,  2003,  2556,  2006,  9677,  5887,  2015,  2076,
         21733,  1999, 24269,  1012,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  1996,  2088,  2740,  3029,  1005,  1055,  1006,  2040,  1007,
          2951,  3074,  2832,  2003, 25352, 14047,  2011, 16655, 26426,  4989,
          1997,  3469,  8293,  2015,  1012,   102],
        [  101,  1047, 10270,  2549,  2003,  2590,  2005,  5372,  2026, 18349,
          3593,  3526, 20582,  1012,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  1050, 22022,  2620,  2072, 14494,  3426,  5012,  2000, 11265,
         24093, 19265,  1012,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0]]) tensor([[  101,  1055,  2078,  2003,  2556,  2006,  9677,  5887,  2015,  2076,
         21733,  1999, 24269,  1012,   102,     0,     0,     0,     0,     0,
 

Iteration:   5%|▌         | 11/202 [00:02<00:50,  3.75it/s]
Epoch:   0%|          | 0/1 [00:02<?, ?it/s]


tensor([[  101,  2659,  3670,  1997, 14719,  2581,  2050,  2515,  2025, 16360,
          8303,  4539,  9165,  2030,  4654,  8743,  1037,  6897,  3853,  1999,
          3231,  2483,  1012,   102,     0,     0,     0,     0],
        [  101,  4761,  7977,  4968,  4031,  1006, 14230,  1007,  2003, 13567,
          3141,  2000, 26957,  5657,  7865,  1006,  7939,  2615,  1011,  1015,
          1007, 19241,  1999,  2250,  4026, 24636,  1012,   102],
        [  101,  2045,  2003,  2053,  2124,  8290,  2090,  7156,  5387, 13323,
          2509,  1013,  1018,  1998,  2350, 10381, 21716, 20363,  2128,  5302,
          9247,  2075,  5876,  1012,   102,     0,     0,     0],
        [  101,  8319,  3526,  1011,  2489, 23079,  6064,  3798,  2024,  3378,
          2007, 13356,  1012,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0]]) tensor([[  101,  2659,  3670,  1997, 14719,  2581,  2050,  2515,  2025, 16360,
          8303,  4539,  

KeyboardInterrupt: 

In [94]:
next(iter(MaskedAutoEncoderDataset(["asdasdf adsf sda fads 'asd gasdg aasdfasdf /dsaf sadfsd,dasfadsf. fadsf.fads,fadsfa"],tokenizer))).texts

['[MASK]dasdf ads[MASK] sda [MASK]ds [MASK]asd gasdg aasdfa[MASK]f /dsaf sadfsd,dasfa[MASK]f. [MASK]ds[MASK].fads,fadsfa',
 '[MASK]dasdf [MASK][MASK] sda fads [MASK]asd gas[MASK]g [MASK]sd[MASK][MASK][MASK] [MASK][MASK]af [MASK][MASK][MASK][MASK]dasfadsf[MASK] [MASK][MASK][MASK].fads[MASK]fads[MASK]',
 "asdasdf adsf sda fads 'asd gasdg aasdfasdf /dsaf sadfsd,dasfadsf. fadsf.fads,fadsfa"]

In [77]:
sentence="asdasdf adsf sda fads 'asd gasdg aasdfasdf /dsaf sadfsd,dasfadsf. fadsf.fads,fadsfa"
tokenizer.decode(tokenizer.encode_plus(sentence,return_attention_mask=False,return_token_type_ids=False,add_special_tokens=False)['input_ids'])



"asdasdf adsf sda fads ' asd gasdg aasdfasdf / dsaf sadfsd, dasfadsf. fadsf. fads, fadsfa"

In [5]:
tokenizer.decode(tokenizer.encode('as [MASK]df',add_special_tokens=False))



'as [MASK] df'

In [6]:

sentences=["asdasdf adsf sda fads 'asd gasdg aasdfasdf /dsaf sadfsd,dasfadsf. fadsf.fads,fadsfa"]*10000
masked_sentences=[]
for sentence in sentences:
    words= word_tokenize(sentence)
    # Apply the masking logic to each word and rejoin the sentence
    splitted_tokens = tokenizer.batch_encode_plus(words,return_attention_mask=False,return_token_type_ids=False,add_special_tokens=False)['input_ids']
    masked_sentence=' '.join([tokenizer.decode([ mask_id if np.random.rand() < mask_probability else tok_id for tok_id in word]).replace(" ",'') for word in splitted_tokens])
    masked_sentences.append(masked_sentence)



# masked_sentence = ' '.join([mask_token if np.random.rand() < mask_probability else word for word in words])
        

LookupError: 
**********************************************************************
  Resource [93mpunkt[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt/PY3/english.pickle[0m

  Searched in:
    - '/home/cgrdj/nltk_data'
    - '/home/cgrdj/Documents/code/repos/sentence-transformers/.conda/nltk_data'
    - '/home/cgrdj/Documents/code/repos/sentence-transformers/.conda/share/nltk_data'
    - '/home/cgrdj/Documents/code/repos/sentence-transformers/.conda/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
    - ''
**********************************************************************


In [369]:
masked_sentence

[6904, 5104, 2546, 1012, 6904, 5104]

In [416]:
remove_spaces_table=str.maketrans('', '', ' ')


In [431]:
%%timeit 
tokenizer.decode([ mask_id if np.random.rand() < mask_probability else tok_id for tok_id in splitted_tokens[12]]).replace(" ",'')


53.5 µs ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [429]:
%%timeit -n 100000
tokenizer.decode([ mask_id if np.random.rand() < mask_probability else tok_id for tok_id in splitted_tokens[12]]).translate(remove_spaces_table)


59.8 µs ± 4.79 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [385]:
tokenizer('f')

{'input_ids': [101, 1042, 102], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}

In [227]:
tokenizer.convert_tokens_to_string(tokenizer(sentence)['input_ids'])



TypeError: argument 'tokens': 'int' object cannot be converted to 'PyString'

In [212]:
mask_probability=0.15
mask_token = tokenizer.mask_token  # Get the mask token
tokens = tokenizer.tokenize(sentence)
# Decide randomly which tokens to mask
masked_indices = np.random.rand(len(tokens)) < mask_probability
# Replace selected tokens with the mask token
masked_tokens = [mask_token if mask else token for token, mask in zip(tokens, masked_indices)]
# Convert the list of tokens back to a string
masked_sentence = tokenizer.convert_tokens_to_string(masked_tokens)
# Add the masked sentence to the list
# masked_sentences.append(masked_sentence)
masked_sentence

"asdas [MASK] adsf sda fads ' asd gas [MASK]g aasdfasdf / dsaf sad [MASK]d, [MASK]fadsf. fadsf. fa [MASK], fadsfa"

In [208]:
masked_sentence

"as [MASK]df adsf sd [MASK] fads ' asd gas [MASK]g aa [MASK]fasdf [MASK] [MASK] [MASK] sadfsd, dasfadsf. [MASK] [MASK] [MASK]. fads [MASK] fa [MASK]fa"

In [203]:

word_tokenize(sentence)




['asdasdf',
 'adsf',
 'sda',
 'fads',
 "'asd",
 'gasdg',
 'aasdfasdf',
 '/dsaf',
 'sadfsd',
 ',',
 'dasfadsf',
 '.',
 'fadsf.fads',
 ',',
 'fadsfa']

# Data

In [192]:
dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(os.getcwd(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
print("Dataset downloaded here: {}".format(data_path))

Dataset downloaded here: /Users/cgrdj/Documents/Code/sentence-transformers/datasets/scifact


In [193]:
corpus, queries, qrels = GenericDataLoader(data_path).load(split="train") # or split = "train" or "dev"
unsupervised_train_data =  list(queries.values())#+[data['title']+' \n '+data['text'] for data in list(corpus.values())]
random.Random(0).shuffle(unsupervised_train_data)


2024-03-16 00:49:24 - Loading Corpus...


100%|██████████| 5183/5183 [00:00<00:00, 11485.08it/s]

2024-03-16 00:49:24 - Loaded 5183 TRAIN Documents.
2024-03-16 00:49:24 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 ver




In [194]:
data_path = "datasets/scifact"
test_corpus, test_queries, test_qrels = GenericDataLoader(data_path).load(split="test") # or split = "train" or "dev"

2024-03-16 00:49:25 - Loading Corpus...


  0%|          | 0/5183 [00:00<?, ?it/s]

100%|██████████| 5183/5183 [00:00<00:00, 16216.28it/s]


2024-03-16 00:49:26 - Loaded 5183 TEST Documents.
2024-03-16 00:49:26 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 vers