In [145]:
import numpy as np
import itertools
import time
import random
import json
from collections import Counter, defaultdict

In [4]:
def generate_cards(attributes, attr_order):
    
    cards = []
    idx_to_card = {}
    card_to_idx = {}
    attrs_to_idx = defaultdict(lambda: defaultdict(list))

    i = 0
    attr_vals = [attributes[attr] for attr in attr_order]
    for combo in itertools.product(*attr_vals):
        card = tuple(combo)
        cards.append(card)
        card_to_idx[card] = i
        idx_to_card[i] = card
        i += 1
        for attr_val, attr_typ in zip(combo, attr_order):
            attrs_to_idx[attr_typ][attr_val].append(i)
    
    assert len(cards) == len(set(cards))
    print(f'Generated {len(cards)} unqiue cards')                
    return cards, idx_to_card, card_to_idx, attrs_to_idx


cards, idx_to_card, card_to_idx, attrs_to_idx = generate_cards(attributes, attr_order=('color', 'fill', 'shape', 'config'))
cards[:5]

Generated 81 unqiue cards


[('red', 'void', 'square', 'XXO'),
 ('red', 'void', 'square', 'XOX'),
 ('red', 'void', 'square', 'OXX'),
 ('red', 'void', 'circle', 'XXO'),
 ('red', 'void', 'circle', 'XOX')]

In [15]:
def generate_card_pairs(cards, card_to_idx):
    '''
    find all combos of cards, filter down to the ones that share concepts.
    '''
    cardpairs = []

    for card1, card2 in itertools.product(cards, cards):
        if card1 != card2:
            matching_concepts = tuple(s1 if s1==s2 else '-' for s1,s2 in zip(card1,card2))
            if set(matching_concepts) != {'-'}:
                cardpairs.append(((card_to_idx[card1], card_to_idx[card2]), matching_concepts))
    print(f'Generated {len(cardpairs)} unqiue cardpairs')  
    return cardpairs

cardpairs = generate_card_pairs(cards, card_to_idx)
cardpairs[:5]

Generated 5184 unqiue cardpairs


[((0, 1), ('red', 'void', 'square', '-')),
 ((0, 2), ('red', 'void', 'square', '-')),
 ((0, 3), ('red', 'void', '-', 'XXO')),
 ((0, 4), ('red', 'void', '-', '-')),
 ((0, 5), ('red', 'void', '-', '-'))]

In [16]:
def match_concept_to_card(concept, card):
    '''
    Given a concept, determine if card matches. Identify nonConcepts and violatedConcepts
    
    Arguments:
        concept: ('red', 'void', '-', '-')
        card: ex1. ('red', 'void', 'triangle', 'XOX')
              ex2. ('green', 'void', 'square', 'OXX')
    Returns:
        match: bool. ex1. True,
                     ex2. False
        nonConcept:  ex1. ('-', '-', 'triangle', 'XOX')
                     ex2. ('-', '-', 'square', 'OXX')
        violatedConcept:    ex1. ('-', '-', '-', '-')
                     ex2. ('red', '-', '-', '-')
    '''
    nonConcept = []
    violatedConcept = []
    match = True
    for ct, cd in zip(concept, card):
        if ct == '-':
            nonConcept.append(cd)
            violatedConcept.append('-')
        else:
            if cd != ct:
                match = False
                violatedConcept.append(ct)
            else:
                violatedConcept.append('-')
            nonConcept.append('-')

    return match, tuple(nonConcept), tuple(violatedConcept)
    
print(match_concept_to_card(concept=('red', 'void', '-', '-'), card=('red', 'void', 'triangle', 'XOX')))
print(match_concept_to_card(concept=('-', '-', 'triangle', 'XOX'), card=('red', 'void', 'triangle', 'XOX')))
print(match_concept_to_card(concept=('red', 'void', '-', '-'), card=('green', 'void', 'square', 'OXX')))
print(match_concept_to_card(concept=('red', 'void', '-', '-'), card=('green', 'solid', 'square', 'OXX')))

