In [None]:
from utils import assign_free_gpus
assign_free_gpus()

import time
import json
from pathlib import Path
from sklearn.datasets import make_moons
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
plt.style.use('seaborn')
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
from pytorch_lightning.loggers import TensorBoardLogger
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.distributions.normal import Normal

from generate_data import MoonsDataModule, MoonsDataset
from models import LinearClassifier, Classifier, GAN, ContrastiveMapping

torch.manual_seed(0)
np.random.seed(0)
rng = np.random.default_rng(0)
path_results = Path.cwd().parent / 'results'

In [None]:
# noise = 0.3
# path_classifier = path_results / 'classifier' / '2022-12-06_163615_linear_noise0.3' / 'checkpoints' / 'epoch=99-step=6300.ckpt'
# path_gan = path_results / 'GAN' / '2022-11-29_152439_noise0.3' / 'checkpoints' / 'epoch=99-step=12600.ckpt'
# # path_model = path_results / 'contrastiveMapping' / '2023-01-17_170210_linear_noise0.3' / 'checkpoints' / 'epoch=59-step=15120.ckpt'
# path_model = path_results / 'contrastiveMapping' / '2023-01-17_171054_linear_noise0.3_learndistrib' / 'checkpoints' / 'epoch=59-step=15120.ckpt'



noise = 0.1
path_classifier = path_results / 'classifier' / '2022-11-29_152904_linear_noise0.1' / 'checkpoints' / 'epoch=99-step=6300.ckpt'
path_gan = path_results / 'GAN' / '2022-11-29_152457_noise0.1' / 'checkpoints' / 'epoch=99-step=12600.ckpt'
# path_model = path_results / 'contrastiveMapping' / '2023-01-17_162629_linear_noise0.1' / 'checkpoints' / 'epoch=59-step=15120.ckpt'
# path_model = path_results / 'contrastiveMapping' / '2023-01-17_163649_linear_noise0.1_learndistrib' / 'checkpoints' / 'epoch=59-step=15120.ckpt'
# path_model = path_results / 'contrastiveMapping' / '2023-01-17_174727_linear_noise0.1_classifLossMSE' / 'checkpoints' / 'epoch=59-step=15120.ckpt'
# path_model = path_results / 'contrastiveMapping' / '2023-01-17_175117_linear_noise0.1_classifLossKLDiv' / 'checkpoints' / 'epoch=59-step=15120.ckpt'
path_model = path_results / 'contrastiveMapping' / '2023-01-17_175828_linear_noise0.1_learndistrib_classifLossMSE' / 'checkpoints' / 'epoch=59-step=15120.ckpt'

In [None]:
classifier = LinearClassifier.load_from_checkpoint(str(path_classifier))
gan = GAN.load_from_checkpoint(str(path_gan))

data_test = MoonsDataset(n_samples=10000, noise=noise, random_state=2)
x_test = data_test.x
y_test = data_test.y

In [None]:
trainer = pl.Trainer(accelerator='auto', devices=1)
trainer.validate(classifier, datamodule=MoonsDataModule(n_samples=20000, noise=noise, random_state=2))

# SHOW DESCISION BOUNDARY
x = np.linspace(-2, 3, 100)
y = np.linspace(-2, 2, 100)

grid_data = np.zeros((len(x)*len(y), 2))
i = 0
for x_ in x:
    for y_ in y:
        grid_data[i] = [x_, y_]
        i += 1
grid_data = torch.from_numpy(grid_data).float()

with torch.no_grad():
    y = classifier(grid_data)
class_pred = torch.sigmoid(y).round().cpu().flatten()#.numpy()

# SHOW CLASSIF LOSS
with torch.no_grad():
    logits = classifier(x_test).squeeze()
    classif_loss = F.binary_cross_entropy_with_logits(logits, y_test, reduction='none')


fig, ax = plt.subplots()
ax.set_title('classifier decision boundary')
ax.scatter(grid_data[class_pred==0, 0], grid_data[class_pred==0, 1], alpha=1, c='C0', label='predicted class 0')
ax.scatter(grid_data[class_pred!=0, 0], grid_data[class_pred!=0, 1], alpha=1, c='C1', label='predicted class 1')
ax.scatter(x_test[y_test==0, 0], x_test[y_test==0, 1], alpha=0.1, c=classif_loss[y_test==0], cmap='Reds', marker='o', label='real data - class 0')
im = ax.scatter(x_test[y_test==1, 0], x_test[y_test==1, 1], alpha=0.1, c=classif_loss[y_test==1], cmap='Reds', marker='+', label='real data - class 1')
leg = ax.legend(frameon=True)
for lh in leg.legendHandles: 
    lh.set_alpha(1)
