It has been observed in prior work that tasks sharing some underlying structure will exhibit overlapping neural activity. For instance, Yang et. al. (2019) trained RNNs on several cognitive tasks and observed clustering in the neural activity space: some clusters specialized to particular tasks, while others were shared between tasks. In theory, it is also possible to observe a completely distributed representation (i.e. no modular clusters). While Yang focused on sensory tasks, we aim to study tasks involving abstract relations: transitive inference and divisibility. These tasks are likely to have some common underlying structure, as both represent transitive relations. We will compare the neural geometry of the same RNN trained on one of these tasks at a time to that trained on both (using interleaving).

Questions we hope to answer: How will the neural representation of a given task change when more than one task is learned simultaneously? In the latter case, will we find that the activations shared between the two tasks are also present in some form when only one task is learned at a time? That is—does a neural network organize its activity differently when related tasks must be learned together? We will use RDM analysis and dimensionality reduction techniques to look for specialized clusters in neural activity space. We will then compare our networks using RSA/RDA, as well as dynamics-based methods such as DSA and fixed/slow point analysis.


### Imports

In [1]:
# General
import numpy as np
#import pandas as pd
#from scipy.stats import zscore
import random
#from statistics import mean

# Deep learning
import torch
from torch import nn, optim
from torch.nn import GRUCell, RNNCell
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# # Response visualizations
# !pip install umap-learn
# import umap
#import matplotlib as mpl
#from matplotlib import pyplot as plt
#import seaborn as sns

# Set random seeds for reproducibility
np.random.seed(12)
torch.manual_seed(12)

<torch._C.Generator at 0x111fd7b30>

### Tasks Example

In [None]:
# Transitive Inference: (i.e. index 0 is greater than index 1 = "A > B")
# Example stimulus: [[1, 0, 0],[0, 1, 0]]
# Example output: 0

# Subset Inclusion: (i.e. index 0 is greater than index 1 = "A \contains B")
# Example stimulus: [[1, 1, 1],[0, 1, 0]]
# Example output: 0 = "A \contains B"

# Divisibility:
# Example stimulus: [6,3]
# Example output: 0 = "A is divisble by B"

In [None]:
[1,0,0] < [1,0,0]

False

----
### Task Design and Data Generation

In [2]:
# @title Task design
def int2bits(n, width):
    '''
    Convert an Integer to Its Binary Representation as a List/Array/Tensor
    e.g. n = 3, width = 4 -> [0, 0, 1, 1]
    '''
    return [int(b) for b in bin(n)[2:].zfill(width)] # bin() returns a str with prefix '0b'
    


def gen_batch(batch_size, n_elements = 3, 
              task = 'ti'):
    '''
    Generate a Batch of Stimuli (list) for the Desired Task
    
    Inputs:
    - batch_size   : (int) num of stimulus
    - n_elements   : (int) Total number of elements on the base set
    - task         : (str)
                     'ti' - Transitive Inference; without Reflexitivity (relation with itself)
                     'si' - Subset Inclusion;
                     'div'- Divisibility; without 0

    Returns: [[A, B], label] x batch_size       ## of shape [batch_size, 2, n_elements]
    '''
    assert task in ['ti',
                    'si',
                    'div'], f'Requested Task [{task}] Not Supported!'
    stimuli = []
    if 'ti' == task.lower() or 'div' == task.lower():
        max_b_size = int(n_elements)
        # Transitive Inference
        if 'ti' == task.lower():
            # All possible instances
            stimulus_dict = np.eye(n_elements, dtype = int)
            np.random.shuffle(stimulus_dict)
            # Randomly sample two instances in the dict as a stimulus
            for pair_id in range(batch_size):
                dict_indices = random.sample(range(max_b_size), k = 2) # without replacement
                # target = which instance is greater
                target = 0 if dict_indices[0] > dict_indices[1] else 1
                stimuli.append([stimulus_dict[dict_indices].tolist(), target])
        else: # Divisibility
            for pair_id in range(batch_size):
                stimulus_dict = random.sample(range(1, max_b_size + 1), k = 2) # excluding 0
                # target = if the previous is divisible by the latter
                target = int(0 == stimulus_dict[0] % stimulus_dict[1])
                stimuli.append([stimulus_dict, target])
    # Subset Inclusion
    elif 'si' == task.lower():
        max_b_size = 2 ** int(n_elements)
        # randomly sample two instances
        for pair_id in range(batch_size):
            idx_A, idx_B = random.sample(range(max_b_size), k = 2)
            # target if the previous is a superset of the latter
            A, B = np.array(int2bits(idx_A, width = n_elements)), np.array(int2bits(idx_B, n_elements))
            element_indices_B = np.arange(n_elements)[1 == B]
            target = int(A[element_indices_B].all()) # if non-zero entries in B are also in A
            stimuli.append([[A.tolist(), B.tolist()], target])
    return stimuli


