# Explainability

In this experiment, we will try and use explainability techniques to understand what the model has learned.

In [25]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np

In [26]:
class CellPredictorNeuralNetwork(nn.Module):
    """Predicts the next state of the cells.

    Inputs:
        x: Tensor of shape (batch_size, channels, width+2, height+2), where channels=1. width and height are the dimensions of the entire game grid.
           We add one cell of padding on each side to ensure that predictions can be made for the boundary cells.
    
    Returns: Tensor of shape (batch_size, width, height), the logits of the predicted states.
    """

    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(1, 85, 3)
        self.conv1 = nn.Conv2d(85, 10, 1)
        self.conv2 = nn.Conv2d(10, 1, 1)

    def forward(self, x):
        x = F.relu(self.conv0(x))
        x = F.relu(self.conv1(x))
        logits = self.conv2(x)
        logits = torch.squeeze(logits, 1) # Remove channels dimension
        return logits

In [27]:
model = CellPredictorNeuralNetwork()
model.load_state_dict(torch.load("model_weights.pth"))

<All keys matched successfully>

In [28]:
# Generate activation for one feature and one example

feature = 3

X = torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=torch.float).reshape(1, 1, 3, 3)
with torch.no_grad():
    a = model.conv0(X)
a = a.squeeze()[feature]
print(a)

tensor(0.1964)


# Middle cell alive or dead

Here, we'll see how well the network has learned the concept of the middle cell being alive or dead.

In [29]:
class FixedMiddleCellDataset(torch.utils.data.Dataset):
    _size = 256

    def __init__(self, middle_cell):
        if middle_cell not in [0, 1]:
            raise ValueError("Middle cell value must be 0 or 1.")
        self._middle_cell = middle_cell

    def __len__(self):
        return self._size

    def __getitem__(self, idx):
        if idx < 0 or idx >= self._size:
            raise IndexError()
        # 12 -> "0b1100" -> "1100" -> "00001100"
        idx_bin = bin(idx)[2:].rjust(8, "0")
        idx_bin = idx_bin[:4] + str(self._middle_cell) + idx_bin[4:]
        X = torch.tensor([float(ch) for ch in idx_bin], dtype=torch.float32).reshape(1, 3, 3) # (channels, width, height)
        alive = X[0, 1, 1] > 0.5
        alive_neighbours = torch.sum(X) - X[0, 1, 1]
        next_alive = (alive and alive_neighbours > 1.5 and alive_neighbours < 3.5) or (not alive and alive_neighbours > 2.5 and alive_neighbours < 3.5)
        y = torch.tensor([float(next_alive)], dtype=torch.float32).reshape(1, 1) # (width, height)
        return X, y

In [30]:
pos_dataset = FixedMiddleCellDataset(1)
neg_dataset = FixedMiddleCellDataset(0)
pos_dataloader = DataLoader(pos_dataset)
neg_dataloader = DataLoader(neg_dataset)

In [31]:
# Run all data through the first layer of the neural network and see the activation for the given feature.

def get_pre_activations_for_feature(feature, dataloader):
    pre_activations = np.zeros(len(dataloader))
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            z = model.conv0(X)
            z = z.squeeze()[feature]
            pre_activations[batch] = z.item()
    return pre_activations

In [32]:
feature = 20

pos_pre_activations = get_pre_activations_for_feature(feature, pos_dataloader)
neg_pre_activations = get_pre_activations_for_feature(feature, neg_dataloader)

pos_pre_activations
print("Positive pre-activations:",
      f"mean {np.mean(pos_pre_activations):.3f},",
      f"std {np.std(pos_pre_activations):.3f},",
      f"min {np.min(pos_pre_activations):.3f}",
      f"max {np.max(pos_pre_activations):.3f}",
      f"median {np.median(pos_pre_activations):.3f}")
print(f"Negative pre-activations:",
      f"mean {np.mean(neg_pre_activations):.3f},",
      f"std {np.std(neg_pre_activations):.3f},",
      f"min {np.min(neg_pre_activations):.3f}",
      f"max {np.max(neg_pre_activations):.3f}",
      f"median {np.median(neg_pre_activations):.3f}")

