# Feature Extractor

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
from torchvision import datasets, transforms
from IPython.display import display, clear_output
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


import torchvision.models as models
from torchvision.models import (
        EfficientNet_B0_Weights,
        EfficientNet_B1_Weights,
        EfficientNet_B2_Weights,
        EfficientNet_B3_Weights
        )

class EfficientNetFeature(nn.Module):
    def __init__(self, version='b0', pretrained=True):
        super(EfficientNetFeature, self).__init__()

        # EfficientNet 사전학습 모델 로드
        if version == 'b0':
            weights = EfficientNet_B0_Weights.DEFAULT if pretrained else None
            self.backbone = models.efficientnet_b0(weights=weights)
        elif version == 'b1':
            weights = EfficientNet_B1_Weights.DEFAULT if pretrained else None
            self.backbone = models.efficientnet_b1(weights=weights)
        elif version == 'b2':
            weights = EfficientNet_B2_Weights.DEFAULT if pretrained else None
            self.backbone = models.efficientnet_b2(weights=weights)
        elif version == 'b3':
            weights = EfficientNet_B3_Weights.DEFAULT if pretrained else None
            self.backbone = models.efficientnet_b3(weights=weights)
        else:
            raise ValueError(f"Unsupported EfficientNet version: {version}")

        # 특징 추출부만 사용
        self.features = self.backbone.features
        self.avgpool = self.backbone.avgpool

        # 특징 추출 차원
        self.feature_dim = self.backbone.classifier[1].in_features

    def forward(self, x):
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x


class FeatureMixer(nn.Module):
    def __init__(self, input_features:int, output_features:int):
        super().__init__()

        self.head = nn.MultiheadAttention(embed_dim=input_features, num_heads=8)
        self.mixer = nn.Sequential(
            nn.LayerNorm(input_features),
            nn.Linear(input_features, input_features//2),
            nn.GELU(),
            nn.LayerNorm(input_features//2),
            nn.Dropout(0.1),
            nn.Linear(input_features//2, output_features),
            nn.GELU()
        )

        self.shortcut = nn.Linear(input_features, output_features)

    def forward(self, x):
        return self.mixer(self.head(x, x, x)[0]) + self.shortcut(x)


cuda


# Model

In [None]:
class ZClassifier(nn.Module):
    def __init__(self, beta:float=1.0, num_classes=10, latent_dim=30):
        super().__init__()

        self.num_classes = num_classes
        self.latent_dim = latent_dim

        self.features = EfficientNetFeature("b0")

        self.head = FeatureMixer(self.features.feature_dim, 256)

        self.mu_head = nn.Sequential(
            FeatureMixer(256, 256),
            nn.Linear(256, num_classes))

        self.sigma_head = nn.Sequential(
            FeatureMixer(256, 256),
            nn.Linear(256, num_classes))

        self.epsilon = Normal(0, beta)

    @autocast('cuda')
    def forward(self, x):
        x = self.features(x)
        x = self.head(x)

        self.mu = self.mu_head(x).reshape(x.size(0), self.num_classes, 1)
        self.logvar = self.sigma_head(x).reshape(x.size(0), self.num_classes, 1) + 1e-6
        sigma = self.logvar.exp()
        eps = self.epsilon.rsample(torch.Size([x.size(0), self.num_classes, self.latent_dim])).to(self.mu.device)
        self.X = self.mu + sigma * eps # Noisy code

        return self.X.mean(-1) # Noisy head


# Loss fn

In [None]:
class ProbitLoss(nn.Module):
    def __init__(self, beta:float=1):
        super().__init__()
        self.cdf = Normal(0, 1).cdf
        self.beta = beta

    def forward(self, input, label, mu, logvar):

        probs = self.cdf(input)

        onehots = F.one_hot(label, num_classes=input.size(-1)).float()
        ce = F.cross_entropy(torch.logit(probs, eps=1e-6), onehots)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)

        return ce + self.beta * kl_div.mean()

class LogitLoss(nn.Module):
    def __init__(self, beta:float=1):
        super().__init__()
        self.beta = beta

    def forward(self, input, label, mu, logvar):

        onehots = F.one_hot(label, num_classes=input.size(-1)).float()
        ce = F.cross_entropy(input, onehots)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)

        return ce + self.beta * kl_div.mean()

class zScoreLoss(nn.Module):
    def __init__(self, beta:float=1):
        super().__init__()

        self.normal = Normal(0, 1)

        self.beta = beta

    def forward(self, input, label, mu, logvar):

        class_mu = torch.arange(0, input.size(1), device=input.device).unsqueeze(0)
        onehots = F.one_hot(label, num_classes=input.size(1)).float()

        z_score = ((input - class_mu) / logvar.squeeze().exp().sqrt()) * onehots
        log_likelihood = self.normal.log_prob(z_score)

        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)

        return -log_likelihood.mean() + self.beta * kl_div.mean()

