# Part #2: How to train a TART reasoning module

In [1]:
%load_ext autoreload
%reload_ext autoreload

In [2]:
%autoreload 2

In [3]:
import os
import sys


sys.path.append(f'{os.path.dirname(os.path.dirname(os.getcwd()))}')
import warnings

from reasoning_module.samplers import get_data_sampler
from reasoning_module.tasks import get_task_sampler
from reasoning_module.models import TransformerModel   

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module='tqdm')

# Tutorial Overview
The goal of this notebook is to familiarize users with the TART reasoning module training process. The notebook will be structured into two parts: (1) training data exploration and (2) model architecture and training protocol exploration.

Below we provide a brief description on the TART reasoning module parameters.

Data sampling parameters:

* n_dims (int): dimension of input data
* n_positions (int): Total number of (x,y) pairs to sample for each sequence.
* batch_size (int): size of train batches
* weight_multiplier (int): represents noise level of problem, parameterized by ($\alpha$)
* variable_noise (bool): whether to randomly sample $\alpha$ on a per batch level. If set to true, $\alpha$ is uniformly sampled from $[1, \alpha]$.

Reasoning module parameters:
* n_positions (int): Total # of (x,y) pairs in the input sequence: max_seq_length = 258 * 2
* n_layer(int): # number of transformer layers
* n_head (int): # number of attention heads
* n_embd (int):  # hidden dimension of the reasoning module


For the purposes of this exploration, we will use the following parameters.


In [4]:
### RUN THIS CELL ###

# Data sampling parameters
n_dims = 16 # number of dimensions of the input data
n_positions=258 # total number of (x,y) pairs
batch_size = 32 # batch size
data_type = "gaussian" # input data distribution
training_task_type = "probabilistic_logistic_regression" # task type
weight_multiplier = 5 # weight multiplier for the task
variable_noise=False # whether to use variable noise for the task

In [5]:
### RUN THIS CELL ###

# Reasoning module parameters
n_positions = 258 # Total # of (x,y) pairs in the input sequence: max_seq_length = 258 * 2
n_layer=12 # number of transformer layers
n_head=8 # number of attention heads
n_embd=256 # hidden dimension of the reasoning module


### Part 1: Training data overview
Recall that the TART reasoning head is trained on sequences ($s_t$) of $(x,y)$ pairs sample from $d$-dimensional logstic regression functions. We will begin by inspecting one training sequence.


In [6]:
# instantiate task and data samplers
data_sampler = get_data_sampler(data_type, n_dims=n_dims)
task_sampler = get_task_sampler(
            training_task_type,
            n_dims,
            batch_size,
            weight_multiplier=weight_multiplier,
            variable_noise=variable_noise,
            n_points=n_positions, 
    )
task = task_sampler()

Let's sample the x and y's of an arbitrary sequence, $s_t$, with 258 $(x,y)$ pairs

In [7]:
# sample xs and ys
xs, _ = data_sampler.sample_xs(n_positions, batch_size, n_dims)
ys, _ = task.evaluate(xs)

In [8]:
# print shapes of xs and ys
print(f"Shape of tensor of X's: {xs.shape}")
print(f"Shape of tensor of Y's: {ys.shape}")

Shape of tensor of X's: torch.Size([32, 258, 16])
Shape of tensor of Y's: torch.Size([32, 258])


Notice that dim 0 = batch_size, dim 1 = n_positions (or sequence length), and dim 2 is n_dims (or hidden dimension size). Y's will be converted to a one-hot vector of dim 16 before being passed to the model 

### Part 2: Training protocol 
In this section, we will review the reasoning module architecture and training protocol. In this excercise, we will be using the transformer architecture for the reasoning module (in subsquent notebooks we will show how different architectures can be used --- i.e., [Hyena](https://arxiv.org/abs/2302.10866)).

Let us begin by instantiating out reasoning module.


In [9]:
# instantiate reasoning module
reasoning_module = TransformerModel(
    n_dims=n_dims,
    n_positions=n_positions,
    n_embd=n_embd,
    n_layer=n_layer,
    n_head=n_head,
)

We now construct our input sequence given our xs and ys above. Recall that ys still doesn't have a 3rd dimension and needs to be converted to a one hot! Moreover, we need to interleave our x's and y's into a single sequence --- i.e., $x_0$, $y_0$, $x_1$, $y_1$ ... $x_{258}$, $y_{258}$.

In [10]:
# construct input sequence
input_sequence = reasoning_module._combine(xs, ys)
print(f"Shape of input sequence: {input_sequence.shape}")

Shape of input sequence: torch.Size([32, 516, 16])


As expected, our input sequence has dimensions (batch_size, 2 * n_positions, n_dims). We will now take a single step with our model.

In [11]:
# perform single forward pass
output = reasoning_module._step(input_sequence)

Let us now compute the loss over the output. The reasoning module is trained by computing the loss over the predicted $y_i$'s in sequence $s_t$. A binary cross entropy loss is used.

In [16]:
# compute loss!
loss_func = task.get_training_metric()
loss = loss_func(output, ys)
print(f"Loss: {loss}")

Loss: 0.7249126434326172


Wohoo! You have performed one training step for a TART reasoning module! In a full training run, we would perform $n$ such steps. It is important to note that for each step $t$ in training, we sample a *different* logistic regression problem to construct our $s_t$.

To conduct a full training run, we refer readers to the following file for a sample 
 `src/reasoning_module/conf/tart_heads/reasoning_head_s258.yaml` to see a sample configuration file.

Given such a configuration file, training can be performed using:

```
python src/reasoning_module/train.py --config src/reasoning_module/conf/tart_heads/reasoning_head_s258.yaml
```