In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

### Hyperparameters

In [None]:
N_HIDDEN = 64

MAX_STEPS = 20
VECTOR_LEN = 8

BATCH_SIZE = 128

LAMBDA_P = 0.2
BETA = 0.01

LEARNING_RATE = 0.0003

EPOCHS = 50

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)

### Setup data

The parity task input comprises a vector consisting of $0s$, $1s$, and $-1s$.
The output indicates the parity of $1s$ present; $1$ if there is an odd quantity and $0$ otherwise. The input is created by randomly assigning a random number of elements $[1,len)$ in a zeros vector as either $1$ or $-1$.

In [None]:
def get_data2(rows):
    my_x = torch.randint(-1,2,(rows, VECTOR_LEN), dtype=torch.float)
    my_y = ((my_x == 1.).sum(axis=1, dtype=torch.float) % 2).unsqueeze(1)

    num_batches = rows // BATCH_SIZE
    last_batch_size = rows % BATCH_SIZE

    # Reshape the dataset into batches
    batches_x = torch.tensor_split(my_x[:num_batches * BATCH_SIZE], num_batches)

    batches_y = torch.tensor_split(my_y[:num_batches * BATCH_SIZE], num_batches)

    if last_batch_size > 0:
        batches_x = batches_x + tuple([my_x[num_batches * BATCH_SIZE:]])
        batches_y = batches_y + tuple([my_y[num_batches * BATCH_SIZE:]])
    return batches_x, batches_y

In [None]:
train_loader = get_data2(1_000_000)
val_loader = get_data2(500_000)
test_loader = get_data2(1_000_000)

### Loss

In [None]:
class RecLoss(nn.Module):
    # loss_func is the loss function for the original NN, cross-entropy in this case
    def __init__(self, loss_fn: nn.Module):
        super(RecLoss, self).__init__()
        self.loss_fn = loss_fn

    def forward(self, p, y_pred, y_true):
        y_pred = y_pred.squeeze(2)
        y_true = y_true.squeeze(1)
        
        total_loss = torch.stack([(self.loss_fn(y_pred_step, y_true) * p_step).mean() for y_pred_step, p_step in zip(y_pred,p)]).sum()
        return total_loss

def geometric_dist(steps, lam):
    return lam * torch.pow((1-lam), torch.arange(1, steps+1) - 1)

class RegLoss(nn.Module):
    def __init__(self, lambda_p, max_n_steps):
        super(RegLoss, self).__init__()
        self.max_n_steps = max_n_steps
        self.pg = geometric_dist(self.max_n_steps, lambda_p)
        self.kl_div = nn.KLDivLoss(reduction='batchmean')

    def forward(self, p):
        pt = p.transpose(0,1)
        cut_pg = self.pg[:pt.shape[1]].unsqueeze(0)

        l = self.kl_div(pt.log(), cut_pg.expand_as(pt))
        return l

### Step

In [None]:
class PonderParity(nn.Module):
    def __init__(self, gru_cell, vector_len, max_steps, batch_size, lam_p, n_hidden, lr, device):
        super(PonderParity, self).__init__()
        self.gru_cell, self.input_vector_len, self.max_steps = gru_cell, vector_len, max_steps
        self.batch_size, self.lam_p, self.n_hidden, self.learning_rate = batch_size, lam_p, n_hidden, lr
        self._early_halt, self.device = False, device

        # lambda layer (probablity of halting)
        self.lam = nn.Linear(self.n_hidden, 1)

        # output layer
        self.out = nn.Linear(self.n_hidden, 1)

        # losses
        self.criterion = nn.BCEWithLogitsLoss(reduction='none').to(self.device)
        self.rec_loss = RecLoss(self.criterion).to(self.device)
        self.reg_loss = RegLoss(self.lam_p, self.max_steps).to(self.device)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

        nn.utils.clip_grad_norm_(self.parameters(), 1.0)

    def halt_if_possible(self):
        self._early_halt = True

    def never_halt(self):
        self._early_halt = False

    def _dump_info(self):
        print(f'''Training PonderParity with: 
            VectorLen: {self.input_vector_len}
            MaxSteps: {self.max_steps}
            LambdaP: {self.lam_p}
            Hidden: {self.n_hidden}
            LearningRate: {self.learning_rate}\n''')
        

    def forward(self, x):

        self.batch_size = x.shape[0]
        h = torch.zeros((self.batch_size, self.n_hidden))

        h = self.gru_cell(x, h)

        lam_ns = np.zeros((self.batch_size, self.max_steps))

        ps = np.zeros((self.max_steps, self.batch_size))

        ys = []

        halt_step = np.zeros(self.batch_size)

        for step_n in range(1, self.max_steps + 1):

            i = step_n - 1

            # compute probability of halting at step_n and save value
            # lam_n.shape == (batch_size,)
            lam_n = torch.sigmoid(self.lam(h)).cpu().detach().numpy().flatten() if step_n < self.max_steps else np.ones((self.batch_size))

            lam_ns[:,i] = lam_n
            pn = lam_ns[:,i] * np.prod((1 - lam_ns[:,:i]), axis=1)
            ps[i] = pn

            # here the size of h becomes (self.batch_size, self.classes), where self.classes is the number of classes (10)
            out = self.out(h)
            ys.append(out)

            # flip coin for each element of batch with probability lam_n (batch_size,)
            # ... * (should_halt == 0) makes sure that earlier steps in which we deciced to halt is never replaced
            should_halt = (np.random.rand(self.batch_size) < lam_n) * (halt_step == 0)

            halt_step[should_halt] = step_n

            h = self.gru_cell(x,h)

            if self._early_halt and (not self.training) and halt_step.all():
                break

        return torch.tensor(ps, requires_grad=True),\
                torch.stack(ys),\
                torch.tensor(halt_step, dtype=torch.long)


