In [1]:
import numpy as np 
import pickle
import time 

In [2]:
data_dir = 'data/20240308_data/'

attr_all = np.load(data_dir+'attr_all.npy') # (40, 12000, 3, 9, 3), (num_rules, num_samples, num_panels, num_pos, num_attributes)

with open(data_dir+'r_dict_M7_withR.pkl', 'rb') as file:
    r_dict_M7_withR = pickle.load(file)
    
with open(data_dir+'r_dict_M10_withR.pkl', 'rb') as file:
    r_dict_M10_withR = pickle.load(file)
    
with open(data_dir+'r_dict_num_withR.pkl', 'rb') as file:
    r_dict_num_withR = pickle.load(file)
    
with open(data_dir+'r_dict_pos_withR.pkl', 'rb') as file:
    r_dict_pos_withR = pickle.load(file)
    
def get_key_first2(a_list):
    """
    Convert a_list to key string.
    e.g., a_list = [[1,2], [1]], key='12-1'
    """
    key_parts = [''.join(map(str, sublist)) for sublist in a_list[:2]]
    return '-'.join(key_parts)

def get_attr(attr):
    """attr: (3, 9, 3)"""
    x_num, x_pos, x_shape, x_color, x_size = [None] * 3, [None] * 3, [None] * 3, [None] * 3, [None] * 3

    for i, attr_panel in enumerate(attr):
        valid_pos = np.where((attr_panel == -1).sum(axis=1) == 0)[0]
        x_pos[i] = valid_pos.tolist()  
        x_num[i] = [len(valid_pos)]
        x_shape[i], x_color[i], x_size[i] = [np.unique(attr_panel[valid_pos, j]).tolist() for j in range(3)]

    return x_num, x_pos, x_shape, x_color, x_size

def check_rule(x, r_dict_withR):
    key_first2 = get_key_first2(x) 

    if key_first2 in r_dict_withR:
        d = r_dict_withR[key_first2]
        for key, value in d.items():
            if x[2] == value:
                return key  
    return None  

def check_row_rule(attr): 
    """attr: (3, 9, 3), return list of rule """
    x_num, x_pos, x_shape, x_color, x_size = get_attr(attr)
    
    rule_list = []
    R_shape = check_rule(x_shape, r_dict_M7_withR)
    if R_shape is not None: 
        rule_list.append(R_shape)

    R_color = check_rule(x_color, r_dict_M10_withR)
    if R_color is not None: 
        rule_list.append(R_color+10)

    R_size = check_rule(x_size, r_dict_M10_withR)
    if R_size is not None: 
        rule_list.append(R_size+20)
        
    R_num = check_rule(x_num, r_dict_num_withR)
    if R_num is not None: 
        rule_list.append(R_num+30)
        
    R_pos = check_rule(x_pos, r_dict_pos_withR)
    if R_pos is not None: 
        rule_list.append(R_pos+37)
        
    return rule_list

from collections import defaultdict
def check_rule_overlap(attr_list): 
    """
    Inputs: 
        attr_list: (3, 3, 9, 3), (3 rows, 3 panels, 9 pos, 3 attr)
    Outputs: 
        r3: list of rules appearing in all 3 rows 
        r2: list of rules appear in only 2 of the 3 rows 
    """
    rule_all = [check_row_rule(a) for a in attr_list] # rule_list for each row, e.g., [[0,1], [0,1], [0,2]]
    
    r_dict = defaultdict(int) # e.g., {0: 3, 1: 2, 2: 1}, key=rule_ind, value=number of occurance in 3 rows
    for rule in rule_all:
        for x in rule: 
            if x not in r_dict.keys(): 
                r_dict[x] = 0 
            r_dict[x] += 1 

    r3, r2 = [], []
    for k, v in r_dict.items():
        if v == 3:
            r3.append(k)
        elif v == 2:
            r2.append(k)
    
    return r3, r2

def check_r3_r2_batch(attr_sample): 
    """
    Inputs: 
        attr_sample: e.g., (4000, 3, 3, 9, 3), (num_samples, 3 rows, 3 panels, 9 pos, 3 attrs)
    Outputs: 
        r3_all: list, rule that appear in all 3 rows for each sample 
        r2_all: list, rule that appear in 2 of 3 rows 
        
    TODO: multiple CPU cores? concurrent.futures
    """
    r3_all = []
    r2_all = []
    for attr_list in attr_sample: 
        r3, r2 = check_rule_overlap(attr_list)

        r3_all.append(r3)
        r2_all.append(r2)
        
    return r3_all, r2_all 

In [3]:
attr_all.shape

(40, 12000, 3, 9, 3)

In [4]:
attr_sample = attr_all[0].reshape((-1, 3, 3, 9, 3))

In [5]:
attr_sample.shape

(4000, 3, 3, 9, 3)

In [6]:
start_time = time.time()
r3_all, r2_all = check_r3_r2_batch(attr_sample)
dur = time.time() - start_time 

In [7]:
dur

1.3459701538085938