It has been observed in prior work that tasks sharing some underlying structure will exhibit overlapping neural activity. For instance, Yang et. al. (2019) trained RNNs on several cognitive tasks and observed clustering in the neural activity space: some clusters specialized to particular tasks, while others were shared between tasks. In theory, it is also possible to observe a completely distributed representation (i.e. no modular clusters). While Yang focused on sensory tasks, we aim to study tasks involving abstract relations: transitive inference and divisibility. These tasks are likely to have some common underlying structure, as both represent transitive relations. We will compare the neural geometry of the same RNN trained on one of these tasks at a time to that trained on both (using interleaving).

Questions we hope to answer: How will the neural representation of a given task change when more than one task is learned simultaneously? In the latter case, will we find that the activations shared between the two tasks are also present in some form when only one task is learned at a time? That is—does a neural network organize its activity differently when related tasks must be learned together? We will use RDM analysis and dimensionality reduction techniques to look for specialized clusters in neural activity space. We will then compare our networks using RSA/RDA, as well as dynamics-based methods such as DSA and fixed/slow point analysis.


### TO-DOs
- Fix the Loss Calculation for masked output from RNN
- Fix the data generation and training logistics (currently no shuffling)
- Analysis

### Imports

In [1]:
# General
import numpy as np
#import pandas as pd
#from scipy.stats import zscore
import random
#from statistics import mean

# Deep learning
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# # Response visualizations
# !pip install umap-learn
# import umap
#import matplotlib as mpl
#from matplotlib import pyplot as plt
#import seaborn as sns

# Set random seeds for reproducibility
np.random.seed(12)
torch.manual_seed(12)

<torch._C.Generator at 0x111c97af0>

### Tasks Example

In [2]:
# Transitive Inference: (i.e. index 0 is greater than index 1 = "A > B")
# Example stimulus: [[1, 0, 0],[0, 1, 0]]
# Example output: 0

# Subset Inclusion: (i.e. index 0 is greater than index 1 = "A \contains B")
# Example stimulus: [[1, 1, 1],[0, 1, 0]]
# Example output: 0 = "A \contains B"

# Divisibility:
# Example stimulus: [6,3]
# Example output: 0 = "A is divisble by B"

In [None]:
[1,0,0] < [1,0,0]

False

----
### RNN Task Design and Data Generation

In [2]:
# @title Task design
def int2bits(n, width):
    '''
    Convert an Integer to Its Binary Representation as a List/Array/Tensor
    e.g. n = 3, width = 4 -> [0, 0, 1, 1]
    '''
    return [int(b) for b in bin(n)[2:].zfill(width)] # bin() returns a str with prefix '0b'
    

def make_stimulus_dict(input_length, task='ti'):
    stimulus_order_dict = []

    if task == 'ti':
      for index in range(input_length):
        stimulus_key = [0]*input_length
        stimulus_key[index] = 1
        stimulus_order_dict.append(stimulus_key)
      random.shuffle(stimulus_order_dict)

    # elif task == "si":
    #   for index in range(input_length):
    #     stimulus_key = [0]*input_length


    return stimulus_order_dict

def generate_RNN_trial(trial_length, item_dictionary, item_length, output_size):
  items = item_dictionary.copy()
  sorting_func = lambda x : item_dictionary.index(x)
  random.shuffle(items)
  inputs = []
  inputs_flat = []
  outputs = []
  for i in range(trial_length):
    two_samples = [0,0]
    while two_samples[0] == two_samples[1]:
      two_samples = random.sample(items,2)
    inputs.extend([two_samples])
    inputs_flat.extend(two_samples)
    inputs_sorted = [sorted(inputs_flat, key=sorting_func)]
    for i in inputs_sorted:
      current_size = len(i)
      dummy_vec = [0]*item_length
      i.extend([dummy_vec]*(output_size-current_size))
    outputs.extend(inputs_sorted)

  inputs = torch.tensor(inputs).float()
  outputs = torch.tensor(outputs).float()

  return inputs, outputs

def get_RNN_batch(num_trials, trial_length, item_dictionary, item_length, pad=False):
  if pad:
    output_size = trial_length * 2
  else: output_size = 0

  trials = []
  for n in range(num_trials):
    trials.append(generate_RNN_trial(trial_length, item_dictionary, item_length, output_size))
  return trials