cbar = fig.colorbar(im, ax=ax, label='classifier loss')
cbar.solids.set(alpha=1)

In [None]:
n_samples = len(x_test)
z = torch.randn(n_samples, gan.latent_dim, device=gan.device)
if gan.c_dim > 0:
    rnd_label = torch.randint(gan.c_dim, size=(z.shape[0],), device=gan.device)
    c = F.one_hot(rnd_label, num_classes=gan.c_dim)
    z = torch.cat([z, c], dim=1)
with torch.no_grad():
    w = gan.generator.mapping(z)
    x_fake = gan.generator.synthesis(w).detach().cpu().numpy()
    rnd_label = rnd_label.cpu().numpy()
    w = w.detach().cpu().numpy()

plt.figure()
plt.title('fake vs. real data')
plt.scatter(x_test[y_test==0, 0], x_test[y_test==0, 1], alpha=0.5, c='C0', label='real data - class 0')
plt.scatter(x_test[y_test==1, 0], x_test[y_test==1, 1], alpha=0.5, c='C1', label='real data - class 1')
plt.scatter(x_fake[rnd_label==0, 0], x_fake[rnd_label==0, 1], alpha=0.5, c='C2', label='fake data - class 0')
plt.scatter(x_fake[rnd_label==1, 0], x_fake[rnd_label==1, 1], alpha=0.5, c='C3', label='fake data - class 1')
plt.legend()

# w_embedded = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(w)
# plt.figure()
# plt.title('t-SNE in W')
# plt.scatter(w_embedded[rnd_label==0, 0], w_embedded[rnd_label==0, 1], alpha=0.5, c='C0', label='class 0')
# plt.scatter(w_embedded[rnd_label==1, 0], w_embedded[rnd_label==1, 1], alpha=0.5, c='C1', label='class 1')
# plt.legend()

In [None]:
model = ContrastiveMapping.load_from_checkpoint(str(path_model), gan=gan, classifier=classifier)

def selection_function(x):
    p_x_y0 = Normal(model.mapping.mean0, model.mapping.std0).log_prob(x).exp().detach()
    p_x_y1 = Normal(model.mapping.mean1, model.mapping.std1).log_prob(x).exp().detach()
    p_y0 = 0.5
    p_y1 = 0.5
    p_y0_x = p_x_y0 * p_y0 / (p_x_y0 * p_y0 + p_x_y1 * p_y1)
    p_y1_x = p_x_y1 * p_y1 / (p_x_y0 * p_y0 + p_x_y1 * p_y1)
    max_p_y_x = torch.maximum(p_y0_x, p_y1_x)
    return max_p_y_x

# Define domain for fake data

In [None]:
# Risk coverage curves
u, c = model.mapping.sample_u(x_test.shape[0])
u_domain = u[:, -1]
with torch.no_grad(): 
    logits = model(u)
y = torch.sigmoid(logits).squeeze()
classif_loss = F.binary_cross_entropy_with_logits(logits.squeeze(), c.float(), reduction='none')
classif_correct = (y.round() == c)

# baseline: random selection
domain_cutoff_random = np.linspace(0, 1, 100)
coverage_random = np.zeros_like(domain_cutoff_random)
risk_random = np.zeros_like(domain_cutoff_random)
acc_random = np.zeros_like(domain_cutoff_random)
for i, cut in enumerate(domain_cutoff_random):
    nb_samples = int((1-cut) * x_test.shape[0]) # 1-cut to be coherent with other indices below (low value -> high coverage)
    idx_domain = rng.choice(np.arange(x_test.shape[0]), size=nb_samples, replace=False)
    coverage_random[i] = x_test[idx_domain].shape[0] / x_test.shape[0]
    risk_random[i] = classif_loss[idx_domain].mean()
    acc_random[i] = classif_correct[idx_domain].float().mean()

# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 1000)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff_baseline)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = (y > cut) | (1-y > cut)
    coverage_baseline[i] = idx_domain.float().mean()
    risk_baseline[i] = classif_loss[idx_domain].mean()
    acc_baseline[i] = classif_correct[idx_domain].float().mean()