In [10]:
vdata = gen_batch(10, 4, task = 'si')
vdata

[[[[0, 1, 0, 0], [0, 0, 0, 1]], 0],
 [[[0, 1, 1, 0], [1, 0, 1, 0]], 0],
 [[[0, 1, 0, 1], [1, 0, 0, 0]], 0],
 [[[0, 1, 0, 1], [1, 0, 1, 0]], 0],
 [[[0, 0, 1, 0], [1, 1, 0, 1]], 0],
 [[[1, 0, 1, 1], [1, 0, 0, 1]], 1],
 [[[0, 0, 1, 0], [1, 1, 1, 1]], 0],
 [[[1, 1, 0, 1], [0, 1, 1, 1]], 0],
 [[[0, 1, 0, 0], [1, 0, 1, 0]], 0],
 [[[0, 0, 0, 0], [1, 1, 0, 0]], 0]]

torch.Size([10, 2, 4])

### Data Preparation (PyTorch)

In [28]:
def prepare_data(data, 
                 dtype = torch.float, device = 'cpu'):
    '''
    data : [[A, B], label] x batch_size

    returns three tensors of shapes [batch_size, n_elements], [batch_size, n_elements], [n_labels]
    '''
    x1 = []
    x2 = []
    y_t = []
    for [A, B], label in data:
        x1.append(A)
        x2.append(B)
        y_t.append(label)
    return  torch.tensor(x1, dtype = dtype, device = device), torch.tensor(x2, dtype = dtype, device = device), torch.tensor(y_t, dtype = dtype, device = device)

class AbstractTaskDataset(Dataset):
    def __init__(self, data):
        self.raw_data = data
        self.data, self.labels = torch.tensor([d for d,_ in data]), torch.tensor([l for _,l in data])
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
    def __len__(self):
        return len(self.data)
    
### Data Loader, Train_Test split and Shuffle Helpers

In [29]:
si_datast = AbstractTaskDataset(vdata)

In [34]:
t_loader = DataLoader(si_datast, batch_size = 2, shuffle = False)

In [39]:
d, l = next(iter(t_loader))
d.reshape(2, -1)

tensor([[0, 1, 0, 0, 0, 0, 0, 1],
        [0, 1, 1, 0, 1, 0, 1, 0]])

In [30]:
x, y = si_datast[0]

In [33]:
x

tensor([[0, 1, 0, 0],
        [0, 0, 0, 1]])

In [43]:
# example
x1, x2, y_t = prepare_data(gen_batch(batch_size = 10, n_elements = 5, task = 'si'))

----
## Model Architecture

### Base Backbones