In [3]:
def get_batch(batch_size, task, stimulus_dict=None, input_length=3, mode='train'):

  # Helpers
  def get_binary_vec(input_length):
    return [random.randint(0, 1) for _ in range(input_length)]

  def make_stimulus(input_length, task='ti'):
    if task == 'ti':
      if mode == 'test':
        dict_indices = [0,0]
        while abs(dict_indices[0] - dict_indices[1]) < 2:
          dict_indices = random.sample(range(input_length), 2)
          random.shuffle(dict_indices)
      elif mode == 'train':
        temp_idx = random.sample(range(input_length-1), 1)
        dict_indices = [temp_idx[0], temp_idx[0]+1]
        random.shuffle(dict_indices)

      stimulus = [stimulus_dict[i] for i in dict_indices]

      if dict_indices[0] < dict_indices[1]: # If the first item is greater, it appears earlier in the list
        target = 0
      else:
        target = 1

    elif task == 'si':
      stimulus = [0,0]
      if mode == 'test':
        # while stimulus[0] == stimulus[1]:
        #   stimulus = [get_binary_vec(input_length), get_binary_vec(input_length)]
        first_vec = get_binary_vec(input_length)
        while stimulus[0] == stimulus[1]:
          temp_indices = random.sample(range(input_length), random.randint(2, input_length))
          second_vec = first_vec.copy()
          for i in temp_indices:
            if second_vec[i] == 1:
              second_vec[i] = 0
            # else:
            #   second_vec[i] = 1
          stimulus = [first_vec, second_vec]
      elif mode == 'train':
        first_vec = get_binary_vec(input_length)
        temp_idx = random.sample(range(input_length), 1)
        second_vec = first_vec.copy()
        while stimulus[0] == stimulus[1]:
          if second_vec[temp_idx[0]] == 1:
            second_vec[temp_idx[0]] = 0
          else:
            second_vec[temp_idx[0]] = 1
          stimulus = [first_vec, second_vec]

      random.shuffle(stimulus)
      if stimulus[0] > stimulus[1]: # If first binary number is greater, it is the superset
        target = 0
      else:
        target = 1

    # elif task == "div":
    #   stimulus = [random.sample(range(0, 100, 2),1)[0], random.sample(range(0, 100, 2),1)[0]]
    #   if stimulus[0] % stimulus[1] == 0:
    #     target = 0
    #   else:
    #     target = 1

    return stimulus, target

  stimuli = []
  labels = []
  for _ in range(batch_size):
    stimulus, target = make_stimulus(input_length, task)
    stimuli.append(stimulus)
    labels.append(target)

  stimuli = torch.tensor(stimuli).float()
  labels = torch.tensor(labels).float()

  if torch.cuda.is_available():
    stimuli = stimuli.to('cuda')
    labels = labels.to('cuda')

  return stimuli, labels

----
## Model Architecture

### Base Backbones

In [4]:
# 1. MLP
class MLP(nn.Module):
    '''
    MLP with Leaky ReLU as Intermediate Activations
    input_dim  :
    hidden_dim :
    output_dim :
    num_layers :
    dp_rate    : dropout rate
    '''
    def __init__(self, input_dim, hidden_dim, output_dim, 
                 num_layers = 2, dp_rate = 0.1):
        super(MLP, self).__init__()
        # Hyperparameters
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        
        c_in, c_out = input_dim, hidden_dim
        layers = []
        for lid in range(num_layers - 1):
            layers += [nn.Linear(c_in, c_out),
                       nn.LeakyReLU(inplace = True),
                       nn.Dropout(dp_rate)]
            c_in = hidden_dim
        layers += [nn.Linear(c_in, output_dim)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


# 2. Simple RNN
class SimpleRNN(nn.Module):
    '''
    RNN in its Simplest Form using nn.RNN
    Activation Choices: 'relu' or 'tanh'

        - input_dim  : input feat dim
        - hidden_dim : 
        - output_dim : 
        - num_layers : number of RNN units
    '''
    def __init__(self, input_dim, hidden_dim, output_dim = None, 
                 num_layers = 2, activation = 'relu'):
        super(SimpleRNN, self).__init__()
        # Hyperparams
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers,
                          batch_first = True, nonlinearity = activation)
        self.readout = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, h = None):
        '''
        Shape
        Inputs:
            - x      : [batch_size, seq_len, input_dim]
            - h      : [num_layers, batch_size, hidden_dim]
        ----
        Intermediate:
            - out    : [batch_size, seq_len, hidden_dim] # taking the last sequence in the out as the final output pred
            - hidden : [num_layers, batch_size, hidden_dim]
        ----
        Outputs:
            - output : [batch_size, output_dim] 
            - hidden : [num_layers, batch_size, hidden_dim] 
        '''
        out, hidden = self.rnn(x, h)
        # Return pred of each seq in the batch
        output = self.readout(out[:, -1,:]) 
        return output, hidden

