In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import seaborn as sns

### Hyperparameters

In [None]:
N_CLASSES = 10

N_OUT_CNN = 32
N_OUT_MLP = 32
N_HIDDEN_MLP = 32

MAX_STEPS = 15

BATCH_SIZE = 64

LAMBDA_P = 0.2
BETA = 0.05

LEARNING_RATE = 0.0003

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)

### Setup MNIST data

In [None]:
default_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

train_set, val_set = random_split(MNIST('./', download=True, train=True, transform=default_transform), [55000, 5000])
test_set = MNIST('./', download=True, train=False, transform=default_transform)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

### Loss

In [None]:
class RecLoss(nn.Module):
    # loss_func is the loss function for the original NN, cross-entropy in this case
    def __init__(self, loss_fn: nn.Module):
        super(RecLoss, self).__init__()
        self.loss_fn = loss_fn

    def forward(self, p, y_pred, y_true):
        total_loss = torch.stack([(self.loss_fn(y_pred_step, y_true) * p_step).mean() for y_pred_step, p_step in zip(y_pred,p)]).sum()
        return total_loss

def geometric_dist(steps, lam):
    return lam * torch.pow((1-lam), torch.arange(1, steps+1) - 1)

class RegLoss(nn.Module):
    def __init__(self, lambda_p, max_n_steps):
        super(RegLoss, self).__init__()
        self.max_n_steps = max_n_steps
        self.pg = geometric_dist(self.max_n_steps, lambda_p)
        self.kl_div = nn.KLDivLoss(reduction='batchmean')

    def forward(self, p):        
        pt = p.transpose(0,1)
        cut_pg = self.pg[:pt.shape[1]].unsqueeze(0)
        l = self.kl_div(pt.log(), cut_pg.expand_as(pt)) 
        return l

### Step

In [None]:
class PonderCNN(nn.Module):
    def __init__(self, mlp, cnn, mpl_out, classes, max_steps, batch_size, learning_rate, lam_p):
        super(PonderCNN, self).__init__()
        self.mlp_out, self.classes, self.mlp, self.cnn, self.max_steps = mpl_out, classes, mlp, cnn, max_steps
        self.batch_size, self.learning_rate, self.lam_p = batch_size, learning_rate, lam_p
        self._early_halt = False
        
        # lambda layer (probablity of halting)
        self.lam = nn.Linear(self.mlp_out, 1)

        # output layer
        self.out = nn.Linear(self.mlp_out, self.classes)

        # losses
        self.criterion = nn.CrossEntropyLoss()
        self.rec_loss = RecLoss(self.criterion)
        self.reg_loss = RegLoss(self.lam_p, self.max_steps)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def halt_if_possible(self):
        self._early_halt = True

    def never_halt(self):
        self._early_halt = False

    def forward(self, x):
        cnn_out = self.cnn(x)
        self.batch_size = cnn_out.shape[0]
        # send data through cnn
        h = torch.zeros((self.batch_size, self.mlp_out))
        
        h = self.mlp(torch.concat([h, cnn_out], dim=1))

        lam_ns = np.zeros((self.batch_size, self.max_steps))

        ps = np.zeros((self.max_steps, self.batch_size))

        ys = []

        halt_step = np.zeros(self.batch_size)

        for step_n in range(1, self.max_steps + 1):

            i = step_n - 1
            
            # compute probability of halting at step_n and save value
            # lam_n.shape = (batch_size,)
            lam_n = torch.sigmoid(self.lam(h)).flatten().detach().numpy() if step_n < self.max_steps else np.ones((self.batch_size))

            lam_ns[:,i] = lam_n
            pn = lam_ns[:,i] * np.prod((1 - lam_ns[:,:i]), axis=1)
            ps[i] = pn

            # here the size of h becomes (self.batch_size, self.classes), where self.classes is the number of classes (10)
            ys.append(self.out(h))

            # flip coin for each element of batch with probability lam_n (batch_size,)
            # ... * (should_halt == 0) makes sure that earlier step in which we deciced to halt is never replaced
            should_halt = (np.random.rand(self.batch_size) < lam_n) * (halt_step == 0) 
            
            halt_step[should_halt] = step_n

            cnn_out = self.cnn(x)
            h = self.mlp(torch.concat([h, cnn_out], dim=1))

            if self._early_halt and (not self.training) and halt_step.all():
                break

        return torch.tensor(ps, requires_grad=True),\
                torch.stack(ys),\
                torch.tensor(halt_step, dtype=torch.long)
                

### Step NN

In [None]:
class MLP(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(MLP, self).__init__()
        self.i2h = nn.Linear(n_input, n_hidden)
        self.h2o = nn.Linear(n_hidden, n_output)
        self.droput = nn.Dropout(0.2)

    def forward(self, x):
        x = F.relu(self.i2h(x))
        x = self.droput(x)
        x = F.relu(self.h2o(x))
        return x

class CNN(nn.Module):
    def __init__(self, n_input=28, n_output=50, kernel_size=5):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=kernel_size)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=kernel_size)
        self.conv2_drop = nn.Dropout2d()

        self.lin_size = int(np.floor((np.floor((n_input - (kernel_size - 1)) / 2) - (kernel_size - 1)) / 2))
        self.fc1 = nn.Linear(self.lin_size ** 2 * 20, n_output)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        return x

