In [None]:
from __future__ import print_function

import copy
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

warnings.filterwarnings("ignore")

In [None]:
def calculate_true_probabilities(x, mu_vec, noise_variance):

    N_samples = x.shape[0]
    K_classes = mu_vec.shape[0]

    log_probabilities = np.zeros((N_samples, K_classes))
    for k in range(K_classes):
        diff = x - mu_vec[k, :]
        squared_distance = np.sum(diff ** 2, axis=1)
        log_probabilities[:, k] = -squared_distance / (2 * noise_variance)

    probabilities = np.exp(log_probabilities - log_probabilities.max(axis=1, keepdims=True))
    probabilities /= probabilities.sum(axis=1, keepdims=True)

    return probabilities

def sigmoid(x):
  return 1/(1+np.exp(-x))

def _y_to_oht(label):
  label_oht = torch.zeros(label.shape[0],K_CLAS).to(label.device)
  label_oht.scatter_(1,label,1)
  label_oht = label_oht.float()
  return label_oht

def cal_entropy(logits, p):
  return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(logits,1), p)

def average_l2_distance(tensor1, tensor2):

    tensor1_softmax = torch.softmax(tensor1, dim=1)

    l2_distances = torch.norm(tensor1_softmax - tensor2, dim=1)

    avg_l2_distance = torch.mean(l2_distances).item()

    return avg_l2_distance

def expected_calibration_error(tensor1, tensor2, n_bins=10):

    tensor1_softmax = torch.softmax(tensor1, dim=1)

    confidence, predictions = torch.max(tensor1_softmax, dim=1)

    _, labels = torch.max(tensor2, dim=1)

    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    ece = 0.0

    for i in range(n_bins):
        in_bin = (confidence > bin_boundaries[i]) & (confidence <= bin_boundaries[i + 1])
        prop_in_bin = torch.mean(in_bin.float()).item()

        if prop_in_bin > 0:
            accuracy_in_bin = torch.mean((predictions[in_bin] == labels[in_bin]).float()).item()
            avg_confidence_in_bin = torch.mean(confidence[in_bin]).item()
            ece += abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece

def normalize_data(data_train, data_val, data_test,return_stats=False,accept_stats=False,mean=0,std=0):

  if accept_stats == False:
    mean = np.mean(data_train, axis=0)
    std = np.std(data_train, axis=0)

  data_normalized_train = (data_train - mean) / std
  data_normalized_val = (data_val - mean) / std
  data_normalized_test = (data_test - mean) / std

  if return_stats == False:
    return data_normalized_train, data_normalized_val, data_normalized_test

  return data_normalized_train, data_normalized_val, data_normalized_test, mean, std

