In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [17]:
!pip install pyunpack
!pip install patool
from pyunpack import Archive

Archive('/content/drive/MyDrive/fine_tuned_bi_encoder.rar').extractall('/content/drive/MyDrive')
Archive('/content/drive/MyDrive/fine_tuned_cross_encoder.rar').extractall('/content/drive/MyDrive')

Collecting patool
  Using cached patool-3.0.3-py2.py3-none-any.whl.metadata (4.3 kB)
Using cached patool-3.0.3-py2.py3-none-any.whl (98 kB)
Installing collected packages: patool
Successfully installed patool-3.0.3


In [2]:
!pip install -U sentence-transformers rank_bm25 datasets

Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━

In [3]:
import json
import torch
import random
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

  from tqdm.autonotebook import tqdm, trange


In [4]:
# prepare for the training data
"""
we need to look for query, relevent text and irrelevant text from the files

query_file:qid\tquery
test_file: the first field is qid, the third field is PMID
dataset: a list include some dict element, the dict element includes PMID and text

we should create a training data from these three files
"""
passages = []
with open('/content/drive/MyDrive/FIR-s05-medline.json', 'r') as f:
    for line in f:
        doc = json.loads(line)
        if 'index' in line:
            continue
        elif 'PMID' in doc:
            pmid = doc['PMID']

            ti = doc['TI'] if 'TI' in doc else ''

            ab = doc['AB'] if 'AB' in doc else ''

            text = ti + ". " + ab
            passages.append({'PMID': pmid, 'text': text})
        else:
            print('there is an error!')

def add_noise_to_text(text):
    words = text.split()
    num_words = len(words)

    # random noise
    noise_ratio = random.uniform(0, 0.1)

    num_noisy_words = max(1, int(num_words * noise_ratio))

    for _ in range(num_noisy_words):

        noise_type = random.choice(['replace_char', 'delete_word', 'shuffle_words'])

        if noise_type == 'replace_char' and num_words > 0:
            word_idx = random.randint(0, num_words - 1)
            char_idx = random.randint(0, len(words[word_idx]) - 1)
            noisy_word = list(words[word_idx])
            noisy_word[char_idx] = random.choice('abcdefghijklmnopqrstuvwxyz')
            words[word_idx] = ''.join(noisy_word)

        elif noise_type == 'delete_word' and num_words > 1:
            del words[random.randint(0, num_words - 1)]
            num_words -= 1

        elif noise_type == 'shuffle_words' and num_words > 1:
            start_idx = random.randint(0, num_words - 2)
            end_idx = random.randint(start_idx + 1, num_words)
            words[start_idx:end_idx] = random.sample(words[start_idx:end_idx], len(words[start_idx:end_idx]))

    return ' '.join(words)

def augment_training_data(query_file, test_file, dataset, augment_factor=2):

    queries = {}
    with open(query_file, 'r') as qf:
        for line in qf:
            qid, query_text = line.strip().split('\t')
            queries[qid] = query_text

    relevant_pmids = {}
    with open(test_file, 'r') as tf:
        for line in tf:
            fields = line.strip().split('\t')
            qid = fields[0]
            pmid = fields[2]
            if qid not in relevant_pmids:
                relevant_pmids[qid] = []
            relevant_pmids[qid].append(pmid)

    pmid_to_text = {doc['PMID']: doc['text'] for doc in dataset}

    augmented_training_data = []

    for qid, query in queries.items():
        if qid in relevant_pmids:
            relevant_texts = [pmid_to_text[pmid] for pmid in relevant_pmids[qid] if pmid in pmid_to_text]

            for _ in range(augment_factor):
                for rel_text in relevant_texts:
                    noisy_rel_text = add_noise_to_text(rel_text)

                    irrelevant_pmids = set(pmid_to_text.keys()) - set(relevant_pmids[qid])
                    irr_pmid = random.choice(list(irrelevant_pmids))

                    augmented_training_data.append({
                        'query': query,
                        'relevant_text': noisy_rel_text,
                        'irrelevant_text': pmid_to_text[irr_pmid]
                    })

    return augmented_training_data



