In [1]:
'''
defining experiments
'''
SEED = 44
def experiments():
    experiments = [
        {"name":"Sentence similarity",
             "strategy": "one_vs_all",
             "model_type":"similarity", 
             "model":"sentence-transformers/all-MiniLM-L6-v2"},
        {"name":"Zero-shot simple", 
             "strategy": "expand_parents", 
             "model_type":"zero-shot", 
             "model":"facebook/bart-large-mnli"},
        {"name":"Zero-shot hypothesis template", 
             "strategy": "expand_parents", 
             "model_type":"zero-shot-hypothesis", 
             "model":"MoritzLaurer/deberta-v3-large-zeroshot-v2.0"},
    ]
import torch
torch.cuda.is_available()


True

In [8]:
import torch
import pickle
import pandas as pd
import gc
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline

force_cpu = False
device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu"
print(f"Device: {device}")

def main():
    terms = get_terms_df()
    datasets = load_datasets()
    
    for exp in experiments():
        flush() #ensuring memory is freed before loading a model
        model = load_model(exp["model"], exp["model_type"], device)
        
        for ds in datasets.values():
            hierarchy = ds["hierarchy"]
            training_df = ds["training"]
            test_df = ds["test"]
    
def predict_with_strategy(hierarchy_name, training_df, testing_df, terms_df, strategy, n, model, model_type):
    if testing_df is None:
       print("skipping hierarcy {hierarchy_name}") 
    elif strategy == "one_vs_all":
        hierarchy_codes = terms_df.hierarchiCode.unique().asList()
        assert(len(hierarchy_codes)==1)
        hierarchy_code = hierarchyi_codes[0]
        text_map = get_all(terms_df, hierarchy_code)  
        categories = [* text_codes.keys]
        input_texts = testing_df.text.asList()
        predicted_texts = predict(input_texts, categories, hierarchy_name, m, model, model_type)
        return [[text_map[text] for text in predicted] for predicted in predicted_texts]

def load_model(model, model_type, device):
    print("loading model") 
    if model_type == "similarity":
        m =  SentenceTransformer(model, device = device)
    elif model_type.startswith("zero-shot"):
        m = pipeline("zero-shot-classification", model=model, device = device)
    else:
        raise ValueError("Unknown model type")
    print(f"{model} loaded") 
    return m
    
def predict(texts, categories, hierarchy_name, n, model, model_type):
    results = [*range(0,len(texts))]
    if mode_type == "similarity":
        text_vecs =  model.encode(texts, convert_to_tensor = True)
        cat_vecs = model.encode(categories, convert_to_tensor = True)
        for i in range(0, len(results)):
            similarities = util.pytorch_cos_sim(text_vecs[i], cat_vecs[i]).item(0)
            topn = [categories[j] for sim, j in sorted(((sim, j) for j, sim in enumerate(similarities)), reverse=True)][0:n]
            results[i] = topn
    elif model_type == "zero-shot":
        for i in range(0, len(results)):
            output = classifier(text, categories)
            topn = [i for v, i in sorted(((v, i) for i, v in enumerate(output["scores"])), reverse=True)][0:n]
            results[i] = [output['labels'][i] for i in topn]
    elif model_type == "zero-shot-hypothesis":
        for i in range(0, len(results)):
            hypothesis_template = get_hypothesis_template(hierarchy_name)
            output = m_classifier(text, categories, hypothesis_template=hypothesis_template, multi_label=False)
            topn = [i for v, i in sorted(((v, i) for i, v in enumerate(output["scores"])), reverse=True)][0:n]
            results[i] = [output['labels'][i] for i in topn]
    else:
        raise ValueError(f"the model type {model_type} is not supported")
    return results

    
    


def predict_with_strategy(text, hierarchy, n, first = True):
    candidates = get_children({"root"})
    if len(candidates)==0:
        return parents
    cand_en = [*candidates.keys()]
    bests = best_matches(ccam_en, [d for d,c in candidates.items()], top)
    best_codes = [(candidates[text], text) for text in bests]

    print(ccam_en)
    for t in best_codes:
        print(f"---------> {t}")

    rest = search_snomed_hierarchy(ccam_en, best_codes, top, first = False)
    ret = [(code, text) for code, text in [*best_codes, *rest]]
    if first: # final refinement
        retmap = {text:code for code, text in ret}
        bests = best_matches(ccam_en, [d for d, c in retmap.items()], top)
        best_codes = [(retmap[text], text) for text in bests]
        return best_codes
    else:
        return ret

def get_all(terms_df, hierarchy_code):
    #terms_df columns: termCode 	parentCode 	hierarchyCode 	status
    assert(terms_df.hierarchyCode.nunique()==1)
    children = terms_df[
            (terms_df.hierarchyCode == hierarchy_code) & 
            (terms_df.status == "APPROVED" ) & 
        ][["termExtendedName", "termCode"]]
    return {k:v for k, v in children.values}

def get_children(parent_codes, hierarchy_code, terms_df):
    #terms_df columns: termCode 	parentCode 	hierarchyCode 	status
    assert(terms_df.hierarchyCode.nunique()==1)
    children = terms_df[
            (terms_df.hierarchyCode == hierarchy_code) & 
            (terms_df.status == "APPROVED" ) & 
            (terms_df.parentCode.isin(parent_codes))
        ][["termExtendedName", "termCode"]]
    return {k:v for k, v in children.values}