In [15]:
class OrderPredRNN(nn.Module):
    '''
    Task specific RNN
    Activation Choices: 'relu' or 'tanh'

        - input_dim  : input feat dim
        - hidden_dim : 
        - n_pairs    : 
        - num_layers : number of RNN units
    '''
    def __init__(self, input_dim, hidden_dim, n_pairs, 
                 num_layers = 2, activation = 'relu'):
        super(OrderPredRNN, self).__init__()
        # Hyperparams
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_pairs = n_pairs
        self.num_layers = num_layers
        
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers,
                          batch_first = True, nonlinearity = activation)
        self.readout = nn.Linear(hidden_dim, 4 * n_pairs * n_pairs)

    def forward(self, x, h = None):
        '''
        Shape
        Inputs:
            - x      : [batch_size, n_pairs, input_dim]
            - h      : [num_layers, batch_size, hidden_dim]
        ----
        Intermediate:
            - out    : [batch_size, n_pairs, hidden_dim] # k-th vector in the n_pair dimension corresponds to information for the first k-pairs of inputs
            - hidden : [num_layers, batch_size, hidden_dim]
        ----
        Outputs:
            - output : [batch_size, n_pairs, 2n_pairs, 2*n_pairs ] each matrix in n_pairs dim is a 2*n_pairs x 2*n_pairs matrix indicating the probability of each element to be present at each position 
            - hidden : [num_layers, batch_size, hidden_dim] Abandoned
        '''
        out, hidden = self.rnn(x, h)
        # Return pred of each seq in the batch
        output = self.readout(out)  # everything 
        
        output = output.view(-1, self.n_pairs, 2*self.n_pairs, 2*self.n_pairs)
        output = F.softmax(output, dim = -1)
        return output

```
output, hidden = RNN(input)_torch
input (seq_len, feat_dim) : (2, 1)
output (seq_len, hidden_dim)

hidden: [num_rnn_unit, batch_size, hidden_dim = 10]

readout(hidden): [num_rnn_unit, batch_size, output_dim] 

output(-1, hidden_dim)

RNN - Fibbonaci 
(1, 1) -> [1, 1, (2)]
(1, 2) -> [1, 1, 2, (3)]
============
0: (1, 2) - > [1, 2, 0, 0, 0, 0] 
1: (3, 2) - > [1, 2, 2, 3, 0, 0]
2: (4, 7) - > [1, 2, 2, 2, 3, 7]


window_size = 2
batch_size = 3
expected_output_size = 6
readout(hidden): [num_rnn_unit, batch_size = 3, output_dim = 1] 
```


----
### Hybrid Model