In [None]:
# 1. MLP
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, out_dim, 
                 num_layers = 2, dp_rate = 0.1):
        super(MLP, self).__init__()
        # Hyperparameters
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.num_layers = num_layers
        
        c_in, c_out = input_dim, hidden_dim
        layers = []
        for lid in range(num_layers - 1):
            layers += [nn.Linear(c_in, c_out),
                       nn.LeakyReLU(inplace = True),
                       nn.Dropout(dp_rate)]
            c_in = hidden_dim
        layers += [nn.Linear(c_in, c_out)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

# 2. RNN
class GRU_RNN(nn.Module):
    def __init__(
        self, input_dim, hidden_dim, output_dim=None, latent_ic_var=0.05
    ):
        super(GRU_RNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.cell = GRUCell(input_dim, self.hidden_dim)
        self.readout = nn.Linear(self.hidden_dim, output_dim, bias=True)
        self.latent_ics = torch.nn.Parameter(
            torch.zeros(latent_size), requires_grad=True
        )
        self.latent_ic_var = latent_ic_var

    def init_hidden(self, batch_size):
        init_h = self.latent_ics.unsqueeze(0).expand(batch_size, -1)
        ic_noise = torch.randn_like(init_h) * self.latent_ic_var
        return init_h + ic_noise

    def forward(self, inputs, hidden):
        hidden = self.cell(inputs, hidden)
        output = self.readout(hidden)
        return output, hidden

# 3. Simple RNN
class SimpleRNN(nn.Module):
    '''
    RNN in its Simplest Form using nn.RNNCell
    Activation Choices: 'relu' or 'tanh'
    '''
    def __init__(self, input_dim, hidden_dim, output_dim = None, activation = 'relu'):
        super(SimpleRNN, self).__init__()
        # Hyperparams
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        self.cell = RNNCell(input_dim, hidden_dim, nonlinearity = activation)
        self.readout = nn.Linear(hidden_dim, output_dim)

    def forward(self, inputs, hidden = None):
        hidden = self.cell(inputs, hidden)
        output = self.readout(hidden)
        return output, hidden

----
### Hybrid Model

In [None]:
class MyModel(nn.Module):
    '''
    A Higher Level Model for Three Tasks that Administrates 
      - Trunnk Part to Generate Common Hidden Representations
      - Multiple Downstreaming Readout Heads

      ##### Requires Handling of Hidden States for RNN 
    '''
    def __init__(self, trunk_net, hidden_dim, out_dim = 1,
                num_layers = 2):
        super(MyModel, self).__init__()
        self.trunk = trunk_net
        self.head_mlp = MLP(trunk_net.output_dim, hidden_dim, out_dim, num_layers)
        self.head_rnn = SimpleRNN(trunk_net.output_dim, hidden_dim, out_dim) ### TBD

    def forward(self, x, hidden = None, task_type = 0):
        '''
        ### Task Type definition:
            - 0 : 
            - 1 :  other types
        '''
        x = self.trunk(x)
        if 0 != task_type:
            x = self.head_mlp(x)
            hidden = None
        else:
            x, hidden = self.head_rnn(x, hidden)
        return nn.functional.sigmoid(x), hidden 




In [None]:
### Example use 
#trunk_net = MLP(input_dim, hidden_dim, hidden_dim)
#mymodel = MyModel(trunk_net, hidden_dim, out_dim = 1)


# Example Training (could be wrapped into a helper function)
'''
# Generic Train Method
def train(model, dataloader, tasktype = 0, 
          loss_func = nn.BCELoss(), optimizer = torch.optim.Adam(lr = 0.01)):
    hidden = None
    for e in range(n_epochs):
        for x, y_t in dataloader:
            # Forward Pass
            y_pred = mymodel(data, hidden)
            loss = loss_func(y_pred, y_t)
            # Backward Pass
            optimizer.zero_grad() # Clear gradients
            loss.backward()
            optimizer.step()
    #return loss or acc lists
        

# Train with Task Requiring RNN
history = train(mymodel, train_data, tasktype = 0)

# Train with Task with Linear Readout
history = train(mymodel, train_data, tasktype = 1)
'''

----
### Scratch

In [6]:
import itertools
def get_test_data(n):
    """Generate test data for relational tasks.
    """
    test_x = np.zeros((n, n, 2*n))
    for i, j in itertools.product(range(n), range(n)):
        test_x[i,j,i] = 1 
        test_x[i,j,n+j] = 1 
    test_x = torch.from_numpy(test_x).float()
    test_x = test_x/torch.sqrt(torch.tensor(2))
    return test_x

def get_transitive_data(n):
    """Generate training data for TI task.
    """
    test_x = get_test_data(n)
    x = test_x[tuple(zip(*([(i, i+1) for i in range(n-1)] + [(i+1, i) for i in range(n-1)])))]
    y = torch.tensor([1.]*(n-1)+[-1.]*(n-1))
    return x, y

In [12]:
x, y = get_transitive_data(5)

In [13]:
x

tensor([[0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.7071],
        [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071,
         0.0000]])

In [10]:
y.shape

torch.Size([8])