In [None]:
import sys
sys.path.append('./moons') 

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import time
import json
from pathlib import Path
import sklearn
from sklearn.datasets import make_moons
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
plt.style.use('seaborn')
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
# from pytorch_lightning.loggers import TensorBoardLogger
import torch.nn.functional as F
import pytorch_lightning as pl

from generate_data import MoonsDataModule, MoonsDataset
from models import LinearClassifier, Classifier, GAN
from temperature_scaling import ModelWithTemperature, _ECELoss

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(0)
np.random.seed(0)
rng = np.random.default_rng(0)
path_models = Path.cwd().parent / 'models' / 'moons'

# Load data, classifer, and GAN

In [None]:
noise = 0.3
linear = False
path_classifier = path_models / 'classifier' / '2023-05-24_095926_noise0.3' / 'checkpoints' / 'epoch=99-step=6300.ckpt'
classifier = Classifier.load_from_checkpoint(str(path_classifier)).eval().to(device)

# path_gan = path_models / 'GAN' / '2023-05-24_103512_noise0.3_classCondone-hot_classifCondMSP' / 'checkpoints' / 'epoch=99-step=12600.ckpt'
# path_gan = path_models / 'GAN' / '2023-06-08_162712_noise0.3_classCondone-hot_classifCondMSP' / 'checkpoints' / 'epoch=499-step=63000.ckpt'
path_gan = path_models / 'GAN' / '2023-06-08_133522_noise0.3_classCondone-hot_classifCondNone' / 'checkpoints' / 'epoch=199-step=25200.ckpt'
gan = GAN.load_from_checkpoint(str(path_gan), classifier=classifier).eval().to(device)

In [None]:
data_test = MoonsDataset(n_samples=10000, noise=noise, random_state=2)
x_test = data_test.x
y_test = data_test.y

In [None]:
trainer = pl.Trainer(accelerator='auto', devices=1)
trainer.validate(classifier, datamodule=MoonsDataModule(n_samples=20000, noise=noise, random_state=2))

# SHOW DESCISION BOUNDARY
x = np.linspace(-2, 3, 100)
y = np.linspace(-2, 2, 100)

grid_data = np.zeros((len(x)*len(y), 2))
i = 0
for x_ in x:
    for y_ in y:
        grid_data[i] = [x_, y_]
        i += 1
grid_data = torch.from_numpy(grid_data).float()

with torch.no_grad():
    y = classifier(grid_data)
class_pred_grid = torch.sigmoid(y).round().cpu().flatten()#.numpy()

# SHOW CLASSIF LOSS
with torch.no_grad():
    logits = classifier(x_test)
    classif_loss = F.binary_cross_entropy_with_logits(logits.squeeze(), y_test, reduction='none')
    
# CORRECT PRED
class_pred = torch.sigmoid(logits).round().cpu().flatten()
classif_loss = class_pred != y_test

fig, ax = plt.subplots()
# ax.set_title('classifier decision boundary')
ax.scatter(grid_data[class_pred_grid==0, 0], grid_data[class_pred_grid==0, 1], alpha=1, c='C0', label='predicted class 0')
ax.scatter(grid_data[class_pred_grid!=0, 0], grid_data[class_pred_grid!=0, 1], alpha=1, c='C1', label='predicted class 1')
ax.scatter(x_test[y_test==0, 0], x_test[y_test==0, 1], alpha=0.2, c=classif_loss[y_test==0], cmap='Reds', marker='o', label='real data - class 0')
im = ax.scatter(x_test[y_test==1, 0], x_test[y_test==1, 1], alpha=0.2, c=classif_loss[y_test==1], cmap='Reds', marker='+', label='real data - class 1')
leg = ax.legend(frameon=True)
for lh in leg.legendHandles: 
    lh.set_alpha(1)
# cbar = fig.colorbar(im, ax=ax, label='prediction error')
# cbar.solids.set(alpha=1)


In [None]:
def get_MSP_correct(dataloader, classifier, device):

    classifier.eval()
    msp = torch.zeros((len(dataloader.dataset)), device=device)
    correct = torch.zeros((len(dataloader.dataset)), device=device)
    idx = 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        batch_size = X.shape[0]

        with torch.no_grad():
            logits = classifier(X)
        probas_class1 = torch.sigmoid(logits).squeeze()
        probas_class0 = 1 - probas_class1
        msp[idx:idx+batch_size] = torch.maximum(probas_class0, probas_class1)
        correct[idx:idx+batch_size] = probas_class1.round() == y
        idx += batch_size

    return msp, correct

msp_test, correct_test = get_MSP_correct(DataLoader(data_test, 1000), classifier.to(device), device)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.set_xlabel('MSP value')
ax.hist(msp_test.cpu(), alpha=0.5, bins=50, log=False);


