In [39]:
import numpy as np
import torch
from matplotlib import pyplot as plt

from sympy import Symbol, Eq, Implies, Equivalent
import sympy
from tqdm import tqdm

from functools import reduce, partial
from operator import mul
from collections import defaultdict
from typing import Optional, Mapping


## Plots

In [None]:
def plot_labels(make_labels_fn, main_target: str = 'y_0'):
    shape = (40, 40)
    x_grid = np.stack(
        np.meshgrid(
            *[
                np.linspace(0, 1, s)
                for s in shape
            ],
            indexing='ij',
        ),
        axis=-1
    ).reshape((reduce(mul, shape, 1), len(shape))).astype(np.float32)

    all_preds = make_labels_fn(x_grid)
    preds = all_preds[main_target]
    titles = {
        main_target: f'$y = {main_target}$',
    }
    # notation used in paper
    # titles.update({f'y_{t}': '$c^{(' + str(t) + ')}$' for t in range(1, 4)})
    titles.update({f'y_{t}': f'$y_{t}$' for t in range(1, 4)})

    cmap = plt.colormaps['rainbow'].resampled(3)
    fig, ax = plt.subplots(2, 2, figsize=(6, 6), layout='compressed')
    for i, k in enumerate(all_preds.keys()):
        a = ax.ravel()[i]
        im = a.pcolormesh(
            x_grid[:, 0].reshape(shape),
            x_grid[:, 1].reshape(shape),
            all_preds[k].reshape(shape),
            cmap=cmap,
            vmin=0.0, vmax=1.0,
            # levels=np.array([0, 1, 2]),  # np.linspace(0.0, 1.0, 11)
        )
        im.set_clim((0, 2))
        a.set_xticks([0, 0.25, 0.5, 0.75, 1.0])
        a.set_yticks([0, 0.25, 0.5, 0.75, 1.0])
        # a.axis('equal')
        a.set_aspect(1)
        a.set_title(titles[k])

    # plt.tight_layout()
    # fig.subplots_adjust(right=0.8)
    # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
    # fig.colorbar(im, cax=cbar_ax)
    bar = fig.colorbar(im, ax=ax.ravel().tolist(), shrink=0.5, ticks=np.array([0, 1, 2]))
    bar.ax.set_ylabel('class', rotation=90)

In [None]:
def plot_predictions(model, device: str = 'cpu', input_dtype=torch.float, shape: tuple = (40, 40)):
    x_grid = np.stack(
        np.meshgrid(
            *[
                np.linspace(0, 1, s)
                for s in shape
            ],
            indexing='ij',
        ),
        axis=-1
    ).reshape((reduce(mul, shape, 1), len(shape))).astype(np.float32)


    with torch.no_grad():
        all_preds = {
            k: v.cpu()
            for k, v in model(torch.tensor(x_grid).to(input_dtype).to(device)).items()
        }
        preds = all_preds['y_0'].numpy()
    titles = {
        'y_0': '$\Pr(y = 2)$',
    }
    titles.update({f'y_{t}': '$\Pr(c^{(' + str(t) + ')} = 2)$' for t in range(1, 4)})

    fig, ax = plt.subplots(2, 2, figsize=(6, 6), layout='compressed')
    for i, k in enumerate(all_preds.keys()):
        a = ax.ravel()[i]
        im = a.contourf(
            x_grid[:, 0].reshape(shape),
            x_grid[:, 1].reshape(shape),
            np.clip(all_preds[k].numpy()[:, 1].reshape(shape), 0.0, 1.0),
            cmap='plasma',
            vmin=0.0, vmax=1.0,
            levels=np.linspace(0.0, 1.0, 11)
        )
        im.set_clim((0.0, 1.0))
        a.set_xticks([0, 0.25, 0.5, 0.75, 1.0])
        a.set_yticks([0, 0.25, 0.5, 0.75, 1.0])
        # a.axis('equal')
        a.set_aspect(1)
        a.set_title(titles[k])

    bar = fig.colorbar(im, ax=ax.ravel().tolist(), extend='neighter', shrink=0.5)


## Training

In [None]:
class NamedTensorDataset(torch.utils.data.Dataset):
    def __init__(self, named_tensors: Mapping[str, torch.Tensor]):
        self.named_tensors = named_tensors

    def __getitem__(self, i) -> Mapping[str, torch.Tensor]:
        return {
            k: v[i]
            for k, v in self.named_tensors.items()
        }

    def __len__(self) -> int:
        lengths = [len(v) for v in self.named_tensors.values()]
        assert reduce(lambda a, b: b if a is None or a == b else -1, lengths, None) != -1, \
            f"All tensors must have the same length! Got: {lengths}"
        return lengths[0]

In [None]:
def reset_model_parameters(nn, seed: Optional[int] = None):
    if seed is not None:
        torch.manual_seed(seed)

    def _visitor(layer):
        if isinstance(layer, torch.nn.Linear):
            layer.reset_parameters()
        return

    nn.apply(_visitor)


In [None]:
def prepare_train_loader(dataset, device: str, batch_size: int):
    device_params = dict()
    if device != 'cpu':
        device_params = {'pin_memory': True, 'pin_memory_device': device}
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        **device_params
    )
    return loader


def train_on_dataset(model, dataset,
                     n_epochs: int = 100,
                     device: str = 'cpu',
                     optim_factory = None,
                     seed: int = 12345,
                     batch_size: int = 16,
                     input_key: str = 'X',
                     input_dtype=torch.float,
                     epoch_callback=None):
    """Train the model on the dataset.

    Args:
        model: Model.
        dataset: Train dataset, consisting of dictionaries (key -> data tensor).
        device: Device, 'cpu' or 'cuda:n', e.g. 'cuda:0'.
        optim_factory: Optimizator factory, getting model parameters and returning optimizator.

                For example, `lambda params: torch.optim.SGD(params, lr=1.e-2)`.

        seed: Random seed.
        batch_size: Batch size.
        input_key: Name of the input tensor in the dataset. By default is 'X'.

    """
    EPS = 1.e-12
    reset_model_parameters(model, seed=seed)
    model.to(device)

    if optim_factory is None:
        optim_factory = partial(torch.optim.AdamW, lr=1.e-4)
    optim = optim_factory(model.parameters())

    loader = prepare_train_loader(dataset, device, batch_size=batch_size)

    clf_loss_fn = torch.nn.NLLLoss()  # predictions are probabilities, not logits (!)
    history = defaultdict(list)

    model.train()
    for epoch in tqdm(range(n_epochs)):
        batch_losses = []
        for data in loader:
            data_device = {
                k: v.to(device)
                for k, v in data.items()
                if k != input_key
            }
            batch_x = data[input_key].to(input_dtype).to(device)
            concept_probas = model(batch_x)

            loss = 0.0
            for concept_name, preds in concept_probas.items():
                if concept_name in data:
                    loss = loss + clf_loss_fn(
                        torch.log(torch.clamp_min(preds, EPS)),
                        data_device[concept_name]
                    )
            batch_losses.append(loss.item())

            optim.zero_grad()
            loss.backward()
            optim.step()
        history['loss_train'].append(np.mean(batch_losses))
        if epoch_callback is not None:
            epoch_callback(model)
    model.eval()

    return history
