In [1]:
from collections import Counter
import typing as T

import pandas as pd
import spacy
from spacy.matcher import Matcher
from strenum import StrEnum

nlp = spacy.load("en_core_web_sm")


ATIS_INTENTS_PATH = "/home/khaymon/hse_nlp/spacy_nlu/data/atis_intents.csv"
ATIS_UTTERANCES_PATH = "/home/khaymon/hse_nlp/spacy_nlu/data/atis_utterances.txt"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
atis_intents = pd.read_csv(ATIS_INTENTS_PATH, header=None)
atis_intents.columns = ["intent", "utterance"]

atis_intents.sample()

Unnamed: 0,intent,utterance
3217,atis_flight,show me the last flight on wednesday from atl...


In [8]:
atis_intents.intent.value_counts()

intent
atis_flight                                 3666
atis_airfare                                 423
atis_ground_service                          255
atis_airline                                 157
atis_abbreviation                            147
atis_aircraft                                 81
atis_flight_time                              54
atis_quantity                                 51
atis_flight#atis_airfare                      21
atis_airport                                  20
atis_distance                                 20
atis_city                                     19
atis_ground_fare                              18
atis_capacity                                 16
atis_flight_no                                12
atis_meal                                      6
atis_restriction                               6
atis_airline#atis_flight_no                    2
atis_ground_service#atis_ground_fare           1
atis_airfare#atis_flight_time                  1
atis_cheapest

## Main parsing pipeline

In [131]:
class Intents(StrEnum):
    FLIGHT = "atis_flight"
    AIRFARE = "atis_airfare"
    GROUND_SERVICE = "atis_ground_service"
    AIRLINE = "atis_airline"
    ABBREVIATION = "atis_abbreviation"
    AIRCRAFT = "atis_aircraft"
    FLIGHT_TIME = "atis_flight_time"
    QUANTITY = "atis_quntity"
    AIRPORT = "atis_airport"
    DISTANCE = "atis_distance"
    CITY = "atis_city"
    ground_fare = "atis_ground_fare"
    

class IntentChecker:
    @classmethod
    def has_intent(cls, text: str, nlp) -> bool:
        raise NotImplementedError()
    
    @classmethod
    def name(cls) -> str:
        raise NotImplementedError()

    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        assert cls.has_intent(text, nlp)

        raise NotImplementedError()

    
class IntentMatcher:
    def __init__(self, nlp, intent_checkers: T.List[IntentChecker]):
        self._nlp = nlp
        self._intent_checkers = intent_checkers
    
    def get_intent(self, text: str) -> T.Tuple[T.List[str], T.List[T.Dict]]:
        intents = []
        intents_params = []
        for checker in self._intent_checkers:
            if checker.has_intent(text, self._nlp):
                intents.append(checker.name())
                intents_params.append(checker.parameters(text, self._nlp))
        
        return intents, intents_params

## Parameters matchers

In [132]:
class ParametersMatcher:
    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        raise NotImplementedError()


class FromToParametersMatcher(ParametersMatcher):
    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        matcher = Matcher(nlp.vocab)

        pattern = [{"POS": "ADP"}, {"ENT_TYPE": "GPE"}]
        matcher.add("prepositionLocation", [pattern])

        doc = nlp(text)
        matches = matcher(doc)
        
        parameters = {
            "from": None,
            "to": None,
        }
        for mid, start, end in matches:
            is_from = False
            gpe_token = None

            for idx in range(start, end):
                if doc[idx].lemma_ in ("from",):
                    is_from = True
                if doc[idx].ent_type_ == "GPE":
                    gpe_token = doc[idx].lemma_
            if is_from:
                parameters["from"] = gpe_token
            else:
                parameters["to"] = gpe_token
                
        return parameters
    
    
class WhereParametersMatcher(ParametersMatcher):
    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        matcher = Matcher(nlp.vocab)

        pattern = [{"POS": "ADP"}, {"ENT_TYPE": "GPE"}]
        matcher.add("prepositionLocation", [pattern])

        doc = nlp(text)
        matches = matcher(doc)
        
        parameters = {"in": None}
        for mid, start, end in matches:
            where_token = False
            gpe_token = None

            for idx in range(start, end):
                if doc[idx].lemma_ in ("in",):
                    where_token = True
                if doc[idx].ent_type_ == "GPE":
                    gpe_token = doc[idx].lemma_
            if where_token:
                parameters["in"] = gpe_token
                
        return parameters
    

class AbbrParametersMatcher(ParametersMatcher):
    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        pattern1 = [{"TEXT": {"REGEX": "\w{1,2}\d{1,2}"}}]
        pattern2 = [{"SHAPE": {"IN": ["x", "xx"]}}, {"SHAPE": {"IN": ["d", "dd"]}}]
        pattern3 = [{"TEXT": {"IN": ["class", "code", "abbrev", "abbreviation"]}}, {"SHAPE": {"IN": ["x", "xx"]}}]
        pattern4 = [{"POS": "NOUN", "SHAPE": {"IN": ["x", "xx"]}}]
        
        matcher = Matcher(nlp.vocab)
        matcher.add("abbrevEntities", [pattern2, pattern3, pattern4])
        
        doc = nlp(text)
        matches = matcher(doc)

        return {"abbreviations": [str(doc[start:end]) for _, start, end in matches]}

## Intents checkers

