In [None]:
import sys
sys.path.append('./stylegan') 
sys.path.append('./stylegan/stylegan2')
 
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import time
import json
import pickle
from pathlib import Path
import sklearn
from sklearn.datasets import make_moons
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
plt.style.use('seaborn')
import numpy as np
import torch
import torchvision
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
# from pytorch_lightning.loggers import TensorBoardLogger
import torch.nn.functional as F
import pytorch_lightning as pl

from models import CNN_MNIST
from temperature_scaling import ModelWithTemperature, _ECELoss
from stylegan2.training.dataset import ImageFolderDataset

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(0)
np.random.seed(0)
rng = np.random.default_rng(0)

for p in [
    Path('/d/alecoz/projects'), # DeepLab
    Path(os.path.expandvars('$WORK')), # Jean Zay
    Path('w:/')]: # local
    if os.path.exists(p):
        path_main = p
# path_results = path_main / 'uncertainty-conditioned-gan/results'
path_data = path_main / 'DATA'
path_models = Path.cwd().parent / 'models' / 'MNIST'

# Load data, classifer, and GAN

In [None]:
# DATA CORRUPTED
path_model = path_models / 'stylegan2-training-runs' / '00004-mnist_stylegan2_blur_noise_maxSeverity5_proba100-cond-cifar-classifCond'
path_classifier = path_models / 'classifier' / 'CNN_mnist_stylegan2_blur_noise_maxSeverity5_proba100_20230525_1128.pth'
dataset_train = 'mnist_stylegan2_blur_noise_maxSeverity5_proba50'
dataset_test = 'mnistTest_stylegan2_blur_noise_maxSeverity5_proba50'

In [None]:
def postprocess_images(images):
    assert images.dim() == 4, "Expected 4D (B x C x H x W) image tensor, got {}D".format(images.dim())
    images = ((images + 1) / 2).clamp(0, 1) # scale
    images = images[:, :, 2:30, 2:30] # remove padding
    return images

def plot_images_grid(images, title=''):
    images = images * 255
    images = images.to(torch.uint8)
    plt.figure()
    plt.imshow(vutils.make_grid(images.cpu(), pad_value=255).permute(1,2,0), vmin=0, vmax=255)
    plt.axis('off')
    plt.grid(False)
    plt.title(title)
    
def generate_random_images(n_images=5):
    z = torch.randn((n_images, G.z_dim), device=device)
    labels = torch.randint(0, n_classes, (n_images,), device=device)
    c = F.one_hot(labels, n_classes)
    ws = G.mapping(z, c, truncation_psi=1)
    img = G.synthesis(ws, noise_mode='const', force_fp32=True)
    img = postprocess_images(img)
    return img

def get_classifier_MSP(logits):
    max_softmax_proba = torch.max(torch.softmax(logits, axis=1), axis=1).values
    return max_softmax_proba

def get_classifier_TCP(logits, labels):
    y_as_idx = F.one_hot(labels.long(), num_classes=10).bool()
    probas = torch.softmax(logits, dim=1)
    true_class_proba = probas[y_as_idx]
    return true_class_proba

def mahalanobis(U, v):
    ''' Compute the Mahalanobis distance between each row of U and v'''
    cov_inv = torch.linalg.inv(torch.cov(U.T))
    m = torch.zeros(U.shape[0])
    for i in range(U.shape[0]):
        delta = U[i, :] - v
        m[i] = torch.sqrt(torch.matmul(torch.matmul(delta, cov_inv), delta.T))
    return m

In [None]:
# LOAD GENERATOR
if not str(path_model).endswith('pkl'):
    # find best model in folder
    with open(path_model / 'metric-fid50k_full.jsonl', 'r') as json_file:
        json_list = list(json_file)
    best_fid = 1e6
    for json_str in json_list:
        json_line = json.loads(json_str)
        if json_line['results']['fid50k_full'] < best_fid:
            best_fid = json_line['results']['fid50k_full']
            best_model = json_line['snapshot_pkl']
    print('Best FID: {:.2f} ; best model : {}'.format(best_fid, best_model))
    path_model = path_model / best_model
    with open(path_model, 'rb') as f:
        G = pickle.load(f)['G_ema'].eval().to(device)  # torch.nn.Module