train = augment_training_data('/content/drive/MyDrive/FIR-s05-training-queries-simple.txt','/content/drive/MyDrive/FIR-s05-training-qrels.txt',passages,augment_factor=50)


In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses, CrossEncoder
from torch.utils.data import DataLoader

bi_encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3') # msmarco-distilbert-base-v4

train_examples = []
for data in train:
    train_examples.append(InputExample(
        texts=[data['query'], data['relevant_text'], data['irrelevant_text']]
    ))
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

train_loss = losses.TripletLoss(model=bi_encoder)

bi_encoder.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=100,show_progress_bar=True)

bi_encoder.save('fine_tuned_bi_encoder')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.72k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/430 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Step,Training Loss
500,0.5517
1000,0.0108


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [None]:
from torch.utils.data import DataLoader

def prepare_crossencoder_examples(training_data):
    examples = []
    for data in training_data:
        query = data['query']
        relevant_text = data['relevant_text']
        irrelevant_text = data['irrelevant_text']

        examples.append(InputExample(texts=[query, relevant_text], label=1))

        examples.append(InputExample(texts=[query, irrelevant_text], label=0))

    return examples

train_examples = prepare_crossencoder_examples(train)

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2', num_labels=1) # cross-encoder/ms-marco-MiniLM-L-6-v2

num_epochs = 3
warmup_steps = 100
learning_rate = 2e-5

cross_encoder.fit(
    train_dataloader=train_dataloader,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    optimizer_params={'lr': learning_rate},
    show_progress_bar=True
)

cross_encoder.save('fine_tuned_cross_encoder')

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/875 [00:00<?, ?it/s]

Iteration:   0%|          | 0/875 [00:00<?, ?it/s]

Iteration:   0%|          | 0/875 [00:00<?, ?it/s]

In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses, CrossEncoder

bi_encoder = SentenceTransformer('fine_tuned_bi_encoder')
cross_encoder = CrossEncoder('fine_tuned_cross_encoder')

bi_encoder.max_seq_length = 512
top_k = 500

encoded_passages = []

with open('/content/drive/MyDrive/FIR-s05-medline.json', 'r') as f:
    for line in f:
        doc = json.loads(line)
        if 'index' in line:
            continue
        elif 'PMID' in doc:
            pmid = doc['PMID']

            ti = doc['TI'] if 'TI' in doc else ''

            ab = doc['AB'] if 'AB' in doc else ''

            text = ti + ". " + ab
            encoded_passages.append({'PMID': pmid, 'text': text, 'embedding': None}) 
        else:
            print('there is an error!')

for passage in tqdm(encoded_passages, desc="Encoding passages"):
    text = passage['text']
    embedding = bi_encoder.encode(text, convert_to_tensor=True)
    passage['embedding'] = embedding

Encoding passages: 100%|██████████| 257741/257741 [32:44<00:00, 131.17it/s]


In [9]:
from sentence_transformers import util
def search_bi_encoder(query, bi_encoder, encoded_passages, top_k=500):

    query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    passage_embeddings = torch.stack([p['embedding'] for p in encoded_passages])


    similarities = util.cos_sim(query_embedding, passage_embeddings)[0]

    top_n_indices = similarities.argsort(descending=True)[:top_k]

    top_results = [{'PMID': encoded_passages[i]['PMID'],
                    'text': encoded_passages[i]['text'],
                    'bi_encoder_score': similarities[i].item()}
                   for i in top_n_indices]

    return top_results

def rerank_with_cross_encoder(query, cross_encoder, bi_encoder_results, top_k=100):

    query_text_pairs = [(query, res['text']) for res in bi_encoder_results]

    cross_scores = cross_encoder.predict(query_text_pairs)

    for i, score in enumerate(cross_scores):
        bi_encoder_results[i]['cross_encoder_score'] = score

    reranked_results = sorted(bi_encoder_results, key=lambda x: x['cross_encoder_score'], reverse=True)[:top_k]

    return reranked_results

