# Two-Moons CALAS

In this notebook, we will make the first ever attempt to train a <u>**CALAS**</u> model.

In a CALAS model, we split the forward pass into two steps.

In the first step, we do this:

* Forward our nominal data through the chosen representation and embed it in a suitable space.
* Then, we estimate the entropy of the embedded data under the *current* model.

In the second step:

* Permute the nominal embedded data in such a way that its entropy becomes *deliberately* too small or too large under the current model.
* Forward both the nominal embedded data and the permuted data and compute a loss under a contrastive or conditional model (where the nominal data is assigned a different class than the modified data).

We have a `CalasFlow`, which is a conditional Normalizing Flow.
While we cannot directly modify the entropy of data, we can modify its likelihood.
For a normalizing flow with a (quasi-standard) normal base distribution, entropy is approximately anti-proportional to likelihood.

___________



In [1]:
SEED = 1

import __init__
import torch
from torch import Tensor
from torch.nn import DataParallel

from calas.models.flow import CalasFlowWithRepr
from calas.models.flow_test import make_flows
from calas.models.repr import AE_UNet_Repr

torch.manual_seed(SEED)
dev = torch.device('cuda:0')

repr = AE_UNet_Repr(input_dim=2, hidden_sizes=(96,64,96)).to(dev)
flow = CalasFlowWithRepr(num_classes=2, repr=repr, flows=make_flows(K=6, dim=repr.embed_dim, units=128, layers=2)).to(dev)

data_parallel_devs = list(f'cuda:{idx}' for idx in range(torch.cuda.device_count()))
if len(data_parallel_devs) > 1:
    flow = DataParallel(module=flow, device_ids=data_parallel_devs, output_device=dev, dim=0)

In [2]:
from calas.tools.two_moons import two_moons_rejection_sampling
from calas.data.dataset import ListDataset

num_train = 20_000
train = ListDataset(items=zip(
    torch.split(tensor=two_moons_rejection_sampling(nsamples=20_000, pure_moons=True, pure_moons_mode='tight', seed=SEED), split_size_or_sections=1, dim=0),
    torch.split(tensor=torch.zeros(num_train), split_size_or_sections=1, dim=0)))


valid = ListDataset(items=zip(
    torch.split(tensor=two_moons_rejection_sampling(nsamples=20_000, complementary=True, pure_moons=True, pure_moons_mode='tight', seed=SEED), split_size_or_sections=1, dim=0),
    torch.split(tensor=torch.ones(num_train), split_size_or_sections=1, dim=0)))

We use the following function to determine the distribution of likelihoods under the current model.

In [3]:
from typing import Optional

def estimate_train_likelihood(num_samples: Optional[int]=None) -> tuple[float, float, float]:
    num_samples = train.size if num_samples is None else num_samples

    result: list[Tensor] = []
    with torch.no_grad():
        flow.eval()
        for batch in train.iter_batched(batch_size=1_000):
            x, clz = torch.cat(list(t[0] for t in batch)), torch.cat(list(t[1] for t in batch))
            likelihood = flow.log_rel_lik_X(input=x.to(dev), classes=clz.to(dev)).detach()
            result.append(likelihood)
    
    temp = torch.cat(tensors=result)
    min, max, std = temp.min().item(), temp.max().item(), temp.std().item()
    torch.cuda.empty_cache()
    return min, max, std

In [None]:
estimate_train_likelihood()

# Controlled And Linear Anomaly Synthesis

Now that we have defined our model and the **in-distribution** data, we will need to define how exactly we are going to synthesize anomalies.

We will treat the problem in this notebook as an unsupervised anomaly detection problem: Points that fall within any of the two moons are considered in-distribution.
However, we do not have any explicit out-of-distribution data.
In CALAS, the idea is synthesize <u>***near in-distribution outliers***</u>.
During training, those shall be used to <u>***concretize the manifold***</u> of the in-distribution data.

For each batch of nominal data, we will synthesize another batch of same size that is derived from the nominal data (i.e., not just random noise).
The synthesis shall produce data that is very close to the nominal data.

For a sample of $N$ nominal observations, we will produce $N/2$ observations, each having a likelihood that is smaller than the minimum observed likelihood in the nominal batch, and $N/2$ observations that have a likelihood that is larger.

______________


In [None]:
from calas.data.synthesis import Synthesis
from calas.data.permutation import Space, Likelihood, Data2Data_Grad, CurseOfDimDataPermute, Normal2Normal_NoGrad, Normal2Normal_Grad

# We will define two different Synthesis strategies; one to create samples of
# lower likelihood, and one to create samples with higher likelihood. In either
# case the space we're gonna modify in will always be E here.

space = Space.Embedded

