In [1]:
import argparse
import os
import sys
import time
import numpy as np
from PIL import Image

import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import urllib
import time

This is following the model from LucidRains with some advisement from the original code

https://github.com/lucidrains/ESBN-pytorch

# helper functions

In [2]:
def exists(val):
    return val is not None

In [3]:
def safe_cat(t, el, dim =0):
    if not exists(t):
        return el
    return tf.concat((t, el), axis = dim)

In [4]:
def map_fn(fn, *args, **kwargs):
    def inner(*arr):
        return map(lambda t: fn(t, *args, **kwargs), arr)
    return inner

# the Class

In [5]:
class ESBN(keras.layers.Layer):
    def __init__(
        self,
        *,
        value_dim = 64,
        key_dim = 64,
        hidden_dim = 512,
        output_dim = 4,
        encoder = None
    ):
        super().__init__()
        self.h0 = tf.zeros(hidden_dim)
        self.c0 = tf.zeros(hidden_dim)
        self.k0 = tf.zeros(key_dim + 1)
        
        self.rnn = tf.keras.layers.LSTMCell(hidden_dim)  #What is the difference between this and just LSTM
        self.to_gate = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid) #?
        self.to_key = tf.keras.layers.Dense(key_dim)
        self.to_output = tf.keras.layers.Dense(output_dim)
        
        self.encoder = tf.keras.Sequential()
        self.encoder.add(tf.keras.layers.Conv2D(32, kernel_size=4, strides=2, 
                                                activation=tf.keras.activations.relu))
        self.encoder.add(tf.keras.layers.Conv2D(64, kernel_size=4, strides=2,
                                                activation=tf.keras.activations.relu))
        self.encoder.add(tf.keras.layers.Conv2D(64, kernel_size=4, strides=2,
                                                activation=tf.keras.activations.relu))
        self.encoder.add(tf.keras.layers.Flatten())
        self.encoder.add(tf.keras.layers.Dense(value_dim))
                         # if not exists(encoder) else encoder What does???
            
        self.to_confidence = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid)
        
    def forward(self, images):
        b = images.shape[1]
        Mk = None
        Mv = None
        
        hx, cx, kx, k0 = map_fn(repeat, 'd -> b d', b = b)(self.h0, self.c0, self.k0, self.k0)
        out =[] #note, there could be issues with lists. aslo consider a better way becaues other languages bla bla
        
        for ind, image in enumerate(images):
            is_first = ind == 0
            z = self.encoder(image)
            hx, cx = self.rnn(kx, (hx, cx)) #return state?
            y, g, kw = self.to_output(hx), self.to_gate(hx), self.to_key(hx)
            
            # if is_first: #redundent?
            #     kx = k0
            # else:
            if not is_first:
                # attention
                sim = tf.keras.layers.EinsumDense('b n d, b d -> b n',
                                                  activations=tf.keras.activations.linear)([Mv, z])
                wk = tf.keras.layers.Activation(activation=tf.keras.activations.softmax)(sim)
                # sim = tf.einsum('b n d, b d -> b n', Mv, z)
                # wk = sim.tf.softmax(-1) #????
                
                # calculate confidence
                sim, wk = map_fn(rearrange, 'b n -> b n ()')(sim, wk) #consider
                ck = self.to_confidence(sim)
                
                #kx = g.sigmoid() * (wk * torch.cat((Mk, ck), dim = -1)).sum(dim = 1)
                #g already has the sigmoid attached
                #make a cat layer
                #lambda layer to do the sum
                #be careful
                #another for the 
                cc = tf.keras.layers.concatenate([Mk, ck], axis=-1) #verified to work the same
                cc = tf.math.reduce_sum(wk * cc, axis = 1) #same as torch.sum
                kx = g * cc
                
            kw, z = map_fn(rearrange, 'b d -> b () d')(kw, z)
            Mk = safe_cat(Mk, kw, dim = 1)
            Mv = safe_cat(Mv, z, dim = 1)
            out.append(y)
            
        return tf.stack(out) #because this is a list of tensors it should work
                

# Same-Difference

In [6]:
#This is all part of the args mess in their code
#initially setting this to be theri defaults

#Model Settings
model_name = 'ESBN'
norm_type = 'contextnorm'
encoder = 'conv'
#Task settings
task = 'same_diff'
train_gen_mathod = 'full_space'
n_shapes = 100 #total num of shapes available for training/testing
m_holdout = 0 #number of objects (out of n) withheld during training
#Training Settings
train_batch_size = 32
train_set_size = 10000
train_proportion = 0.95
lr = 5e-4
epochs = 50
log_interval = 10
#Test settings
test_batch_size = 100
test_set_size = 10000
#Device Settings
no_cuda = False #Actions??? 'store_true'
device = 0 
#run number
run = 1

In [9]:
all_shapes = np.arange(n_shapes)
np.random.shuffle(all_shapes)
all_shapes

array([10, 16, 50, 74, 25, 99,  0, 65, 89, 91, 73,  1, 30, 64,  9, 48, 38,
       39, 21, 77, 79, 46, 90, 44, 40, 22, 31, 70, 80, 87, 26, 36, 81, 53,
       35, 69, 83, 17, 78, 75, 66, 32, 52, 33, 47, 57, 49, 51, 94, 14, 98,
       42, 92, 54, 76, 27, 56, 45, 59,  5, 85, 72, 18, 84, 93, 97, 55, 63,
        2, 11, 15, 12, 13, 71, 34, 58, 37, 19, 62, 86,  8, 61, 29, 41, 24,
       23, 60, 43, 88, 28,  3,  6,  7, 82, 67, 96, 20,  4, 95, 68])