# Since the dataset represents the full population instead of just a sample, we can assess whether the feature distinguishes between
# the two datasets by just checking if the ranges of pre-activations overlap.
ranges_overlap = (np.max(pos_pre_activations) > np.min(neg_pre_activations)) and (np.max(neg_pre_activations) > np.min(pos_pre_activations))
print(f"Ranges overlap: {ranges_overlap}")

# Another way we can assess the difference is by looking at the difference of the means
mean_diff = np.mean(pos_pre_activations) - np.mean(neg_pre_activations)
# This should really be normalized by the standard deviation
mean_diff_std = np.sqrt(np.var(pos_pre_activations) + np.var(neg_pre_activations))
mean_diff_normalized  = mean_diff / mean_diff_std
print(f"Normalized difference of means: {mean_diff_normalized:.5f}")

Positive pre-activations: mean 0.414, std 0.323, min -0.320 max 1.148 median 0.414
Negative pre-activations: mean 0.515, std 0.323, min -0.219 max 1.248 median 0.515
Ranges overlap: True
Normalized difference of means: -0.22030


In [33]:
# Now let's get pre-activations for all features in a given layer. Then we can see which features seem to be differentiating most between the two datasets.

def get_pre_activations(dataloader):
    pre_activations = np.zeros((len(dataloader), model.conv0.out_channels))
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            z = model.conv0(X) # (batch_size, channels, width, height) -> (batch_size, channels, width, height)
            z = z.squeeze() # (batch_size, channels, width, height) -> (channels,)
            pre_activations[batch, :] = z
    return pre_activations

pos_pre_activations = get_pre_activations(pos_dataloader) # (dataset_size, num_features)
neg_pre_activations = get_pre_activations(neg_dataloader)

## Test using range overlap

In [34]:
# For each feature, let's assess range overlap

def test_range_overlap(pos, neg):
    return (pos.max(axis=0) >= neg.min(axis=0)) & (neg.max(axis=0) >= pos.min(axis=0))


ranges_overlap = test_range_overlap(pos_pre_activations, neg_pre_activations)
ranges_overlap

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True])

All the ranges overlap! So maybe the concept isn't exemplified by a single feature in the conv0 layer.

## Test using normalized difference of means

In [35]:
def get_mean_diff_normalized(pos, neg):
    mean_diff = pos.mean(axis=0) - neg.mean(axis=0)
    mean_diff_std = np.sqrt(pos.var(axis=0) + neg.var(axis=0))
    mean_diff_normalized  = mean_diff / mean_diff_std
    return mean_diff_normalized


mean_diff_normalized = get_mean_diff_normalized(pos_pre_activations, neg_pre_activations)
mean_diff_normalized

array([ 0.35055501, -1.18749087,  0.39031749, -0.18178185,  0.53735601,
        0.28707581,  0.45817236, -0.3978403 , -0.83634738, -0.12782358,
        1.33446187, -0.68430029,  1.03569632, -0.14038728, -0.10905208,
       -0.81344565,  0.24017467,  0.69235158, -0.63207173,  0.39854949,
       -0.22029624,  0.67174854, -0.62178556,  0.61221606, -0.61281918,
       -0.50408767, -0.38026249,  0.13007052,  0.46328631,  0.14295915,
        0.94649088,  0.68014044, -0.00449297,  0.04731425,  0.26528664,
       -0.33205287,  0.01735685,  1.20233603, -0.28440147,  0.60779882,
       -0.31448142, -0.85091064,  0.61773081, -0.61133225,  0.98860413,
        0.68814836,  0.7716982 ,  0.33752691, -0.55906849,  0.36775185,
        0.14225164,  0.43671175, -0.06118962, -0.30881532,  0.05641451,
       -0.06657133,  0.00937756,  0.41167833, -0.04405811, -0.4941482 ,
        0.72850119,  1.14242464,  0.02609398, -0.42680432,  0.17257551,
        0.57126231,  0.00232297, -0.61229959, -0.66937132,  0.22

Assuming the activations are roughly normally distributed, we can run a two-tailed z-test on the normalized difference of means.
We use a z-test rather than a t-test because we know the full dataset, so the std and variance are actually the population std and variance.

For a two-tailed z-test at the 5% level, the critical value is 1.96. Let's see which normalized mean diffs are more extreme than +/- 1.96.

In [36]:
(mean_diff_normalized > 1.96) | (mean_diff_normalized < -1.96)

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False])

