It has been observed in prior work that tasks sharing some underlying structure will exhibit overlapping neural activity. For instance, Yang et. al. (2019) trained RNNs on several cognitive tasks and observed clustering in the neural activity space: some clusters specialized to particular tasks, while others were shared between tasks. In theory, it is also possible to observe a completely distributed representation (i.e. no modular clusters). While Yang focused on sensory tasks, we aim to study tasks involving abstract relations: transitive inference and divisibility. These tasks are likely to have some common underlying structure, as both represent transitive relations. We will compare the neural geometry of the same RNN trained on one of these tasks at a time to that trained on both (using interleaving).

Questions we hope to answer: How will the neural representation of a given task change when more than one task is learned simultaneously? In the latter case, will we find that the activations shared between the two tasks are also present in some form when only one task is learned at a time? That is—does a neural network organize its activity differently when related tasks must be learned together? We will use RDM analysis and dimensionality reduction techniques to look for specialized clusters in neural activity space. We will then compare our networks using RSA/RDA, as well as dynamics-based methods such as DSA and fixed/slow point analysis.


### Imports

In [1]:
# General
import numpy as np
#import pandas as pd
#from scipy.stats import zscore
import random
#from statistics import mean

# Deep learning
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# # Response visualizations
# !pip install umap-learn
# import umap
#import matplotlib as mpl
#from matplotlib import pyplot as plt
#import seaborn as sns

# Set random seeds for reproducibility
np.random.seed(12)
torch.manual_seed(12)

<torch._C.Generator at 0x10abbbad0>

### Tasks Example

In [None]:
# Transitive Inference: (i.e. index 0 is greater than index 1 = "A > B")
# Example stimulus: [[1, 0, 0],[0, 1, 0]]
# Example output: 0

# Subset Inclusion: (i.e. index 0 is greater than index 1 = "A \contains B")
# Example stimulus: [[1, 1, 1],[0, 1, 0]]
# Example output: 0 = "A \contains B"

# Divisibility:
# Example stimulus: [6,3]
# Example output: 0 = "A is divisble by B"

In [None]:
[1,0,0] < [1,0,0]

False

----
### Task Design and Data Generation

In [2]:
# @title Task design
def int2bits(n, width):
    '''
    Convert an Integer to Its Binary Representation as a List/Array/Tensor
    e.g. n = 3, width = 4 -> [0, 0, 1, 1]
    '''
    return [int(b) for b in bin(n)[2:].zfill(width)] # bin() returns a str with prefix '0b'
    


def gen_batch(batch_size, n_elements = 3, 
              task = 'ti'):
    '''
    Generate a Batch of Stimuli (list) for the Desired Task
    
    Inputs:
    - batch_size   : (int) num of stimulus
    - n_elements   : (int) Total number of elements on the base set
    - task         : (str)
                     'ti' - Transitive Inference; without Reflexitivity (relation with itself)
                     'si' - Subset Inclusion; 
                     'div'- Divisibility; without 0

    Returns: [[A, B], label] x batch_size       ## of shape [batch_size, 2, n_elements]
    '''
    assert task in ['ti',
                    'si',
                    'div'], f'Requested Task [{task}] Not Supported!'
    stimuli = []
    if 'ti' == task.lower() or 'div' == task.lower():
        max_b_size = int(n_elements)
        # Transitive Inference
        if 'ti' == task.lower():
            # All possible instances
            stimulus_dict = np.eye(n_elements, dtype = int)
            np.random.shuffle(stimulus_dict)
            # Randomly sample two instances in the dict as a stimulus
            for pair_id in range(batch_size):
                dict_indices = random.sample(range(max_b_size), k = 2) # without replacement
                # target = which instance is greater
                target = 0 if dict_indices[0] > dict_indices[1] else 1
                stimuli.append([stimulus_dict[dict_indices].tolist(), target])
        else: # Divisibility
            for pair_id in range(batch_size):
                stimulus_dict = random.sample(range(1, max_b_size + 1), k = 2) # excluding 0
                # target = if the previous is divisible by the latter
                target = int(0 == stimulus_dict[0] % stimulus_dict[1])
                stimuli.append([stimulus_dict, target])
    # Subset Inclusion
    elif 'si' == task.lower():
        max_b_size = 2 ** int(n_elements)
        # randomly sample two instances
        for pair_id in range(batch_size):
            idx_A, idx_B = random.sample(range(max_b_size), k = 2)
            # target if the previous is a superset of the latter
            A, B = np.array(int2bits(idx_A, width = n_elements)), np.array(int2bits(idx_B, n_elements))
            element_indices_B = np.arange(n_elements)[1 == B]
            target = int(A[element_indices_B].all()) # if non-zero entries in B are also in A
            stimuli.append([[A.tolist(), B.tolist()], target])
    return stimuli