In [None]:
def accuracy(y_pred, halted_step, labels, ponder):
    corr_pred = 0
    for s in range(halted_step.shape[0]):
        if ponder:
            corr_pred += ((y_pred[halted_step[s]-1][s] > 0.0).to(torch.float) == labels[s]).sum()
        else:
            corr_pred += ((y_pred[MAX_STEPS-1][s] > 0.0).to(torch.float) == labels[s]).sum()
    return corr_pred / halted_step.shape[0]

### Train

In [None]:
gru_cell = nn.GRUCell(VECTOR_LEN, N_HIDDEN).to(device)
pnet = PonderParity(gru_cell, VECTOR_LEN, MAX_STEPS, BATCH_SIZE, LAMBDA_P, N_HIDDEN, LEARNING_RATE, device).to(device)
pnet._dump_info()

statistics = {'training_losses':[],
                'validation_losses':[],
                'validation_accuracies':[],
                'separate_losses': []} # (rec, reg)

best_loss = float('inf')
best_model_state = None

for epoch in range(EPOCHS):  # loop over the dataset multiple times

    sum_losses = .0
    pnet.train()
    for i, (inputs, labels) in enumerate(zip(train_loader[0],train_loader[1]), 1):
        pnet.optimizer.zero_grad()
        inputs = inputs.to(device)
        labels = labels.to(device)

        p, y_pred, halted_step = pnet(inputs) 

        rec = pnet.rec_loss(p, y_pred, labels)
        reg = pnet.reg_loss(p)
        loss = rec + BETA * reg

        loss.backward()

        pnet.optimizer.step()

        statistics['training_losses'].append(loss.item())
        statistics['separate_losses'].append((rec.item(), BETA * reg.item()))
        sum_losses += loss.item()
        print(f'Loss: {loss.item():.3f} | Rec: {rec.item():.3f} | Reg: {BETA * reg.item():.3f}', end='\r')
        if i % 500 == 0:
            print(f'Epoch: {epoch+1} - Step: {i} | Train Loss: {sum_losses / 500:.3f}')
            sum_losses = .0

    pnet.eval()
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(zip(val_loader[0],val_loader[1]), 1):
            inputs = inputs.to(device)
            labels = labels.to(device)

            p, y_pred, halted_step = pnet(inputs)

            val_acc = accuracy(y_pred, halted_step, labels, True)
            statistics['validation_accuracies'].append(val_acc.item())

            rec = pnet.rec_loss(p, y_pred, labels)
            reg = pnet.reg_loss(p)
            
            val_loss = rec + BETA * reg

            statistics['validation_losses'].append(val_loss.item())
            print(f'Running Validation Loss: {val_loss.item():.3f}', end='\r')
            if val_loss.item() < best_loss:
                best_loss = val_loss.item()
                best_model_state = pnet.state_dict() 

    print(f'Epoch: {epoch+1} | Mean Val Loss: {np.array(statistics["validation_losses"]).mean():.3f} | Mean Val Acc: {100 * np.array(statistics["validation_accuracies"]).mean():.1f}%') 
    statistics["validation_accuracies"].clear()
    statistics["validation_losses"].clear()

if not os.path.exists('models_PARITY/'):
    os.makedirs('models_PARITY/')
torch.save(best_model_state, f'models_PARITY/parity_model_{LAMBDA_P}.pth')


In [None]:
# filename = ''
# test_data = []
# gru_cell2 = nn.GRUCell(VECTOR_LEN, N_HIDDEN).to(device)
# model = PonderParity(gru_cell2, VECTOR_LEN, MAX_STEPS, BATCH_SIZE, LAMBDA_P, N_HIDDEN, LEARNING_RATE, device).to(device)

# statistics = {'test_accuracies':[]}

# model.load_state_dict(torch.load(filename))
# model.eval() 
# model.halt_if_possible()

# with torch.no_grad():
#     accuracies = []
#     hs_occurrences = np.zeros(MAX_STEPS)
#     for i, (inputs, labels) in enumerate(zip(test_loader[0],test_loader[1]), 1):
#         p, y_pred, halted_step = model(inputs)
#         where, many = np.unique(halted_step, return_counts=True)
#         hs_occurrences[where - 1] += many
        
#         test_acc = accuracy(y_pred, halted_step, labels, True)
#         accuracies.append(test_acc.item())
#         statistics['test_accuracies'].append(test_acc.item())
        
#     test_data.append(accuracies)
#     fig, ax = plt.subplots(1,2,figsize=(10, 5))
#     x = np.arange(1, MAX_STEPS+1)

#     y = np.array((geometric_dist(MAX_STEPS, lp)))
#     sns.barplot(x=x, y=y, ax=ax[0])
#     ax[0].set_title(f"Geometric Distribution for $\lambda_p$ = {lp}")
#     ax[0].set_ylim(0,1)

#     y = hs_occurrences/hs_occurrences.sum()
#     sns.barplot(x=x, y=y, ax=ax[1])
#     ax[1].set_title(f"Real halt step distribution for $\lambda_p$ = {lp}")
#     ax[1].set_ylim(0,1)
#     plt.tight_layout()
#     plt.show()

# print(f'Mean Test Acc: {100 * np.array(statistics["test_accuracies"]).mean():.1f}%\n')

# test_data = np.array(test_data)