In [31]:
class MyModel(nn.Module):
    '''
    A Higher Level Model for Three Tasks that Administrates 
      - Trunnk Part to Generate Common Hidden Representations
      - Multiple Downstreaming Readout Heads

      Trunk Net : Produces the relation for each input pair as the context input
      RNN       : input(original_inputs, hidden = output of trunk_net)

      ##### Trunk Net's output dimension needs to be an even number for RNN to work
      ##### Requires Handling of Hidden States for RNN 
      #####
    '''
    def __init__(self, trunk_net, hidden_dim, n_pairs, output_dim = 1,
                num_layers = 2):
        '''
        
        '''
        super(MyModel, self).__init__()
        # Parameters
        self.hidden_dim = hidden_dim # Even number
        self.output_dim = output_dim
        self.n_pairs = n_pairs
        self.num_layers = num_layers
        
        self.trunk = trunk_net

        self.fc1 = nn.Linear(n_pairs * trunk_net.output_dim, trunk_net.output_dim)
            
        self.head_mlp = MLP(trunk_net.output_dim, hidden_dim, output_dim = output_dim, num_layers = num_layers)
        self.head_rnn = OrderPredRNN(trunk_net.input_dim, hidden_dim = trunk_net.output_dim, 
                                     n_pairs = n_pairs, num_layers = num_layers) ###

    def forward(self, x, task_type = 0): # pred = model(x, task_type = 1) to use the FF
        '''
        ### Task Type definition:
            - 0 : 
            - 1 :  other types FF

            - x : [batch_size, n_pairs, 2 * feat_dim = trunk_net.input_dim]

            -out: [batch_size, n_pairs, 2*n_pairs, 2*n_pairs] for task 0
                : [batch_size, n_pairs, output_dim] for task 1
        '''
        out_trunk = self.trunk(x) # [batch_size, n_pairs, hidden_dim = trunk_net.output_dim]
        batch_size = x.shape[0]
        if 0 == task_type:
            # Prepare the trunk output as the hidden state/context, [num_layers, batch_size, trunk_net.output_dim]
            out_trunk = self.fc1(out_trunk.reshape(batch_size, -1))
            hidden = out_trunk.expand((self.num_layers, batch_size, -1))
            
            out = self.head_rnn(x, hidden) # [batch_size, n_pairs, 2*n_pairs, 2*n_pairs]
        else:
            out = nn.functional.sigmoid(self.head_mlp(out_trunk)).squeeze().flatten() # [batch_size*n_pairs, -1]

        return out

    def get_mask(self, batch_size):
        '''
        mask of shape [batch_size, n_pairs, 2*n_pairs, 2*n_pairs]
        for the output
        '''
        mask = torch.zeros((self.n_pairs, 2*self.n_pairs, 2*self.n_pairs), dtype = torch.bool)
        for k in range(self.n_pairs):
            mask[k, :2*(k+1), :2*(k+1)] = 1
        return mask.expand((batch_size, self.n_pairs, 2*self.n_pairs, 2*self.n_pairs))
        
    def infer_orders(self, x, one_hot = True):
        '''
        x       : [batch_size, n_pairs, 2 * feat_dim = trunk_net.input_dim]
        returns : [batch_size, n_pairs, 2*n_pairs, feat_dim]
        '''
        batch_size = x.shape[0]
        feat_dim = self.trunk.input_dim // 2
        orders = torch.zeros((batch_size, self.n_pairs, 2*self.n_pairs, feat_dim))

        # Creating a copy of all the input accessed by index
        element_list = x.clone().reshape((batch_size, 2 * self.n_pairs, -1))
        
        logits = self.forward(x) # [batch_size, n_pairs, 2*n_pairs, 2*n_pairs]
        mask = self.get_mask(batch_size)

        masked_logits = logits * mask
        # for each sequence(0 ~ n_pairs) of length 2*n_pairs
        orders_with_indices = masked_logits.argmax(axis = -1) # [batch_size, n_pairs, 2*n_pairs]
        
        for b in range(batch_size):
            for k in range(self.n_pairs):
                for i in range(2*(k+1)):
                    orders[b, k, i] = element_list[b, orders_with_indices[b, k, i]]
        return orders
        

In [18]:
### Custom loss for RNN task
def loss_RNNtask(y_p, y_t):
    '''
    y_p : [batch_size, n_pairs, 2*n_pairs, 2*n_pairs]
    y_t : [batch_size, n_pairs, 2*n_pairs]
    
    ### May need to adjust the calculation since the masked out entries are still used for loss ###
    '''
    # Reshape to fit nn.CrossEntropyLoss()
    batch_size, n_pairs = y_t.shape[:2]
    criterion = nn.CrossEntropyLoss()
    y_p = y_p.view((batch_size*n_pairs, 2*n_pairs, 2*n_pairs))
    y_t = y_t.view((batch_size*n_pairs, -1))
    return criterion(y_p, y_t)

In [19]:
# Generic Train Method ### Does not handle the devices for now
def test(model, dataloader, task_type = 0):
    '''
    task_type : 0 for RNN; 1 for FF
    '''
    # Eval Mode
    model.eval()
    total_acc = 0.0
    total_loss = 0.0
    loss_func = nn.BCELoss() if task_type else loss_RNNtask
    with torch.no_grad():
        for x, y_t in dataloader:
            b_size = x.shape[0]
            # Bemindful of shape changes
            y_p = model(x, task_type = task_type)

            if 0 == task_type:
                '''
                mask : [batch_size, n_pairs, 2*n_pairs, 2*n_pairs]
                y_p  : [batch_size, n_pairs, 2*n_pairs, 2*n_pairs]
                y_t  : [batch_size, n_pairs, 2*n_pairs]
                '''
                mask = model.get_mask(b_size)
                y_p = y_p * mask 
                y_t = y_t * mask[:,:,:,0]
            # Calculate Loss
            loss = loss_func(y_p, y_t).clone().detach()
            
            if 0 != task_type:
                # Calculate Acc only for FF task
                preds = y_p > 0.5
                acc = (preds == y_t).sum().item() / len(y_t)
                total_acc += acc
            
            # Record statistics
            total_loss += loss
            
    # Back to Train Mode
    model.train()
    return total_loss / len(dataloader), total_acc / len(dataloader)