In [None]:
def accuracy(y_pred, halted_step, labels, ponder):
    corr_pred = 0
    for s in range(halted_step.shape[0]):
        if ponder:
            corr_pred += torch.argmax(y_pred[halted_step[s]-1][s]) == labels[s]
        else:
            corr_pred += torch.argmax(y_pred[MAX_STEPS-1][s]) == labels[s]
    return corr_pred / halted_step.shape[0]

### Train

In [None]:
mlp = MLP(N_OUT_CNN+N_OUT_MLP, N_HIDDEN_MLP, N_OUT_MLP).to(device)
cnn = CNN(n_input=28,n_output=N_OUT_CNN).to(device)

pnet = PonderCNN(mlp, cnn, N_OUT_MLP, N_CLASSES, MAX_STEPS, BATCH_SIZE, LEARNING_RATE, LAMBDA_P).to(device)

statistics = {'training_losses':[],
                # 'training_accuracies':[],
                'validation_losses':[],
                'validation_accuracies':[],
                'test_losses':[],
                'test_accuracies':[]}

best_loss = float('inf')
best_model_state = None

for epoch in range(2):   
    sum_losses = .0
    pnet.train()
    for i, (inputs, labels) in enumerate(train_loader, 1):
        pnet.optimizer.zero_grad()

        p, y_pred, halted_step = pnet(inputs) 


        rec = pnet.rec_loss(p, y_pred, labels)
        reg = pnet.reg_loss(p)
        loss = rec + BETA * reg  

        loss.backward()
        pnet.optimizer.step()

        statistics['training_losses'].append(loss.item())
        sum_losses += loss.item()
        print(f'Running Train Loss: {loss.item():.3f}', end='\r')
        if i % 100 == 0:
            print(f'Epoch: {epoch+1} - Step: {i} | Train Loss: {sum_losses / 100:.3f}')
            sum_losses = .0

    pnet.eval()
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(val_loader, 1):
            p, y_pred, halted_step = pnet(inputs)

            val_acc = accuracy(y_pred, halted_step, labels, True)
            statistics['validation_accuracies'].append(val_acc.item())

            rec = pnet.rec_loss(p, y_pred, labels)
            reg = pnet.reg_loss(p)
            val_loss = rec + BETA * reg

            statistics['validation_losses'].append(val_loss.item())
            if val_loss.item() < best_loss:
                best_loss = val_loss.item()
                best_model_state = pnet.state_dict()

    print(f'Epoch: {epoch+1} | Mean Val Loss: {np.array(statistics["validation_losses"]).mean():.3f} | Mean Val Acc: {100 * np.array(statistics["validation_accuracies"]).mean():.1f}%') 

pnet.eval()
pnet.halt_if_possible()
with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_loader, 1):
        p, y_pred, halted_step = pnet(inputs)  
        
        test_acc = accuracy(y_pred, halted_step, labels, True)
        statistics['test_accuracies'].append(test_acc.item())
        
        rec = pnet.rec_loss(p, y_pred, labels)
        reg = pnet.reg_loss(p)
        test_loss = rec + BETA * reg

        statistics['test_losses'].append(test_loss.item())

print(f'Mean Test Loss: {np.array(statistics["test_losses"]).mean():.3f} | Mean Test Acc: {100 * np.array(statistics["test_accuracies"]).mean():.1f}%\n')

if not os.path.exists('models_MNIST/'):  
    os.makedirs('models_MNIST/')
torch.save(best_model_state, f'models_MNIST/MNIST_best_model_{LAMBDA_P}.pth')


### Testing

In [None]:
# filename = ''
# test_data = []
# mlp = MLP(N_OUT_CNN+N_OUT_MLP, N_HIDDEN_MLP, N_OUT_MLP).to(device)
# cnn = CNN(n_input=28,n_output=N_OUT_CNN).to(device)
# model = PonderCNN(mlp, cnn, N_OUT_MLP, N_CLASSES, MAX_STEPS, BATCH_SIZE, LEARNING_RATE, lp).to(device)


# model.load_state_dict(torch.load(filename))
# model.eval()
# model.halt_if_possible()
# with torch.no_grad():
#     accuracies = []
#     hs_occurrences = np.zeros(MAX_STEPS)
#     for i, (inputs, labels) in enumerate(test_loader, 1):
#         p, y_pred, halted_step = model(inputs)   
        
#         where, many = np.unique(halted_step, return_counts=True)
#         hs_occurrences[where - 1] += many

#         test_acc = accuracy(y_pred, halted_step, labels, True)
#         accuracies.append(test_acc.item())

#     test_data.append(accuracies)

#     fig, ax = plt.subplots(1,2,figsize=(10, 5))
#     x = np.arange(1, MAX_STEPS+1)

#     y = np.array((geometric_dist(MAX_STEPS, lp)))
#     sns.barplot(x=x, y=y, ax=ax[0])
#     ax[0].set_title(f"Geometric Distribution for $\lambda_p$ = {lp}")
#     ax[0].set_ylim(0,1)

#     y = hs_occurrences/hs_occurrences.sum()
#     sns.barplot(x=x, y=y, ax=ax[1])
#     ax[1].set_title(f"Real halt step distribution for $\lambda_p$ = {lp}")
#     ax[1].set_ylim(0,1)
#     plt.tight_layout()
#     plt.show()
    

# test_data = np.array(test_data)