In [3]:
vdata = gen_batch(12, 4, task = 'si')
vdata

[[[[1, 0, 1, 0], [0, 0, 1, 1]], 0],
 [[[0, 0, 1, 0], [1, 1, 1, 1]], 0],
 [[[0, 1, 1, 0], [0, 1, 0, 0]], 1],
 [[[1, 1, 0, 1], [0, 0, 1, 1]], 0],
 [[[1, 1, 0, 1], [0, 0, 0, 1]], 1],
 [[[1, 0, 0, 1], [0, 0, 0, 0]], 1],
 [[[1, 1, 0, 1], [1, 1, 1, 0]], 0],
 [[[0, 1, 1, 0], [1, 1, 1, 0]], 0],
 [[[0, 0, 0, 1], [0, 0, 1, 0]], 0],
 [[[1, 0, 0, 0], [0, 1, 1, 1]], 0],
 [[[0, 0, 1, 0], [1, 1, 1, 0]], 0],
 [[[1, 0, 1, 1], [1, 0, 1, 0]], 1]]

torch.Size([10, 2, 4])

### Data Preparation (PyTorch)

In [4]:
class AbstractTaskDataset(Dataset):
    def __init__(self, data):
        self.raw_data = data
        #self.data, self.labels = torch.tensor([d for d,_ in data]), torch.tensor([l for _,l in data])
        self.data, self.labels = [d for d,_ in data],[l for _,l in data]
    def __getitem__(self, idx):
        #return torch.tensor(self.raw_data[idx][0]), torch.tensor(self.raw_data[idx][1])
        return torch.tensor(self.data[idx], dtype = torch.float), torch.tensor(self.labels[idx], dtype = torch.float)
    def __len__(self):
        return len(self.raw_data)
    
### Data Loader, Train_Test split and Shuffle Helpers

In [5]:
si_datast = AbstractTaskDataset(vdata)

In [6]:
train_loader = DataLoader(si_datast, batch_size = 4, shuffle = False)

In [7]:
d, l = next(iter(train_loader))
d

tensor([[[1., 0., 1., 0.],
         [0., 0., 1., 1.]],

        [[0., 0., 1., 0.],
         [1., 1., 1., 1.]],

        [[0., 1., 1., 0.],
         [0., 1., 0., 0.]],

        [[1., 1., 0., 1.],
         [0., 0., 1., 1.]]])

In [30]:
x, y = si_datast[0]

In [53]:
len(train_loader)

3

----
## Model Architecture

### Base Backbones

In [8]:
# 1. MLP
class MLP(nn.Module):
    '''
    MLP with Leaky ReLU as Intermediate Activations
    input_dim  :
    hidden_dim :
    output_dim :
    num_layers :
    dp_rate    : dropout rate
    '''
    def __init__(self, input_dim, hidden_dim, output_dim, 
                 num_layers = 2, dp_rate = 0.1):
        super(MLP, self).__init__()
        # Hyperparameters
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        
        c_in, c_out = input_dim, hidden_dim
        layers = []
        for lid in range(num_layers - 1):
            layers += [nn.Linear(c_in, c_out),
                       nn.LeakyReLU(inplace = True),
                       nn.Dropout(dp_rate)]
            c_in = hidden_dim
        layers += [nn.Linear(c_in, output_dim)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


# 2. Simple RNN
class SimpleRNN(nn.Module):
    '''
    RNN in its Simplest Form using nn.RNN
    Activation Choices: 'relu' or 'tanh'

        - input_dim  : input feat dim
        - hidden_dim : 
        - output_dim : 
        - num_layers : number of RNN units
    '''
    def __init__(self, input_dim, hidden_dim, output_dim = None, 
                 num_layers = 1, activation = 'relu'):
        super(SimpleRNN, self).__init__()
        # Hyperparams
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers,
                          batch_first = True, nonlinearity = activation)
        self.readout = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, h = None):
        '''
        Shape
        Inputs:
            - x      : [batch_size, seq_len, input_dim]
            - h      : [num_layers, batch_size, hidden_dim]
        ----
        Intermediate:
            - out    : [batch_size, seq_len, hidden_dim]
            - hidden : [num_layers, batch_size, hidden_dim]
        ----
        Outputs:
            - output : [batch_size, output_dim]
            - hidden : [num_layers, batch_size, hidden_dim]
        '''
        out, hidden = self.rnn(x, h)
        # Return pred of each seq in the batch
        output = self.readout(out[:, -1,:]) 
        return output, hidden

