In [None]:
import copy

import ltn
import numpy as np
import torch
from dataset import DataLoader, get_mnist_dataset_for_digits_addition
from logic import Stable_AND
from models import LogitsToPredicate, SingleDigitClassifier
from train import train_logic
import seaborn as sns
import torch.functional as F

from scipy.stats import entropy

## Imports for plotting
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib inline
import seaborn as sns
sns.set()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
train_set, test_set = get_mnist_dataset_for_digits_addition()
train_loader = DataLoader(train_set, 32, shuffle=True)
test_loader = DataLoader(test_set, 32, shuffle=False)

In [None]:
And = ltn.Connective(Stable_AND())
# we use relaxed aggregators: see paper for details
Exists = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e")
Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f")

In [None]:
def plot_dists(val_dict, color="C0", xlabel=None, stat="frequency", use_kde=True):
    columns = len(val_dict)
    fig, ax = plt.subplots(1, columns, figsize=(columns*3, 2.5))
    fig_index = 0
    for key in val_dict.keys():
        key_ax = ax[fig_index%columns]
        sns.histplot(val_dict[key], ax=key_ax, color=color, bins=50, stat=stat,
                     kde=use_kde and ((val_dict[key].max()-val_dict[key].min())>1e-8)) # Only plot kde if there is variance
        key_ax.set_title(f"{key} " + (r"(%i $\to$ %i)" % (val_dict[key].shape[1], val_dict[key].shape[0]) if len(val_dict[key].shape)>1 else ""))
        if xlabel is not None:
            key_ax.set_xlabel(xlabel)
        fig_index += 1
    fig.subplots_adjust(wspace=0.4)
    return fig


def visualize_weight_distribution(model, color="C0"):
    weights = {}
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            continue
        if "batch_norm" in name or "bn" in name:
            continue
        
        key_name = f"{' '.join(name.split('.')[2:])}"
        weights[key_name] = param.detach().view(-1).cpu().numpy()

    ## Plotting
    fig = plot_dists(weights, color=color, xlabel="Weight vals")
    fig.suptitle("Weight distribution", fontsize=14, y=1.05)
    plt.show()
    plt.close()


def visualize_gradients(model, color="C0", print_variance=False):
    """
    Inputs:
        net - Object of class BaseNetwork
        color - Color in which we want to visualize the histogram (for easier separation of activation functions)
    """
    model.eval()
    small_loader = train_loader
    operand_images, sum_label, _ = next(iter(small_loader))
    operand_images, sum_label = operand_images.to(device), sum_label.to(device)
    images_x = ltn.Variable("x", operand_images[:, 0])
    images_y = ltn.Variable("y", operand_images[:, 1])
    labels_z = ltn.Variable("z", sum_label)
    d_1 = ltn.Variable("d_1", torch.tensor(range(10)))
    d_2 = ltn.Variable("d_2", torch.tensor(range(10)))

    sat_agg = Forall(
        ltn.diag(images_x, images_y, labels_z),
        Exists(
            vars=[d_1, d_2],
            formula=And(model(images_x, d_1), model(images_y, d_2)),
            cond_vars=[d_1, d_2, labels_z],
            cond_fn=lambda d1, d2, z: torch.eq(d1.value + d2.value, z.value),
        ),
    ).value

    model.zero_grad()

    loss = 1.0 - sat_agg
    loss.backward()
    grads = {}
    for name, params in model.named_parameters():
        if "weight" in name and "batch_norm" not in name and "bn" not in name:
            key_name = f"{' '.join(name.split('.')[2:])}"
            grads[key_name] = params.grad.view(-1).cpu().clone().numpy()
    model.zero_grad()

    ## Plotting
    fig = plot_dists(grads, color=color, xlabel="Grad magnitude")
    fig.suptitle("Gradient distribution", fontsize=14, y=1.05)
    plt.show()
    plt.close()

    if print_variance:
        for key in grads.keys():
            print(f"{key} - Variance: {np.var(grads[key])}")

