In [27]:
import torch
import torch.nn as nn

from torch.nn.functional import one_hot

In [28]:
class StraightThroughWrapper(nn.Module):
    """
    Wrapper around any PyTorch module `transform` that takes a single tensor as input
    to allow gradients passing through a discretization step
    defined by `discretizer`.
    
    The idea is to apply `transform` to the discretized and non discretized
    input. Then, we zero out the non discretized tensor by substracting the
    same tensor but detached from the computation graph.
    Thereby, gradients are backpropagated to the non discretized tensor
    while the output looks as if coming only from the discretized tensor.
    """
    def __init__(self, transform: nn.Module, discretizer: nn.Module):
        super(StraightThroughWrapper, self).__init__()
        self.transform = transform
        self.discretizer = discretizer
    
    def forward(self, x):
        x_discrete = self.discretizer(x)
        
        standard_transform = self.transform(x)
        discrete_transform = self.transform(x_discrete)
        
        return standard_transform + discrete_transform - standard_transform.detach()

In [29]:
class BinaryDiscretizer(nn.Module):
    """Binarises tensors with values in [0, 1] by rounding"""
    def forward(self, x):
        residuals = torch.where(x > 0.5, (x - 1.).detach(), x.detach())
        return x - residuals

In [30]:
class MultiDiscretizer(nn.Module):
    """
    Maps every value in a tensor to one of a given set of possible values.
    Every value is mapped to the closest of possible values.
    """
    def __init__(self, num_values: int):
        super(MultiDiscretizer, self).__init__()
        self.num_values = num_values
        
        self.offsets = nn.Parameter(torch.randn(self.num_values))
        self.alpha = nn.Parameter(torch.zeros(self.num_values))
    
    def forward(self, x):
        val_distribution = x.unsqueeze(-1).expand((*x.shape, self.num_values))
        val_distribution = -torch.pow(val_distribution - self.offsets, 2)
        val_distribution = torch.exp(self.alpha) * val_distribution
        val_distribution = torch.softmax(val_distribution, dim=-1)
        
        values_expanded = self.offsets.expand(val_distribution.shape)
        values = (values_expanded * val_distribution).sum(dim=-1)
        
        val_distribution_discrete = one_hot(torch.argmax(val_distribution, dim=-1), self.num_values).bool()
        discrete_residuals = torch.where(val_distribution_discrete, val_distribution-1, val_distribution)
        val_distribution_discrete = val_distribution - discrete_residuals
        values_discrete = (values_expanded * val_distribution_discrete).sum(dim=-1)
        
        return values + values_discrete - values.detach()

## Load Data

We evaluate models on MNIST data

In [5]:
import numpy as np

from sklearn.datasets import fetch_openml
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

In [6]:
x, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=10000
)

scaler = StandardScaler()
scaler.fit(x_train)
x_train = scaler.transform(x_train)
x_test = scaler.transform(x_test)

y_train = np.array([int(label) for label in y_train.tolist()], dtype=np.int32)
y_test = np.array([int(label) for label in y_test.tolist()], dtype=np.int32)

## Train Model

In [7]:
from tqdm.auto import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

In [21]:
num_labels = len(set(y_train.tolist()))
hidden_size = 1024

def train(model: nn.Module, epochs: int = 10):
    model = model.cuda()
    optimizer = AdamW(model.parameters())
    criterion = nn.CrossEntropyLoss()
    
    dataset = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long())
    dataloader = DataLoader(dataset, shuffle=True, batch_size=32)
    
    step = 0
    pbar = tqdm(desc="Progress", total=epochs * len(dataloader))

    for epoch in range(epochs):
        for x_batch, y_batch in dataloader:
            optimizer.zero_grad()
            loss = criterion(model(x_batch.cuda()), y_batch.cuda())
            loss.backward()
            optimizer.step()
        
            step += 1
            pbar.update(1)
        
            if step % 100 == 0:
                pbar.set_postfix_str(f"Loss: {loss.detach().cpu().item():.4f}")
        
    pbar.close()
    return model