In [10]:
if(m_holdout > 0):
    train_shapes = all_shapes[m_holdout:]
    test_shapes = all_shapes[:m_holdout]
else:
    train_shapes = all_shapes
    test_shapes = all_shapes

In [17]:
#now we're in their function
import sys
import random
import numpy as np

# Prevent python from saving out .pyc files
sys.dont_write_bytecode = True

y_dim = 2
seq_len = 2

In [14]:
# If m = 0, training and test sets are drawn from same set of shapes
if(m_holdout == 0):
    # Total number of possible trials
    shapes_avail = n_shapes
    total_trials = (shapes_avail * (shapes_avail - 1)) * 2
    
    # Proportion of training set size vs test set size
    test_proportion = 1 - train_proportion
    
    # Create training/test set sizes
    train_set_size = np.round(train_proportion * total_trials).astype(int)
    test_set_size = np.round(test_proportion * total_trials).astype(int)
    
else: 
    # Ensure that there are enough potential trials for desired training set size (or change train set size)
    shapes_avail = n_shapes - m_holdout
    total_trials = (shapes_avail * (shapes_avail - 1)) * 2
    
    if(train_set_size > total_trials):
        train_set_size = total_trials
    
    # Ensure that there are enough potential trials for desired test set size (or change test set size)
    shapes_avail = n_shapes - (n_shapes - m_holdout)
    total_trials = (shapes_avail * (shapes_avail - 1)) * 2
    
    if(test_set_size > total_trials):
        test_set_size = total_trials
            

In [18]:
# If m = 0, training and test sets are drawn from same set of shapes
if m_holdout == 0:
    # Create all possible trials
    same_trials = []
    diff_trials = []
    for shape1 in train_shapes:
        for shape2 in train_shapes:
            if shape1 == shape2:
                same_trials.append([shape1, shape2])
            else:
                diff_trials.append([shape1, shape2])
    # Shuffle
    random.shuffle(same_trials)
    random.shuffle(diff_trials)
    # Split trials for train and test sets
    same_trials_train = same_trials[:np.round(train_proportion * len(same_trials)).astype(int)]
    same_trials_test = same_trials[np.round(train_proportion * len(same_trials)).astype(int):]
    diff_trials_train = diff_trials[:np.round(train_proportion * len(diff_trials)).astype(int)]
    diff_trials_test = diff_trials[np.round(train_proportion * len(diff_trials)).astype(int):]
# Otherwise, training and test sets are completely disjoint (in terms of the shapes that are used), and can be generated separately
else:
    # Create all possible training trials
    same_trials_train = []
    diff_trials_train = []
    for shape1 in train_shapes:
        for shape2 in train_shapes:
            if shape1 == shape2:
                same_trials_train.append([shape1, shape2])
            else:
                diff_trials_train.append([shape1, shape2])
    # Shuffle
    random.shuffle(same_trials_train)
    random.shuffle(diff_trials_train)
    # Create all possible test trials
    same_trials_test = []
    diff_trials_test = []
    for shape1 in test_shapes:
        for shape2 in test_shapes:
            if shape1 == shape2:
                same_trials_test.append([shape1, shape2])
            else:
                diff_trials_test.append([shape1, shape2])
    # Shuffle
    random.shuffle(same_trials_test)
    random.shuffle(diff_trials_test)

In [20]:
# Duplicate 'same' trials to match number of 'different' trials
same_trials_train_balanced = []
for t in range(len(diff_trials_train)):
    same_trials_train_balanced.append(same_trials_train[np.floor(np.random.rand()*len(same_trials_train)).astype(int)])
same_trials_test_balanced = []
for t in range(len(diff_trials_test)):
    same_trials_test_balanced.append(same_trials_test[np.floor(np.random.rand()*len(same_trials_test)).astype(int)])
# Combine all same and different trials for training set
all_train_seq = []
all_train_targ = []
for t in range(len(same_trials_train_balanced)):
    all_train_seq.append(same_trials_train_balanced[t])
    all_train_targ.append(0)
for t in range(len(diff_trials_train)):
    all_train_seq.append(diff_trials_train[t])
    all_train_targ.append(1)
# Combine all same and different trials for test set
all_test_seq = []
all_test_targ = []
for t in range(len(same_trials_test_balanced)):
    all_test_seq.append(same_trials_test_balanced[t])
    all_test_targ.append(0)
for t in range(len(diff_trials_test)):
    all_test_seq.append(diff_trials_test[t])
    all_test_targ.append(1)
# Shuffle trials in training set
train_ind = np.arange(len(all_train_seq))
np.random.shuffle(train_ind)
all_train_seq = np.array(all_train_seq)[train_ind]
all_train_targ = np.array(all_train_targ)[train_ind]
# Shuffle trials in test set
test_ind = np.arange(len(all_test_seq))
np.random.shuffle(test_ind)
all_test_seq = np.array(all_test_seq)[test_ind]
all_test_targ = np.array(all_test_targ)[test_ind]
# Select subset if desired dataset size is smaller than number of all possible trials
if (train_set_size + test_set_size) < total_trials:
    all_train_seq = all_train_seq[:train_set_size, :]
    all_train_targ = all_train_targ[:train_set_size]
    all_test_seq = all_test_seq[:test_set_size, :]
    all_test_targ = all_test_targ[:test_set_size]


In [21]:
# Create training and test sets
train_set = {'seq_ind': all_train_seq, 'y': all_train_targ}
test_set = {'seq_ind': all_test_seq, 'y': all_test_targ}

In [27]:
#back to train and eval at line 197