def plot_validation_loss(val_losses):
    plt.figure(figsize=(8, 5))
    plt.plot(val_losses, label='Validation Loss', color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Validation Loss per Epoch')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.ylim(0,0.5)
    plt.legend()
    plt.show()

def train_teacher(model, optimizer, x_train, x_val, y_train, y_val, p_train, p_val, p_noisy_train, p_noisy_val, label_type = 'standard', lambd = 1.0, distill = False):

    model.train()
    seed_torch(42)
    length_train = np.shape(x_train)[0]
    length_val = np.shape(x_val)[0]
    vacc_max = 0
    vecep_min = 1000000
    vdistp_min = 1000000
    results = {'loss': [], 'tacc': [], 'vacc': [], 'tdistp': [], 'vdistp': [], 'tecep': [], 'vecep': [], 'IDX_MAX': [], 'val_loss' : [], 'val_acc' : [] }

    x_val_t = torch.Tensor(x_val)
    y_val_t = torch.Tensor(y_val).long()
    p_val_t = torch.Tensor(p_val)

    for j in range(epochs_teacher):
        perm = np.random.permutation(length_train)
        x_train = x_train[perm]
        y_train = y_train[perm]
        p_train = p_train[perm]
        p_noisy_train = p_noisy_train[perm]

        flag = 0
        for i in range(0, length_train, batch_size_teacher):
            if (i + batch_size_teacher) > length_train:
                break
            x_in = x_train[i:i+batch_size_teacher, :]
            y_in = y_train[i:i+batch_size_teacher, :]
            p_in = p_train[i:i+batch_size_teacher, :]
            p_in_noisy = p_noisy_train[i:i+batch_size_teacher, :]
            x_in = torch.Tensor(x_in)
            p_in = torch.Tensor(p_in)
            p_in_noisy = torch.Tensor(p_in_noisy)
            y_in = torch.Tensor(y_in).long()

            x_out = model(x_in)

            if label_type == 'gt':
              p_tgt = p_in
            elif label_type == 'noisy':
              p_tgt = p_in_noisy
            else:
              p_tgt = _y_to_oht(y_in)

            if distill == False:
              loss = cal_entropy(x_out, p_tgt)
            else:
              p_tgt1 = _y_to_oht(y_in)
              p_tgt2 = p_in_noisy
              loss = (1-lambd)*cal_entropy(x_out, p_tgt1) + lambd*cal_entropy(x_out, p_tgt2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            results['loss'].append(loss.item())

            tf_table = []
            curr_len = np.shape(x_val_t)[0]
            with torch.no_grad():
              x_val_out = model(x_val_t)
              p_tgt_val = _y_to_oht(y_val_t)
              val_loss = cal_entropy(x_val_out, p_tgt_val)
              x_val_out = x_val_out.detach()

            pred = x_val_out.data.max(1, keepdim=True)[1]
            tf_table.append(pred.eq(y_val_t.data.view_as(pred)))
            tf_table = torch.stack(tf_table).reshape(-1, 1)
            correct = tf_table.sum()
            valy_accuracy = 100*(correct/curr_len)

            results['val_loss'].append(val_loss.item())
            results['val_acc'].append(valy_accuracy.item())

            if flag == 0:
              bayes_loss = nn.KLDivLoss(reduction='batchmean')(p_val_t.log(), p_tgt_val)
              print(bayes_loss)
              flag = 1

        tacc, tdistp, tecep = validate_teacher(model, x_train, y_train, p_train)
        vacc, vdistp, vecep = validate_teacher(model, x_val, y_val, p_val)
        results['tacc'].append(tacc)
        results['vacc'].append(vacc)
        results['tdistp'].append(tdistp)
        results['vdistp'].append(vdistp)
        results['tecep'].append(tecep)
        results['vecep'].append(vecep)

        if vacc >= vacc_max:
            vacc_max = vacc
            ES_model = copy.deepcopy(model)
            results['IDX_MAX'] = j
        if vdistp <= vdistp_min:
            vdistp_min = vdistp
            ES_model_dist = copy.deepcopy(model)
        if vecep <= vecep_min:
            vecep_min = vecep
            ES_model_ece = copy.deepcopy(model)
        if j % 10 == 0:
            print('Epoch: {:3d}/{:3d}\tLoss: {:.3f}\tTACC: {:.3f},\tVACC:{:.3f}, \tVDIST:{:.3f}, \tVECE:{:.3f}, \tBestVDIST:{:.3f}, \tBestVACC:{:.3f}, \tBestVECE:{:.3f}'.format(j, epochs_teacher, results['loss'][-1], tacc, vacc,vdistp,vecep,vdistp_min,vacc_max, vecep_min))

    return ES_model, ES_model_dist, ES_model_ece, results

def validate_teacher(model, x, y, p):

    model.eval()
    correct = 0
    dist_p = 0
    tf_table = []
    curr_len = np.shape(x)[0]

    x = torch.Tensor(x)
    y = torch.Tensor(y).long()
    p = torch.Tensor(p)
    with torch.no_grad():
      x_out = model(x)
      x_out = x_out.detach()
      pred = x_out.data.max(1, keepdim=True)[1]
      tf_table.append(pred.eq(y.data.view_as(pred)))

    model.train()
    tf_table = torch.stack(tf_table).reshape(-1, 1)
    correct = tf_table.sum()
    accuracy = 100*(correct/curr_len)

    dist_p = average_l2_distance(x_out, p)
    ece_p = expected_calibration_error(x_out, p)

    return accuracy, dist_p, ece_p


In [None]:
class MLP(nn.Module):
  def __init__(self, in_dim, hid_size = 40):
    super(MLP, self).__init__()
    self.in_dim = in_dim
    self.hid_size = hid_size
    self.fc1 = nn.Linear(self.in_dim, self.hid_size)
    self.fc2 = nn.Linear(self.hid_size, self.hid_size)
    self.fc3 = nn.Linear(self.hid_size, K_CLAS)
    self.act = nn.ReLU(True)

  def forward(self, x):
    h1 = self.act(self.fc1(x))
    h2 = self.act(self.fc2(h1))
    out = self.fc3(h2)
    return out

In [None]:
def dirichlet_noise_dataset(P, alpha, rng=None, min_prob=1e-12):

    rng = np.random.default_rng() if rng is None else rng
    P = np.asarray(P, dtype=np.float64)

    P_safe = np.clip(P, min_prob, None)
    P_safe /= P_safe.sum(axis=1, keepdims=True)

    if np.isscalar(alpha):
        alphas = (alpha * P_safe)
    else:
        alphas = (np.asarray(alpha)[:, None] * P_safe)

    G = rng.gamma(shape=alphas, scale=1.0)
    Gsum = G.sum(axis=1, keepdims=True)

    mask_bad = (Gsum.squeeze(-1) == 0)
    if np.any(mask_bad):
        G[mask_bad] = P_safe[mask_bad]
        Gsum[mask_bad] = 1.0
    return G / Gsum


def generate_data_noisy(N_Data, data_split, noise_variance=4,
                        label_noise_alpha=None, rng=None):

    rng = np.random.default_rng() if rng is None else rng

    y_true = np.random.randint(0, K_CLAS, [N_Data, 1]).astype(np.float32)
    mu_true = np.zeros((N_Data, X_DIM))
    for i in range(N_Data):
        mu_true[i, :] = MU_VEC[y_true[i].astype(int), :]
    x_true = mu_true + np.random.randn(N_Data, X_DIM) * np.sqrt(noise_variance)

    p_true = calculate_true_probabilities(x_true, MU_VEC, noise_variance)

    indices = np.arange(N_Data)
    rng.shuffle(indices)

    x_true = x_true[indices]
    y_true = y_true[indices]
    p_true = p_true[indices]

    N_train = int(N_Data * data_split[0])
    N_val = int(N_Data * data_split[1])
    N_test = N_Data - N_train - N_val

    x_train = x_true[:N_train, :]
    x_val   = x_true[N_train:N_train + N_val, :]
    x_test  = x_true[N_train + N_val:, :]

    y_train = y_true[:N_train]
    y_val   = y_true[N_train:N_train + N_val]
    y_test  = y_true[N_train + N_val:]

    p_train = p_true[:N_train, :]
    p_val   = p_true[N_train:N_train + N_val, :]
    p_test  = p_true[N_train + N_val:, :]

    if label_noise_alpha is not None and label_noise_alpha > 0:
        alpha_main = float(label_noise_alpha)
        alpha_weak = 0.1 * alpha_main

        p_noisy_train = dirichlet_noise_dataset(p_train, alpha_main, rng=rng)
        p_noisy_val   = dirichlet_noise_dataset(p_val,   alpha_main, rng=rng)
        p_noisy_test  = dirichlet_noise_dataset(p_test,  alpha_main, rng=rng)

        p_noisy_train_alpha01 = dirichlet_noise_dataset(p_train, alpha_weak, rng=rng)
        p_noisy_val_alpha01   = dirichlet_noise_dataset(p_val,   alpha_weak, rng=rng)
        p_noisy_test_alpha01  = dirichlet_noise_dataset(p_test,  alpha_weak, rng=rng)
    else:
        p_noisy_train = p_train.copy()
        p_noisy_val   = p_val.copy()
        p_noisy_test  = p_test.copy()

        p_noisy_train_alpha01 = p_train.copy()
        p_noisy_val_alpha01   = p_val.copy()
        p_noisy_test_alpha01  = p_test.copy()

    return (x_train, x_val, x_test,
            y_train, y_val, y_test,
            p_train, p_val, p_test,
            p_noisy_train, p_noisy_val, p_noisy_test,
            p_noisy_train_alpha01, p_noisy_val_alpha01, p_noisy_test_alpha01)

In [None]:
def seed_torch(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_torch(2)

# Dataset parameters
K_CLAS = 5                          # Number of classes in the toy dataset
delta_mu = 1                        # Separation between class means
NOISE = 2.5                           # Variance of noise (data uncertainty control)
label_noise_var = 5                   # Dirichlet noise parameter
X_DIM = 30                          # Dimension of input signal x
const = 5
MU_VEC = np.random.choice([const - delta_mu, const + 0, const + delta_mu], size=(K_CLAS, X_DIM))  # Main class centers

# Teacher parameters
teacher_split = [0.45, 0.45, 0.1]       # Split ratio between of train/valid/test dataset
N_teacher = int(5e4)
epochs_teacher = 2
batch_size_teacher = 1
LR_teacher = 5e-4
Hidden_size_teacher = 128

initial_model_teacher = MLP(in_dim=X_DIM, hid_size = Hidden_size_teacher)

x_train_te, x_val_te, x_test_te, y_train_te, y_val_te, y_test_te, p_train_te, p_val_te, p_test_te, p_noisy_train_te, p_noisy_val_te, p_noisy_test_te, p_noisy1_train_te, p_noisy1_val_te, p_noisy1_test_te = generate_data_noisy(N_teacher, teacher_split, noise_variance=NOISE,label_noise_alpha=label_noise_var)
x_train_te_n, x_val_te_n, x_test_te_n, teacher_mean, teacher_std = normalize_data(x_train_te, x_val_te, x_test_te,return_stats=True)

In [None]:
OHT_model = copy.deepcopy(initial_model_teacher)
OHT_optimizer = optim.SGD(OHT_model.parameters(), lr=LR_teacher, momentum=0.9)
best_acc_OHT_model, best_dist_OHT_model, best_ece_OHT_model, OHT_results = train_teacher(OHT_model, OHT_optimizer, x_train_te_n, x_val_te_n, y_train_te, y_val_te, p_train_te, p_val_te, p_noisy_train_te, p_noisy_val_te)

In [None]:
perf_model = copy.deepcopy(initial_model_teacher)
perf_optimizer = optim.SGD(perf_model.parameters(), lr=LR_teacher, momentum=0.9)
best_acc_perf_model, best_dist_perf_model, best_ece_perf_model, perf_results = train_teacher(perf_model, perf_optimizer, x_train_te_n, x_val_te_n, y_train_te, y_val_te, p_train_te, p_val_te, p_noisy_train_te, p_noisy_val_te, label_type = 'gt')

In [None]:
noisy_model = copy.deepcopy(initial_model_teacher)
noisy_optimizer = optim.SGD(noisy_model.parameters(), lr=LR_teacher, momentum=0.9)
best_acc_noisy_model, best_dist_noisy_model, best_ece_noisy_model, noisy_results = train_teacher(noisy_model, noisy_optimizer, x_train_te_n, x_val_te_n, y_train_te, y_val_te, p_train_te, p_val_te, p_noisy_train_te, p_noisy_val_te, label_type = 'noisy')

In [None]:
noisy2_model = copy.deepcopy(initial_model_teacher)
noisy2_optimizer = optim.SGD(noisy2_model.parameters(), lr=LR_teacher, momentum=0.9)
best_acc_noisy2_model, best_dist_noisy2_model, best_ece_noisy2_model, noisy2_results = train_teacher(noisy2_model, noisy2_optimizer, x_train_te_n, x_val_te_n, y_train_te, y_val_te, p_train_te, p_val_te, p_noisy1_train_te, p_noisy1_val_te, label_type = 'noisy')

In [None]:
def save_family(prefix):
    acc_name   = f"best_acc_{prefix}_model"
    dist_name  = f"best_dist_{prefix}_model"
    ece_name   = f"best_ece_{prefix}_model"
    res_name   = f"{prefix}_results"

    acc_model   = globals()[acc_name]
    dist_model  = globals()[dist_name]
    ece_model   = globals()[ece_name]
    results_obj = globals()[res_name]

    torch.save(acc_model.state_dict(),  f"{save_dir}/{acc_name}.pth")
    torch.save(dist_model.state_dict(), f"{save_dir}/{dist_name}.pth")
    torch.save(ece_model.state_dict(),  f"{save_dir}/{ece_name}.pth")
    with open(f"{save_dir}/{res_name}.pkl", "wb") as f:
        pickle.dump(results_obj, f)
    print(f"Saved {prefix}")

for fam in ["perf", "noisy", "noisy2", "OHT"]:
    save_family(fam)