In [20]:
from neo4j import GraphDatabase
from glob import glob

In [6]:
import csv

In [5]:
driver = GraphDatabase.driver(uri="bolt://localhost:7687",auth=("neo4j","admin"))

### Load buildings

In [28]:
races = ['zerg','terran','protoss']

In [306]:
def get_building_path(race):
    return '../data/'+race+'_buildings.csv'

In [307]:
def create_building(tx,name,race):
    return tx.run("""
    CREATE (a:Building {name:$name, race:$race})
    """, name=name,race=race)

In [308]:
for race in races:
    filepath = get_building_path(race)
    with open(filepath, 'r', encoding='utf-8') as f:
        buildings = f.readlines()[0]
        buildings = buildings.split(',')
    with driver.session() as session:
        for name in buildings:
            session.write_transaction(create_building, name, race)

### Load units

In [309]:
def get_units_path(race):
    return '../data/'+race+'_units.csv'

In [310]:
def create_units(tx, name, race):
    return tx.run("""
    CREATE (a:Unit {name:$name, race:$race})
    """, name=name, race=race)

In [311]:
for race in races:
    filepath = get_units_path(race)
    with open(filepath, 'r', encoding='utf-8') as f:
        units = f.readlines()[0]
        units = units.split(',')
    with driver.session() as session:
        for name in units:
            session.write_transaction(create_units, name, race)

### Load relations

In [312]:
with open('../data/units.csv', 'r', encoding='utf-8') as f:
    data = f.readlines()

In [313]:
def produce(tx, building, unit):
    return tx.run("""
    MATCH (n:Building), (u:Unit)
    WHERE toLower(n.name) = toLower($building) and
    toLower(u.name) = toLower($unit)
    MERGE (n)-[:Produce]->(u)
    """, building=building, unit=unit)

In [314]:
def allow(tx, building1, building2):
    return tx.run("""
    MATCH (n:Building), (m:Building)
    WHERE toLower(n.name) = toLower($building1) and
    toLower(m.name) = toLower($building2)
    MERGE (n)-[:Allow]->(m)
    """, building1=building1, building2=building2)

In [315]:
def isStrongAgainst(tx, unit1, unit2):
    return tx.run("""
    MATCH (n:Unit), (m:Unit)
    WHERE toLower(n.name) = toLower($unit1) and
    toLower(m.name) = toLower($unit2)
    MERGE (n)-[:StrongerThan]->(m)
    """,unit1=unit1,unit2=unit2)

In [316]:
def isWeakAgainst(tx, unit1, unit2):
    return tx.run("""
    MATCH (n:Unit), (m:Unit)
    WHERE toLower(n.name) = toLower($unit1) and
    toLower(m.name) = toLower($unit2)
    MERGE (n)<-[:StrongerThan]-(m)
    """,unit1=unit1,unit2=unit2)

In [317]:
def get_fn(name):
    return {
        'produce': produce,
        'allow': allow,
        'isStrongAgainst': isStrongAgainst,
        'isWeakAgainst': isWeakAgainst
    }.get(name)

In [318]:
for line in data:
    elements = line.strip().split(',')
    f = get_fn(elements[1])
    for i in elements[2:]:
        with driver.session() as session:
            session.write_transaction(f, elements[0], i)

### Find counters

pool first:

In [322]:
def get_unit_counter(tx, unit,race):
    return tx.run("""
    MATCH (u:Unit)<-[:StrongerThan]-(u2:Unit)<-[:Produce]-(m:Building )
    WHERE u.name = $unit and m.race = $race
    RETURN m.name as build, u2.name as unit
    """, unit=unit, race=race)

In [323]:
with driver.session() as session:
    val = session.read_transaction(get_unit_counter, 'SiegeTank','terran').data()

In [324]:
val

[{'build': 'Starport', 'unit': 'Banshee'}]

In [325]:
def get_build_counter(tx, building, race):
    return tx.run("""
    MATCH (u:Building)-[:Produce]->(:Unit)<-[:StrongerThan]-(u2:Unit)<-[:Produce]-(m:Building )
    WHERE u.name = $building and m.race = $race
    RETURN m.name as build, u2.name as unit
    """,building=building,race=race)

In [332]:
with driver.session() as session:
    val = session.read_transaction(get_build_counter, 'SpawningPool','terran').data()

In [333]:
val

[{'build': 'Barracks', 'unit': 'Marine'},
 {'build': 'Factory', 'unit': 'Hellion'},
 {'build': 'Factory', 'unit': 'WidowMine'},
 {'build': 'Armory', 'unit': 'Hellbat'}]