synthesis_lower = Synthesis(flow=flow, space_in=space, space_out=space)
synthesis_lower.add_permutations(
    Data2Data_Grad(flow=flow, space=space, seed=SEED),
    CurseOfDimDataPermute(flow=flow, space=space, seed=SEED, use_grad_dir=False, num_dims=(8,8)),
    Normal2Normal_NoGrad(flow=flow, method='quantiles', use_loc_scale_base_grad=True, u_min=0.01, u_max=0.01, seed=SEED)
)


synthesis_higher = Synthesis(flow=flow, space_in=space, space_out=space)
synthesis_higher.add_permutations(
    Normal2Normal_Grad(flow=flow, method='quantiles', u_min=0.01, seed=SEED),
    CurseOfDimDataPermute(flow=flow, space=space, u_min=0.001, u_max=0.01, u_frac_negative=0.5, seed=SEED),
    Data2Data_Grad(flow=flow, space=space, seed=SEED)
)

# Discriminator

____


In [None]:
from torch import nn


def smooth_maximum(a: Tensor, max: float|Tensor, alpha=0.0001) -> Tensor:
    # For min, just replace the plus before the sqrt with minus!
    return 0.5 * ((a + max) + torch.sqrt((a - max)**2.0 + alpha))

def smooth_minimum(a: Tensor, min: float|Tensor, alpha: float=0.0001) -> Tensor:
    return 0.5 * ((a + min) - torch.sqrt((a - min)**2.0 + alpha))

def smooth_01(a: Tensor, alpha: float=0.0001) -> Tensor:
    return smooth_maximum(a=smooth_minimum(a=a, min=1.0, alpha=alpha), max=0.0, alpha=alpha)
    # return 0.5 + 0.5 * torch.tanh(a) # Smooth Heaviside