else:
    with dnnlib.util.open_url(path_model) as f:
        G = legacy.load_network_pkl(f)['G_ema'].eval().requires_grad_(False).to(device)
        

# LOAD CLASSIFIER
classifier = CNN_MNIST()
classifier.load_state_dict(torch.load(path_classifier, map_location=device))
classifier = classifier.eval().requires_grad_(False).to(device)


# LOAD DATASET
path_dataset = path_data / 'MNIST' / f'{dataset_train}.zip'
train_data = ImageFolderDataset(path_dataset, use_labels=True)
train_dataloader = DataLoader(train_data, batch_size=128)

path_dataset = path_data / 'MNIST' / f'{dataset_test}.zip'
test_data = ImageFolderDataset(path_dataset, use_labels=True)
test_dataloader = DataLoader(test_data, batch_size=128)

# Classifier

In [None]:
correct = 0
for x, y in train_dataloader:
    x = (x / 255)[:, :, 2:30, 2:30]
    y = y.argmax(1)
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        logits = classifier(x)
        pred = torch.max(logits, axis=1).indices
        correct += (pred == y).sum().item()
accuracy = correct / len(train_data)
print('Accuracy on training set: {:.2f}%'.format(accuracy * 100))

correct = 0
for x, y in test_dataloader:
    x = (x / 255)[:, :, 2:30, 2:30]
    y = y.argmax(1)
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        logits = classifier(x)
        pred = torch.max(logits, axis=1).indices
        correct += (pred == y).sum().item()
accuracy = correct / len(test_data)

print('Accuracy on test set: {:.2f}%'.format(accuracy * 100))

In [None]:
for x, y in train_dataloader:
    x = (x / 255)[:, :, 2:30, 2:30]
    y = y.argmax(1)
    x, y = x.to(device), y.to(device)
    break

fig, axs = plt.subplots(4, 5, figsize=(6, 6))
for i in range(20): # for each image
    ax = axs.flatten()[i]
    img = x[i]
    
    logits = classifier(img.unsqueeze(0))
    probas = torch.softmax(logits, axis=1)
    msp, class_pred = torch.max(probas, axis=1)
    
    ax.imshow(img.cpu().numpy().squeeze(), vmin=0, vmax=1, cmap='gray')
    ax.axis('off')
    ax.grid(False)
    title = 'p({})={:.2f}'.format(class_pred.item(), msp.item())
    ax.set_title(title)

In [None]:
def get_MSP_TCP(dataloader, classifier, device):

    classifier.eval()
    msp = torch.zeros((len(dataloader.dataset)))
    tcp = torch.zeros((len(dataloader.dataset)))
    idx = 0
    for X, y in dataloader:
        batch_size = X.shape[0]

        X = (X / 255)[:, :, 2:30, 2:30]
        y = y.argmax(1)
        X, y = X.to(device), y.to(device)

        with torch.no_grad():
            logits = classifier(X)
        msp[idx:idx+batch_size] = get_classifier_MSP(logits)
        tcp[idx:idx+batch_size] = get_classifier_TCP(logits, y)
        idx += batch_size

    return msp, tcp



msp_train, tcp_train = get_MSP_TCP(train_dataloader, classifier, device)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.set_xlabel('MSP value')
ax.hist(msp_train, alpha=0.5, bins=50, log=True)

# Useful functions for calibration