----
### Hybrid Model

In [12]:
class MyModel(nn.Module):
    '''
    A Higher Level Model for Three Tasks that Administrates 
      - Trunnk Part to Generate Common Hidden Representations
      - Multiple Downstreaming Readout Heads

      ##### Trunk Net's output dimension needs to be an even number for RNN to work
      ##### Requires Handling of Hidden States for RNN 
      #####
    '''
    def __init__(self, trunk_net, hidden_dim, output_dim = 1,
                num_layers = 2, seq_len = 2):
        '''
        seq_len : the number of items in the flattened input seq
        '''
        super(MyModel, self).__init__()
        # Parameters
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.seq_len = seq_len
        
        self.trunk = trunk_net
        if trunk_net.output_dim % seq_len:
            print(f"The output dim of 'trunk_net' is better to be divisible by the length of an input sequence!!\nCurrent trunk_net.output_dim = {self.trunk_net.output_dim} and seq_len = {self.seq_len}")
            
        self.head_mlp = MLP(trunk_net.output_dim, hidden_dim, output_dim = output_dim, num_layers = num_layers)
        self.head_rnn = SimpleRNN(trunk_net.output_dim // self.seq_len, hidden_dim, output_dim, num_layers = num_layers) ###

    def forward(self, x, hidden, task_type = 0):
        '''
        ### Task Type definition:
            - 0 : 
            - 1 :  other types

            - x : [batch_size, sequence_len * feat_dim = 2 * ]
            - hidden : [num_layers, batch_size, hidden_dim]
        '''
        out_trunk = self.trunk(x) # [batch_size, trunk_net.output_dim]
        batch_size = x.shape[0]
        if 0 == task_type:
            # Reshape the RNN's input into [batch_size, seq_len, trunk_net.output_dim / seq_len = -1]
            out, hidden = self.head_rnn(out_trunk.reshape((batch_size, self.seq_len, -1)), hidden)
        else:
            out = self.head_mlp(out_trunk)
        return nn.functional.sigmoid(out).squeeze(), hidden 


In [13]:
# Generic Train Method ### Does not handle the devices for now
def test(model, dataloader, task_type = 0, loss_func = nn.BCELoss()):
    # Eval Mode
    model.eval()
    hidden = None
    total_acc = 0.0
    total_loss = 0.0
    with torch.no_grad():
        for x, y_t in dataloader:
            b_size = x.shape[0]
            # Bemindful of shape changes
            y_p = model(x.reshape(b_size, -1), hidden)
            '''
            # Forward Pass of RNN only
            y_p, _ = model(x, hidden)
            y_p = F.sigmoid(y_p.squeeze())
            '''
            # Calculate Loss
            loss = loss_func(y_p, y_t).clone().detach()
            # Calculate Acc
            preds = y_p > 0.5
            acc = (preds == y_t).sum().item() / len(y_t)

            # Record statistics
            total_loss += loss
            total_acc += acc
    # Back to Train Mode
    model.train()
    return total_loss / len(dataloader), total_acc / len(dataloader)


def train(model, train_dataloader, test_dataloader = None, n_epochs = 20, 
          tasktype = 0, lr = 0.01,
          loss_func = nn.BCELoss()):
    loss_lst = []
    acc_lst = []
    test_acc_lst = []
    
    # Using Adam by default
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    model.train()
    for e in range(n_epochs + 1):
        for x, y_t in train_dataloader:
            ## !! Re_init hidden state !!
            hidden = None
            optimizer.zero_grad() # Clear gradients
            # Forward Pass
            y_pred, updated_hidden = mymodel(x.reshape(x.shape[0], -1), hidden, tasktype)
            loss = loss_func(y_pred, y_t)
            
            # Backward Pass
            loss.backward(retain_graph=True) ### Is the same graph accessed multiple times over different backward pass?
            #loss.backward()
            optimizer.step()

            # Record Statistics
            loss_lst.append(loss.clone().detach())
            
            preds = y_pred > 0.5
            acc = (preds == y_t).sum().item() / len(y_t)
            acc_lst.append(acc)
            hidden = updated_hidden
        if 0 == (e) % 10:
            t_loss = 0.0
            t_acc = 0.0
            if test_dataloader is not None:
                t_loss, t_acc = test(model, test_dataloader, task_type, loss_func)
            print(f'Epoch [{e}/{n_epochs}]:\t--LastBatchLoss:{loss:.3f}, TrainAcc:{acc:.3f}, TestLoss:{t_loss:.3f}, TestAcc:{t_acc:.3f}')

    return loss_lst, acc_lst

In [14]:
### Example use 
# Debug
torch.autograd.set_detect_anomaly(True)

input_dim = 4
hidden_dim = 8
output_dim = 1
trunk_net = MLP(input_dim * 2, hidden_dim, hidden_dim)
mymodel = MyModel(trunk_net, hidden_dim, output_dim = output_dim)

# Train with Task Requiring RNN
l_lst, a_lst = train(mymodel, train_loader, tasktype = 0, n_epochs = 20, lr = 0.01)

# Train with Task with Linear Readout
#history = train(mymodel, train_data, tasktype = 1)

Epoch [0/20]:	--LastBatchLoss:0.644, TrainAcc:0.750, TestLoss:0.000, TestAcc:0.000
Epoch [10/20]:	--LastBatchLoss:0.523, TrainAcc:0.750, TestLoss:0.000, TestAcc:0.000
Epoch [20/20]:	--LastBatchLoss:0.028, TrainAcc:1.000, TestLoss:0.000, TestAcc:0.000


1

----
### Scratch

In [None]:
# @title Task design
def int2bits(n, width):
    '''
    Convert an Integer to Its Binary Representation as a List/Array/Tensor
    e.g. n = 3, width = 4 -> [0, 0, 1, 1]
    '''
    return [int(b) for b in bin(n)[2:].zfill(width)] # bin() returns a str with prefix '0b'
    


def gen_batch(batch_size, n_elements = 3, 
              task = 'ti'):
    '''
    Generate a Batch of Stimuli (list) for the Desired Task
    
    Inputs:
    - batch_size   : (int) num of stimulus
    - n_elements   : (int) Total number of elements on the base set
    - task         : (str)
                     'ti' - Transitive Inference; without Reflexitivity (relation with itself)
                     'si' - Subset Inclusion; 
                     'div'- Divisibility; without 0

    Returns: [[A, B], label] x batch_size       ## of shape [batch_size, 2, n_elements]
    '''
    assert task in ['ti',
                    'si',
                    'div'], f'Requested Task [{task}] Not Supported!'
    stimuli = []
    if 'ti' == task.lower() or 'div' == task.lower():
        max_b_size = int(n_elements)
        # Transitive Inference
        if 'ti' == task.lower():
            # All possible instances
            stimulus_dict = np.eye(n_elements, dtype = int)
            np.random.shuffle(stimulus_dict)
            # Randomly sample two instances in the dict as a stimulus
            for pair_id in range(batch_size):
                dict_indices = random.sample(range(max_b_size), k = 2) # without replacement
                # target = which instance is greater
                target = 0 if dict_indices[0] > dict_indices[1] else 1
                stimuli.append([stimulus_dict[dict_indices].tolist(), target])
        else: # Divisibility
            for pair_id in range(batch_size):
                stimulus_dict = random.sample(range(1, max_b_size + 1), k = 2) # excluding 0
                # target = if the previous is divisible by the latter
                target = int(0 == stimulus_dict[0] % stimulus_dict[1])
                stimuli.append([stimulus_dict, target])
    # Subset Inclusion
    elif 'si' == task.lower():
        max_b_size = 2 ** int(n_elements)
        # randomly sample two instances
        for pair_id in range(batch_size):
            idx_A, idx_B = random.sample(range(max_b_size), k = 2)
            # target if the previous is a superset of the latter
            A, B = np.array(int2bits(idx_A, width = n_elements)), np.array(int2bits(idx_B, n_elements))
            element_indices_B = np.arange(n_elements)[1 == B]
            target = int(A[element_indices_B].all()) # if non-zero entries in B are also in A
            stimuli.append([[A.tolist(), B.tolist()], target])
    return stimuli

In [6]:
import itertools
def get_test_data(n):
    """Generate test data for relational tasks.
    """
    test_x = np.zeros((n, n, 2*n))
    for i, j in itertools.product(range(n), range(n)):
        test_x[i,j,i] = 1 
        test_x[i,j,n+j] = 1 
    test_x = torch.from_numpy(test_x).float()
    test_x = test_x/torch.sqrt(torch.tensor(2))
    return test_x

def get_transitive_data(n):
    """Generate training data for TI task.
    """
    test_x = get_test_data(n)
    x = test_x[tuple(zip(*([(i, i+1) for i in range(n-1)] + [(i+1, i) for i in range(n-1)])))]
    y = torch.tensor([1.]*(n-1)+[-1.]*(n-1))
    return x, y

In [12]:
x, y = get_transitive_data(5)

In [13]:
x

tensor([[0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7071,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.7071],
        [0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.7071, 0.0000, 0.0000, 0.0000, 0.7071,
         0.0000]])

In [10]:
y.shape

torch.Size([8])