# cut max proba computed in U
domain_cutoff = np.linspace(0.5, 1, 1000)
coverage = np.zeros_like(domain_cutoff)
risk = np.zeros_like(domain_cutoff)
acc = np.zeros_like(domain_cutoff)
for i, cut in enumerate(domain_cutoff):
    idx_domain = selection_function(u_domain) > cut
    coverage[i] = idx_domain.float().mean()
    risk[i] = classif_loss[idx_domain].mean()
    acc[i] = classif_correct[idx_domain].float().mean()


fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

ax1.set_title('coverage vs. risk\n(obtained by varying confidence/uncertainty threshold)\n(using pseudo-labels)')
ax1.plot(coverage, risk, label='cut in U')
ax1.plot(coverage_baseline, risk_baseline, label='baseline (max softmax)')
ax1.plot(coverage_random, risk_random, label='baseline (random)')
ax1.legend()
ax1.set_xlabel('coverage')
ax1.set_ylabel('risk')

ax2.set_title('coverage vs. accuracy\n(obtained by varying confidence/uncertainty threshold)\n(using pseudo-labels)')
ax2.plot(coverage, acc, label='cut in U')
ax2.plot(coverage_baseline, acc_baseline, label='baseline (max softmax)')
ax2.plot(coverage_random, acc_random, label='baseline (random)')
ax2.legend()
ax2.set_xlabel('coverage')
ax2.set_ylabel('accuracy')

ax3.set_title('coverage vs threshold value')
ax3.plot((domain_cutoff-domain_cutoff.min())/(domain_cutoff.max()-domain_cutoff.min()), coverage, label='cut in U')
ax3.plot((domain_cutoff_baseline-domain_cutoff_baseline.min())/(domain_cutoff_baseline.max()-domain_cutoff_baseline.min()), coverage_baseline, label='baseline (max softmax)')
ax3.plot((domain_cutoff_random-domain_cutoff_random.min())/(domain_cutoff_random.max()-domain_cutoff_random.min()), coverage_random, label='baseline (random)')
ax3.legend()
ax3.set_xlabel('normalized threshold value')
ax3.set_ylabel('coverage')

ax4.set_title('risk vs threshold value')
ax4.plot((domain_cutoff-domain_cutoff.min())/(domain_cutoff.max()-domain_cutoff.min()), risk, label='cut in U')
ax4.plot((domain_cutoff_baseline-domain_cutoff_baseline.min())/(domain_cutoff_baseline.max()-domain_cutoff_baseline.min()), risk_baseline, label='baseline (max softmax)')
ax4.plot((domain_cutoff_random-domain_cutoff_random.min())/(domain_cutoff_random.max()-domain_cutoff_random.min()), risk_random, label='baseline (random)')
ax4.legend()
ax4.set_xlabel('normalized threshold value')
ax4.set_ylabel('risk')

# Define domain for real data

In [None]:
# Risk coverage curves for real test data
u = model.encoder(x_test.to(model.device))
u_domain = u[:, -1]

with torch.no_grad(): 
    logits = model.classifier(x_test.to(model.device))
y = torch.sigmoid(logits).squeeze()
classif_loss = F.binary_cross_entropy_with_logits(logits.squeeze(), y_test, reduction='none')
classif_correct = (y.round() == y_test)
tcp = torch.zeros_like(y)
tcp[y_test==0] = 1 - y[y_test==0]
tcp[y_test==1] = y[y_test==1]

# baseline: random selection
domain_cutoff_random = np.linspace(0, 1, 100)
coverage_random = np.zeros_like(domain_cutoff_random)
risk_random = np.zeros_like(domain_cutoff_random)
acc_random = np.zeros_like(domain_cutoff_random)
for i, cut in enumerate(domain_cutoff_random):
    nb_samples = int((1-cut) * x_test.shape[0]) # 1-cut to be coherent with other indices below (low value -> high coverage)
    idx_domain = rng.choice(np.arange(x_test.shape[0]), size=nb_samples, replace=False)
    coverage_random[i] = x_test[idx_domain].shape[0] / x_test.shape[0]
    risk_random[i] = classif_loss[idx_domain].mean()
    acc_random[i] = classif_correct[idx_domain].float().mean()
    acc_random[i] = classif_correct[idx_domain].float().mean()
    