def visualize_activations(model, color="C0", print_variance=False):
    model.eval()
    small_loader = train_loader
    operand_images, sum_label, _ = next(iter(small_loader))
    operand_images, sum_label = operand_images.to(device), sum_label.to(device)

    operand_images = operand_images.flatten(start_dim=0, end_dim=1)

    # Pass one batch through the network, and calculate the gradients for the weights
    x = operand_images
    activations = {}
    with torch.no_grad():
        i = 0

        for conv in model.model.logits_model.mnistconv.conv_layers:
            x = model.model.logits_model.mnistconv.relu(conv(x))
            activations[f"conv_layer_{i}"] = x.view(-1).detach().cpu().numpy()
            x = model.model.logits_model.mnistconv.maxpool(x)
            i += 1

        x = torch.flatten(x, start_dim=1)
        for i in range(len(model.model.logits_model.mnistconv.linear_layers)):
            x = model.model.logits_model.mnistconv.tanh(model.model.logits_model.mnistconv.batch_norm_layers[i](model.model.logits_model.mnistconv.linear_layers[i](x)))
            activations[f"linear_layer_{i}"] = x.view(-1).detach().cpu().numpy()
            i += 1

        for i in range(len(model.model.logits_model.linear_layers) - 1):
            x = model.model.logits_model.tanh(model.model.logits_model.batch_norm_layers[i](model.model.logits_model.linear_layers[i](x)))
            activations[f"linear_layer_{i + 1}"] = x.view(-1).detach().cpu().numpy()
            i += 1
            
        x = model.model.logits_model.linear_layers[-1](x)
        activations["logits"] = x.view(-1).detach().cpu().numpy()

    

    ## Plotting
    fig = plot_dists(activations, color=color, stat="density", xlabel="Activation vals")
    fig.suptitle("Activation distribution", fontsize=14, y=1.05)
    plt.show()
    plt.close()

    if print_variance:
        for key in activations.keys():
            print(f"{key} - Variance: {np.var(activations[key])}")

In [None]:
i = 0

converging_models = []
non_converging_models = []

