# Filter Module

## Initialisation

Basic approach to determine the project directory

In [1]:
import os
import csv

import json
import unittest

from common import Printable, testExit, projdir

from entrant import Entrant

from constants import *

## Filter Class

Class to apply filters and cache results

In [2]:
class Filter(Printable):
    def __init__(self, entrants, verbosity=1):
        '''Initialise speed run object'''    

        super().__init__(verbosity=verbosity)
        
        self.entrants = entrants
        
        self.entrantFields = None

        self.cache = {}


    def determineEntrantFields(self):
        '''Determine which fields are in entrants'''

        if len(self.entrants) > 0:
            entrantId = list(self.entrants.keys())[0]
            self.entrantFields = set(self.entrants[entrantId].entrantDict.keys())
        else:
            self.entrantFields = set()


    def applyFilter(self, entrants, filterRule):
        '''Retrieve entrants that match a single filter criteria, without cache'''

        filteredEntrants = {}

        if len(self.entrants) > 0:
            self.determineEntrantFields()

            if '=' in filterRule:
                filterParts = filterRule.split('=')
                if filterParts[0] not in self.entrantFields:
                    raise ValueError('Invalid field in filter - {}'.format(filterParts[0]))
                for entrantId in entrants:
                    value = entrants[entrantId].getValue(filterParts[0])
                    if '%' in filterParts[1]:
                        wildCardParts = filterParts[1].split('%')
                        if value and value.startswith(wildCardParts[0]) and value.endswith(wildCardParts[1]):
                            filteredEntrants[entrantId] = entrants[entrantId]
                    elif value and value == filterParts[1]:
                        filteredEntrants[entrantId] = entrants[entrantId]
            elif '<>' in filterRule:
                filterParts = filterRule.split('<>')
                if filterParts[0] not in self.entrantFields:
                    raise ValueError('Invalid field in filter - {}'.format(filterParts[0]))
                for entrantId in entrants:
                    value = entrants[entrantId].getValue(filterParts[0])
                    if '%' in filterParts[1]:
                        wildCardParts = filterParts[1].split('%')
                        if value and not (value.startswith(wildCardParts[0]) and value.endswith(wildCardParts[1])):
                            filteredEntrants[entrantId] = entrants[entrantId]
                    elif value and value != filterParts[1]:
                        filteredEntrants[entrantId] = entrants[entrantId]
            else:
                raise ValueError('Invalid filter rule "{}"'.format(filterRule))

        return filteredEntrants


    def getEntrants(self, filterText):
        '''Retrieve entrants that match multiple filter criteria, utilising cache'''

        entrants = self.entrants

        filterRules = filterText.split(',')
        filterKey = ''
        
        for filterRule in filterRules:
            if filterKey:
                filterKey = filterKey + ',' + filterRule
            else:
                filterKey = filterRule
                
            if filterKey in self.cache:
                entrants = self.cache[filterKey]
            else:
                entrants = self.applyFilter(entrants, filterRule)
                self.cache[filterKey] = entrants

        return entrants

## Unit Tests

A handful of basic filter tests, utilising a dummy "event" class

In [3]:
class DummyEvent(Printable):
    def __init__(self, path, verbosity=1):
        
        super().__init__(verbosity=verbosity)

        self.path = path

        self.entrants = {}


    def loadConfig(self):
        '''Read app config from JSON'''

        filename = os.path.join(self.path, CONFIG_DIR, EVENT_CONFIG)
        with open(filename, 'r', encoding='utf-8') as f:
            jsonTxt = f.read()
            self.eventConfig = json.loads(jsonTxt)


    def loadEntrants(self):
        '''Read entrants from JSON'''

        csvPath = os.path.join(self.path, CONFIG_DIR, ENTRANTS_CSV)
        
        with open(csvPath, 'r', encoding='utf-8') as f:
            csvReader = csv.reader(f)
            headers = next(csvReader)

            for values in csvReader:
                if ''.join(values).strip():
                    entrant = Entrant(self.eventConfig, headers, values, verbosity=self.verbosity)
                    if entrant.getValue('ID') not in self.entrants:
                        self.entrants[entrant.getValue('ID')] = entrant
                    else:
                        raise ValueError('Duplicate entrant ID "{}"'.format(entrant.getValue('ID')))