def process_test_set_with_rerank(test_file, bi_encoder, cross_encoder, encoded_passages, output_bi_file, output_rerank_file, top_k_bi=500, top_k_rerank=100):

    with open(test_file, 'r') as f, \
         open(output_bi_file, 'w') as bi_out_f, \
         open(output_rerank_file, 'w') as rerank_out_f:

        for line in f:
            qid, query_text = line.strip().split('\t')

            bi_encoder_results = search_bi_encoder(query_text, bi_encoder, encoded_passages, top_k=top_k_bi)

            reranked_results = rerank_with_cross_encoder(query_text, cross_encoder, bi_encoder_results, top_k=top_k_rerank)

            # bi_out_f.write(f"bi_encoder results for QID: {qid}\n")
            for res in bi_encoder_results:
                bi_out_f.write(f"{qid}\t{res['PMID']}\t{res['bi_encoder_score']}\n")
            # bi_out_f.write("\n")

            # rerank_out_f.write(f"cross_encoder re-ranked results for QID: {qid}\n")
            for res in reranked_results:
                rerank_out_f.write(f"{qid}\t{res['PMID']}\t{res['cross_encoder_score']}\n")
            # rerank_out_f.write("\n")

# process_test_set_with_rerank('/content/drive/MyDrive/FIR-s05-training-queries-simple.txt', bi_encoder, cross_encoder, encoded_passages, '/content/drive/MyDrive/Colab Notebooks/bi_encoder_tune.txt', '/content/drive/MyDrive/Colab Notebooks/CrossEncoder_tune.txt')

In [23]:
def print_result(result_file,test_file):
  """
  result_file.txt: qid\tPMID\tscore
  test_file.txt: this file includes some inrelevant fileds,we only need the first qid and the third PMID
  """
  # read document of retrieved
  retrieved_results = {}
  with open(result_file, 'r') as rf:
    for line in rf:
      if line.strip():
        # print(line)
        qid = line.split('\t')[0]
        pmid = line.split('\t')[1]
        score = line.split('\t')[2]
        if qid not in retrieved_results:
          retrieved_results[qid] = []
        retrieved_results[qid].append(pmid)
  # print(retrieved_results)
  # read the test file
  relevant_docs = {}
  with open(test_file, 'r') as tf:
    for line in tf:
        # print(line)
        fields = line.strip().split('\t')
        qid = fields[0]
        pmid = fields[2]
        if qid not in relevant_docs:
            relevant_docs[qid] = []
        relevant_docs[qid].append(pmid)
  # print(relevant_docs)
  # calculate the performance the IR model
  total_precision = 0
  total_recall = 0
  total_ap = 0
  num_queries = len(relevant_docs)

  for qid in relevant_docs:
      if qid not in retrieved_results:
          print(f"No results found for QID: {qid}")
          continue

      retrieved = retrieved_results[qid]
      relevant = set(relevant_docs[qid])
      # print(retrieved)
      # print(relevant)

      tp = 0  # true positives
      precision_at_k = []
      for k, pmid in enumerate(retrieved, 1):
          if pmid in relevant:
              tp += 1
              precision_at_k.append(tp / k)

      precision = tp / len(retrieved) if len(retrieved) > 0 else 0
      recall = tp / len(relevant) if len(relevant) > 0 else 0
      ap = sum(precision_at_k) / len(relevant) if len(relevant) > 0 else 0

      total_precision += precision
      total_recall += recall
      total_ap += ap

  avg_precision = total_precision / num_queries
  avg_recall = total_recall / num_queries
  mean_ap = total_ap / num_queries
  # print the performance
  print(f"Average Precision: {avg_precision:.4f}")
  print(f"Average Recall: {avg_recall:.4f}")
  print(f"Mean Average Precision (MAP): {mean_ap:.4f}")

In [None]:
print_result('/content/drive/MyDrive/Colab Notebooks/CrossEncoder_tune.txt','/content/drive/MyDrive/FIR-s05-training-qrels.txt')

Average Precision: 0.0282
Average Recall: 0.7806
Mean Average Precision (MAP): 0.3410


In [15]:
import numpy as np

def read_results(result_file):

    queries = {}
    with open(result_file, 'r') as f:
        for line in f:
            line = line.strip('\n')
            if not line:
                continue
            qid = line.split('\t')[0]
            pmid = line.split('\t')[1]

            if qid not in queries:
                queries[qid] = {'retrieved_docs': []}
            queries[qid]['retrieved_docs'].append(pmid)
    return queries