while True:
    if torch.cuda.is_available():
        torch.cuda.manual_seed(i)
        torch.cuda.manual_seed_all(i)

    torch.manual_seed(i)
    np.random.seed(i)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    train_set, test_set = get_mnist_dataset_for_digits_addition()

    # create train and test loader
    train_loader = DataLoader(train_set, 32, shuffle=True)
    test_loader = DataLoader(test_set, 32, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    And = ltn.Connective(Stable_AND())
    # we use relaxed aggregators: see paper for details
    Exists = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e")
    Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f")

    cnn_s_d = SingleDigitClassifier().to(device)
    Digit_s_d = ltn.Predicate(LogitsToPredicate(cnn_s_d)).to(device)

    model_copy = copy.deepcopy(Digit_s_d)

    optimizer = torch.optim.Adam(Digit_s_d.parameters(), lr=0.001)
    metrics_prl, model = train_logic(
        Digit_s_d,
        optimizer,
        train_loader,
        test_loader,
        And,
        Exists,
        Forall,
        n_epochs=1,
        verbose=True,
    )

    if metrics_prl['test_accuracy_sum'][-1] > 0.5:
        print(f"Model {i} converged with test accuracy: {metrics_prl['test_accuracy_sum'][-1]}")
        if len(converging_models) < 10:
            converging_models.append(model_copy)
    else:
        print(f"Model {i} did not converge with test accuracy: {metrics_prl['test_accuracy_sum'][-1]}")
        non_converging_models.append(model_copy)

    if len(converging_models) >= 10 and len(non_converging_models) >= 10:
        break

    i += 1

    



In [None]:
def plot_multiple_dists(val_dicts, color="C0", xlabel=None, stat="probability", use_kde=True):
    columns = len(val_dicts[0])
    fig, ax = plt.subplots(1, columns, figsize=(columns*3, 2.5))
    fig_index = 0
    for key in val_dicts[0].keys():
        for i, val_dict in enumerate(val_dicts):
            key_ax = ax[fig_index%columns]
            sns.kdeplot(val_dict[key], ax=key_ax)
        fig_index += 1
    fig.subplots_adjust(wspace=0.4)
    return fig

def visualize_multiple_weight_distributions(models, color="C0"):
    items = []
    for model in models:
        weights = {}
        for name, param in model.named_parameters():
            if name.endswith(".bias"):
                continue
            if "batch_norm" in name or "bn" in name:
                continue
            
            key_name = f"{' '.join(name.split('.')[2:])}"
            weights[key_name] = param.detach().view(-1).cpu().numpy()
        items.append(weights)
        

    ## Plotting
    fig = plot_multiple_dists(items, color=color, xlabel="Weight vals", )
    fig.suptitle("Weight distribution", fontsize=14, y=1.05)
    plt.show()
    plt.close()

    return items

def visualize_multiple_gradients(models, color="C0", print_variance=False):
    items = []
    for model in models:
        model.eval()
        small_loader = train_loader
        operand_images, sum_label, _ = next(iter(small_loader))
        operand_images, sum_label = operand_images.to(device), sum_label.to(device)
        images_x = ltn.Variable("x", operand_images[:, 0])
        images_y = ltn.Variable("y", operand_images[:, 1])
        labels_z = ltn.Variable("z", sum_label)
        d_1 = ltn.Variable("d_1", torch.tensor(range(10)))
        d_2 = ltn.Variable("d_2", torch.tensor(range(10)))

        sat_agg = Forall(
            ltn.diag(images_x, images_y, labels_z),
            Exists(
                vars=[d_1, d_2],
                formula=And(model(images_x, d_1), model(images_y, d_2)),
                cond_vars=[d_1, d_2, labels_z],
                cond_fn=lambda d1, d2, z: torch.eq(d1.value + d2.value, z.value),
            ),
        ).value

        model.zero_grad()

        loss = 1.0 - sat_agg
        loss.backward()
        grads = {}
        for name, params in model.named_parameters():
            if "weight" in name and "batch_norm" not in name and "bn" not in name:
                key_name = f"{' '.join(name.split('.')[2:])}"
                grads[key_name] = params.grad.view(-1).cpu().clone().numpy()
        model.zero_grad()
        
        items.append(grads)
    
    ## Plotting
    fig = plot_multiple_dists(items, color=color, xlabel="Grad magnitude")
    fig.suptitle("Gradient distribution", fontsize=14, y=1.05)
    plt.show()
    plt.close()

    return items

def visualize_multiple_activations(models, color="C0", print_variance=False):
    items = []
    for model in models:
        model.eval()
        small_loader = train_loader
        operand_images, sum_label, _ = next(iter(small_loader))
        operand_images, sum_label = operand_images.to(device), sum_label.to(device)

        operand_images = operand_images.flatten(start_dim=0, end_dim=1)

        # Pass one batch through the network, and calculate the gradients for the weights
        x = operand_images
        activations = {}
        with torch.no_grad():
            i = 0

            for conv in model.model.logits_model.mnistconv.conv_layers:
                x = model.model.logits_model.mnistconv.relu(conv(x))
                activations[f"conv_layer_{i}"] = x.view(-1).detach().cpu().numpy()
                x = model.model.logits_model.mnistconv.maxpool(x)
                i += 1

            x = torch.flatten(x, start_dim=1)
            for i in range(len(model.model.logits_model.mnistconv.linear_layers)):
                x = model.model.logits_model.mnistconv.tanh(model.model.logits_model.mnistconv.batch_norm_layers[i](model.model.logits_model.mnistconv.linear_layers[i](x)))
                activations[f"linear_layer_{i}"] = x.view(-1).detach().cpu().numpy()
                i += 1

            for i in range(len(model.model.logits_model.linear_layers) - 1):
                x = model.model.logits_model.tanh(model.model.logits_model.batch_norm_layers[i](model.model.logits_model.linear_layers[i](x)))
                activations[f"linear_layer_{i + 1}"] = x.view(-1).detach().cpu().numpy()
                i += 1
                
            x = model.model.logits_model.linear_layers[-1](x)
            activations["logits"] = x.view(-1).detach().cpu().numpy()

        items.append(activations)

    ## Plotting
    fig = plot_multiple_dists(items, color=color, stat="probability", xlabel="Activation vals")
    fig.suptitle("Activation distribution", fontsize=14, y=1.05)
    plt.show()
    plt.close()

    return items

In [None]:
# samples_1 is a list of dictionaries, where each dictionary contains the weights of a model
# goal is to flip it around, so that we have a dictionary of lists, where each list contains the weights of all models for a specific layer
def flip_samples(samples):
    flipped_samples = {}
    for sample in samples:
        for key in sample.keys():
            if key not in flipped_samples:
                flipped_samples[key] = []
            flipped_samples[key].append(sample[key])
    return flipped_samples

In [None]:
# Function to compute the KL divergence matrix
def compute_kl_matrix(group_A, group_B, bins='auto', epsilon=1e-10):
    # Combine all samples to get common bin edges
    all_samples = np.concatenate(group_A + group_B)
    bin_edges = np.histogram_bin_edges(all_samples, bins=bins)
    
    nA, nB = len(group_A), len(group_B)
    D = np.zeros((nA, nB))
    
    for i, a in enumerate(group_A):
        hist_a, _ = np.histogram(a, bins=bin_edges)
        hist_a = hist_a + epsilon
        p = hist_a / np.sum(hist_a)
        
        for j, b in enumerate(group_B):
            hist_b, _ = np.histogram(b, bins=bin_edges)
            hist_b = hist_b + epsilon
            q = hist_b / np.sum(hist_b)
            
            D[i, j] = entropy(p, q)
    
    return D


In [None]:
# function for kl divergence matrix where we only have one sample per group
def compute_kl_matrix_single(group_A, group_B, bins='auto', epsilon=1e-10):
    # Combine all samples to get common bin edges
    all_samples = np.concatenate([group_A, group_B])
    bin_edges = np.histogram_bin_edges(all_samples, bins=bins)
    
    hist_a, _ = np.histogram(group_A, bins=bin_edges)
    hist_b, _ = np.histogram(group_B, bins=bin_edges)
    
    hist_a = hist_a + epsilon
    hist_b = hist_b + epsilon
    
    p = hist_a / np.sum(hist_a)
    q = hist_b / np.sum(hist_b)
    
    return entropy(p, q)

In [None]:
def plot_matrix(D, title="KL Divergence Matrix", xlabel="Group B", ylabel="Group A"):
    plt.figure(figsize=(8, 6))
    im = plt.imshow(D, aspect='auto', cmap='viridis', vmin=0, vmax=9.5)
    plt.colorbar(im, label='KL Divergence')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.xticks(range(10), [f'B{j+1}' for j in range(10)])
    plt.yticks(range(10), [f'A{i+1}' for i in range(10)])
    plt.show()

In [None]:
# calculate average of each layer
def average_samples(samples):
    averaged_samples = {}
    for key in samples.keys():
        averaged_samples[key] = np.mean(samples[key], axis=0)
    return averaged_samples

In [None]:
samples_1 = visualize_multiple_weight_distributions(converging_models, color="C0")
samples_2 = visualize_multiple_weight_distributions(non_converging_models, color="C1")

# Flip the samples to have a dictionary of lists
samples_1 = flip_samples(samples_1)
samples_2 = flip_samples(samples_2)

# average the samples
samples_1_avg = average_samples(samples_1)
samples_2_avg = average_samples(samples_2)

In [None]:
average_kl_divergences = []
average_kl_divergences_inverse = []

for layer in samples_1.keys():
    average = []
    inverse_average = []
    # make all possible pairs of samples in the layer
    all_pairs = [(samples_1[layer][i], samples_1[layer][j]) for i in range(len(samples_1[layer])) for j in range(i, len(samples_1[layer]))]
    for i, j in all_pairs:
        kl_divergence = compute_kl_matrix_single(i, j)
        kl_divergence_inverse = compute_kl_matrix_single(j, i)
        average.append(kl_divergence)
        inverse_average.append(kl_divergence_inverse)
    average_kl_divergences.append(sum(average) / (len(average)))
    average_kl_divergences_inverse.append(sum(inverse_average) / (len(inverse_average)))

In [None]:
for i, key in enumerate(samples_1_avg.keys()):
    group_A = samples_1_avg[key]
    group_B = samples_2_avg[key]

    kl_divergence = compute_kl_matrix_single(group_A, group_B, bins='auto', epsilon=1e-10)
    kl_divergence_inverse = compute_kl_matrix_single(group_B, group_A, bins='auto', epsilon=1e-10)
    print(f"Average inner group KL divergence is: {average_kl_divergences[i]:.4f}, inverse is: {average_kl_divergences_inverse[i]:.4f}, KL divergence between converging_models and non-converging_models: {kl_divergence:.4f}, inverse is: {kl_divergence_inverse:.4f}")


In [None]:
for key in samples_1_avg.keys():
    group_A = samples_1_avg[key]
    group_B = samples_2_avg[key]

    kl_divergence = compute_kl_matrix_single(group_A, group_B, bins='auto', epsilon=1e-10)
    kl_divergence_inverse = compute_kl_matrix_single(group_B, group_A, bins='auto', epsilon=1e-10)
    print(f"KL Divergence for {key}: {kl_divergence:.4f}, Inverse: {kl_divergence_inverse:.4f}")