# baseline: max softmax
domain_cutoff_baseline = np.linspace(0, 1, 1000)
coverage_baseline = np.zeros_like(domain_cutoff_baseline)
risk_baseline = np.zeros_like(domain_cutoff_baseline)
acc_baseline = np.zeros_like(domain_cutoff)
for i, cut in enumerate(domain_cutoff_baseline):
    idx_domain = (y > cut) | (1-y > cut)
    coverage_baseline[i] = idx_domain.float().mean()
    risk_baseline[i] = classif_loss[idx_domain].mean()
    acc_baseline[i] = classif_correct[idx_domain].float().mean()
    
# baseline: TCP
domain_cutoff_baselineTCP = np.linspace(0, 1, 1000)
coverage_baselineTCP = np.zeros_like(domain_cutoff_baselineTCP)
risk_baselineTCP = np.zeros_like(domain_cutoff_baselineTCP)
acc_baselineTCP = np.zeros_like(domain_cutoff)
for i, cut in enumerate(domain_cutoff_baselineTCP):
    idx_domain = tcp > cut
    coverage_baselineTCP[i] = idx_domain.float().mean()
    risk_baselineTCP[i] = classif_loss[idx_domain].mean()
    acc_baselineTCP[i] = classif_correct[idx_domain].float().mean()
    
# cut max proba computed in U
domain_cutoff = np.linspace(0.5, 1, 1000)
coverage = np.zeros_like(domain_cutoff)
risk = np.zeros_like(domain_cutoff)
acc = np.zeros_like(domain_cutoff)
for i, cut in enumerate(domain_cutoff):
    idx_domain = selection_function(u_domain) > cut
    coverage[i] = idx_domain.float().mean()
    risk[i] = classif_loss[idx_domain].mean()
    acc[i] = classif_correct[idx_domain].float().mean()

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

ax1.set_title('coverage vs. risk\n(obtained by varying confidence/uncertainty threshold)')
ax1.plot(coverage, risk, label='cut in U')
ax1.plot(coverage_baseline, risk_baseline, label='baseline (max softmax)')
ax1.plot(coverage_random, risk_random, label='baseline (random)')
ax1.plot(coverage_baselineTCP, risk_baselineTCP, label='baseline (TCP oracle)')
ax1.legend()
ax1.set_xlabel('coverage')
ax1.set_ylabel('risk')

ax2.set_title('coverage vs. accuracy\n(obtained by varying confidence/uncertainty threshold)\n(using pseudo-labels)')
ax2.plot(coverage, acc, label='cut in U')
ax2.plot(coverage_baseline, acc_baseline, label='baseline (max softmax)')
ax2.plot(coverage_random, acc_random, label='baseline (random)')
ax2.plot(coverage_baselineTCP, acc_baselineTCP, label='baseline (TCP oracle)')
ax2.legend()
ax2.set_xlabel('coverage')
ax2.set_ylabel('accuracy')

ax3.set_title('coverage vs threshold value')
ax3.plot((domain_cutoff-domain_cutoff.min())/(domain_cutoff.max()-domain_cutoff.min()), coverage, label='cut in U')
ax3.plot((domain_cutoff_baseline-domain_cutoff_baseline.min())/(domain_cutoff_baseline.max()-domain_cutoff_baseline.min()), coverage_baseline, label='baseline (max softmax)')
ax3.plot((domain_cutoff_random-domain_cutoff_random.min())/(domain_cutoff_random.max()-domain_cutoff_random.min()), coverage_random, label='baseline (random)')
ax3.plot((domain_cutoff_baselineTCP-domain_cutoff_baselineTCP.min())/(domain_cutoff_baselineTCP.max()-domain_cutoff_baselineTCP.min()), coverage_baselineTCP, label='baseline (TCP oracle)')
ax3.legend()
ax3.set_xlabel('normalized threshold value')
ax3.set_ylabel('coverage')

ax4.set_title('risk vs threshold value')
ax4.plot((domain_cutoff-domain_cutoff.min())/(domain_cutoff.max()-domain_cutoff.min()), risk, label='cut in U')
ax4.plot((domain_cutoff_baseline-domain_cutoff_baseline.min())/(domain_cutoff_baseline.max()-domain_cutoff_baseline.min()), risk_baseline, label='baseline (max softmax)')
ax4.plot((domain_cutoff_random-domain_cutoff_random.min())/(domain_cutoff_random.max()-domain_cutoff_random.min()), risk_random, label='baseline (random)')
ax4.plot((domain_cutoff_baselineTCP-domain_cutoff_baselineTCP.min())/(domain_cutoff_baselineTCP.max()-domain_cutoff_baselineTCP.min()), risk_baselineTCP, label='baseline (TCP oracle)')
ax4.legend()
ax4.set_xlabel('normalized threshold value')
ax4.set_ylabel('risk')