def read_test_file(test_file):
    relevant_docs = {}
    with open(test_file, 'r') as f:
        for line in f:
          try:
            fields = line.strip().split('\t')
            qid = fields[0]
            pmid = fields[2]
            if qid not in relevant_docs:
                relevant_docs[qid] = []
            relevant_docs[qid].append(pmid)
          except:
            print(line)
            continue
    return relevant_docs

def merge_results_and_relevant(queries, relevant_docs):
    merged_queries = []
    for qid, query_data in queries.items():
        if qid in relevant_docs:
            merged_queries.append({
                'relevant_docs': relevant_docs[qid],
                'retrieved_docs': query_data['retrieved_docs']
            })
    return merged_queries

# Success@K
def success_at_k(relevant_docs, retrieved_docs, k):
    return int(any(doc in relevant_docs for doc in retrieved_docs[:k]))

# Precision@K
def precision_at_k(relevant_docs, retrieved_docs, k):
    retrieved_k = retrieved_docs[:k]
    relevant_retrieved = [doc for doc in retrieved_k if doc in relevant_docs]
    return len(relevant_retrieved) / k

# R-Precision
def r_precision(relevant_docs, retrieved_docs):
    r = len(relevant_docs)
    return precision_at_k(relevant_docs, retrieved_docs, r)

# Precision@Recall
def precision_at_recall(relevant_docs, retrieved_docs, recall_level):
    num_relevant = len(relevant_docs)
    relevant_retrieved = 0
    precisions = []

    for i, doc in enumerate(retrieved_docs):
        if doc in relevant_docs:
            relevant_retrieved += 1
            recall = relevant_retrieved / num_relevant
            precision = relevant_retrieved / (i + 1)
            precisions.append(precision)
            if recall >= recall_level:
                return precision

    return precisions[-1] if precisions else 0

# Average Precision (AP)
def average_precision(relevant_docs, retrieved_docs):
    num_relevant = len(relevant_docs)
    relevant_retrieved = 0
    precision_sum = 0

    for i, doc in enumerate(retrieved_docs):
        if doc in relevant_docs:
            relevant_retrieved += 1
            precision = relevant_retrieved / (i + 1)
            precision_sum += precision

    if num_relevant == 0:
        return 0

    return precision_sum / num_relevant

# Mean Average Precision (MAP)
def mean_average_precision(queries):
    ap_sum = 0
    for query in queries:
        relevant_docs = query['relevant_docs']
        retrieved_docs = query['retrieved_docs']
        ap_sum += average_precision(relevant_docs, retrieved_docs)

    return ap_sum / len(queries)


def print_metrics(queries):
    success_at_1 = np.mean([success_at_k(query['relevant_docs'], query['retrieved_docs'], 1) for query in queries])
    success_at_5 = np.mean([success_at_k(query['relevant_docs'], query['retrieved_docs'], 5) for query in queries])
    success_at_10 = np.mean([success_at_k(query['relevant_docs'], query['retrieved_docs'], 10) for query in queries])

    r_precision_mean = np.mean([r_precision(query['relevant_docs'], query['retrieved_docs']) for query in queries])

    precision_at_1 = np.mean([precision_at_k(query['relevant_docs'], query['retrieved_docs'], 1) for query in queries])
    precision_at_5 = np.mean([precision_at_k(query['relevant_docs'], query['retrieved_docs'], 5) for query in queries])
    precision_at_10 = np.mean([precision_at_k(query['relevant_docs'], query['retrieved_docs'], 10) for query in queries])
    precision_at_50 = np.mean([precision_at_k(query['relevant_docs'], query['retrieved_docs'], 50) for query in queries])
    precision_at_100 = np.mean([precision_at_k(query['relevant_docs'], query['retrieved_docs'], 100) for query in queries])

    precision_at_recall_values = [np.mean([precision_at_recall(query['relevant_docs'], query['retrieved_docs'], recall_level)
                                    for query in queries]) for recall_level in np.arange(0, 1.1, 0.1)]

    map_score = mean_average_precision(queries)

    print(f"mean success_at_1              {success_at_1:.5f}")
    print(f"mean success_at_5              {success_at_5:.5f}")
    print(f"mean success_at_10             {success_at_10:.5f}")
    print(f"mean r_precision               {r_precision_mean:.5f}")
    print(f"mean precision_at_1            {precision_at_1:.5f}")
    print(f"mean precision_at_5            {precision_at_5:.5f}")
    print(f"mean precision_at_10           {precision_at_10:.5f}")
    print(f"mean precision_at_50           {precision_at_50:.5f}")
    print(f"mean precision_at_100          {precision_at_100:.5f}")

    for i, recall_level in enumerate(np.arange(0, 1.1, 0.1)):
        print(f"mean precision_at_recall_{i:02}    {precision_at_recall_values[i]:.5f}")

    print(f"mean average_precision         {map_score:.5f}")


