In [1]:
import pandas as pd
import random

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Load data

In [2]:
infoSheet = pd.read_csv('data/infoSheets_2023-05-18.csv')
print(infoSheet.isnull().sum())
infoSheet.head(10)

ID                        0
name_en                   0
name_fr                 138
abstract_en              20
abstract_fr             146
description_en            0
description_fr          140
taxonomy heading ids      0
dtype: int64


Unnamed: 0,ID,name_en,name_fr,abstract_en,abstract_fr,description_en,description_fr,taxonomy heading ids
0,84606,ADHD Medication Side Effects: Low Appetite and...,,Stimulants prescribed for ADHD can lead to red...,,Background\r\nStimulant medications for attent...,,0
1,92619,5-HTP (5-hydroxytryptophan),,5-HTP (5-Hydroxytryptophan) is a natural subst...,,What is 5-HTP?\r\n5-HTP (5-Hydroxytryptophan) ...,,0
2,50150,A Simple Way to Swallow Pills: The Head Postur...,Truc simple pour avaler les pilules: La techni...,"Swallowing pills can hard for many children, y...","Il n’est pas seul! Beaucoup d’enfants, de jeun...",\r\n\t\r\n\t\tDoes your child or teen have pro...,\r\n\t\r\n\t\tVotre enfant a-t-il de la diffic...,0
3,8920,Abuse and Domestic Violence,Maltraitance et violence familiale,"Abuse is behaviour used to intimidate, isolate...",La maltraitance est un comportement visant à i...,\r\n\tWhat is Abuse and Domestic Violence?\r\n...,\r\n\tQu&#39;est-ce que la maltraitance et la ...,21958876509365437
4,69660,"ADHD in Children, Youth and Adults: Informatio...",,Attention deficit hyperactivity disorder (ADHD...,,"\r\n\tAbbreviations\r\n\r\n\tADHD, attention-d...",,13
5,20505,ADHD: Managing ADHD Medications,TDAH: Gestion des médicaments pour le TDAH,Although ADHD medications may be helpful for s...,Bien que les médicaments du TDAH peuvent être ...,\r\n\tGeneral Considerations\r\n\r\n\t\r\n\t\t...,\r\n\tObservations générales\r\n\r\n\t\r\n\t\t...,13
6,8917,Alcohol and Substance Use Problems in Children...,Problèmes de l’alcoolisme et la toxicomanie ch...,Many children/youth experiment with alcohol an...,Nombreux sont les enfants et les jeunes qui fo...,\r\n\tDavid's Story\r\n\r\n\tUp until this sch...,\r\n\tL’histoire de David\r\n\r\n\tJusqu’à cet...,12
7,8909,Alcohol: Cutting Back or Quitting Drinking,Problèmes liés à la consommation d'alcool,Drinking alcohol is an accepted practice in ma...,La consommation d'alcool est une pratique acce...,\r\n\tIntroduction\r\n\r\n\tMany people who dr...,\r\n\tIntroduction\r\n\r\n\tNombreuses sont le...,12
8,61003,Alcohol Use Disorder in Adults: Information fo...,,Alcohol use problems are common in primary car...,,\r\n\tEpidemiology\r\n\r\n\tPrevalence\r\n\r\n...,,12266
9,18393,Alcohol Use Disorder in Youth: Information for...,,Many youth will experiment with alcohol and su...,,Epidemiology\r\n\r\n\t\r\n\t\t10% of the&nbsp;...,,12266


In [3]:
taxonomy = pd.read_json('data/taxonomy_headings.json')
taxonomy = taxonomy.drop(['created_at',	'updated_at',	'deleted_at', 'alias_of_id', 'short_description',	'original_id'], axis=1)

taxonomy.head(10)

Unnamed: 0,id,name,description,translations
0,1,Root,Root,"{""name"":{""en"":""Root"",""fr"":null},""description"":..."
1,2,All Mental Health Resources,<p>\r\n\tThe listings of mental health resourc...,"{""name"":{""en"":""All Mental Health Resources"",""f..."
2,3,Crisis and Emergency,<p>\r\n\tRefers to all programs that provide i...,"{""name"":{""en"":""Crisis and Emergency"",""fr"":""Res..."
3,4,"System Navigation, including Information and R...","<p>\r\n\tAre you looking for help, but don&#39...","{""name"":{""en"":""System Navigation, including In..."
4,5,Child Welfare including Children's Aid Society...,<p>The child welfare / child protection system...,"{""name"":{""en"":""Child Welfare including Childre..."
5,6,Emergency Shelter and Housing,<p>\r\n\tThere are various shelters that peopl...,"{""name"":{""en"":""Emergency Shelter and Housing"",..."
6,7,Hospital Emergency Department,<p>\r\n\tIs there an emergency such as medical...,"{""name"":{""en"":""Hospital Emergency Department"",..."
7,8,"Crisis Lines including Telephone, Online and Chat",<p>\r\n\tAre you in a crisis? Crisis lines off...,"{""name"":{""en"":""Crisis Lines including Telephon..."
8,9,Psychiatrists,<p>\r\n\tPsychiatrists are medical doctors who...,"{""name"":{""en"":""Psychiatrists"",""fr"":""Psychiatre..."
9,10,A-Z Mental Health Conditions and Topics,<p>\r\n\tAlphabetical list of mental health to...,"{""name"":{""en"":""A-Z Mental Health Conditions an..."


## Pre-processing

In [4]:
from pre_processing import remove_empty, remove_HTML, remove_new_line

infoSheet['description_en'] = infoSheet['description_en'].apply(lambda x: remove_HTML(x))
infoSheet['description_en'] = infoSheet['description_en'].apply(lambda x: remove_new_line(x))

print('Length of taxonomy before preprocessing:', len(taxonomy.index))
taxonomy['description'] = taxonomy['description'].apply(lambda x: remove_HTML(x))
taxonomy['description'] = taxonomy['description'].apply(lambda x: remove_new_line(x))
# taxonomy = remove_empty('description', taxonomy)
print('Length of taxonomy after preprocessing:', len(taxonomy.index))

Length of taxonomy before preprocessing: 277
Length of taxonomy after preprocessing: 277


## Helper functions

In [5]:
def find_largest_numbers(lst):
    # Create a list of tuples containing numbers and their indices
    indexed_numbers = [(num, index) for index, num in enumerate(lst)]
    
    # Sort the list in descending order based on the numbers
    sorted_numbers = sorted(indexed_numbers, key=lambda x: x[0], reverse=True)
    
    # Extract the ten largest numbers and their indices
    largest_numbers = sorted_numbers[:10]

    return largest_numbers

## Prediction

In [6]:
# Get models - The package will take care of downloading the models automatically
# For best performance: EleutherAI/gpt-j-6B
# EleutherAI/gpt-neo-1.3B
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B").to(device)
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
model.eval()

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 2048)
    (wpe): Embedding(2048, 2048)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPTNeoBlock(
        (ln_1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
        )
        (ln_2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=2048, out_features=8192, bias=True)
          (c_proj):

In [7]:
queries = [
    (1, "I'm searching for a metric for information retrieval."),
]

docs = [
    (2, "The ability of an instrument to measure the accurate value is known as accuracy. In other words, it is the the closeness of the measured value to a standard or true value."),
    (3, "In data analysis, cosine similarity is a measure of similarity between two non-zero vectors defined in an inner product space. Cosine similarity is the cosine of the angle between the vectors; that is, it is the dot product of the vectors divided by the product of their lengths. It follows that the cosine similarity does not depend on the magnitudes of the vectors, but only on their angle. The cosine similarity always belongs to the interval [-1,1]. For example, two proportional vectors have a cosine similarity of 1, two orthogonal vectors have a similarity of 0, and two opposite vectors have a similarity of -1. In some contexts, the component values of the vectors cannot be negative, in which case the cosine similarity is bounded in [0,1]."),
    (1, "The closeness of two or more measurements to each other is known as the precision of a substance. If you weigh a given substance five times and get 3.2 kg each time, then your measurement is very precise but not necessarily accurate. Precision is independent of accuracy."),
]

In [8]:
# queries = [(id, text), (), ...]
# docs = [(id, text), (), ...]

def search(queries, docs, topK=10):
    prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
    
    if not isinstance(queries, list):
        queries = [queries]
    if not isinstance(docs, list):
        docs = [docs]
    if topK < len(docs):
        topK = len(docs)

    result = {'queryIDs': [], 'hits': []}
    for (queryID, query) in queries:
        result['queryIDs'].append(queryID)
        tmp = []
        for (docID, doc) in docs:
            context = prompt.format(doc)

            context_enc = tokenizer.encode(context, add_special_tokens=False)
            continuation_enc = tokenizer.encode(query, add_special_tokens=False)
            # Slice off the last token, as we take its probability from the one before
            model_input = torch.tensor(context_enc+continuation_enc[:-1]).to(device)
            continuation_len = len(continuation_enc)
            input_len, = model_input.shape

            # [seq_len] -> [seq_len, vocab]
            logprobs = torch.nn.functional.log_softmax(model(model_input)[0], dim=-1).cpu()
            # [seq_len, vocab] -> [continuation_len, vocab]
            logprobs = logprobs[input_len-continuation_len:]
            # Gather the log probabilities of the continuation tokens -> [continuation_len]
            logprobs = torch.gather(logprobs, 1, torch.tensor(continuation_enc).unsqueeze(-1)).squeeze(-1)
            score = torch.sum(logprobs).item()
            # The higher (closer to 0), the more similar
            print(f"Document: {doc[:20] + '...'} Score: {score}")
            tmp.append([docID, score])
        tmp = sorted(tmp, key=lambda x: x[1], reverse=True)[:topK]
        result['hits'].append(tmp)
    
    return result

In [9]:
search(queries, docs)

Document: The ability of an in... Score: -40.480674743652344
Document: In data analysis, co... Score: -33.243385314941406
Document: The closeness of two... Score: -46.34759521484375


{'queryIDs': [1],
 'hits': [[[3, -33.243385314941406],
   [2, -40.480674743652344],
   [1, -46.34759521484375]]]}

In [10]:
# Randomly choose 10 infoSheets
search_term_indices = [random.randint(0, len(infoSheet.index)) for i in range(10)]
search_term_indices = [infoSheet['ID'][each] for each in search_term_indices]
search_term_indices

[49678, 52861, 54525, 55711, 24356, 86104, 86104, 53112, 26730, 24894]