# Encode data

In [None]:
with torch.no_grad():
    u = model.encoder(x_test.to(model.device))
    u_domain = u[:, -1]
    w = model.mapping(u)
    x_recon = model.gan.generator.synthesis(w)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.hist(u_domain.cpu().numpy(), bins=50)
ax1.set_title('histogram in U_domain')

ax2.scatter(x_test[:, 0], x_test[:, 1], alpha=0.1, label='real data')
ax2.scatter(x_recon[:, 0], x_recon[:, 1], alpha=0.1, label='reconstructed data')
ax2.legend()

# Filter out-of-domain data

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

cut = 0.8
idx_in_domain = selection_function(u_domain) > cut
idx_out_domain = idx_in_domain.logical_not()
coverage = idx_in_domain.float().mean()
risk = classif_loss[idx_in_domain].mean()
acc = classif_correct[idx_in_domain].float().mean()
ax1.set_title(f'threshold = {cut}\ncoverage={coverage:.2f}, risk={risk:.2f}, acc={acc:.2f}')
ax1.scatter(x_test[idx_out_domain, 0], x_test[idx_out_domain, 1], alpha=0.1, label='out domain', c='r')
ax1.scatter(x_test[idx_in_domain, 0], x_test[idx_in_domain, 1], alpha=0.1, label='in domain', c='g')
ax1.legend()

cut = 0.9
idx_in_domain = selection_function(u_domain) > cut
idx_out_domain = idx_in_domain.logical_not()
coverage = idx_in_domain.float().mean()
risk = classif_loss[idx_in_domain].mean()
acc = classif_correct[idx_in_domain].float().mean()
ax2.set_title(f'threshold = {cut}\ncoverage={coverage:.2f}, risk={risk:.2f}, acc={acc:.2f}')
ax2.scatter(x_test[idx_out_domain, 0], x_test[idx_out_domain, 1], alpha=0.1, label='out domain', c='r')
ax2.scatter(x_test[idx_in_domain, 0], x_test[idx_in_domain, 1], alpha=0.1, label='in domain', c='g')
ax2.legend()

cut = 0.95
idx_in_domain = selection_function(u_domain) > cut
idx_out_domain = idx_in_domain.logical_not()
coverage = idx_in_domain.float().mean()
risk = classif_loss[idx_in_domain].mean()
acc = classif_correct[idx_in_domain].float().mean()
ax3.set_title(f'threshold = {cut}\ncoverage={coverage:.2f}, risk={risk:.2f}, acc={acc:.2f}')
ax3.scatter(x_test[idx_out_domain, 0], x_test[idx_out_domain, 1], alpha=0.1, label='out domain', c='r')
ax3.scatter(x_test[idx_in_domain, 0], x_test[idx_in_domain, 1], alpha=0.1, label='in domain', c='g')
ax3.legend()

In [None]:
x = torch.arange(-3, 3, 0.1)
p_y0 = 0.5
p_y1 = 0.5
p_x_y0 = Normal(model.mapping.mean0, model.mapping.std0).log_prob(x).exp().detach()
p_x_y1 = Normal(model.mapping.mean1, model.mapping.std1).log_prob(x).exp().detach()
p_y0_x = p_x_y0 * p_y0 / (p_x_y0 * p_y0 + p_x_y1 * p_y1)
p_y1_x = p_x_y1 * p_y1 / (p_x_y0 * p_y0 + p_x_y1 * p_y1)
max_p_y_x = torch.maximum(p_y0_x, p_y1_x)


plt.figure()
plt.title('proba density function')
plt.plot(x.numpy(), p_x_y0.numpy(), label='class 0')
plt.plot(x.numpy(), p_x_y1.numpy(), label='class 1')
plt.plot(x.numpy(), (p_x_y0+p_x_y1).numpy(), label='sum')
plt.legend()

plt.figure()
plt.plot(x.numpy(), p_y0_x.numpy(), label='p(y=0|x)')
plt.plot(x.numpy(), p_y1_x.numpy(), label='p(y=1|x)')
plt.plot(x.numpy(), max_p_y_x.numpy(), label='max p(y|x)', alpha=0.5)
plt.legend()