queries_test = merge_results_and_relevant(read_results('/content/drive/MyDrive/Colab Notebooks/CrossEncoder_tune.txt'), read_test_file('/content/drive/MyDrive/FIR-s05-training-qrels.txt'))

print_metrics(queries_test)


mean success_at_1              0.36842
mean success_at_5              0.65789
mean success_at_10             0.71053
mean r_precision               0.26804
mean precision_at_1            0.36842
mean precision_at_5            0.18421
mean precision_at_10           0.12105
mean precision_at_50           0.05000
mean precision_at_100          0.02737
mean precision_at_recall_00    0.46930
mean precision_at_recall_01    0.47776
mean precision_at_recall_02    0.45762
mean precision_at_recall_03    0.40843
mean precision_at_recall_04    0.37884
mean precision_at_recall_05    0.37911
mean precision_at_recall_06    0.33693
mean precision_at_recall_07    0.33054
mean precision_at_recall_08    0.31584
mean precision_at_recall_09    0.31404
mean precision_at_recall_10    0.31404
mean average_precision         0.32482


In [25]:
queries_test = merge_results_and_relevant(read_results('/content/drive/MyDrive/Colab Notebooks/CrossEncoder_tune-v2.txt'), read_test_file('/content/drive/MyDrive/training-qrels_large.txt'))
print_metrics(queries_test)


mean success_at_1              0.32000
mean success_at_5              0.62000
mean success_at_10             0.70000
mean r_precision               0.24207
mean precision_at_1            0.32000
mean precision_at_5            0.20400
mean precision_at_10           0.15800
mean precision_at_50           0.05840
mean precision_at_100          0.03680
mean precision_at_recall_00    0.44442
mean precision_at_recall_01    0.43276
mean precision_at_recall_02    0.40604
mean precision_at_recall_03    0.34649
mean precision_at_recall_04    0.31121
mean precision_at_recall_05    0.29260
mean precision_at_recall_06    0.29109
mean precision_at_recall_07    0.28184
mean precision_at_recall_08    0.28040
mean precision_at_recall_09    0.27899
mean precision_at_recall_10    0.27899
mean average_precision         0.24030


In [None]:
queries_test = merge_results_and_relevant(read_results('/content/drive/MyDrive/Colab Notebooks/bm25.txt'), read_test_file('/content/drive/MyDrive/FIR-s05-training-qrels.txt'))
print_metrics(queries_test)

mean success_at_1              0.13158
mean success_at_5              0.23684
mean success_at_10             0.26316
mean r_precision               0.06684
mean precision_at_1            0.13158
mean precision_at_5            0.06842
mean precision_at_10           0.04474
mean precision_at_50           0.01579
mean precision_at_100          0.01000
mean precision_at_recall_00    0.18399
mean precision_at_recall_01    0.17741
mean precision_at_recall_02    0.16884
mean precision_at_recall_03    0.15944
mean precision_at_recall_04    0.13379
mean precision_at_recall_05    0.13409
mean precision_at_recall_06    0.11341
mean precision_at_recall_07    0.11341
mean precision_at_recall_08    0.11341
mean precision_at_recall_09    0.11341
mean precision_at_recall_10    0.11341
mean average_precision         0.09125