(True, ('-', '-', 'triangle', 'XOX'), ('-', '-', '-', '-'))
(True, ('red', 'void', '-', '-'), ('-', '-', '-', '-'))
(False, ('-', '-', 'square', 'OXX'), ('red', '-', '-', '-'))
(False, ('-', '-', 'square', 'OXX'), ('red', 'void', '-', '-'))


In [17]:
def find_notConcepts(concept, attr_order, attributes):
    '''
    Given a concept, find all its opposites.
    
    Arguments:
        concept: tuple. ex.('red', 'void', '-', '-')
        attr_order: tuple. ex.('colors', 'fills', 'shapes', 'configs')
    Returns:
        nonConcepts: list of tuples. ex.[
                 ('red', 'dashed', '-', '-')
                 ('red', 'solid', '-', '-')
                 ('green', 'void', '-', '-')
                 ('green', 'dashed', '-', '-')
                 ('green', 'solid', '-', '-')
                 ('blue', 'void', '-', '-')
                 ('blue', 'dashed', '-', '-')
                 ('blue', 'solid', '-', '-')                 
        ]
    '''
    notConcepts = set()
    
    # ['colors', 'fills']
    keep_attributes = list(attr for con, attr in zip(concept, attr_order) if con != '-')
    # [['red', 'green', 'blue'], ['void', 'dashed', 'solid'], ['-'], ['-']]
    keep_vals = [attributes[attr] if attr in keep_attributes else ['-'] for attr in attr_order]

    for combo in itertools.product(*keep_vals):
        if combo != concept:
            notConcepts.add(combo)
    
    return notConcepts

notConcepts = find_notConcepts(concept=('red', 'void', '-', '-'), attr_order=('color', 'fill', 'shape', 'config'), attributes=attributes)
notConcepts

{('blue', 'dashed', '-', '-'),
 ('blue', 'solid', '-', '-'),
 ('blue', 'void', '-', '-'),
 ('green', 'dashed', '-', '-'),
 ('green', 'solid', '-', '-'),
 ('green', 'void', '-', '-'),
 ('red', 'dashed', '-', '-'),
 ('red', 'solid', '-', '-')}

In [18]:
def make_negative_cards(notConcepts, nonConcept):
    negative_cards = []
    for notC, nonC in itertools.product(notConcepts, [nonConcept]):
        neg_card = tuple((x+y).strip('-') for x,y in zip(notC, nonC))
        negative_cards.append(neg_card)
    return negative_cards

make_negative_cards(notConcepts, nonConcept=('-', '-', 'triangle', 'XOX'))

[('green', 'dashed', 'triangle', 'XOX'),
 ('blue', 'solid', 'triangle', 'XOX'),
 ('red', 'solid', 'triangle', 'XOX'),
 ('green', 'solid', 'triangle', 'XOX'),
 ('red', 'dashed', 'triangle', 'XOX'),
 ('blue', 'dashed', 'triangle', 'XOX'),
 ('blue', 'void', 'triangle', 'XOX'),
 ('green', 'void', 'triangle', 'XOX')]

In [8]:
def quick_match_concept_to_card(concept, card):
    '''
    Given a concept, determine if card matches. Identify nonConcepts and violatedConcepts
    
    Arguments:
        concept: ('red', 'void', '-', '-')
        card: ex1. ('red', 'void', 'triangle', 'XOX')
              ex2. ('green', 'void', 'square', 'OXX')
    Returns:
        match: bool. ex1. True,
                     ex2. False
    '''
    match = True
    for ct, cd in zip(concept, card):
        if ct != '-' and cd != ct:
            match = False
    return match