def train(model, train_dataloader, test_dataloader = None, n_epochs = 20, 
          task_type = 0, lr = 0.01):
    '''
    task_type = 0 for RNN 1 for FF
    
    '''
    loss_lst = []
    acc_lst = []
    test_acc_lst = []

    loss_func = nn.BCELoss() if task_type else loss_RNNtask
    # Using Adam by default
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    model.train()
    for e in range(n_epochs + 1):
        for x, y_t in train_dataloader:
            optimizer.zero_grad() # Clear gradients
            b_size = x.shape[0]
            y_p = mymodel(x, task_type = task_type)
            
            if 0 == task_type:
                '''
                mask : [batch_size, n_pairs, 2*n_pairs, 2*n_pairs]
                y_p  : [batch_size, n_pairs, 2*n_pairs, 2*n_pairs]
                y_t  : [batch_size, n_pairs, 2*n_pairs]
                '''
                mask = model.get_mask(b_size)
                y_p = y_p * mask 
                y_t = y_t * mask[:,:,:,0]

            loss = loss_func(y_p, y_t)
            # Backward Pass
            #loss.backward(retain_graph=True) ### Is the same graph accessed multiple times over different backward pass?
            loss.backward()
            optimizer.step()

            # Record Statistics
            loss_lst.append(loss.clone().detach())
            acc = 0.0
            if 0 != task_type:
                # Calculate Acc only for FF task
                preds = y_p > 0.5
                acc = (preds == y_t).sum().item() / len(y_t)
                acc_lst.append(acc)

        if 0 == (e) % 10:
            t_loss = 0.0
            t_acc = 0.0
            if test_dataloader is not None:
                t_loss, t_acc = test(model, test_dataloader, task_type, loss_func)
            print(f'Epoch [{e}/{n_epochs}]:\t--LastBatchLoss:{loss:.3f}, TrainAcc:{acc:.3f}, TestLoss:{t_loss:.3f}, TestAcc:{t_acc:.3f}')

    return loss_lst, acc_lst

### Example Use (RNN Task)

In [9]:
def undo_one_hot(item_list):
  try:
    converted = []
    for sub_list in item_list:
      new_sub_list = []
      for item in sub_list:
        new_item = [key for key in items_dict.keys() if (torch.tensor(items_dict[key]) == item).all()]
        if len(new_item):
            new_sub_list.append(new_item[0])
        else:
            new_sub_list.append(0)
        '''
        for key in items_dict.keys():
          if (torch.tensor(items_dict[key]) == item).all():
              new_sub_list.append(key)
          else:
              new_sub_list.append(0)
        '''
      converted.append(new_sub_list)
    return converted
  except:
    return item_list

def to_indices(converted_in, converted_out):
    '''
    Convert the de-one-hot version of output further into sequences of input elements indices
    '''
    new_lst = []
    for trial_in, trial_out in zip(converted_in, converted_out):
        new_trial = []
        flat_trial_in = torch.tensor(trial_in).flatten().tolist()
        #print(flat_trial_in)
        for i in range(len(trial_in)):
            #print(flat_trial_in[: 2*(i+1)])
            indices_lst = {}
            for idx, n in enumerate(flat_trial_in[: 2*(i+1)]):
                if n not in indices_lst.keys():
                    indices_lst[n] = idx
            indices_lst[0] = 0
            #print(indices_lst)
            #print('='*50)
            new_trial.append([indices_lst[n] for n in trial_out[i]])
        new_lst.append(new_trial)
    return new_lst


In [10]:
'''
Prepare Data for RNN tasks
'''
params_dict = {
    'batch_size': 10, # (Not Used) Should be the same as n_pairs * n_trials
    'n_epochs': 200,
    'dataset_size': 1000, # (Not Used)
    'learning_rate': 0.0005, # (Not Used)
    'momentum': .99, # (Not Used)
    'task': 'si',
    'n_trials':100, 
    'input_length': 10,
    'n_pairs':10, 
    'hidden_dim': 6
}

# Prepare Data
stimulus_dict = make_stimulus_dict(params_dict['input_length'])