In [None]:
# prepare for the training data
"""
we need to look for query, relevent text and irrelevant text from the files

query_file:qid\tquery
test_file: the first field is qid, the third field is PMID
dataset: a list include some dict element, the dict element includes PMID and text

we should create a training data from these three files
"""
passages = []
with open('/content/drive/MyDrive/trec-medline_large.json', 'r') as f:
    for line in f:
        doc = json.loads(line)
        if 'index' in line:
            continue
        elif 'PMID' in doc:

            pmid = doc['PMID']

            ti = doc['TI'] if 'TI' in doc else ''

            ab = doc['AB'] if 'AB' in doc else ''

            text = ti + ". " + ab
            passages.append({'PMID': pmid, 'text': text})
        else:
            print('there is an error!')

def add_noise_to_text(text):
    words = text.split()
    num_words = len(words)

    noise_ratio = random.uniform(0, 0.1)

    num_noisy_words = max(1, int(num_words * noise_ratio))

    for _ in range(num_noisy_words):
        noise_type = random.choice(['replace_char', 'delete_word', 'shuffle_words'])

        if noise_type == 'replace_char' and num_words > 0:
            word_idx = random.randint(0, num_words - 1)
            char_idx = random.randint(0, len(words[word_idx]) - 1)
            noisy_word = list(words[word_idx])
            noisy_word[char_idx] = random.choice('abcdefghijklmnopqrstuvwxyz')
            words[word_idx] = ''.join(noisy_word)

        elif noise_type == 'delete_word' and num_words > 1:
            del words[random.randint(0, num_words - 1)]
            num_words -= 1

        elif noise_type == 'shuffle_words' and num_words > 1:
            start_idx = random.randint(0, num_words - 2)
            end_idx = random.randint(start_idx + 1, num_words)
            words[start_idx:end_idx] = random.sample(words[start_idx:end_idx], len(words[start_idx:end_idx]))

    return ' '.join(words)

def augment_training_data(query_file, test_file, dataset, augment_factor=2):

    queries = {}
    with open(query_file, 'r') as qf:
        for line in qf:
            qid, query_text = line.strip().split('\t')
            queries[qid] = query_text

    relevant_pmids = {}
    with open(test_file, 'r') as tf:
        for line in tf:
            fields = line.strip().split('\t')
            qid = fields[0]
            pmid = fields[2]
            if qid not in relevant_pmids:
                relevant_pmids[qid] = []
            relevant_pmids[qid].append(pmid)

    pmid_to_text = {doc['PMID']: doc['text'] for doc in dataset}

    augmented_training_data = []

    for qid, query in queries.items():
        if qid in relevant_pmids:
            relevant_texts = [pmid_to_text[pmid] for pmid in relevant_pmids[qid] if pmid in pmid_to_text]

            for _ in range(augment_factor):
                for rel_text in relevant_texts:
                    noisy_rel_text = add_noise_to_text(rel_text)

                    irrelevant_pmids = set(pmid_to_text.keys()) - set(relevant_pmids[qid])
                    irr_pmid = random.choice(list(irrelevant_pmids))

                    augmented_training_data.append({
                        'query': query,
                        'relevant_text': noisy_rel_text,
                        'irrelevant_text': pmid_to_text[irr_pmid]
                    })

    return augmented_training_data



train = augment_training_data('/content/drive/MyDrive/training-queries-simple_large.txt','/content/drive/MyDrive/training-qrels_large.txt',passages,augment_factor=50)


In [5]:
from sentence_transformers import SentenceTransformer, InputExample, losses, CrossEncoder
from torch.utils.data import DataLoader

bi_encoder = SentenceTransformer('msmarco-distilbert-base-v4')

train_examples = []
for data in train:
    train_examples.append(InputExample(
        texts=[data['query'], data['relevant_text'], data['irrelevant_text']]
    ))
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

train_loss = losses.TripletLoss(model=bi_encoder)

bi_encoder.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=100,show_progress_bar=True)

bi_encoder.save('fine_tuned_bi_encoder-v2')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.75k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/545 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/265M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/319 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Step,Training Loss
500,0.2222
1000,0.0015


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [None]:
from torch.utils.data import DataLoader

