In [None]:
# !pip install chromadb

In [1]:
from huggingface_hub import login
access_token_write = "your hugging face access token"
login(token = access_token_write)

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /home/id2531/.cache/huggingface/token
Login successful


# IMPORTS

## PACKAGES

In [18]:

from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
import os
import numpy as np
import pandas as pd
import chromadb
import transformers
import torch
import pickle as pkl

## DEVICE

In [3]:
device = xm.xla_device()
device



device(type='xla', index=0)

## MODEL

In [None]:
model = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation", # specify the task for inference
    model=model,
    torch_dtype=torch.float32,
    device_map="auto",
)

# FUNCTIONS DEFINITION

## PROMPT EXAMPLES
f"""Consider three examples [EXAMPLE 1], [EXAMPLE 2], [EXAMPLE 3], instances of narratives from NASA's ASRS database labeled as {label}". Your objective is to assess if a [NEW NARRATIVE] can be assigned the same label as these examples. Return 1 if it can, and 0 otherwise."""

f"""Given threes examples [EXAMPLE 1], [EXAMPLE 2], [EXAMPLE 3] which are instances of narratives from NASA's ASRS database labeled as "{label}", your task is to determine whether a [NEW NARRATIVE] can be assigned the same label as these examples. Return 1 if the [NEW NARRATIVE] can be assigned the same label as the examples and 0 otherwise."""

In [1]:
def labeller(**kwargs):
    """_summary_
    """
    for name, item in kwargs.items():
        if isinstance(item, str):
            cell = item
        if 'anomalies' in name.lower() and isinstance(item, list):
            anomalies_list = item
    cell_labels =[]
    for prefix in anomalies_list:
        if any(anomaly.strip().startswith(prefix) for anomaly in cell.split(';')):
            cell_labels.append(prefix)
    if cell_labels == []:
        cell_labels = ['Others']
    metadata = dict(zip(cell_labels, [1]*len(cell_labels)))
    return metadata 

def docextractor(**kwargs):
    """_summary_
    """
    for _, item in kwargs.items():
        if isinstance(item, str):
            cell = item
    doc = cell.strip()
    return doc

def load_from_df(**kwargs):
    """_summary_
    """
    for name, item in kwargs.items():
        if "docs" in name.lower():
            source = item
        if "metadata" in name.lower():
            meta = item
        if isinstance(item, pd.DataFrame):
            data = item
        if 'anomalies' in name.lower() and isinstance(item, list):
            anomalies_list = item
    documents = data[source].apply(lambda cell: docextractor(input=cell)).values.tolist()
    metadata = data[meta].apply(lambda cell: labeller(input=cell, anomalies_list=anomalies_list)).values.tolist()
    return documents, metadata

def get_prompt(**kwargs):
    """_summary_
    """
    for name, item in kwargs.items():
        if 'example_1' in name.lower():
            example_1 = item
        if 'example_2' in name.lower():
            example_2 = item
        if 'example_3' in name.lower():
            example_3 = item
        if 'narrative' in name.lower():
            narrative = item
        if 'label' in name.lower():
            label = item
    system_prompt = f"""Consider three examples [EXAMPLE 1], [EXAMPLE 2], [EXAMPLE 3], instances of narratives from NASA's ASRS database labeled as {label}". Your objective is to assess if a [NEW NARRATIVE] can be assigned the same label as these examples. As [Answer] return 1 if it can, and 0 otherwise."""  #  Only provide the numerical result (0 or 1) and no additional information.
    full_prompt = f"""[INST]
    <<SYS>>
    {system_prompt}
    <</SYS>>
    
    [EXAMPLE 1] :
    {example_1}
    
    [EXAMPLE 2] :
    {example_2}

    [EXAMPLE 3] :
    {example_3}
    
    [NEW NARRATIVE] :
    {narrative}
    
    [/INST][Answer]"""
    return full_prompt

def format_pred(**kwargs):
    """"""
    for name, item in kwargs.items():
        if "inference" in name.lower():
            inference = item
        if "pattern" in name.lower():
            pattern = item
    pred = np.nan
    if pattern in inference[0]['generated_text'].strip():
        pred = 1
    else:
        pred = 0
    return pred, inference
            

def get_inference(**kwargs):
    """"""
    for name, item in kwargs.items():
        if "query" in name.lower():
            query = item.strip()
        if "pattern" in name.lower():
            pattern = item.strip()
        if "task" in name.lower():
            task = item.strip()
        if isinstance(item,
                      chromadb.api.models.Collection.Collection):
            store = item
        if isinstance(item,
                      transformers.pipelines.text_generation.TextGenerationPipeline):
            pipeline = item
    examples = store.query(query_texts=[query],
                          include=["documents"],
                          where={task:1},
                          n_results=3)['documents'][0]
    prompt = get_prompt(**dict(narrative=query,
                               example_1=examples[0],
                               example_2=examples[1],
                               example_3=examples[2],
                               label=task))
    inference = pipeline(prompt,
                         temperature=1.0,
                         do_sample=True,
                         num_return_sequences=1,
                         eos_token_id=tokenizer.eos_token_id,
                         max_length=4000)
    pred = format_pred(inference=inference, pattern=pattern)
    return pred

In [None]:
ANOMALY_LABELS = ['Deviation / Discrepancy - Procedural',
                    'Aircraft Equipment',
                    'Conflict',
                    'Inflight Event / Encounter',
                    'ATC Issue',
                    'Deviation - Altitude',
                    'Deviation - Track / Heading',
                    'Ground Event / Encounter',
                    'Flight Deck / Cabin / Aircraft Event',
                    'Ground Incursion',
                    'Airspace Violation',
                    'Deviation - Speed',
                    'Ground Excursion',
                    'No Specific Anomaly Occurred']

# LOAD DATA

In [6]:
train_data = pd.read_parquet('') # load the train_data. Must contain at least Narrative and Anomaly columns. upadate the columns labels to match the labels if necessary.

# CONTENT

In [7]:
client = chromadb.PersistentClient(path="path/where/to/put/the/persitent/vectorstore")

In [None]:
embedder = DefaultEmbeddingFunction() # embedding model

asrsnlp_collection = client.get_or_create_collection(
    name="asrsnlp_collection",
    metadata={"hnsw:space": "cosine"},
    embedding_function=embedder)

train_data = train_data.dropna(axis=0, subset=['Narrative','Anomaly'])
documents = load_from_df(df=train_data, docs='Narrative',metadata='Anomaly', anomalies=ANOMALY_LABELS)

asrsnlp_collection.add(
    documents=documents[0][:40000],
    metadatas=documents[1][:40000],
    ids=[f"ID{i}" for i in range(0,40000)]
)

asrsnlp_collection.add(
    documents=documents[0][40000:80000],
    metadatas=documents[1][40000:80000],
    ids=[f"ID{i}" for i in range(40000, 80000)]
)

asrsnlp_collection.add(
    documents=documents[0][80000:],
    metadatas=documents[1][80000:],
    ids=[f"ID{i}" for i in range(80000,96986)]
)

In [9]:
# asrsnlp_collection = client.get_collection("asrsnlp_collection")

In [38]:
%%time 
# run inference. adjust prompt and pattern to your need
result = get_inference(query=test_data.Narrative[1],
                       task='No Specific Anomaly Occurred',
                       pattern="[Answer]  Yes",
                       store=asrsnlp_collection,
                       pipeline=pipeline)

CPU times: user 47min 12s, sys: 5min 12s, total: 52min 25s
Wall time: 1min 33s


In [None]:
result