items_dict = {i : stimulus_dict[i] for i in range(len(stimulus_dict))}
#items_rank_dict = {v: k for k, v in items_dict.items()}

trials = get_RNN_batch(params_dict['n_trials'], params_dict['n_pairs'], stimulus_dict, params_dict['input_length'], pad=True)

inputs_converted = []
outputs_converted = []
# Hard coded for now 
labels = torch.tensor([0, 1, 1, 0, 1])

for i in range(len(trials)):
  inputs_converted.append(undo_one_hot(trials[i][0]))
  outputs_converted.append(undo_one_hot(trials[i][1]))

processed_tgt = to_indices(inputs_converted, outputs_converted)

for i in range(len(trials)):
  print(f'trial {i+1}:')
  print(f'inputs: {inputs_converted[i]}')
  print(f'outputs: {outputs_converted[i]}\n')
  print(f'processed targets: {processed_tgt[i]}')
  print(f'labels: {labels}')

trial 1:
inputs: [[9, 8], [0, 8], [5, 3], [6, 1], [9, 2], [1, 8], [1, 5], [3, 7], [9, 6], [4, 2]]
outputs: [[8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 8, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 3, 5, 8, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 3, 5, 6, 8, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 5, 6, 8, 8, 9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 2, 3, 5, 6, 8, 8, 8, 9, 9, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 2, 3, 5, 5, 6, 8, 8, 8, 9, 9, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 2, 3, 3, 5, 5, 6, 7, 8, 8, 8, 9, 9, 0, 0, 0, 0], [0, 1, 1, 1, 2, 3, 3, 5, 5, 6, 6, 7, 8, 8, 8, 9, 9, 9, 0, 0], [0, 1, 1, 1, 2, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 8, 9, 9, 9]]