def prepare_crossencoder_examples(training_data):
    examples = []
    for data in training_data:
        query = data['query']
        relevant_text = data['relevant_text']
        irrelevant_text = data['irrelevant_text']

        examples.append(InputExample(texts=[query, relevant_text], label=1))

        examples.append(InputExample(texts=[query, irrelevant_text], label=0))

    return examples

train_examples = prepare_crossencoder_examples(train)

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', num_labels=1)

num_epochs = 3
warmup_steps = 100
learning_rate = 2e-5

cross_encoder.fit(
    train_dataloader=train_dataloader,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    optimizer_params={'lr': learning_rate},
    show_progress_bar=True
)

cross_encoder.save('fine_tuned_cross_encoder-v2')

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/875 [00:00<?, ?it/s]

Iteration:   0%|          | 0/875 [00:00<?, ?it/s]

Iteration:   0%|          | 0/875 [00:00<?, ?it/s]

In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses, CrossEncoder

bi_encoder = SentenceTransformer('fine_tuned_bi_encoder-v2')
cross_encoder = CrossEncoder('fine_tuned_cross_encoder-v2')

bi_encoder.max_seq_length = 512
top_k = 500

encoded_passages = []

with open('/content/drive/MyDrive/FIR-s05-medline.json', 'r') as f:
    for line in f:
        doc = json.loads(line)
        if 'index' in line:
            continue
        elif 'PMID' in doc:
            pmid = doc['PMID']

            ti = doc['TI'] if 'TI' in doc else ''

            ab = doc['AB'] if 'AB' in doc else ''

            text = ti + ". " + ab
            encoded_passages.append({'PMID': pmid, 'text': text, 'embedding': None})
        else:
            print('there is an error!')

for passage in tqdm(encoded_passages, desc="Encoding passages"):
    text = passage['text']
    embedding = bi_encoder.encode(text, convert_to_tensor=True)
    passage['embedding'] = embedding

Encoding passages: 100%|██████████| 257741/257741 [40:49<00:00, 105.21it/s]


In [18]:
process_test_set_with_rerank('/content/drive/MyDrive/FIR-s05-training-queries-simple.txt', bi_encoder, cross_encoder, encoded_passages, '/content/drive/MyDrive/Colab Notebooks/bi_encoder_tune_v2.txt', '/content/drive/MyDrive/Colab Notebooks/CrossEncoder_tune_v2.txt')
queries_test = merge_results_and_relevant(read_results('/content/drive/MyDrive/Colab Notebooks/CrossEncoder_tune_v2.txt'), read_test_file('/content/drive/MyDrive/FIR-s05-training-qrels.txt'))
print_metrics(queries_test)

mean success_at_1              0.47368
mean success_at_5              0.78947
mean success_at_10             0.89474
mean r_precision               0.44107
mean precision_at_1            0.47368
mean precision_at_5            0.24737
mean precision_at_10           0.16842
mean precision_at_50           0.06368
mean precision_at_100          0.03579
mean precision_at_recall_00    0.61632
mean precision_at_recall_01    0.61712
mean precision_at_recall_02    0.60035
mean precision_at_recall_03    0.56612
mean precision_at_recall_04    0.52633
mean precision_at_recall_05    0.52366
mean precision_at_recall_06    0.49134
mean precision_at_recall_07    0.49205
mean precision_at_recall_08    0.47940
mean precision_at_recall_09    0.46072
mean precision_at_recall_10    0.46072
mean average_precision         0.52489


In [None]:
import os, tarfile

import os
from google.colab import files

def make_targz_one_by_one(output_filename, source_dir):
  tar = tarfile.open(output_filename,"w")
  for root,dir_name,files_list in os.walk(source_dir):
    for file in files_list:
      pathfile = os.path.join(root, file)
      tar.add(pathfile)
  tar.close()

  files.download(output_filename)


make_targz_one_by_one('fine_tuned_bi_encoder_export', '/content/fine_tuned_bi_encoder')
make_targz_one_by_one('fine_tuned_cross_encoder_export', '/content/fine_tuned_cross_encoder')
make_targz_one_by_one('fine_tuned_bi_encoder_v2_export','/content/fine_tuned_bi_encoder_v2')
make_targz_one_by_one('fine_tuned_cross_encoder_v2_export','/content/fine_tuned_cross_encoder_v2')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>