In [9]:
def match_concept_to_cards(concept, cards):
    '''
    Given a concept, return subset of cards that match.
    
    Arguments:
        concept: ('red', 'void', '-', '-')
        cards: list of all cards.
    Returns:
        matching_cards: list of cards.
    '''
    matching_cards = []
    for card in cards:
        if quick_match_concept_to_card(concept, card):
            matching_cards.append(card)
    return matching_cards

In [75]:
def gen_keys(cardpair, cards, attr_order, attributes, card_to_idx, num_pos, debug=False):
    
    card_idx_pair, concept = cardpair
    matching_cards = match_concept_to_cards(concept, cards)
    pos_key_card_indices = np.random.choice(len(matching_cards), size=min(num_pos, len(matching_cards)), replace=False)
    
    pos_negs = []
    for pos_key_card_idx in pos_key_card_indices:
        pos_key_card = matching_cards[pos_key_card_idx]
        match_bool, nonConcept, violatedConcept = match_concept_to_card(concept, pos_key_card)
        assert match_bool and set(violatedConcept) == {'-'}
        notConcepts = find_notConcepts(concept, attr_order, attributes)
        negative_key_cards = make_negative_cards(notConcepts, nonConcept)
        negative_key_indices = [card_to_idx[c] for c in negative_key_cards]
        pos_negs.append((card_to_idx[pos_key_card], tuple(negative_key_indices)))
        
        if debug:
            print('#############################')  
            print('cardpair', cardpair)
            print('concept', concept)
            print('pos key card', pos_key_card)
            print('nonConcept', nonConcept)
            print('notConcepts', notConcepts)
            print('negative_key_cards', negative_key_cards)
            print('#############################')  
    return pos_negs

In [152]:
def gen_card_data(attributes, attr_order, num_val=100, num_test=100, num_pos=5, debug=False):
    
    random.seed(42)
    np.random.seed(42)
    
    cards, idx_to_card, card_to_idx, attrs_to_idx = generate_cards(attributes, attr_order)
    cardpairs = generate_card_pairs(cards, card_to_idx)
    
    data = {}
    data['attributes'] = attributes
    data['attr_order'] = attr_order
    data['cards'] = cards
    data['idx_to_card'] = idx_to_card
    data['card_to_idx'] = card_to_idx
    data['attrs_to_idx'] = attrs_to_idx
    data['num_pos'] = num_pos
    
    random.shuffle(cardpairs)
    num_train = len(cardpairs) - num_val - num_test
    data['train_cardpairs'] = cardpairs[:num_train]
    data['valid_cardpairs'] = cardpairs[num_train:num_train+num_val]
    data['test_cardpairs'] = cardpairs[-num_test:]
    
    print('')
    print('Train cardpairs:', len(data['train_cardpairs']))
    print('Valid cardpairs:', len(data['valid_cardpairs']))
    print('Test cardpairs:', len(data['test_cardpairs']))
    assert not(set(data['train_cardpairs']) & set(data['valid_cardpairs']))
    assert not(set(data['valid_cardpairs']) & set(data['test_cardpairs']))
    assert not(set(data['train_cardpairs']) & set(data['test_cardpairs']))
    
    if debug:
        i = 10
        pos_negs = gen_keys(data['test_cardpairs'][i], cards, attr_order, attributes, card_to_idx, num_pos, False)
        print('\n-------Sample Data point for one loss term-------')
        print('test cardpair', idx_to_card[data['test_cardpairs'][i][0][0]], idx_to_card[data['test_cardpairs'][i][0][1]])
        print('concept:',  data['test_cardpairs'][i][1])
        print('\n-------Randomly Drawn +ve. card-------')
        j = 0
        print(idx_to_card[pos_negs[j][0]])
        print('\n-------Its negatives-------')
        for cid in pos_negs[j][1]:
            print(idx_to_card[cid])
            
    # loop through each test_cardpair, gen the set of pos_negs(i.e. max(5) sets of 1positive:3negative cards)
    data['test_pos_negs'] = []
    for cardpair in data['test_cardpairs']:
        data['test_pos_negs'].append(gen_keys(cardpair, cards, attr_order, attributes, card_to_idx, False))
        
    return data