Welp, looks like the concept couldn't be extracted from any of the `conv0` features. Out of curiosity, what's the smallest significance level that would have made at least one test succeed?

In [37]:
print(f"Minimum z-value: {mean_diff_normalized.min():.3f}")
print(f"Maximum z-value: {mean_diff_normalized.max():.3f}")

Minimum z-value: -1.369
Maximum z-value: 1.334


The minimum z-value is -1.369 and the maximum is 1.334. Out of these, -1.369 is the most extreme and corresponds to a significance level of 17%. Not great.

## Post-activation values

Maybe the concept is more clearly represented in the post-activation values. Let's repeat the process for these.

In [38]:
def get_post_activations(dataloader):
    post_activations = np.zeros((len(dataloader), model.conv0.out_channels))
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            a = F.relu(model.conv0(X)) # (batch_size, channels, width, height) -> (batch_size, channels, width, height)
            a = a.squeeze() # (batch_size, channels, width, height) -> (channels,)
            post_activations[batch, :] = a
    return post_activations

pos_post_activations = get_post_activations(pos_dataloader) # (dataset_size, num_features)
neg_post_activations = get_post_activations(neg_dataloader)

In [39]:
ranges_overlap = test_range_overlap(pos_post_activations, neg_post_activations)
ranges_overlap

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True])

Still, all the ranges overlap. This is to be expected. Relu is an increasing function, so it preserves weak inequalities. Range overlap is defined in terms of (weak) inequalities between the endpoints of the range. If ranges $ [a, b] $ and $ [c, d] $ overlap, this means $ b \ge c $ and $ d \ge a $. If we let $ a' = relu(a) $ etc., then $ b' \ge c' $ and $ d' \ge a' $, so the mapped ranges overlap too.

In [40]:
mean_diff_normalized = get_mean_diff_normalized(pos_post_activations, neg_post_activations)
mean_diff_normalized

  mean_diff_normalized  = mean_diff / mean_diff_std


array([ 0.23683171, -1.14832957,  0.38838569, -0.17384463,  0.49229672,
        0.27404148,  0.21782071, -0.14285725, -0.82830463, -0.11904263,
        1.34079986, -0.59257538,  0.94553729, -0.13421474, -0.08376557,
       -0.68746021,  0.20199103,  0.6868629 , -0.39566968,  0.39681848,
       -0.21838428,  0.67088156, -0.16981416,  0.5653979 , -0.08868273,
       -0.49532028, -0.36933328,  0.0983119 ,  0.46026895,  0.1420756 ,
        0.94232762,  0.6800583 , -0.00465905,  0.03278738,  0.26288044,
       -0.23377481,  0.01588153,  1.20730288, -0.14076101,  0.49609867,
       -0.30384616, -0.76729917,  0.19509152, -0.37060805,  0.79466937,
        0.68615106,  0.7628384 ,  0.30309226, -0.39586215,  0.35455994,
        0.08984026,  0.39763355,         nan, -0.24108489,  0.04264589,
       -0.05963881,  0.00879571,  0.4085628 , -0.02514456, -0.43708171,
        0.5131398 ,  1.00241767,  0.02493692, -0.39736413,  0.16943566,
        0.56779377,  0.00224058, -0.33743136, -0.65096573,  0.21

In [41]:
print(f"Minimum z-value: {np.nanmin(mean_diff_normalized):.3f}")
print(f"Maximum z-value: {np.nanmax(mean_diff_normalized):.3f}")

Minimum z-value: -1.394
Maximum z-value: 1.341


These z-values are pretty uninspiring, just like before.

# Higher layer (`conv1`)

There are two possible explanations for why we couldn't identify the "middle cell alive" feature in the `conv0` layer. Firstly, the model could have learnt the feature implicitly in some combination of the 85 channels in the `conv0` layer. Secondly, the model may have learnt the concept in the `conv1` layer instead of `conv0`. Let's test the second hypothesis.

In [42]:
def get_pre_activations_conv1(dataloader):
    pre_activations = np.zeros((len(dataloader), model.conv1.out_channels))
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            z = F.relu(model.conv0(X)) # (batch_size, channels, width, height) -> (batch_size, channels, width, height)
            z = model.conv1(z)
            z = z.squeeze() # (batch_size, channels, width, height) -> (channels,)
            pre_activations[batch, :] = z
    return pre_activations

pos_pre_activations = get_pre_activations_conv1(pos_dataloader) # (dataset_size, num_features)
neg_pre_activations = get_pre_activations_conv1(neg_dataloader)

In [43]:
ranges_overlap = test_range_overlap(pos_pre_activations, neg_pre_activations)
ranges_overlap

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True])