# Utility fns

In [None]:
scaler = GradScaler()

def train_epoch(model, loader, loss_fn, optimizer):
    model.train()
    total_loss = 0
    for x, y in tqdm(loader, desc='Training', leave=False):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()

        with autocast("cuda"):
            X = model(x)  # [B, C, D]
            loss = loss_fn(X, y, model.mu, model.logvar)

        scaler.scale(loss).backward()

        # gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, criteria):
    model.eval()
    correct, total = 0, 0
    eps = 1e-6

    with torch.no_grad(), autocast('cuda'):
        for x, y in tqdm(loader, desc='Validating', leave=False):
            x, y = x.to(device), y.to(device)

            Xbar = model(x)  # [B, C, D]

            if criteria == "logit" or "probit":
                preds = Xbar.argmax(dim=-1)
                correct += (preds == y).sum().item()
                total += y.size(0)

            elif criteria == "z-score":
                class_mu = torch.arange(0, Xbar.size(1), device=Xbar.device).unsqueeze(0)

                z_score = (Xbar - class_mu) / model.logvar.squeeze().exp().sqrt()

                preds = Normal(0, 1).log_prob(z_score).argmax(dim=-1)
                correct += (preds == y).sum().item()
                total += y.size(0)

    return correct / total

# Data Load

In [None]:
# --- Data Loader ---
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
])

train_set = datasets.CIFAR100(root="./data", train=True, download=True, transform=train_transform)

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
])

val_set = datasets.CIFAR100(root="./data", train=False, download=True, transform=val_transform)


train_loader = DataLoader(train_set, batch_size=500, shuffle=True)
val_loader = DataLoader(val_set, batch_size=500)

# Train

In [None]:
# --- Run ---
model = torch.compile(ZClassifier(beta=1e-6, num_classes=100, latent_dim=30).to(device))
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
loss_fn = zScoreLoss(beta=0.5)

In [None]:
# Tracking variables
acc = 0
# Create progress bar
bar = tqdm(range(100), desc="Training Progress")
for epoch in bar:
    loss = train_epoch(model, train_loader, loss_fn, optimizer)
    scheduler.step()

    # Validate and log every 10 epochs
    if epoch % 10 == 0:
        acc = validate(model, val_loader, criteria="z-score")
    # Live description in tqdm bar
    bar.set_description(f"Epoch {epoch + 1} | Loss: {loss:.4f} | Accuracy: {acc:.4f}")

Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Validating:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

Training:   0%|          | 0/100 [00:00<?, ?it/s]

# Check the Normality

In [None]:
# Re-import required libraries after kernel reset
import torch
import pandas as pd
import seaborn as sns
from scipy.stats import normaltest
import matplotlib.pyplot as plt

def plot_zscore_histograms(X, labels, max_classes=10, plot_all=False):
    """
    Plots histograms of z-scores per true class to evaluate Gaussian standardization.

    Args:
        X: Tensor [B, C, D] - latent Gaussian means per class
        labels: Tensor [B] - true class labels
        std_error_value: float (optional) - override standard error, default = 1/sqrt(D)
        max_classes: int - number of classes to plot
    """
    B, C, D = X.shape

    # Class index
    class_means = torch.arange(C).float().view(1, -1).expand(B, -1)
    # Data
    Xs = X.squeeze().cpu().numpy()
    display(Xs.shape)

    # Plot histogram
    plt.figure(figsize=(10, 6))
    df = pd.DataFrame(Xs.T, columns=[f"Class {i}" for i in range(C)])
    # Use melt to transform the DataFrame to a long format suitable for seaborn's histplot with hue
    df_melted = df.melt(var_name='Class', value_name='Z-score')
    sns.histplot(data=df_melted, x='Z-score', hue='Class', alpha=0.2, bins=40, stat="percent", kde=True, legend=False)
    plt.title("Histogram of All Gaussians")
    plt.xlabel("feature space")
    plt.ylabel("Probability")
    plt.grid(True)
    plt.show()

    if plot_all:
        for i in range(C):
            plt.figure(figsize=(10, 6))
            sns.histplot(Xs[i], alpha=0.2, bins=10, stat="probability", kde=True)
            plt.title(f"Histogram of {i}-th Class Gaussians")
            plt.xlabel("feature space")
            plt.ylabel("Count")
            plt.grid(True)
            plt.show()
            display(f"Normal test result for class {i}: {normaltest(Xs[i])}")


In [None]:
loader = DataLoader(val_set, batch_size=1)
with torch.no_grad():
    model.eval()
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        X_bar = model(x)
        break
plot_zscore_histograms(model.X, y, max_classes=100, plot_all=False)

In [None]:
model.X.mean(-1).max()