## Generate Data

In [153]:
attributes = {
    'color': ['red', 'green', 'blue'],
    'fill': ['void', 'dashed', 'solid'],
    'shape': ['square', 'circle', 'triangle'],
    'config': ['XXO', 'XOX', 'OXX']
}

attr_order = ['color', 'fill', 'shape', 'config']

In [157]:
data = gen_card_data(attributes, attr_order, num_val=100, num_test=100, debug=True)

Generated 81 unqiue cards
Generated 5184 unqiue cardpairs

Train cardpairs: 4984
Valid cardpairs: 100
Test cardpairs: 100

-------Sample Data point for one loss term-------
test cardpair ('red', 'void', 'triangle', 'OXX') ('green', 'void', 'circle', 'XOX')
concept: ('-', 'void', '-', '-')

-------Randomly Drawn +ve. card-------
('red', 'void', 'triangle', 'OXX')

-------Its negatives-------
('red', 'dashed', 'triangle', 'OXX')
('red', 'solid', 'triangle', 'OXX')


## Sanity check

In [158]:
i = 20
pos_negs = gen_keys(data['test_cardpairs'][i], data['cards'], attr_order, attributes, data['card_to_idx'], False)
print('test_cardpairs[i]', data['idx_to_card'][data['test_cardpairs'][i][0][0]], data['idx_to_card'][data['test_cardpairs'][i][0][1]], data['test_cardpairs'][i][1])
print('--------------')
j = 0
print(data['idx_to_card'][pos_negs[j][0]])
print('--------------')
for cid in pos_negs[j][1]:
    print(data['idx_to_card'][cid])

test_cardpairs[i] ('green', 'void', 'triangle', 'XOX') ('green', 'dashed', 'triangle', 'OXX') ('green', '-', 'triangle', '-')
--------------
('green', 'solid', 'triangle', 'OXX')
--------------
('blue', 'solid', 'square', 'OXX')
('green', 'solid', 'circle', 'OXX')
('red', 'solid', 'triangle', 'OXX')
('red', 'solid', 'square', 'OXX')
('blue', 'solid', 'circle', 'OXX')
('green', 'solid', 'square', 'OXX')
('red', 'solid', 'circle', 'OXX')
('blue', 'solid', 'triangle', 'OXX')


In [159]:
data['test_pos_negs']

[[(80,
   (17,
    71,
    15,
    69,
    35,
    43,
    33,
    79,
    26,
    7,
    53,
    24,
    51,
    44,
    78,
    42,
    8,
    61,
    6,
    16,
    70,
    34,
    62,
    60,
    25,
    52)),
  (77,
   (14,
    68,
    12,
    66,
    32,
    40,
    30,
    76,
    23,
    4,
    50,
    21,
    48,
    41,
    75,
    39,
    5,
    58,
    3,
    13,
    67,
    31,
    59,
    57,
    22,
    49)),
  (74,
   (11,
    65,
    9,
    63,
    29,
    37,
    27,
    73,
    20,
    1,
    47,
    18,
    45,
    38,
    72,
    36,
    2,
    55,
    0,
    10,
    64,
    28,
    56,
    54,
    19,
    46))],
 [(6, (0, 3)), (17, (11, 14)), (62, (56, 59))],
 [(18,
   (0,
    33,
    24,
    51,
    42,
    57,
    66,
    54,
    12,
    63,
    78,
    21,
    30,
    48,
    6,
    9,
    45,
    75,
    39,
    27,
    36,
    60,
    69,
    3,
    72,
    15)),
  (19,
   (1,
    34,
    25,
    52,
    43,
    58,
    67,
    55,
    13,
    64,
    79,
   

In [160]:
len(data['test_pos_negs'])

100

In [161]:
sum([len(x) for x in data['test_pos_negs']])

300