In [4]:
class TestFilter2019(unittest.TestCase):
    '''Class to test filters using 2019 data'''

    def testBoat(self):
        '''Test a simple boat filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Boat')

        self.assertEqual(len(entrants), 2)
        self.assertEqual(len(testFilter.cache), 1)


    def testKiteboard(self):
        '''Test a simple kiteboard filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Kiteboard')

        self.assertEqual(len(entrants), 10)
        self.assertEqual(len(testFilter.cache), 1)


    def testSailboard(self):
        '''Test a simple sailboard filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Sailboard')

        self.assertEqual(len(entrants), 60)
        self.assertEqual(len(testFilter.cache), 1)


    def testYouthFleetSailboard(self):
        '''Test a youth fleet sailboard filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Sailboard,Status=Youth Fleet')

        self.assertEqual(len(entrants), 6)
        self.assertEqual(len(testFilter.cache), 2)


    def testFirstTimeSailboard(self):
        '''Test a first time sailboard filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Sailboard,Status=Amateur,First Timer=Y')

        self.assertEqual(len(entrants), 5)
        self.assertEqual(len(testFilter.cache), 3)


    def testAmateurSailboard(self):
        '''Test a amateur sailboard filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Sailboard,Status=Amateur')

        # Note that 5 of the 29 "amateur" are actually "first timer" and 6 are "youth fleet" 
        self.assertEqual(len(entrants), 30 - 6)
        self.assertEqual(len(testFilter.cache), 2)


    def testGoldFleetSailboard(self):
        '''Test a gold fleet sailboard filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Sailboard,Status=Gold Fleet')

        self.assertEqual(len(entrants), 19)
        self.assertEqual(len(testFilter.cache), 2)


    def testProFleetSailboard(self):
        '''Test a pro fleet sailboard filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Sailboard,Status=Pro Fleet')

        self.assertEqual(len(entrants), 11)
        self.assertEqual(len(testFilter.cache), 2)


    def testUkwaMember(self):
        '''Test a UKWA member filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Sailboard,UKWA Member=Y')

        self.assertEqual(len(entrants), 36)
        self.assertEqual(len(testFilter.cache), 2)


    def testFemaleSailboard(self):
        '''Test a female sailboard filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Craft Type=Sailboard,Gender=F,Status<>Pro Fleet')

        # Note that 1 of the 4 women is "pro fleet" 
        self.assertEqual(len(entrants), 4 - 1)
        self.assertEqual(len(testFilter.cache), 3)


    def testSailboardTotals(self):
        '''Test that filters for sailboard fleets add up to the expected total'''

        testFilter = Filter(event.entrants)
        sailboards = testFilter.getEntrants('Craft Type=Sailboard')
        youths = testFilter.getEntrants('Craft Type=Sailboard,Status=Youth Fleet')
        firsts = testFilter.getEntrants('Craft Type=Sailboard,Status=Amateur,First Timer=Y')
        amateurs = testFilter.getEntrants('Craft Type=Sailboard,Status=Amateur,First Timer<>Y')
        golds = testFilter.getEntrants('Craft Type=Sailboard,Status=Gold Fleet')
        pros = testFilter.getEntrants('Craft Type=Sailboard,Status=Pro Fleet')
        
        self.assertEqual(len(sailboards), len(youths) + len(firsts) + len(amateurs) + len(golds) + len(pros))
        self.assertEqual(len(testFilter.cache), 7)


    def testCraftTotals(self):
        '''Test that filters for boats, sailboards and kites add up to the expected total'''

        testFilter = Filter(event.entrants)
        boats = testFilter.getEntrants('Craft Type=Boat')
        kites = testFilter.getEntrants('Craft Type=Kiteboard')
        sailboards = testFilter.getEntrants('Craft Type=Sailboard')

        self.assertEqual(len(event.entrants), len(boats) + len(kites) + len(sailboards))
        self.assertEqual(len(testFilter.cache), 3)


    def testGenderTotals(self):
        '''Test that filters for male and female add up to the expected total'''

        testFilter = Filter(event.entrants)
        females = testFilter.getEntrants('Gender=F')
        males = testFilter.getEntrants('Gender=M')

        self.assertEqual(len(event.entrants), len(females) + len(males))
        self.assertEqual(len(testFilter.cache), 2)


    def testForcesGrouping(self):
        '''Test a forces filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Grouping=Forces - %')

        self.assertEqual(len(entrants), 6)
        self.assertEqual(len(testFilter.cache), 1)


    def testNonForcesGrouping(self):
        '''Test a non-forces filter'''

        testFilter = Filter(event.entrants)
        entrants = testFilter.getEntrants('Grouping<>Forces - %')

        # Note that 6 of the people with groupings are "forces" 
        self.assertEqual(len(entrants), 72 - 6)
        self.assertEqual(len(testFilter.cache), 1)


    def testBadField(self):
        '''Test a bad field name'''

        testFilter = Filter(event.entrants)
        with self.assertRaises(ValueError):
            entrants = testFilter.getEntrants('ZZZ=10')


    def testBadOperator(self):
        '''Test a bad operator'''

        testFilter = Filter(event.entrants)
        with self.assertRaises(ValueError):
            entrants = testFilter.getEntrants('Gender!=M')

## Run Unit Tests

Note: Only run unit tests when running this script directly, not during an import

In [5]:
if __name__ == '__main__':
    # Read main config into global variable
    filename = os.path.join(projdir, CONFIG_DIR, APP_CONFIG)
    with open(filename, 'r', encoding='utf-8') as f:
        jsonTxt = f.read()
        appConfig = json.loads(jsonTxt)

    eventYear = '2019'
    eventPath = os.path.join(projdir, EVENTS_DIR, eventYear)
    event = DummyEvent(eventPath)
    event.loadConfig()
    event.loadEntrants()

    unittest.main(argv=['first-arg-is-ignored'], exit=testExit)

.................
----------------------------------------------------------------------
Ran 17 tests in 0.052s

OK


## All Done!