# TODO

 - Embedding for all the lines of the document
 <!-- - Embeddings for all concepts -->
 <!-- - Each concept has a list of neighboring concepts based on similarity (e.g. cosine similarity) -->
 <!-- - The searched term will be embedded and compared to all concepts -->
 - The searched term will be embedded and compared to all lines of the corpus (with hashing to accelerate)
 <!-- - Return patients having the neighboring concepts of the searched term -->
 - Return patients that have big similarity

In [1]:
# %pip install -U sentence-transformers -q

Note: you may need to restart the kernel to use updated packages.


### Importing

In [6]:
# ----------------------------------- tech ----------------------------------- #
import os
import glob
import pickle
from tqdm import tqdm

# ------------------------- Transformers and freinds ------------------------- #
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from sentence_transformers import SentenceTransformer, util
import torch
import torch.nn.functional as F
import numpy as np

# ----------------------------------- local ---------------------------------- #
from data_preprocessing import Get_and_process_data


device = "cuda"

### Configurations

In [2]:
# Load model from HuggingFace Hub
model_checkpoint = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint)
data_path = "../data/train/txt"

### utils

In [3]:
#Mean Pooling - Take average of all tokens
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


#Encode text
def encode(texts, tokenizer = tokenizer, model= model):
    # Tokenize sentences
    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input, return_dict=True)

    # Perform pooling
    embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    
    return embeddings

def semantic_search_base(query_emb, doc_emb):
    #Compute dot score between query and all document embeddings
    scores = torch.mm(query_emb, doc_emb.transpose(0, 1))[0].cpu().tolist()

    #Combine docs & scores
    doc_score_pairs = list(zip(docs, scores))

    #Sort by decreasing score
    doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)

    #Output passages & scores
    for doc, score in doc_score_pairs:
        print(score, doc)
        
def forward(texts, tokenizer= tokenizer, model= model):
    # Tokenize sentences
    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    model_output = model(**encoded_input, return_dict=True)

    # Perform pooling
    embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    
    return embeddings


def forward_doc(texts, tokenizer= tokenizer, model= model, no_grad= False):
    lines = texts.split("\n")
    
    # Tokenize sentences
    encoded_input_lines = tokenizer(lines, padding=True, truncation=True, return_tensors='pt')
    
    # Compute token embeddings
    if no_grad:
        with torch.no_grad():
            model_output_lines = model(**encoded_input_lines, return_dict=True)
    else :
        model_output_lines = model(**encoded_input_lines, return_dict=True)

    # Perform pooling
    embeddings_lines = mean_pooling(model_output_lines, encoded_input_lines['attention_mask'])

    # NOTE: This is an easy approach
    # another mean pooling over the lines of the document
    embeddings = torch.mean(embeddings_lines, 0).unsqueeze(0)
    
    # Normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    
    return embeddings


### Testing Inference from checkpoint

In [14]:
model =model.eval()

In [123]:
# Sentences we want sentence embeddings for
query = "How many people live in London?"
docs = ["Around 9 Million people live in London", "London is known for its financial district"]

#Encode query and docs
query_emb = encode(query)
doc_emb = encode(docs)

semantic_search_base(query_emb, doc_emb)

0.915637195110321 Around 9 Million people live in London
0.49475765228271484 London is known for its financial district


0.915637195110321 Around 9 Million people live in London


0.49475765228271484 London is known for its financial district

### Testing training

In [6]:
encoded_input = tokenizer(query, padding=True, truncation=True, return_tensors='pt')
model_output = model(**encoded_input, return_dict=True)
# model_output

In [7]:
encoded_input["input_ids"].shape

torch.Size([1, 9])

In [8]:
model_output.last_hidden_state.shape

torch.Size([1, 9, 384])

In [9]:
model_output.pooler_output.shape

torch.Size([1, 384])

In [13]:
model.train()

query = "How many people live in London?"
answer = "Around 9 Million people live in London"

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

q = forward(query)
print("q shape :", q.shape)
a = forward(answer)
print("a shape :", a.shape)

loss = loss_fn(a,q)

optimizer.zero_grad()
# loss.backward()
# optimizer.step()

q shape : torch.Size([1, 384])
a shape : torch.Size([1, 384])


### Getting data

In [125]:
doc = ""

with open("../data/train/txt/018636330_DH.txt") as f:
    doc = f.read()
    
doc_emb = forward_doc(doc)
doc_emb.shape

torch.Size([1, 384])

In [144]:
c_emb= encode("norvasc << test >>")
semantic_search_base(c_emb, doc_emb)

0.26280614733695984 Around 9 Million people live in London


### Saving embeddings

In [None]:
# def embeddings():
#     return np.random.rand(1,384)

In [4]:
# what are the elements in the folder ../data/train/txt/
all_docs = {}
text_files = glob.glob(data_path + os.sep +  "*.txt")
for file in tqdm(text_files, "Encoding documents", ascii=True):
    with open(file) as f:
        doc = f.read()
    file_name = os.path.basename(file).split(".")[0]
    all_docs[file_name] = forward_doc(doc, no_grad=True)

Encoding documents: 100%|##########| 170/170 [05:42<00:00,  2.01s/it]


In [7]:
with open(data_path + os.sep + "all_docs.pkl", "wb") as f:
    pickle.dump(all_docs, f)

In [8]:
# with open(data_path + os.sep + "all_docs.pkl", "rb") as f:
#     D = pickle.load(f)

### Classify the embeddings

In [None]:
# model_checkpoint = "allenai/scibert_scivocab_uncased"
# batch_size = 32
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


# data_loader = Get_and_process_data(tokenizer, train_split=0.95, add_unlabeled=True)
# D = data_loader.get_dataset()
# label_list = data_loader.get_label_list()