In [None]:
class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).

    The input to this loss is the logits of a model, NOT the softmax scores.

    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:

    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |

    We then return a weighted average of the gaps, based on the number
    of samples in each bin

    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        if logits.shape[1] == 1: # binary classif
            probas_class1 = torch.sigmoid(logits)
            probas_class0 = 1 - probas_class1
            softmaxes = torch.cat((probas_class0, probas_class1), dim=1)
        else:
            softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        if labels.shape[1] > 1: # one-hot embedding
            labels = labels.argmax(1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

In [None]:
def ece_from_dataloader(model, dataloader):
    # First: collect all the logits and labels for the validation set
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for input, label in dataloader:
            input = input.cuda()
            input = (input / 255)[:, :, 2:30, 2:30]
            logits = model(input)
            logits_list.append(logits)
            labels_list.append(label)
        logits = torch.cat(logits_list).cuda()
        labels = torch.cat(labels_list).cuda()
    ece = _ECELoss()(logits, labels)

    return ece

In [None]:
path_dataset = path_data / 'MNIST' / f'{dataset_test}.zip'
dataset = ImageFolderDataset(path_dataset, use_labels=True)
dataset = test_data
valid_size = 1000
batch_size = 128

# Calibration data from real data

In [None]:
all_ece_calib_before_TS = {}
all_ece_test_before_TS = {}
all_ece_calib_after_TS = {}
all_ece_test_after_TS = {}

for valid_size in np.geomspace(100, 5000, 5, dtype=int):
    
    valid_indices, test_indices = sklearn.model_selection.train_test_split(np.arange(len(dataset)), train_size=valid_size)
    valid_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(valid_indices))
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=SubsetRandomSampler(test_indices))

    print(f'Calibration set size: {valid_size}')
    all_ece_calib_before_TS[valid_size] = []
    all_ece_test_before_TS[valid_size] = []
    all_ece_calib_after_TS[valid_size] = []
    all_ece_test_after_TS[valid_size] = []

    ece_calib_before_TS = ece_from_dataloader(classifier, valid_loader)
    ece_test_before_TS = ece_from_dataloader(classifier, test_loader)
    
    # Performing temperature scaling
    model = ModelWithTemperature(classifier).to(device)
    model.set_temperature(valid_loader)

    ece_calib_after_TS = ece_from_dataloader(model, valid_loader)
    ece_test_after_TS = ece_from_dataloader(model, test_loader)

    all_ece_calib_before_TS[valid_size] += [ece_calib_before_TS.item()]
    all_ece_test_before_TS[valid_size] += [ece_test_before_TS.item()]
    all_ece_calib_after_TS[valid_size] += [ece_calib_after_TS.item()]
    all_ece_test_after_TS[valid_size] += [ece_test_after_TS.item()]


# PLOT
fig, ax = plt.subplots()

means = np.array([np.mean(v) for v in all_ece_calib_before_TS.values()])
stds = np.array([np.std(v) for v in all_ece_calib_before_TS.values()])
ax.plot(all_ece_calib_before_TS.keys(), means, label='ECE calib before TS')
ax.fill_between(all_ece_calib_before_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_calib_after_TS.values()])
stds = np.array([np.std(v) for v in all_ece_calib_after_TS.values()])
ax.plot(all_ece_calib_after_TS.keys(), means, label='ECE calib after TS')
ax.fill_between(all_ece_calib_after_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_test_before_TS.values()])
stds = np.array([np.std(v) for v in all_ece_test_before_TS.values()])
ax.plot(all_ece_test_before_TS.keys(), means, label='ECE test before TS')
ax.fill_between(all_ece_test_before_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_test_after_TS.values()])
stds = np.array([np.std(v) for v in all_ece_test_after_TS.values()])
ax.plot(all_ece_test_after_TS.keys(), means, label='ECE test after TS')
ax.fill_between(all_ece_test_after_TS.keys(), means-stds, means+stds, alpha=0.5)
ax.set_xlabel('Calibration set size')

ax.legend()

In [None]:
all_ece_test_after_TS

# Calibration data from synthetic data