In [None]:
n_samples = len(x_test)
z = torch.randn(n_samples, gan.latent_dim, device=gan.device)
# if gan.condition_dim > 0:
rnd_label = torch.randint(2, size=(z.shape[0],), device=gan.device)
c = F.one_hot(rnd_label, num_classes=2)
confidence = get_MSP_correct(DataLoader(MoonsDataset(n_samples=n_samples, noise=noise, random_state=None), 1000), classifier.to(device), device)[0].unsqueeze(1) # labels from real distrib
# confidence = 0.999*torch.ones((n_samples, 1), device=gan.device)
# confidence = torch.rand((n_samples, 1), device=gan.device)

if gan.classifier_conditioning is not None:
    c = torch.cat([c, confidence], dim=1)
z = torch.cat([z, c], dim=1)
with torch.no_grad():
    w = gan.generator.mapping(z)
    x_fake = gan.generator.synthesis(w).detach().cpu().numpy()
    rnd_label = rnd_label.cpu().numpy()
    w = w.detach().cpu().numpy()

plt.figure()
plt.title('fake vs. real data')
plt.scatter(x_test[y_test==0, 0], x_test[y_test==0, 1], alpha=0.1, c='C0', label='real data - class 0')
plt.scatter(x_test[y_test==1, 0], x_test[y_test==1, 1], alpha=0.1, c='C1', label='real data - class 1')
plt.scatter(x_fake[rnd_label==0, 0], x_fake[rnd_label==0, 1], alpha=0.1, c='C2', label='fake data - class 0')
plt.scatter(x_fake[rnd_label==1, 0], x_fake[rnd_label==1, 1], alpha=0.1, c='C3', label='fake data - class 1')
leg = plt.legend()
for lh in leg.legendHandles: 
    lh.set_alpha(1)

# Useful functions for calibration

In [None]:
def ece_from_dataloader(model, dataloader):
    # First: collect all the logits and labels for the validation set
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for input, label in dataloader:
            input = input.cuda()
            logits = model(input)
            logits_list.append(logits)
            labels_list.append(label)
        logits = torch.cat(logits_list).cuda()
        labels = torch.cat(labels_list).cuda()

    ece = _ECELoss()(logits, labels)

    return ece

# Calibration data from real data

In [None]:
batch_size = 1000

all_ece_calib_before_TS = {}
all_ece_test_before_TS = {}
all_ece_calib_after_TS = {}
all_ece_test_after_TS = {}

test_loader = DataLoader(MoonsDataset(n_samples=20000, noise=noise, random_state=2), batch_size=batch_size)

for valid_size in np.geomspace(100, 10000, 5, dtype=int):
    print(f'Calibration set size: {valid_size}')
    all_ece_calib_before_TS[valid_size] = []
    all_ece_test_before_TS[valid_size] = []
    all_ece_calib_after_TS[valid_size] = []
    all_ece_test_after_TS[valid_size] = []

    for trial in range(10):

        valid_loader = DataLoader(MoonsDataset(n_samples=valid_size, noise=noise, random_state=None), batch_size=batch_size)

        ece_calib_before_TS = ece_from_dataloader(classifier, valid_loader)
        ece_test_before_TS = ece_from_dataloader(classifier, test_loader)

        # Performing temperature scaling
        model = ModelWithTemperature(classifier).to(device)
        model.set_temperature(valid_loader, binary_classif=True)

        ece_calib_after_TS = ece_from_dataloader(model, valid_loader)
        ece_test_after_TS = ece_from_dataloader(model, test_loader)

        all_ece_calib_before_TS[valid_size] += [ece_calib_before_TS.item()]
        all_ece_test_before_TS[valid_size] += [ece_test_before_TS.item()]
        all_ece_calib_after_TS[valid_size] += [ece_calib_after_TS.item()]
        all_ece_test_after_TS[valid_size] += [ece_test_after_TS.item()]


# PLOT
fig, ax = plt.subplots()

means = np.array([np.mean(v) for v in all_ece_calib_before_TS.values()])
stds = np.array([np.std(v) for v in all_ece_calib_before_TS.values()])
ax.plot(all_ece_calib_before_TS.keys(), means, label='ECE calib before TS')
ax.fill_between(all_ece_calib_before_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_calib_after_TS.values()])
stds = np.array([np.std(v) for v in all_ece_calib_after_TS.values()])
ax.plot(all_ece_calib_after_TS.keys(), means, label='ECE calib after TS')
ax.fill_between(all_ece_calib_after_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_test_before_TS.values()])
stds = np.array([np.std(v) for v in all_ece_test_before_TS.values()])
ax.plot(all_ece_test_before_TS.keys(), means, label='ECE test before TS')
ax.fill_between(all_ece_test_before_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_test_after_TS.values()])
stds = np.array([np.std(v) for v in all_ece_test_after_TS.values()])
ax.plot(all_ece_test_after_TS.keys(), means, label='ECE test after TS')
ax.fill_between(all_ece_test_after_TS.keys(), means-stds, means+stds, alpha=0.5)
ax.set_xlabel('Calibration set size')