def get_hierarchy_question(hierarchy):
    if hierarchy == "baseterm":	
        return "This text describes a"
    elif hierarchy == "F02":	#part	Part-nature	This facet describes the nature of the food item or the part of plant or animal it represents.
        return "This is obtained from {}"
    elif hierarchy == "F01":    #source: Source	This facet describes the plant, animal, other organism or other source from which a raw primary commodity 
        return "This is mainly obtained from {}"
    elif hierarchy == "F27":   #racsource	Source-commodities	This facet describes the RPC from which an ingredient or derivative has been obtained.
        return "This is derivated from {}"
    elif hierarchy == "F28":    #process	Process	This facet allows recording different characteristics of the food: preservation treatments a food item underwent
        return "This is processed by {}"
    elif hierarchy == "F04":    #ingred	Ingredient	This facet collects ingredients and/or flavour note.
        return "This contains {}"
    elif hierarchy == "F06":	    #medium	Surrounding-medium	This facet is intended for food packed in any container, together with any additional (    fluid) medium.
        return "This is sell surrounded by {}"
    elif hierarchy == "F08":	    #sweet	Sweetening-agent	This facet allows providing information on the added ingredient(s) used to impart sweetness to a food item.
        return "This can be sweeteded with {}"
    elif hierarchy == "F09":	    #fort	Fortification-agent	This facet allows providing information on the added ingredient(s) used to fortify a food item.
        return "This can be fortified with {}"
    elif hierarchy == "F10":	#qual	Qualitative-info	This facet provides some principal claims related to important nutrients-ingredients, like fat, sugar etc.
        return "When eated provides {}"
    elif hierarchy == "F17":	#cookext	Extent-of-cooking	This facet describes the intensity of heat treatment having been applied to a food item”.
        return "This is be cooked by {}"
    elif hierarchy == "F26":    #	gen	Generic-term	This facet allows recording whether the food list code was chosen because of lack of information on the food item.
        return "This description is ambiguous because {}"
    elif hierarchy == "F21":    #prod	Production-method	The facet production method describes the method used to produce the food.
        return "This text is ambiguous because {}"
    elif hierarchy == "F18":    #packformat	Packaging-format	This facet is used for packaged food and allows recording the container or wrapping form.
        return "This is sell in a {}"
    elif hierarchy == "F19":	#packmat	Packaging-material	This facet is used for packaged food and allows recording the material constituting the packaging containing.
        return "This is sell in a package made of {}"
    elif hierarchy == "F03":     #	state	Physical-state	This facet describes the form (physical aspect) of the food as reported by the consumer .
        return "This seems like a {}"
    elif hierarchy == "F07":     #fat	Fat-content	This is a facet with numerical descriptors, to allow providing the fat content (as percentage w/w) of a food item.
        return "This contains a fat level of {}"
    elif hierarchy == "F11":    #alcohol	Alcohol-content	This is a facet containing information to allow providing the alcohol (ethanol) content (as percentage v/v) of a food item.
        return "This contains an alcohol level of {}"
    elif hierarchy == "F12":	#dough	Dough-Mass	This facet is proposed to provide information on the original dough-mass, for bakery products.
        return "This contains a dough of {}"
    elif hierarchy == "F20":	#partcon	Part-consumed-analysed	this facet allows specifying in which form the food item was analysed or consumed.
        return "This is evaluated by analyzing its {}"
    elif hierarchy == "F22":    #place	Preparation-production-place	This facet allows recording the place where the food was prepared for consumption.
        return "This prepared in a {}"
    elif hierarchy == "F23":    #targcon	Target-consumer	This facet allows recording different consumer classes intended as target for the food item.
        return "This is eated by {}"
    elif hierarchy == "F24":    #use	Intended-use	This facet allows recording the intended use of a food item, in particular with respect to further treatment expected (or not expected) before consumption.
        return "This can gone through {}"
    elif hierarchy == "F25":	#riskingred	Risky-Ingredient	This facet (of specific interest in the microbiological domain) allows recording the presence of microbiologically high-risk ingredients.
        return "This is made with a dangerous {}"
    elif hierarchy == "F29":	#fpurpose	Purpose-of-raising	This facet allows recording the purpose of farming, keeping or breeding (e.g. milk production, egg production).
        return "This is farmed for {}"
    elif hierarchy == "F30":	#replev	Reproductive-level	This facet allows recording classes of animals from the point of view of reproduction.
        return "This animal can reproduce by {}"
    elif hierarchy == "F31":    #	animage	Animal-age-class	This facet allows recording the classes of the animal used in legislation or in the practice, based on age or development stage.
        return "This animal age is {}"
    elif hierarchy == "F32":    #	gender	Gender	This facet allows recording the status of an animal or animal group, with respect to sex.
        return "This animal gender is {}"
    elif hierarchy == "F33":	    #legis	Legislative-classes	This facet allows recording the food additives classes as reported in the legislation in order.
        return "This contains the additive of type {}"
    else:
        raise ValueError(f"The hierarchy {hierarchy} has no question defined")

def get_terms_df():
    df = pd.read_pickle("data/terms.pickle")
    return df[["termCode", "termExtendedName", "parentCode", "hierarchyCode", "status"]]

def load_datasets():
    f_datasets = "data/datasets-training-test.pickle"
    with open(f_datasets, "rb") as f:
        datasets = pickle.load(f)
    return datasets

def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

terms_df()

#main()


Device: cuda


Unnamed: 0,termCode,parentCode,hierarchyCode,status
0,A000C,A000B,MTX,DEPRECATED
1,A000G,A000F,MTX,APPROVED
2,A000H,A000F,MTX,APPROVED
3,A001X,A000L,MTX,APPROVED
4,A0D9Y,A000L,MTX,APPROVED
5,A04KH,A000L,MTX,APPROVED
6,A000Y,A000L,MTX,APPROVED
7,A000S,A000L,MTX,APPROVED
8,A000F,A000L,MTX,APPROVED
9,A001C,A000L,MTX,APPROVED