In [44]:
mean_diff_normalized = get_mean_diff_normalized(pos_post_activations, neg_post_activations)
mean_diff_normalized

  mean_diff_normalized  = mean_diff / mean_diff_std


array([ 0.23683171, -1.14832957,  0.38838569, -0.17384463,  0.49229672,
        0.27404148,  0.21782071, -0.14285725, -0.82830463, -0.11904263,
        1.34079986, -0.59257538,  0.94553729, -0.13421474, -0.08376557,
       -0.68746021,  0.20199103,  0.6868629 , -0.39566968,  0.39681848,
       -0.21838428,  0.67088156, -0.16981416,  0.5653979 , -0.08868273,
       -0.49532028, -0.36933328,  0.0983119 ,  0.46026895,  0.1420756 ,
        0.94232762,  0.6800583 , -0.00465905,  0.03278738,  0.26288044,
       -0.23377481,  0.01588153,  1.20730288, -0.14076101,  0.49609867,
       -0.30384616, -0.76729917,  0.19509152, -0.37060805,  0.79466937,
        0.68615106,  0.7628384 ,  0.30309226, -0.39586215,  0.35455994,
        0.08984026,  0.39763355,         nan, -0.24108489,  0.04264589,
       -0.05963881,  0.00879571,  0.4085628 , -0.02514456, -0.43708171,
        0.5131398 ,  1.00241767,  0.02493692, -0.39736413,  0.16943566,
        0.56779377,  0.00224058, -0.33743136, -0.65096573,  0.21

In [45]:
print(f"Minimum z-value: {np.nanmin(mean_diff_normalized):.3f}")
print(f"Maximum z-value: {np.nanmax(mean_diff_normalized):.3f}")

Minimum z-value: -1.394
Maximum z-value: 1.341


Again, the ranges all overlap and the z-values are not extreme enough to conclude that any particular unit in the `conv1` layer reliably distinguishes between the positive and negative dataset. So, the `conv1` layer has also not explicitly learned the "middle unit is alive" concept.

# Conclusion

In conclusion, there was no simple way to associate any particular unit in the network with the concept of the middle cell being alive. Intuitively, it seems clear that this concept should be relevant to the final prediction, but it seems that the network has learned a a different way of the predicting a cell's next state which doesn't use this concept directly.

In later work, I could try and associate this concept to some linear combination of the activations of the `conv0` layer units. The idea would be to treat the 85 activations of the `conv0` units as independent variables and use linear regression to predict the current state of the middle cell. If this state can be reliably predicted, it means the concept exists in the network implicitly in some form. We could also measure how complex the encoding of this information is by looking at how many of the linear regression coefficients are significantly far from zero.