In [None]:
def create_synthetic_dataset(n_samples=20000):

    z = torch.randn(n_samples, gan.latent_dim, device=gan.device)
    # if gan.condition_dim > 0:
    rnd_label = torch.randint(2, size=(z.shape[0],), device=gan.device)
    c = F.one_hot(rnd_label, num_classes=2)
    confidence = get_MSP_correct(DataLoader(MoonsDataset(n_samples=n_samples, noise=noise, random_state=None), 1000), classifier.to(device), device)[0].unsqueeze(1) # confidence from real distrib
    # confidence = 0.999*torch.ones((n_samples, 1), device=gan.device) # fixed confidence
    # confidence = 0.5 + 0.5*torch.rand((n_samples, 1), device=gan.device) # uniform confidence
    if gan.classifier_conditioning is not None:
        c = torch.cat([c, confidence], dim=1)
    z = torch.cat([z, c], dim=1)
    with torch.no_grad():
        w = gan.generator.mapping(z)
        x_fake = gan.generator.synthesis(w).detach().cpu().numpy()

    class SyntheticDataset(Dataset):

        def __init__(self, x, y):
            self.x = x
            self.y = y

        def __len__(self):
            return len(self.x)

        def __getitem__(self, idx):
            return self.x[idx], self.y[idx]

    synthetic_data = SyntheticDataset(x_fake, rnd_label.cpu().numpy().astype(float))

    return synthetic_data

In [None]:
batch_size = 1000

all_ece_calib_before_TS = {}
all_ece_test_before_TS = {}
all_ece_calib_after_TS = {}
all_ece_test_after_TS = {}

test_loader = DataLoader(MoonsDataset(n_samples=20000, noise=noise, random_state=2), batch_size=batch_size)

for valid_size in np.linspace(100, 10000, 3, dtype=int):
    print(f'Calibration set size: {valid_size}')
    all_ece_calib_before_TS[valid_size] = []
    all_ece_test_before_TS[valid_size] = []
    all_ece_calib_after_TS[valid_size] = []
    all_ece_test_after_TS[valid_size] = []

    for trial in range(10):
        synthetic_data = create_synthetic_dataset(valid_size)
        valid_loader = DataLoader(synthetic_data, batch_size=batch_size)
        ece_calib_before_TS = ece_from_dataloader(classifier, valid_loader)
        ece_test_before_TS = ece_from_dataloader(classifier, test_loader)

        # Performing temperature scaling
        model = ModelWithTemperature(classifier).to(device)
        model.set_temperature(valid_loader)

        ece_calib_after_TS = ece_from_dataloader(model, valid_loader)
        ece_test_after_TS = ece_from_dataloader(model, test_loader)

        all_ece_calib_before_TS[valid_size] += [ece_calib_before_TS.item()]
        all_ece_test_before_TS[valid_size] += [ece_test_before_TS.item()]
        all_ece_calib_after_TS[valid_size] += [ece_calib_after_TS.item()]
        all_ece_test_after_TS[valid_size] += [ece_test_after_TS.item()]


# PLOT
fig, ax = plt.subplots()

means = np.array([np.mean(v) for v in all_ece_calib_before_TS.values()])
stds = np.array([np.std(v) for v in all_ece_calib_before_TS.values()])
ax.plot(all_ece_calib_before_TS.keys(), means, label='ECE calib before TS')
ax.fill_between(all_ece_calib_before_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_calib_after_TS.values()])
stds = np.array([np.std(v) for v in all_ece_calib_after_TS.values()])
ax.plot(all_ece_calib_after_TS.keys(), means, label='ECE calib after TS')
ax.fill_between(all_ece_calib_after_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_test_before_TS.values()])
stds = np.array([np.std(v) for v in all_ece_test_before_TS.values()])
ax.plot(all_ece_test_before_TS.keys(), means, label='ECE test before TS')
ax.fill_between(all_ece_test_before_TS.keys(), means-stds, means+stds, alpha=0.5)

means = np.array([np.mean(v) for v in all_ece_test_after_TS.values()])
stds = np.array([np.std(v) for v in all_ece_test_after_TS.values()])
ax.plot(all_ece_test_after_TS.keys(), means, label='ECE test after TS')
ax.fill_between(all_ece_test_after_TS.keys(), means-stds, means+stds, alpha=0.5)
ax.set_xlabel('Calibration set size')

ax.legend()

In [None]:
all_ece_test_after_TS