ax.legend()

In [None]:
all_ece_test_after_TS

# Calibration data from synthetic data

In [None]:
def create_synthetic_dataset(n_samples=20000):

    z = torch.randn(n_samples, gan.latent_dim, device=gan.device)
    # if gan.condition_dim > 0:
    rnd_label = torch.randint(2, size=(z.shape[0],), device=gan.device)
    c = F.one_hot(rnd_label, num_classes=2)
    confidence = get_MSP_correct(DataLoader(MoonsDataset(n_samples=n_samples, noise=noise, random_state=None), 1000), classifier.to(device), device)[0].unsqueeze(1) # confidence from real distrib
    # confidence = 0.999*torch.ones((n_samples, 1), device=gan.device) # fixed confidence
    # confidence = 0.5 + 0.5*torch.rand((n_samples, 1), device=gan.device) # uniform confidence
    if gan.classifier_conditioning is not None:
        c = torch.cat([c, confidence], dim=1)
    z = torch.cat([z, c], dim=1)
    with torch.no_grad():
        w = gan.generator.mapping(z)
        x_fake = gan.generator.synthesis(w).detach().cpu().numpy()

    class SyntheticDataset(Dataset):

        def __init__(self, x, y):
            self.x = x
            self.y = y

        def __len__(self):
            return len(self.x)

        def __getitem__(self, idx):
            return self.x[idx], self.y[idx]

    synthetic_data = SyntheticDataset(x_fake, rnd_label.cpu().numpy().astype(float))

    return synthetic_data

In [None]:
batch_size = 1000

all_ece_calib_before_TS = {}
all_ece_test_before_TS = {}
all_ece_calib_after_TS = {}
all_ece_test_after_TS = {}

test_loader = DataLoader(MoonsDataset(n_samples=20000, noise=noise, random_state=2), batch_size=batch_size)

for valid_size in np.linspace(100, 10000, 3, dtype=int):
    print(f'Calibration set size: {valid_size}')
    all_ece_calib_before_TS[valid_size] = []
    all_ece_test_before_TS[valid_size] = []
    all_ece_calib_after_TS[valid_size] = []
    all_ece_test_after_TS[valid_size] = []

    for trial in range(10):
        synthetic_data = create_synthetic_dataset(valid_size)
        valid_loader = DataLoader(synthetic_data, batch_size=batch_size)
        ece_calib_before_TS = ece_from_dataloader(classifier, valid_loader)
        ece_test_before_TS = ece_from_dataloader(classifier, test_loader)

        # Performing temperature scaling
        model = ModelWithTemperature(classifier).to(device)
        model.set_temperature(valid_loader, binary_classif=True)

        ece_calib_after_TS = ece_from_dataloader(model, valid_loader)
        ece_test_after_TS = ece_from_dataloader(model, test_loader)

        all_ece_calib_before_TS[valid_size] += [ece_calib_before_TS.item()]
        all_ece_test_before_TS[valid_size] += [ece_test_before_TS.item()]
        all_ece_calib_after_TS[valid_size] += [ece_calib_after_TS.item()]
        all_ece_test_after_TS[valid_size] += [ece_test_after_TS.item()]


# PLOT
fig, ax = plt.subplots()

means = np.array([np.mean(v) for v in all_ece_calib_before_TS.values()])
stds = np.array([np.std(v) for v in all_ece_calib_before_TS.values()])
ax.plot(all_ece_calib_before_TS.keys(), means, label='ECE calib before TS')
ax.fill_between(all_ece_calib_before_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_calib_after_TS.values()])
stds = np.array([np.std(v) for v in all_ece_calib_after_TS.values()])
ax.plot(all_ece_calib_after_TS.keys(), means, label='ECE calib after TS')
ax.fill_between(all_ece_calib_after_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_test_before_TS.values()])
stds = np.array([np.std(v) for v in all_ece_test_before_TS.values()])
ax.plot(all_ece_test_before_TS.keys(), means, label='ECE test before TS')
ax.fill_between(all_ece_test_before_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_test_after_TS.values()])
stds = np.array([np.std(v) for v in all_ece_test_after_TS.values()])
ax.plot(all_ece_test_after_TS.keys(), means, label='ECE test after TS')
ax.fill_between(all_ece_test_after_TS.keys(), means-stds, means+stds, alpha=0.5)
ax.set_xlabel('Calibration set size')

ax.legend()

In [None]:
all_ece_test_after_TS

In [None]:
for x, y in test_loader:
    print(x.shape)
    print(y.shape)
    break

In [None]:
y