In [61]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from scipy.spatial.distance import cosine
import random

device = torch.device("mps" if getattr(torch,'has_mps',False) else "cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='mps')

## Load data

In [62]:
taxonomy = pd.read_json('data/taxonomy_headings.json')
taxonomy = taxonomy.drop(['created_at',	'updated_at',	'deleted_at', 'alias_of_id', 'short_description',	'original_id'], axis=1)

taxonomy.head(10)

Unnamed: 0,id,name,description,translations
0,1,Root,Root,"{""name"":{""en"":""Root"",""fr"":null},""description"":..."
1,2,All Mental Health Resources,<p>\r\n\tThe listings of mental health resourc...,"{""name"":{""en"":""All Mental Health Resources"",""f..."
2,3,Crisis and Emergency,<p>\r\n\tRefers to all programs that provide i...,"{""name"":{""en"":""Crisis and Emergency"",""fr"":""Res..."
3,4,"System Navigation, including Information and R...","<p>\r\n\tAre you looking for help, but don&#39...","{""name"":{""en"":""System Navigation, including In..."
4,5,Child Welfare including Children's Aid Society...,<p>The child welfare / child protection system...,"{""name"":{""en"":""Child Welfare including Childre..."
5,6,Emergency Shelter and Housing,<p>\r\n\tThere are various shelters that peopl...,"{""name"":{""en"":""Emergency Shelter and Housing"",..."
6,7,Hospital Emergency Department,<p>\r\n\tIs there an emergency such as medical...,"{""name"":{""en"":""Hospital Emergency Department"",..."
7,8,"Crisis Lines including Telephone, Online and Chat",<p>\r\n\tAre you in a crisis? Crisis lines off...,"{""name"":{""en"":""Crisis Lines including Telephon..."
8,9,Psychiatrists,<p>\r\n\tPsychiatrists are medical doctors who...,"{""name"":{""en"":""Psychiatrists"",""fr"":""Psychiatre..."
9,10,A-Z Mental Health Conditions and Topics,<p>\r\n\tAlphabetical list of mental health to...,"{""name"":{""en"":""A-Z Mental Health Conditions an..."


In [63]:
infoSheet = pd.read_csv("data/infoSheets_2023-05-18.csv")
print(infoSheet.isnull().sum())
# infoSheet = infoSheet.dropna(subset=['abstract_en'])
# infoSheet.reset_index(drop=True, inplace=True)
print('\nNumber of rows: ', len(infoSheet.index))
infoSheet.head(5)

ID                        0
name_en                   0
name_fr                 138
abstract_en              20
abstract_fr             146
description_en            0
description_fr          140
taxonomy heading ids      0
dtype: int64

Number of rows:  346


Unnamed: 0,ID,name_en,name_fr,abstract_en,abstract_fr,description_en,description_fr,taxonomy heading ids
0,84606,ADHD Medication Side Effects: Low Appetite and...,,Stimulants prescribed for ADHD can lead to red...,,Background\r\nStimulant medications for attent...,,0
1,92619,5-HTP (5-hydroxytryptophan),,5-HTP (5-Hydroxytryptophan) is a natural subst...,,What is 5-HTP?\r\n5-HTP (5-Hydroxytryptophan) ...,,0
2,50150,A Simple Way to Swallow Pills: The Head Postur...,Truc simple pour avaler les pilules: La techni...,"Swallowing pills can hard for many children, y...","Il n’est pas seul! Beaucoup d’enfants, de jeun...",\r\n\t\r\n\t\tDoes your child or teen have pro...,\r\n\t\r\n\t\tVotre enfant a-t-il de la diffic...,0
3,8920,Abuse and Domestic Violence,Maltraitance et violence familiale,"Abuse is behaviour used to intimidate, isolate...",La maltraitance est un comportement visant à i...,\r\n\tWhat is Abuse and Domestic Violence?\r\n...,\r\n\tQu&#39;est-ce que la maltraitance et la ...,21958876509365437
4,69660,"ADHD in Children, Youth and Adults: Informatio...",,Attention deficit hyperactivity disorder (ADHD...,,"\r\n\tAbbreviations\r\n\r\n\tADHD, attention-d...",,13


## Helper function

In [64]:
def find_largest_numbers(lst):
    # Create a list of tuples containing numbers and their indices
    indexed_numbers = [(num, index) for index, num in enumerate(lst)]
    
    # Sort the list in descending order based on the numbers
    sorted_numbers = sorted(indexed_numbers, key=lambda x: x[0], reverse=True)
    
    # Extract the ten largest numbers and their indices
    largest_numbers = sorted_numbers[:10]

    return largest_numbers

## Load embeddings

In [65]:
infoSheet_embeddings = torch.load('data/embeddings/sgpt_infoSheet_embeddings.pt')
print(infoSheet_embeddings.shape)

taxonomy_embeddings = torch.load('data/embeddings/sgpt_taxonomy_embeddings.pt')
print(taxonomy_embeddings.shape)

torch.Size([326, 2048])
torch.Size([192, 2048])


## Prediction

In [66]:
# Randomly choose 10 info sheets
search_term_indices = [random.randint(0, len(infoSheet_embeddings)) for i in range(10)]
search_term_indices

[311, 312, 18, 242, 185, 4, 163, 286, 206, 175]

In [67]:
predictions = {'infoSheet_id': [], 'pred_taxonomy_id': [], 'similarity_score': [], 'gold_taxonomy_id': []}

for search_term_idx in search_term_indices:
    cos_sim = []
    for taxonomy_idx in range(len(taxonomy_embeddings)):
        cos_sim.append(1 - cosine(infoSheet_embeddings[search_term_idx], taxonomy_embeddings[taxonomy_idx]))
    
    lst = find_largest_numbers(cos_sim)
    for each in lst:
        if each[0] >= 0.7:
            predictions['infoSheet_id'].append(search_term_idx)
            predictions['pred_taxonomy_id'].append(each[1])
            predictions['similarity_score'].append(each[0])
            predictions['gold_taxonomy_id'].append(infoSheet['taxonomy heading ids'][search_term_idx].split(','))

print('Length of predictions: ', len(predictions['infoSheet_id']))
predictions = pd.DataFrame(predictions)
predictions.head(10)

Length of predictions:  11


Unnamed: 0,infoSheet_id,pred_taxonomy_id,similarity_score,gold_taxonomy_id
0,312,67,0.904687,"[0, 59]"
1,18,74,0.728077,"[0, 75, 50, 9, 36, 54, 37]"
2,18,15,0.712463,"[0, 75, 50, 9, 36, 54, 37]"
3,18,171,0.706356,"[0, 75, 50, 9, 36, 54, 37]"
4,185,176,0.72483,"[0, 59, 3, 207]"
5,4,12,0.820638,"[0, 13]"
6,163,58,0.718912,"[0, 276]"
7,163,43,0.700634,"[0, 276]"
8,286,147,0.712204,[0]
9,175,176,0.835282,[0]


## Evaluation

In [68]:
infoSheet.loc[infoSheet.index == 312]

Unnamed: 0,ID,name_en,name_fr,abstract_en,abstract_fr,description_en,description_fr,taxonomy heading ids
312,8911,Supporting a Family Member or Friend with Ment...,Soutenir un ou une membre de la famille ou ami...,Chances are high that someone you know has a m...,Il est fort probable que vous connaissiez quel...,\r\n\tIntroduction\r\n\r\n\tIf someone you lov...,\r\n\tIntroduction\r\n\r\n\tSi une personne qu...,59


In [70]:
taxonomy.loc[taxonomy.index == 59]

Unnamed: 0,id,name,description,translations
59,60,Supportive Counselling,"<p class=""MsoNormal"">\r\n\tSupportive counsell...","{""name"":{""en"":""Supportive Counselling"",""fr"":""C..."


In [52]:
retrieved_relevant = 0

for idx, row in predictions.iterrows():
    if row['pred_taxonomy_id'] + 1 in row['gold_taxonomy_id']:
        retrieved_relevant += 1

print(retrieved_relevant)
print('Precision: ' + str(retrieved_relevant / len(predictions.index)))

0
Precision: 0.0
