In [None]:
# !pip install sentence-transformers
# !pip install datasets
!pip install beir

Collecting beir
  Downloading beir-1.0.0.tar.gz (64 kB)
[K     |████████████████████████████████| 64 kB 2.2 MB/s 
[?25hCollecting sentence-transformers
  Downloading sentence-transformers-2.2.0.tar.gz (79 kB)
[K     |████████████████████████████████| 79 kB 7.1 MB/s 
[?25hCollecting pytrec_eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
Collecting faiss_cpu
  Downloading faiss_cpu-1.7.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)
[K     |████████████████████████████████| 8.6 MB 24.0 MB/s 
[?25hCollecting elasticsearch==7.9.1
  Downloading elasticsearch-7.9.1-py2.py3-none-any.whl (219 kB)
[K     |████████████████████████████████| 219 kB 49.9 MB/s 
Collecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 38.3 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |███████████████████

In [None]:
from datasets import load_dataset

squad = load_dataset(
    'squad',
    split='train'
)
squad[0]

Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'id': '5733be284776f41900661182',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'title': 'University_of_Notre_Dame'}

In [None]:
passages = list(set(squad['context']))
len(passages)

18891

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('BeIR/query-gen-msmarco-t5-base-v1')
model = T5ForConditionalGeneration.from_pretrained('BeIR/query-gen-msmarco-t5-base-v1')
model.eval()

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dr

In [None]:
import torch
from tqdm.auto import tqdm

pairs = []
file_count = 0

# set to no_grad as we don't need to calculate gradients for back prop
with torch.no_grad():
    # loop through each passage individually
    for p in tqdm(passages):
        p = p.replace('\t', ' ')
        # create input tokens
        input_ids = tokenizer.encode(p, return_tensors='pt')
        # generate output tokens (query generation)
        outputs = model.generate(
            input_ids=input_ids,
            max_length=64,
            do_sample=True,
            top_p=0.95,
            num_return_sequences=3
        )
        # decode output tokens to human-readable language
        for output in outputs:
            query = tokenizer.decode(output, skip_special_tokens=True)
            # append (query, passage) pair to pairs list, separate by \t
            pairs.append(query.replace('\t', ' ')+'\t'+p)
        
        # once we have 1024 pairs write to file
        if len(pairs) > 1024:
            with open(f'data/pairs_{file_count}.tsv', 'w', encoding='utf-8') as fp:
                fp.write('\n'.join(pairs))
            file_count += 1
            pairs = []

if pairs is not None:
    # save the final, smaller than 1024 batch
    with open(f'data/pairs_{file_count}.tsv', 'w', encoding='utf-8') as fp:
        fp.write('\n'.join(pairs))

  0%|          | 0/18891 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
print("Paragraph:")
# print(para)

print("\nGenerated Queries:")
for i in range(len(outputs)):
    query = tokenizer.decode(outputs[i], skip_special_tokens=True)
    print(f'{i + 1}: {query}')

### MNR loss

In [None]:
from pathlib import Path

paths = [str(path) for path in Path('data').glob('*.tsv')]
paths[:5]

In [None]:
from sentence_transformers import InputExample
from tqdm.auto import tqdm

pairs = []
for path in tqdm(paths):
    with open(path, 'r', encoding='utf-8') as fp:
        lines = fp.read().split('\n')
        for line in lines:
            if '\t' not in line:
                continue
            else:
                q, p = line.split('\t')
                pairs.append(InputExample(
                    texts=[q, p]
                ))


In [None]:
from sentence_transformers import datasets

batch_size = 24

loader = datasets.NoDuplicatesDataLoader(
    pairs, batch_size=batch_size
)

In [None]:
from sentence_transformers import models, SentenceTransformer

mpnet = models.Transformer('microsoft/mpnet-base')
pooler = models.Pooling(
    mpnet.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

model = SentenceTransformer(modules=[mpnet, pooler])

model

In [None]:
from sentence_transformers import losses

loss = losses.MultipleNegativesRankingLoss(model)

In [None]:
epochs = 3
warmup_steps = int(len(loader) * epochs * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    output_path='mpnet-genq-squad',
    show_progress_bar=True
)

# beir

In [None]:
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.generation import QueryGenerator as QGen
from beir.generation.models import QGenModel
from beir.retrieval.train import TrainRetriever
from sentence_transformers import SentenceTransformer, losses, models

import pathlib, os
import logging

  from tqdm.autonotebook import tqdm


In [None]:
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

#### Download nfcorpus.zip dataset and unzip the dataset

In [None]:
dataset = "nfcorpus"

url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir =  "datasets"
data_path = util.download_and_unzip(url, out_dir)

#### Provide the data_path where nfcorpus has been downloaded and unzipped

In [None]:
corpus = GenericDataLoader(data_path).load_corpus()

2022-05-21 08:52:08 - Loading Corpus...


  0%|          | 0/3633 [00:00<?, ?it/s]

2022-05-21 08:52:09 - Loaded 3633 Documents.
2022-05-21 08:52:09 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants died, of

##############################
#### 1. Query-Generation  ####
##############################

#### question-generation model loading 



In [None]:
model_path = "BeIR/query-gen-msmarco-t5-base-v1"
generator = QGen(model=QGenModel(model_path))

2022-05-21 08:52:19 - Use pytorch device: cuda


#### Query-Generation using Nucleus Sampling (top_k=25, top_p=0.95) ####
#### https://huggingface.co/blog/how-to-generate
#### Prefix is required to seperate out synthetic queries and qrels from original

In [None]:
prefix = "gen"

#### Generating 3 questions per passage. 
#### Reminder the higher value might produce lots of duplicates

In [None]:
ques_per_passage = 3

#### Generate queries per passage from docs in corpus and save them in data_path


In [None]:
generator.generate(corpus, output_dir=data_path, ques_per_passage=ques_per_passage, prefix=prefix)

2022-05-21 08:52:22 - Starting to Generate 3 Questions Per Passage using top-p (nucleus) sampling...
2022-05-21 08:52:22 - Params: top_p = 0.95
2022-05-21 08:52:22 - Params: top_k = 25
2022-05-21 08:52:22 - Params: max_length = 64
2022-05-21 08:52:22 - Params: ques_per_passage = 3
2022-05-21 08:52:22 - Params: batch size = 32


pas:   0%|          | 0/114 [00:00<?, ?it/s]

2022-05-21 09:06:33 - Saving 10827 Generated Queries...
2022-05-21 09:06:33 - Saving Generated Queries to datasets/nfcorpus/gen-queries.jsonl
2022-05-21 09:06:33 - Saving Generated Qrels to datasets/nfcorpus/gen-qrels/train.tsv


################################
#### 2. Train Dense-Encoder ####
################################

#### Training on Generated Queries ####

In [None]:
corpus, gen_queries, gen_qrels = GenericDataLoader(data_path, prefix=prefix).load(split="train")
#### Please Note - not all datasets contain a dev split, comment out the line if such the case
dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_path).load(split="dev")

2022-05-21 09:06:33 - Loading Corpus...


  0%|          | 0/3633 [00:00<?, ?it/s]

2022-05-21 09:06:33 - Loaded 3633 TRAIN Documents.
2022-05-21 09:06:33 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants di

  0%|          | 0/3633 [00:00<?, ?it/s]

2022-05-21 09:06:33 - Loaded 3633 DEV Documents.
2022-05-21 09:06:33 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants died

In [None]:
#### Provide any HuggingFace model and fine-tune from scratch
model_name = "distilbert-base-uncased" 
word_embedding_model = models.Transformer(model_name, max_seq_length=350)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


2022-05-21 09:06:38 - Use pytorch device: cuda


#### Or provide already fine-tuned sentence-transformer model
```model = SentenceTransformer("msmarco-distilbert-base-v3")```

#### Provide any sentence-transformers model path

In [None]:
model_path = "bert-base-uncased" # or "msmarco-distilbert-base-v3"
retriever = TrainRetriever(model=model, batch_size=32)

#### Prepare training samples

In [None]:
train_samples = retriever.load_train(corpus, gen_queries, gen_qrels)
train_dataloader = retriever.prepare_train(train_samples, shuffle=True)
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)

Adding Input Examples:   0%|          | 0/170 [00:00<?, ?it/s]

2022-05-21 09:06:38 - Loaded 10827 training pairs.


#### Prepare dev evaluator

In [None]:
ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)

2022-05-21 09:06:38 - eval set contains 3633 documents and 324 queries


#### If no dev set is present evaluate using dummy evaluator
``` ir_evaluator = retriever.load_dummy_evaluator()```

In [None]:
model_save_path = os.path.join("output", "{}-GenQ-nfcorpus".format(model_path))
print(model_save_path)
os.makedirs(model_save_path, exist_ok=True)

output/bert-base-uncased-GenQ-nfcorpus


#### Configure Train params

In [None]:
num_epochs = 1
evaluation_steps = 5000
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)

retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                evaluator=ir_evaluator, 
                epochs=num_epochs,
                output_path=model_save_path,
                warmup_steps=warmup_steps,
                evaluation_steps=evaluation_steps,
                use_amp=True)

2022-05-21 09:10:02 - Starting to Train...




Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/170 [00:00<?, ?it/s]

RuntimeError: ignored

In [None]:
!nvidia-smi

Sat May 21 09:10:13 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P0    60W / 149W |  11350MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
!ls

pairs.tsv  sample_data


In [3]:
!pip install sentence_transformers

Collecting sentence_transformers
  Downloading sentence-transformers-2.2.0.tar.gz (79 kB)
[K     |████████████████████████████████| 79 kB 6.7 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 49.6 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 61.2 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.6.0-py3-none-any.whl (84 kB)
[K     |████████████████████████████████| 84 kB 3.3 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 60.0 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64

In [1]:
from sentence_transformers import InputExample
from tqdm import tqdm

pairs = []
with open("pairs.tsv",encoding='utf-8') as fp:
    lines = fp.read().split('\n')
    for line in lines:
        if '\t' not in line:
            continue
        else:
            q,p = line.split('\t')
            pairs.append(InputExample(texts=[q,p]))

In [2]:
from sentence_transformers import datasets

batch_size = 8

loader = datasets.NoDuplicatesDataLoader(pairs,batch_size=batch_size)

### t5-base

In [3]:
from sentence_transformers import models,SentenceTransformer

model_t5 = models.Transformer("sentence-transformers/sentence-t5-base")

pooler = models.Pooling(
    model_t5.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

model = SentenceTransformer(modules=[model_t5, pooler])

In [4]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': None, 'do_lower_case': False}) with Transformer model: T5EncoderModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

### MNR loss

In [5]:
from sentence_transformers import losses

loss = losses.MultipleNegativesRankingLoss(model)

In [6]:
epochs = 10
warmup_steps = int(len(loader) * epochs * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    output_path='sentence-t5-base-dureader',
    show_progress_bar=True
)



Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration:   0%|          | 0/128 [00:00<?, ?it/s]

In [7]:
query = "市场动荡如何影响存款准备金率"
xq = model.encode([query]).tolist()

In [9]:
# xq