processed targets: [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 5, 4, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 7, 5, 4, 6, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0,

In [11]:
'''
# Input shape of the model must be [batch_size = num_trials, n_pairs, -1]
# Prepare the generated data for train/test 
'''
x = torch.tensor([t[0].reshape(params_dict['n_pairs'], 2*params_dict['input_length']).tolist() for t in trials])
processed_data = [[x, torch.tensor(processed_tgt)]]
# first trial in processed data
test_x = processed_data[0][0] # Same as 
test_y = processed_data[0][1]

In [32]:
'''
!!!Copied from above!!!
params_dict = {
    'batch_size': 10, # (Not Used) Should be the same as n_pairs * n_trials
    'n_epochs': 200,
    'dataset_size': 1000, # (Not Used)
    'learning_rate': 0.0005, # (Not Used)
    'momentum': .99, # (Not Used)
    'task': 'si',
    'n_trials':100, 
    'input_length': 10,
    'n_pairs':10, 
    'hidden_dim': 6
}
'''
trunk_net = MLP(params_dict['input_length'] * 2, params_dict['hidden_dim'], params_dict['hidden_dim'])

mymodel = MyModel(trunk_net, params_dict['hidden_dim'], params_dict['n_pairs'])


In [25]:
# Infering the order based on the input
mymodel.infer_orders(test_x)

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [1., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 

In [24]:
# Train on Task 0
loss_lst, acc_lst = train(mymodel, processed_data, test_dataloader = None, 
                          n_epochs = params_dict['n_epochs'], task_type = 0, lr = 0.1) # May need larger Learning Rate

Epoch [0/200]:	--LastBatchLoss:2.944, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [10/200]:	--LastBatchLoss:2.941, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [20/200]:	--LastBatchLoss:2.939, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [30/200]:	--LastBatchLoss:2.936, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [40/200]:	--LastBatchLoss:2.933, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [50/200]:	--LastBatchLoss:2.934, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [60/200]:	--LastBatchLoss:2.931, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [70/200]:	--LastBatchLoss:2.932, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [80/200]:	--LastBatchLoss:2.929, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [90/200]:	--LastBatchLoss:2.930, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [100/200]:	--LastBatchLoss:2.934, TrainAcc:0.000, TestLoss:0.000, TestAcc:0.000
Epoch [110/200]:	--LastBatchLoss:2.937, TrainAcc:0.000, TestLoss:

In [27]:
# Test Model
test(mymodel, processed_data, task_type = 0)

(tensor(2.9269), 0.0)

### Example Use (Transitive Inference)

In [28]:
stimulus_dict = make_stimulus_dict(params_dict['input_length'])
# Create 3 batches
TI_data = [get_batch(params_dict['n_pairs'] * params_dict['n_trials'], params_dict['task'], stimulus_dict, params_dict['input_length']) for _ in range(3)]
# Reshaping the input ** must be of shape [n_trials, n_pairs, 2 * feat_dim = trunk_net.input_dim]
TI_data = [(sti.reshape(params_dict['n_trials'], params_dict['n_pairs'], -1), tgt) for sti, tgt in TI_data]

In [33]:
loss_lst, acc_lst = train(mymodel, TI_data, n_epochs = params_dict['n_epochs'], task_type = 1)

Epoch [0/200]:	--LastBatchLoss:0.696, TrainAcc:0.496, TestLoss:0.000, TestAcc:0.000
Epoch [10/200]:	--LastBatchLoss:0.319, TrainAcc:0.962, TestLoss:0.000, TestAcc:0.000
Epoch [20/200]:	--LastBatchLoss:0.012, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [30/200]:	--LastBatchLoss:0.010, TrainAcc:0.997, TestLoss:0.000, TestAcc:0.000
Epoch [40/200]:	--LastBatchLoss:0.005, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [50/200]:	--LastBatchLoss:0.007, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [60/200]:	--LastBatchLoss:0.006, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [70/200]:	--LastBatchLoss:0.007, TrainAcc:0.999, TestLoss:0.000, TestAcc:0.000
Epoch [80/200]:	--LastBatchLoss:0.009, TrainAcc:0.998, TestLoss:0.000, TestAcc:0.000
Epoch [90/200]:	--LastBatchLoss:0.001, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [100/200]:	--LastBatchLoss:0.004, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [110/200]:	--LastBatchLoss:0.003, TrainAcc:1.000, TestLoss:

### Example Use (Subset Inclusion)

In [34]:
stimulus_dict = make_stimulus_dict(params_dict['input_length'])
# Create 3 batches
SI_data = [get_batch(params_dict['n_pairs'] * params_dict['n_trials'], 'si', stimulus_dict, params_dict['input_length']) 
           for _ in range(3)]
# Reshaping the input ** must be of shape [n_trials, n_pairs, 2 * feat_dim = trunk_net.input_dim]
SI_data = [(sti.reshape(params_dict['n_trials'], params_dict['n_pairs'], -1), tgt) for sti, tgt in SI_data]

In [35]:
loss_lst, acc_lst = train(mymodel, SI_data, n_epochs = params_dict['n_epochs'], task_type = 1)

Epoch [0/200]:	--LastBatchLoss:0.003, TrainAcc:0.999, TestLoss:0.000, TestAcc:0.000
Epoch [10/200]:	--LastBatchLoss:0.001, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [20/200]:	--LastBatchLoss:0.002, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [30/200]:	--LastBatchLoss:0.003, TrainAcc:0.999, TestLoss:0.000, TestAcc:0.000
Epoch [40/200]:	--LastBatchLoss:0.005, TrainAcc:0.998, TestLoss:0.000, TestAcc:0.000
Epoch [50/200]:	--LastBatchLoss:0.001, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [60/200]:	--LastBatchLoss:0.000, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [70/200]:	--LastBatchLoss:0.000, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [80/200]:	--LastBatchLoss:0.001, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [90/200]:	--LastBatchLoss:0.001, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000
Epoch [100/200]:	--LastBatchLoss:0.002, TrainAcc:0.999, TestLoss:0.000, TestAcc:0.000
Epoch [110/200]:	--LastBatchLoss:0.000, TrainAcc:1.000, TestLoss:

### Saving and Loading Model

In [36]:
# Save 
w_path = './weights/hybrid_test'
torch.save(mymodel.state_dict(), w_path)

In [37]:
# Load
params_dict = {
    'batch_size': 10, # (Not Used) Should be the same as n_pairs * n_trials
    'n_epochs': 200,
    'dataset_size': 1000, # (Not Used)
    'learning_rate': 0.0005, # (Not Used)
    'momentum': .99, # (Not Used)
    'task': 'si',
    'n_trials':100, 
    'input_length': 10,
    'n_pairs':10, 
    'hidden_dim': 6
}
pre_model = MyModel(trunk_net, params_dict['hidden_dim'], params_dict['n_pairs'])
pre_model.load_state_dict(torch.load('./weights/hybrid_test'))

<All keys matched successfully>

----

In [253]:
out = mymodel(test_x)
out.shape # [batch_size = n_trials, n_pairs, 2*n_pairs, 2*n_pairs] 

torch.Size([1, 5, 10, 10])

In [252]:
test_y.shape

torch.Size([1, 5, 10])

In [64]:
'''
k-th pair of input, matrix 2n_pairs x 2n_pairs (4 , 4)
2 = n_pairs
[a,b] -> [a, b, -, -] 
[c,b] -> [a, b, b, c]

index 
0 a
1 b
2 c
3 b

matrix 0 for input pair 0
[0.6, 0.4, 0, 0]    [0] argmax a
[0.2, 0.8, 0, 0]    [1]        b
[0, 0, 0, 0]        []
[0, 0, 0, 0]        []
sequence of order is of len 4
'''


torch.Size([1, 5, 10, 10])

In [248]:
trunk_net(test_x).shape

torch.Size([1, 5, 4])

----
### Scratch

In [None]:
# @title Task design
def int2bits(n, width):
    '''
    Convert an Integer to Its Binary Representation as a List/Array/Tensor
    e.g. n = 3, width = 4 -> [0, 0, 1, 1]
    '''
    return [int(b) for b in bin(n)[2:].zfill(width)] # bin() returns a str with prefix '0b'
    


def gen_batch(batch_size, n_elements = 3, 
              task = 'ti'):
    '''
    Generate a Batch of Stimuli (list) for the Desired Task
    
    Inputs:
    - batch_size   : (int) num of stimulus
    - n_elements   : (int) Total number of elements on the base set
    - task         : (str)
                     'ti' - Transitive Inference; without Reflexitivity (relation with itself)
                     'si' - Subset Inclusion; 
                     'div'- Divisibility; without 0

    Returns: [[A, B], label] x batch_size       ## of shape [batch_size, 2, n_elements]
    '''
    assert task in ['ti',
                    'si',
                    'div'], f'Requested Task [{task}] Not Supported!'
    stimuli = []
    if 'ti' == task.lower() or 'div' == task.lower():
        max_b_size = int(n_elements)
        # Transitive Inference
        if 'ti' == task.lower():
            # All possible instances
            stimulus_dict = np.eye(n_elements, dtype = int)
            np.random.shuffle(stimulus_dict)
            # Randomly sample two instances in the dict as a stimulus
            for pair_id in range(batch_size):
                dict_indices = random.sample(range(max_b_size), k = 2) # without replacement
                # target = which instance is greater
                target = 0 if dict_indices[0] > dict_indices[1] else 1
                stimuli.append([stimulus_dict[dict_indices].tolist(), target])
        else: # Divisibility
            for pair_id in range(batch_size):
                stimulus_dict = random.sample(range(1, max_b_size + 1), k = 2) # excluding 0
                # target = if the previous is divisible by the latter
                target = int(0 == stimulus_dict[0] % stimulus_dict[1])
                stimuli.append([stimulus_dict, target])
    # Subset Inclusion
    elif 'si' == task.lower():
        max_b_size = 2 ** int(n_elements)
        # randomly sample two instances
        for pair_id in range(batch_size):
            idx_A, idx_B = random.sample(range(max_b_size), k = 2)
            # target if the previous is a superset of the latter
            A, B = np.array(int2bits(idx_A, width = n_elements)), np.array(int2bits(idx_B, n_elements))
            element_indices_B = np.arange(n_elements)[1 == B]
            target = int(A[element_indices_B].all()) # if non-zero entries in B are also in A
            stimuli.append([[A.tolist(), B.tolist()], target])
    return stimuli

In [6]:
import itertools
def get_test_data(n):
    """Generate test data for relational tasks.
    """
    test_x = np.zeros((n, n, 2*n))
    for i, j in itertools.product(range(n), range(n)):
        test_x[i,j,i] = 1 
        test_x[i,j,n+j] = 1 
    test_x = torch.from_numpy(test_x).float()
    test_x = test_x/torch.sqrt(torch.tensor(2))
    return test_x

def get_transitive_data(n):
    """Generate training data for TI task.
    """
    test_x = get_test_data(n)
    x = test_x[tuple(zip(*([(i, i+1) for i in range(n-1)] + [(i+1, i) for i in range(n-1)])))]
    y = torch.tensor([1.]*(n-1)+[-1.]*(n-1))
    return x, y

In [12]:
x, y = get_transitive_data(5)

In [13]:
x

tensor([[0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.7071],
        [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071,
         0.0000]])

In [10]:
y.shape

torch.Size([8])