In [2]:
import json
import shutil
import timeit
from collections import defaultdict
import traceback

from cognitive.auto_task.arguments import get_args
from cognitive import task_generator as tg
from cognitive import constants as const
from cognitive import stim_generator as sg
from cognitive import info_generator as ig
from cognitive.auto_task.auto_task_util import *

import numpy as np
import random
import networkx as nx
from tqdm import tqdm
from PIL import Image
import os
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout

from typing import Tuple, Union


In [None]:
if __name__ == '__main__':
    args = get_args()
    print(args)

    const.DATA = const.Data(dir_path=args.stim_dir)

    task_dir = args.task_dir

    if task_dir:
        if not os.path.isdir(task_dir):
            raise ValueError('Task Directory not found')
        start = timeit.default_timer()
        task_folders = [f.path for f in os.scandir(task_dir) if f.is_dir()]
        for f in task_folders:
            try:
                # uncomment to reconstruct the graph
                # labels, adj = os.path.join(f, 'node_labels'), os.path.join(f, 'adj_dict')
                # with open(labels, 'rb') as h:
                #     labels = json.load(h)
                # with open(adj, 'rb') as h:
                #     adj = json.load(h)
                # g = nx.from_dict_of_dicts(adj, create_using=nx.DiGraph)
                # g = nx.relabel_nodes(g, labels)
                # print(sorted(g))
                task_json_fp = os.path.join(f, 'temporal_task.json')
                with open(task_json_fp, 'rb') as h:
                    task_dict = json.load(h)
                task_dict['operator'] = tg.load_operator_json(task_dict['operator'])
                temporal_task = tg.TemporalTask(
                    operator=task_dict['operator'],
                    n_frames=task_dict['n_frames'],
                    first_shareable=task_dict['first_shareable'],
                    whens=task_dict['whens']
                )
                for i in range(args.n_trials):
                    instance_fp = os.path.join(f, f'trial_{i}')
                    if os.path.exists(instance_fp):
                        shutil.rmtree(instance_fp)
                    os.makedirs(instance_fp)

                    write_trial_instance(temporal_task, instance_fp, args.img_size, args.fixation_cue)
            except Exception as e:
                traceback.print_exc()
        stop = timeit.default_timer()
        print('Time taken to generate trials: ', stop - start)
    else:
        start = timeit.default_timer()
        # TODO: check for duplicated tasks by comparing task graphs
        for i in range(args.n_tasks):
            # make directory for saving task information
            fp = os.path.join(args.output_dir, str(i))
            if os.path.exists(fp):
                shutil.rmtree(fp)
            os.makedirs(fp)

            count = 0
            # generate a subtask graph and actual task
            subtask_graph = subtask_graph_generator(count=count, max_op=args.max_op, max_depth=args.max_depth,
                                                    select_limit=args.select_limit)
            subtask = tg.subtask_generation(subtask_graph)
            count = subtask_graph[2] + 1
            for _ in range(args.max_switch):
                if random.random() < args.switch_threshold:  # if add switch
                    new_task_graph = subtask_graph_generator(count=count, max_op=args.max_op, max_depth=args.max_depth,
                                                             select_limit=args.select_limit)
                    count = new_task_graph[2] + 1

                    conditional = subtask_graph_generator(count=count, max_op=args.max_op, max_depth=args.max_depth,
                                                          select_limit=args.select_limit,
                                                          root_op=random.choice(boolean_ops))
                    conditional_task = tg.subtask_generation(conditional)
                    count = conditional[2] + 1
                    if random.random() < 0.5:
                        do_if = subtask_graph
                        do_if_task = subtask
                        do_else = new_task_graph
                        do_else_task = tg.subtask_generation(do_else)
                    else:
                        do_if = new_task_graph
                        do_if_task = tg.subtask_generation(do_if)
                        do_else = subtask_graph
                        do_else_task = subtask
                    subtask_graph = switch_generator(conditional, do_if, do_else)
                    count = subtask_graph[2] + 1
                    subtask = tg.switch_generation(conditional_task, do_if_task, do_else_task)
            # TODO: some guess objset error where ValueError occurs
            # write_instance(subtask_graph, subtask, fp, args.img_size, args.n_trials)
            write_task_instance(subtask_graph, subtask, fp)
        stop = timeit.default_timer()
        print('Time taken to generate tasks: ', stop - start)