class SM01(nn.Module):
    def __init__(self, alpha: float, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha

    def forward(self, x: Tensor) -> Tensor:
        return smooth_01(a=x, alpha=self.alpha)


class SimpleDiscr(nn.Module):
    def __init__(self, num_in: int, num_hidden: int, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.linear_1 = nn.Linear(in_features=num_in, out_features=num_hidden, bias=True)

        self.model = nn.Sequential(
            # nn.SiLU(),
            SM01(alpha=0.0001),
            
            nn.Linear(in_features=num_hidden, out_features=num_hidden, bias=True),
            SM01(alpha=0.0001),

            nn.Linear(in_features=num_hidden, out_features=2, bias=True),
            SM01(alpha=0.0001)
        )
    
    def forward(self, x: Tensor) -> Tensor:
        return self.forward_embedding(self.linear_1(torch.atleast_2d(x)))
    
    def forward_embedding(self, x: Tensor) -> Tensor:
        """Here, we assume the input has already been passed through the linear_1 layer."""
        return self.model(torch.atleast_2d(x))


model = SimpleDiscr(num_in=repr.embed_dim, num_hidden=256).to(dev)

print(model)

# Split-Step Training

We will perform "split-step" training.
Instead of just forwarding the data and computing the loss, we will perform the two steps manually.

*NOTE*: The representation will also learn to reconstruct the counter-examples, **unless** we clone/detach the nominal sample prior to modifying it.
I added a boolean flag in the next block so this can be controlled.

_____


In [7]:
REPR_LEARN_COUNTER = True

from warnings import warn
from torch import Tensor, cuda, device
from torch.nn.functional import kl_div, softmax, one_hot


assert cuda.is_available(), "Don't do this on CPU..."
dev = device('cuda:0')

def split_step_forward(nominal_batch: Tensor, nominal_classes: Tensor, epoch: int, nominal_only: bool=False, accept_all: bool=True) -> Tensor:
    """
    Takes as input a batch of nominal data, produces the counter examples,
    forwards both with appropriate conditions, and returns the computed loss
    as a tensor, on which we can call `backward()`.
    """
    flow.eval()
    assert not flow.training and not repr.training
    nominal_batch = nominal_batch.to(device=dev)
    nominal_classes = torch.atleast_1d(nominal_classes.squeeze().to(device=dev, dtype=torch.int64))
    num_nominal = nominal_batch.shape[0]


    # First: Forward the data through the representation!
    nominal_E = flow.X_to_E(input=nominal_batch)
    

    use_train, use_train_clz = nominal_E, nominal_classes
    discr_loss = 0.0

    if not nominal_only:
        lik_min, lik_max, lik_sd = estimate_train_likelihood()
        target_lik_min = lik_min
        target_lik_min_crit = lik_min - 0.5 * lik_sd
        
        modified_lower_E, modified_lower_E_mask = synthesis_lower.rsynthesize(
            likelihood=Likelihood.Decrease,
            sample=nominal_E.clone().detach(), classes=nominal_classes.clone().detach(), target_lik=target_lik_min, target_lik_crit=target_lik_min_crit, max_steps=3, accept_all=accept_all)
        if accept_all:
            modified_lower_E_mask = torch.where(True | modified_lower_E_mask, True, True).to(modified_lower_E_mask.device)
        
        num_lower = modified_lower_E_mask.sum()
        if num_lower == 0:
            return torch.full((1,), torch.nan)
        
        nominal_E = nominal_E[0:num_lower]
        nominal_classes = nominal_classes[0:num_lower]
        modified_lower_E = modified_lower_E[modified_lower_E_mask]
        
        use_train: Tensor = None
        use_train_clz: Tensor = None
        if REPR_LEARN_COUNTER:
            use_train = torch.vstack(tensors=(
                nominal_E,
                nominal_E.clone().copy_(modified_lower_E)))
            use_train_clz = torch.cat(tensors=(
                nominal_classes,
                torch.ones_like(nominal_classes)))
        else:
            use_train = torch.vstack(tensors=(
                nominal_E, modified_lower_E))
            use_train_clz = torch.cat(tensors=(
                nominal_classes, torch.ones_like(nominal_classes).detach()))
        
        use_train_clz_kl = one_hot(use_train_clz, 2).to(dtype=use_train.dtype)
        discr_pred = softmax(model.forward_embedding(x=use_train), dim=1)
        discr_loss = repr.embed_dim * kl_div(input=discr_pred.log(), target=use_train_clz_kl, reduction='batchmean', log_target=False)
    

    # Sixth step: Compute the Loss!
    flow.train()
    loss = flow.loss_wrt_E(embeddings=use_train, classes=use_train_clz)
    return loss + discr_loss

# Training

In [None]:
STEPS = 200
BATCH_SIZE = 128
NUM_NOMINAL = 3

flow.to(device=dev)
optim = torch.optim.Adam(params=flow.parameters(recurse=True), lr=5e-4)


loss_before = float('inf')
step = 0
while step < STEPS:
    optim.zero_grad()

    batch = train.shuffle(seed=step).take(num=BATCH_SIZE)
    loss = split_step_forward(
        epoch=step,
        nominal_only= step < NUM_NOMINAL,
        nominal_batch=torch.cat(tensors=list(t[0] for t in batch)).to(dev),
        nominal_classes=torch.cat(tensors=list(t[1] for t in batch)).to(dev))
    
    if torch.isfinite(loss):
        loss.backward()
        optim.step()
        print(f'Loss: {loss.item()}')
        step += 1


It looks like the training was relatively stable, except for some hick-ups in the beginning, where for a few steps it was not possible to synthesize sufficiently many counter examples.
Also, I have not skipped training steps where the loss was NaN.

__________


In [9]:
# torch.save({
#     'model_state_dict': flow.state_dict(),
#     'optimizer_state_dict': optim.state_dict(),
#     'loss': loss
#     }, 'model.pickle')

# Evaluation

We will do the following:

* Generate some new (in-distribution) data from the two moons problem and check its likelihood under the correct (0) and wrong (1) class.
* Generate some ***complementary*** data from the two moons problem and check how well the score distinguishes it.

In the latter case, we would ideally like to see the flow to systematically assign lower (close to zero) likelihoods to the complementary data.

In [10]:
def likelihood_score(data: Tensor) -> Tensor:
    with torch.no_grad():
        clz_0 = torch.zeros(size=(data.shape[0],), device=data.device)
        clz_1 = torch.ones(size=(data.shape[0],), device=data.device)
        
        return flow.eval().log_rel_lik_X(data, clz_0) - flow.eval().log_rel_lik_X(data, clz_1)

In [11]:
flow.eval()

test_id = two_moons_rejection_sampling(nsamples=5_000, seed=SEED+1).to(device=dev)
test_id_clz = torch.full(size=(test_id.shape[0],), fill_value=0., device=dev)

# Note that we have not trained with pure+tight moons so this won't be perfect
# because the data we trained on will bleed a little into the complementary data.
test_comp = two_moons_rejection_sampling(nsamples=5_000, seed=SEED+1, complementary=True, pure_moons=True, pure_moons_mode='tight').to(device=dev)
test_comp_clz = torch.full(size=(test_comp.shape[0],), fill_value=1., device=dev)

In [None]:
likelihood_score(data=test_id).mean()

In [None]:
likelihood_score(data=test_comp).mean()

In [None]:
import seaborn as sns

sns.displot(data=torch.vstack(tensors=(
    likelihood_score(test_id),
    likelihood_score(test_comp)
)).T.detach().cpu().numpy())

In [None]:
softmax(model.forward_embedding(flow.X_to_E(test_id)), dim=1)

In [None]:
softmax(model.forward_embedding(flow.X_to_E(test_comp)), dim=1)

In [17]:
def discriminator_score(data: Tensor) -> Tensor:
    temp = softmax(model.forward_embedding(flow.X_to_E(data)), dim=1).detach()
    return temp[:,0] - temp[:,1]

In [None]:
import matplotlib.pyplot as plt

with torch.no_grad():
    grid_size = 250
    xx, yy = torch.meshgrid(torch.linspace(-5, 5, grid_size), torch.linspace(-5, 5, grid_size), indexing='ij')
    zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
    zz = zz.to(dev)

    probs = flow.log_rel_lik_X(input=zz, classes=torch.zeros(zz.shape[0], device=dev))
    # probs = likelihood_score(data=zz)
    p_target = probs.view(*xx.shape).cpu().data.numpy()

    plt.figure(figsize=(8, 8))
    plt.pcolormesh(xx, yy, p_target, shading='auto', cmap='coolwarm')#, vmin=likelihood.min().item(), vmax=likelihood.max().item())
    plt.gca().set_aspect('equal', 'box')
    plt.grid(visible=True)
    plt.colorbar()
    plt.show()
    torch.cuda.empty_cache()

In [None]:
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve, RocCurveDisplay

with torch.no_grad():
    y_true = torch.cat((test_id_clz, test_comp_clz)).detach().cpu().numpy()
    # y_pred = likelihood_score(data=torch.cat((test_id, test_comp))).detach().cpu().numpy()
    y_pred = discriminator_score(data=torch.cat((test_id, test_comp))).detach().cpu().numpy() #* likelihood_score(data=torch.cat((test_id, test_comp))).detach().cpu().numpy()

    fpr, tpr, thresholds = roc_curve(y_true=y_true, y_score=y_pred)
    youden_j = tpr - fpr
    optimal_idx = np.argmax(youden_j)
    optimal_threshold = thresholds[optimal_idx]
    print(f'optimal_threshold={optimal_threshold:.4f}, roc_auc_score={roc_auc_score(y_true=y_true, y_score=y_pred):.4f}')

    RocCurveDisplay.from_predictions(y_true=y_true, y_pred=y_pred)
    torch.cuda.empty_cache()

In [None]:
# These are correctly forwarded, i.e., the classes match!
with torch.no_grad():
    b_id = flow.X_to_B(test_id, test_id_clz)[0].flatten()
    b_comp = flow.X_to_B(test_comp, test_comp_clz)[0].flatten()

    # Same, but mixed-up classes!
    b_id_wrong = flow.X_to_B(test_id, test_comp_clz)[0].flatten()
    b_comp_wrong = flow.X_to_B(test_comp, test_id_clz)[0].flatten()

    aspect = 1.5
    sns.displot(b_id.detach().cpu().numpy(), aspect=aspect)
    sns.displot(b_comp.detach().cpu().numpy(), aspect=aspect)
    sns.displot(b_id_wrong.detach().cpu().numpy(), aspect=aspect)
    sns.displot(b_comp_wrong.detach().cpu().numpy(), aspect=aspect)
    torch.cuda.empty_cache()

In [None]:
with torch.no_grad():
    samples_np = test_id.detach().cpu().numpy()
    xx, yy = torch.meshgrid(test_id[:, 0], test_id[:, 1], indexing='ij')
    likelihood = flow.log_rel_lik_X(test_id, test_id_clz)

    plt.figure(figsize=(8, 8))
    plt.xlim(-3, 3)
    plt.ylim(-3, 3)
    plt.scatter(samples_np[:, 0], samples_np[:, 1], c=likelihood.detach().cpu().numpy(), cmap='coolwarm')#, vmin=likelihood.mean().item())
    plt.gca().set_aspect('equal', 'box')
    plt.grid(visible=True)
    plt.colorbar()
    plt.show()
    torch.cuda.empty_cache()

In [None]:
with torch.no_grad():
    samples_np = test_comp.detach().cpu().numpy()
    xx, yy = torch.meshgrid(test_comp[:, 0], test_comp[:, 1], indexing='ij')
    likelihood = flow.log_rel_lik_X(test_comp, test_id_clz)

    plt.figure(figsize=(8, 8))
    plt.xlim(-3, 3)
    plt.ylim(-3, 3)
    plt.scatter(samples_np[:, 0], samples_np[:, 1], c=likelihood.detach().cpu().numpy(), cmap='coolwarm')#, vmin=likelihood.mean().item())
    plt.gca().set_aspect('equal', 'box')
    plt.grid(visible=True)
    plt.colorbar()
    plt.show()
    torch.cuda.empty_cache()