# Compute Subgraph Motif

## Imports

In [1]:
from typing import Dict
import signal
from contextlib import contextmanager
from pathlib import Path
import os
import sys
from joblib import Parallel, delayed
from collections import ChainMap
import pickle
import random
from multiprocessing import Pool

In [2]:
import pymfinder
import pymfinder.mfinder.mfinder as cmfinder

In [3]:
sys.path.append(str(Path(os.path.abspath("")).parent))

import config
from dataset import load_ids
from dataset import _process_problem
from config import NODE_TYPE

## Constants

In [4]:
MOTIFSIZE = 3
NRADOMGRAPHS = 100
TIME_LIMIT = 100

MAX_PROB_LEN = 10
NO_PROBLEMS = 1000

In [5]:
#ID_FILE = '../id_files/deepmath.txt'
ID_FILE = '../id_files/train.txt'

# Load an process problems

In [6]:
ids = load_ids(ID_FILE)
print('Number of problems', len(ids))

Number of problems 22179


In [7]:
def filter_problems_by_formula_size(ids, max_len):
    
    res = []
    for i in ids:
        with open(os.path.join(config.PROBLEM_DIR, i), 'rb') as f:
            prob_len = len(f.readlines())

        if prob_len <= max_len:
            res.append(i)
        
    return res

ids = filter_problems_by_formula_size(ids, MAX_PROB_LEN)

In [8]:
#ids = ids[:5] # FIXME
print('Number of problems', len(ids))

Number of problems 8600


### Sample the problems

In [9]:
random.seed(7)
ids = random.choices(ids, k=NO_PROBLEMS)

## Helper

In [10]:
class TimeoutException(Exception): pass

@contextmanager
def time_limit(seconds):
    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")
    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)

In [11]:
def compute_motifs(network, motifsize, nrandom_graphs) -> Dict[int, int]:
    
    res = pymfinder.pymfinder(network, motifsize=motifsize, nrandomizations=nrandom_graphs, links=False)
    
    res_count = {motif_id: {'count': motif.real, 'z-score': motif.real_z} for motif_id, motif in res.motifs.items()}
    return res_count

In [12]:
def analyse_graph(prob):
    print(prob)
    
    data = _process_problem(prob, config.PROBLEM_DIR, remove_argument_node=False)
    network = data.edge_index.T.numpy().tolist()

    try:
        with time_limit(TIME_LIMIT):
            res = compute_motifs(network, MOTIFSIZE, NRADOMGRAPHS)
    except (TimeoutException, SystemError) as e: # mfinder throws system error terminated on timeout
        res = -1


    return {prob: res}

In [13]:
# The network motifs
#pymfinder.print_motifs(3, links=True)

In [14]:
'''
res = []
for i in ids:
    print(i)
    r = analyse_graph(i)
    res.append(r)
    
res
#'''

'\nres = []\nfor i in ids:\n    print(i)\n    r = analyse_graph(i)\n    res.append(r)\n    \nres\n#'

In [15]:
result = Parallel(n_jobs=os.cpu_count() - 2)(delayed(analyse_graph)(i) for i in ids)
data = dict(ChainMap(*result)) # Convert lsit to dict of dict

In [16]:
# Run with multiprocessing instaad as I trust it more

In [17]:
'''
star_args = [(i,) for i in ids]
workers = max(os.cpu_count() - 2, 3)
pool = Pool(max(os.cpu_count() - 2, 3))
res = pool.starmap(analyse_graph, star_args)
pool.close()
pool.join()
'''

'\nstar_args = [(i,) for i in ids]\nworkers = max(os.cpu_count() - 2, 3)\npool = Pool(max(os.cpu_count() - 2, 3))\nres = pool.starmap(analyse_graph, star_args)\npool.close()\npool.join()\n'

In [18]:
with open(f'graph_motif_{MOTIFSIZE}_nrandom_{NRADOMGRAPHS}_timelimit_{TIME_LIMIT}.pickle', 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [19]:
#data