In [22]:
def inference(model: nn.Module):
    model.eval()
    
    test_dataset = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test).long())
    test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=32)

    y_pred = []

    with torch.no_grad():
        for x_batch, _ in tqdm(test_dataloader):
            x_batch = x_batch.cuda()
            y_pred_batch_scores = model(x_batch)
            y_pred_batch = torch.argmax(y_pred_batch_scores, axis=1).cpu().tolist()
            y_pred.extend(y_pred_batch)
    
    return y_pred

### Test Non-Discrete Model

In [23]:
non_discrete_model = nn.Sequential(
    nn.Linear(x_train.shape[1], hidden_size),
    nn.Sigmoid(),
    nn.Linear(hidden_size, hidden_size),
    nn.Sigmoid(),
    nn.Linear(hidden_size, num_labels),
)

non_discrete_model = train(non_discrete_model, epochs=5)
y_pred = inference(non_discrete_model)

print(classification_report(y_true=y_test.tolist(), y_pred=y_pred, zero_division=0.0))

Progress:   0%|          | 0/9375 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.98      0.98      0.98       983
           1       0.99      0.99      0.99      1148
           2       0.96      0.98      0.97      1020
           3       0.98      0.95      0.97      1044
           4       0.97      0.97      0.97       983
           5       0.95      0.96      0.96       873
           6       0.98      0.98      0.98       955
           7       0.94      0.99      0.96      1006
           8       0.97      0.96      0.97      1006
           9       0.98      0.94      0.96       982

    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000



### Test Binary Discrete Model

In [24]:
binary_discretizer = BinaryDiscretizer()
binary_discrete_model = nn.Sequential(
    nn.Linear(x_train.shape[1], hidden_size),
    nn.Sigmoid(),
    StraightThroughWrapper(nn.Linear(hidden_size, hidden_size), binary_discretizer),
    nn.Sigmoid(),
    StraightThroughWrapper(nn.Linear(hidden_size, num_labels), binary_discretizer),
)

binary_discrete_model = train(binary_discrete_model, epochs=10)
y_pred = inference(binary_discrete_model)

print(classification_report(y_true=y_test.tolist(), y_pred=y_pred, zero_division=0.0))

Progress:   0%|          | 0/18750 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.96      0.99      0.98       983
           1       0.99      0.98      0.99      1148
           2       0.96      0.98      0.97      1020
           3       0.97      0.97      0.97      1044
           4       0.98      0.97      0.97       983
           5       0.95      0.97      0.96       873
           6       0.98      0.99      0.98       955
           7       0.97      0.98      0.97      1006
           8       0.96      0.95      0.95      1006
           9       0.97      0.94      0.96       982

    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000



### Test Multi-Discrete Model

In [31]:
multi_discrete_model = nn.Sequential(
    nn.Linear(x_train.shape[1], hidden_size),
    nn.ReLU(),
    StraightThroughWrapper(nn.Linear(hidden_size, hidden_size), MultiDiscretizer(num_values=10)),
    nn.ReLU(),
    StraightThroughWrapper(nn.Linear(hidden_size, num_labels), MultiDiscretizer(num_values=10)),
)

multi_discrete_model = train(multi_discrete_model, epochs=25)
y_pred = inference(multi_discrete_model)

print(classification_report(y_true=y_test.tolist(), y_pred=y_pred, zero_division=0.0))

Progress:   0%|          | 0/46875 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.99      0.98      0.99       983
           1       0.99      0.99      0.99      1148
           2       0.96      0.97      0.97      1020
           3       0.98      0.96      0.97      1044
           4       0.98      0.97      0.97       983
           5       0.95      0.97      0.96       873
           6       0.97      0.99      0.98       955
           7       0.95      0.98      0.96      1006
           8       0.98      0.96      0.97      1006
           9       0.97      0.95      0.96       982

    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000