In [133]:
class FlightIntentChecker(IntentChecker):
    @classmethod
    def has_intent(cls, text: str, nlp) -> bool:
        doc = nlp(text)
        matcher = Matcher(nlp.vocab)
        pattern = [{"TEXT": {"REGEX": "(flight|flights)"}}]
        
        matcher.add("flightIntent", [pattern])
        matches = matcher(doc)
        
        return len(matches) > 0
    
    @classmethod
    def name(cls) -> str:
        return Intents.FLIGHT

    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        return FromToParametersMatcher.parameters(text, nlp)


class AirfareIntentChecker(IntentChecker):
    @classmethod
    def has_intent(cls, text: str, nlp) -> bool:
        doc = nlp(text)
        matcher = Matcher(nlp.vocab)
        pattern = [{"TEXT": {"REGEX": "fare"}}]
        
        matcher.add("flightIntent", [pattern])
        matches = matcher(doc)
        
        return len(matches) > 0
    
    @classmethod
    def name(cls) -> str:
        return Intents.AIRFARE
    
    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        return FromToParametersMatcher.parameters(text, nlp)


class GroundServiceIntentChecker(IntentChecker):
    @classmethod
    def has_intent(cls, text: str, nlp) -> bool:
        doc = nlp(text)
        matcher = Matcher(nlp.vocab)
        pattern = [{"TEXT": {"REGEX": "(transportat|taxi|rent)"}}]
        
        matcher.add("flightIntent", [pattern])
        matches = matcher(doc)
        
        return len(matches) > 0
    
    @classmethod
    def name(cls) -> str:
        return Intents.GROUND_SERVICE
    
    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        return WhereParametersMatcher.parameters(text, nlp)


class AirlineIntentChecker(IntentChecker):
    @classmethod
    def has_intent(cls, text: str, nlp) -> bool:
        doc = nlp(text)
        matcher = Matcher(nlp.vocab)
        pattern = [{"TEXT": {"REGEX": "airline"}}]
        
        matcher.add("flightIntent", [pattern])
        matches = matcher(doc)
        
        return len(matches) > 0
    
    @classmethod
    def name(cls) -> str:
        return Intents.AIRLINE
    
    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        return FromToParametersMatcher.parameters(text, nlp)
    

class AbbreviationIntentChecker(IntentChecker):
    @classmethod
    def has_intent(cls, text: str, nlp) -> bool:
        return len(AbbrParametersMatcher.parameters(text, nlp)["abbreviations"]) > 0
    
    @classmethod
    def name(cls) -> str:
        return Intents.ABBREVIATION
    
    @classmethod
    def parameters(cls, text: str, nlp) -> T.Dict:
        return AbbrParametersMatcher.parameters(text, nlp)

## Tests

In [134]:
tests = {
    FlightIntentChecker: {
        "positives": [
            {
                "text": "show me flights from denver to boston on tuesday",
                "params": {"from": "denver", "to": "boston"}
            }
        ],
        "negatives": [
            "what's the fare from washington to boston",
            "is there ground transportation in baltimore",
            "what airlines fly from boston to pittsburgh",
            "what does restriction ap 57 mean"
        ]
    },
    AirfareIntentChecker: {
        "positives": [
            {
                "text": "what's the fare from washington to boston",
                "params": {"from": "washington", "to": "boston"}
            }
        ],
        "negatives": [
            "show me flights from denver to boston on tuesday",
            "is there ground transportation in baltimore",
            "what airlines fly from boston to pittsburgh",
            "what does restriction ap 57 mean"
        ]
    },
    GroundServiceIntentChecker: {
        "positives": [
            {
                "text": "is there ground transportation in baltimore",
                "params": {"in": "baltimore"}
            }
        ],
        "negatives": [
            "show me flights from denver to boston on tuesday",
            "what's the fare from washington to boston",
            "what airlines fly from boston to pittsburgh",
            "what does restriction ap 57 mean"
        ]
    },
    AirlineIntentChecker: {
        "positives": [
            {
                "text": "what airlines fly from boston to pittsburgh",
                "params": {"from": "boston", "to": "pittsburgh"}
            }
        ],
        "negatives": [
            "show me flights from denver to boston on tuesday",
            "what's the fare from washington to boston",
            "is there ground transportation in baltimore",
            "what does restriction ap 57 mean"
        ]
    },
    AbbreviationIntentChecker: {
        "positives": [
            {
                "text": "what does restriction ap 57 mean",
                "params": {"abbreviations": ["ap 57"]}
            }
        ],
        "negatives": [
            "show me flights from denver to boston on tuesday",
            "what's the fare from washington to boston",
            "is there ground transportation in baltimore",
            "what airlines fly from boston to pittsburgh"
        ]
    }
}

In [135]:
for intent_checker in tests:
    for positive in tests[intent_checker]["positives"]:
        assert intent_checker.has_intent(positive["text"], nlp), f"{intent_checker.__name__} for \"{positive['text']}\""
        assert intent_checker.parameters(positive["text"], nlp) == positive["params"], f"{intent_checker.__name__} for \"{positive['text']}\""
    for negative in tests[intent_checker]["negatives"]:
        assert not intent_checker.has_intent(negative, nlp), f"{intent_checker.__name__} for \"{negative}\""
        
print("All tests are passed")

All tests are passed


## Main pipeline

In [136]:
intent_matcher = IntentMatcher(
    nlp, intent_checkers=[
        FlightIntentChecker,
        AirfareIntentChecker,
        GroundServiceIntentChecker,
        AirlineIntentChecker,
        AbbreviationIntentChecker,
    ]
)

intent_matcher.get_intent("what airlines fly from boston to pittsburgh")

([<Intents.AIRLINE: 'atis_airline'>], [{'from': 'boston', 'to': 'pittsburgh'}])