#Import Libraries


In [None]:
import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import normalize
from torchvision import datasets, transforms
import torchvision.models as models
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import wandb
import random
import math
from sklearn.cluster import MiniBatchKMeans

#Common Functions

In [None]:
class NPYAuxDataset(Dataset):
    def __init__(self, npy_file, transform=None):
        self.data = np.load(npy_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        if self.transform:
            img = self.transform(img)
        return img

def validate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

#Loss Terms and Energy Based Sampling

In [None]:
class energy_loss(nn.Module):
    def __init__(self, id_threshold, ood_threshold):
        super(energy_loss, self).__init__()
        self.id_threshold = id_threshold
        self.ood_threshold = ood_threshold

    def forward(self, id_scores, ood_scores):
        id_mask = (id_scores >= self.id_threshold).float()
        id_loss = torch.mean(((id_scores - self.id_threshold) * id_mask) ** 2)

        ood_mask = (ood_scores <= self.ood_threshold).float()
        ood_loss = torch.mean(((self.ood_threshold - ood_scores) * ood_mask) ** 2)

        return id_loss + ood_loss

class gradient_regularization(nn.Module):
    def __init__(self, id_threshold, ood_threshold):
        super(gradient_regularization, self).__init__()
        self.id_threshold = id_threshold
        self.ood_threshold = ood_threshold

    def forward(self, id_scores, ood_scores, id_outputs, ood_outputs):
        id_score_grads = torch.autograd.grad(outputs=id_scores, inputs=id_outputs,
                                           grad_outputs=torch.ones_like(id_scores),
                                           retain_graph=True, create_graph=True)[0]
        ood_score_grads = torch.autograd.grad(outputs=ood_scores, inputs=ood_outputs,
                                            grad_outputs=torch.ones_like(ood_scores),
                                            retain_graph=True, create_graph=True)[0]

        id_grad_norm = torch.norm(id_score_grads.view(id_score_grads.size(0), -1), dim=1)
        ood_grad_norm = torch.norm(ood_score_grads.view(ood_score_grads.size(0), -1), dim=1)

        id_mask = (id_scores <= self.id_threshold).float()
        ood_mask = (ood_scores <= self.ood_threshold).float()

        id_grad_loss = torch.mean(id_grad_norm * id_mask)
        ood_grad_loss = torch.mean(ood_grad_norm * ood_mask)

        return id_grad_loss + ood_grad_loss

def energy_based_sampling(aux_dataloader, feature_extractor, model, num_clusters, device):

    kmeans = MiniBatchKMeans(n_clusters=num_clusters, batch_size=1024)

    with torch.no_grad():
        for aux_batch in tqdm(aux_dataloader, desc="Partial Fitting"):
            aux_batch = aux_batch.to(device)
            batch_features = feature_extractor(aux_batch).cpu().numpy()
            kmeans.partial_fit(batch_features)

    min_energy_per_cluster = [float("inf")] * num_clusters
    max_energy_per_cluster = [float("-inf")] * num_clusters

    min_sample_per_cluster = [None] * num_clusters
    max_sample_per_cluster = [None] * num_clusters

    with torch.no_grad():
        for aux_batch in tqdm(aux_dataloader, desc="Min/Max Sample Extraction"):
            aux_batch = aux_batch.to(device)
            batch_features = feature_extractor(aux_batch).cpu().numpy()
            cluster_labels = kmeans.predict(batch_features)

            id_outputs = model(aux_batch)
            batch_energy_scores = -torch.logsumexp(id_outputs, dim=1).cpu().numpy()

            for i, cluster_id in enumerate(cluster_labels):
                energy = batch_energy_scores[i]

                if energy < min_energy_per_cluster[cluster_id]:
                    min_energy_per_cluster[cluster_id] = energy
                    min_sample_per_cluster[cluster_id] = aux_batch[i].cpu().numpy()

                if energy > max_energy_per_cluster[cluster_id]:
                    max_energy_per_cluster[cluster_id] = energy
                    max_sample_per_cluster[cluster_id] = aux_batch[i].cpu().numpy()

    return min_sample_per_cluster, max_sample_per_cluster



#Model Architecture

In [None]:
class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        base_model = models.resnet18(pretrained=False)
        self.features = nn.Sequential(*list(base_model.children())[:-1])
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        features = self.features(x)
        features = features.view(features.size(0), -1)
        out = self.fc(features)
        return out

    def get_features(self, x):
        features = self.features(x)
        return features.view(features.size(0), -1)

class DenseNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        base_model = models.densenet121(pretrained=False)
        self.features = base_model.features
        self.classifier = nn.Linear(base_model.classifier.in_features, num_classes)

    def forward(self, x):
        features = self.features(x)
        features = nn.functional.adaptive_avg_pool2d(features, (1, 1))
        features = features.view(features.size(0), -1)
        out = self.classifier(features)
        return out

    def get_features(self, x):
        features = self.features(x)
        features = nn.functional.adaptive_avg_pool2d(features, (1, 1))
        return features.view(features.size(0), -1)

#Model, Datasets and Loss Function

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# use either ResNet18 or DenseNet121
model = ResNet18(num_classes = 10) # adjust for number of classes
model = DenseNet(num_classes = 10)

# for training with multiple GPUs
model = nn.DataParallel(model)
model = model.to(device)
model = model.float()




In [None]:
# Transformations
transform_aux = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

transform_cifar = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [None]:
# Create datasets
cifar100_root = '/path/to/dataset'
randomimages300k_root = '/path/to/dataset'

cifar10_train = datasets.CIFAR10(root=cifar10_root, train=True, download=True, transform=transform_cifar)
cifar10_test = datasets.CIFAR10(root=cifar10_root, train=False, download=True, transform=transform_cifar)
cifar100_train = datasets.CIFAR100(root=cifar100_root, train=True, download=True, transform=transform_cifar)
cifar100_test = datasets.CIFAR100(root=cifar100_root, train=False, download=True, transform=transform_cifar)
randomimages300k_dataset = NPYAuxDataset(randomimages300k_root, transform=transform_aux)

# Create dataloaders
batch_size = 64
id_dataloader = DataLoader(cifar10_train, batch_size=batch_size, shuffle=True, num_workers=16)
id_test_dataloader = DataLoader(cifar10_test, batch_size=batch_size, shuffle=False, num_workers=16)
aux_dataloader = DataLoader(randomimages300k_dataset, batch_size=batch_size, shuffle=True, num_workers=16)


Files already downloaded and verified
Files already downloaded and verified


In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)

In [None]:
# Loss functions
energy_l = energy_loss(id_threshold=-27, ood_threshold=-5)
gradient_l = gradient_regularization(id_threshold=-27, ood_threshold=-5)
criterion_ce = nn.CrossEntropyLoss()

#Training

In [None]:
num_epochs = 50

for epoch in range(num_epochs):
    total_ce_loss = 0.0
    total_energy_loss = 0.0
    total_gradient_loss = 0.0
    num_batches = 0
    epoch_loss = 0.0

    model.train()

    # Perform energy-based sampling at the start of the epoch
    all_aux_features = []
    all_aux_outputs = []

    with torch.no_grad():
        for aux_batch in tqdm(aux_dataloader, desc="Processing Aux Data", leave=False):
            aux_batch = aux_batch.to(device)
            aux_features = model.module.get_features(aux_batch)
            aux_outputs = model.module.fc(aux_features)
            all_aux_features.append(aux_features.cpu().numpy())
            all_aux_outputs.append(-torch.logsumexp(aux_outputs, dim=1).cpu().numpy())

    # Stack all auxiliary features and energy scores
    all_aux_features = np.vstack(all_aux_features)
    all_aux_energy_scores = np.concatenate(all_aux_outputs)

    # Perform clustering and sampling
    min_energy_samples, max_energy_samples = energy_based_sampling(
    aux_dataloader=aux_dataloader,
    feature_extractor=lambda x: model.module.get_features(x),
    model=model.module,  # Use the underlying model
    num_clusters=batch_size,
    device=device
)

    # Prepare auxiliary samples for training (already matches batch size)
    min_energy_samples = [x for x in min_energy_samples if x is not None]
    max_energy_samples = [x for x in max_energy_samples if x is not None]

    min_energy_samples = np.stack(min_energy_samples, axis=0)
    min_energy_samples = torch.from_numpy(min_energy_samples).float().to(device)

    max_energy_samples = np.stack(max_energy_samples, axis=0)
    max_energy_samples = torch.from_numpy(max_energy_samples).float().to(device)

    min_energy_samples.requires_grad = True
    max_energy_samples.requires_grad = True

    batch_loop = tqdm(id_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

    for id_inputs, id_labels in batch_loop:
        # Directly use the pre-sampled auxiliary data for the batch
        id_inputs, id_labels = id_inputs.to(device), id_labels.to(device)
        id_inputs.requires_grad = True

        # Forward pass
        id_outputs = model(id_inputs)
        aux_min_outputs = model(min_energy_samples)
        aux_max_outputs = model(max_energy_samples)

        # Compute energy scores
        id_energy_scores = -torch.logsumexp(id_outputs, dim=1)
        aux_min_energy_scores = -torch.logsumexp(aux_min_outputs, dim=1)
        aux_max_energy_scores = -torch.logsumexp(aux_max_outputs, dim=1)

        # Compute losses
        ce_loss_value = criterion_ce(id_outputs, id_labels)
        energy_loss_value = energy_l(id_energy_scores, aux_min_energy_scores)
        gradient_loss_value = gradient_l(
            id_energy_scores, aux_max_energy_scores, id_inputs, max_energy_samples
        )

        # Total loss
        total_loss = ce_loss_value + 0.1 * energy_loss_value + 1.0 * gradient_loss_value

        # Accumulate losses
        total_ce_loss += ce_loss_value.item()
        total_energy_loss += energy_loss_value.item()
        total_gradient_loss += gradient_loss_value.item()
        epoch_loss += total_loss.item()
        num_batches += 1

        # Optimization step
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        batch_loop.set_postfix({
            "CE Loss": ce_loss_value.item(),
            "Energy Loss": energy_loss_value.item(),
            "Grad Loss": gradient_loss_value.item(),
            "Total Loss": total_loss.item()
        })


    # Calculate average losses
    avg_ce_loss = total_ce_loss / num_batches
    avg_energy_loss = total_energy_loss / num_batches
    avg_gradient_loss = total_gradient_loss / num_batches
    avg_total_loss = epoch_loss / num_batches

    # Validate
    val_accuracy = validate(model, id_test_dataloader, device)


Partial Fitting: 100%|██████████| 4688/4688 [01:38<00:00, 47.83it/s]    
Min/Max Sample Extraction: 100%|██████████| 4688/4688 [02:06<00:00, 37.12it/s]
Partial Fitting: 100%|██████████| 4688/4688 [01:37<00:00, 48.10it/s]                                                               
Min/Max Sample Extraction: 100%|██████████| 4688/4688 [02:05<00:00, 37.24it/s]
Partial Fitting: 100%|██████████| 4688/4688 [01:36<00:00, 48.39it/s]                                                                
Min/Max Sample Extraction: 100%|██████████| 4688/4688 [02:06<00:00, 37.18it/s]
Partial Fitting: 100%|██████████| 4688/4688 [01:37<00:00, 48.02it/s]                                                                
Min/Max Sample Extraction: 100%|██████████| 4688/4688 [02:06<00:00, 37.05it/s]
Partial Fitting: 100%|██████████| 4688/4688 [01:38<00:00, 47.56it/s]                                                               
Min/Max Sample Extraction: 100%|██████████| 4688/4688 [02:06<00:00, 37.14it/s]
Pa

0,1
avg_ce_loss,█▅▄▄▄▃▅▃▃▃▃▂▂▂▂▂▂▂▃▂▂▂▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▂▁▁
avg_energy_loss,▃▃▂▂▂▂█▂▃▂▂▁▂▁▁▁▂▁▂▂▁▂▁▁▁▁▁▂▂▁▁▂▂▁▁▂▁▂▁▁
avg_gradient_loss,▆▁▁▅▃▂▅▃▂▄▃▄▃▃▃▅▃▄▂▃▅▃▃▄▄▃▃▄▄▅▄▅▄▆▄▃▆█▇▅
avg_total_loss,▆▅▃▄▃▃█▃▂▃▂▂▂▂▂▂▂▂▁▃▁▂▁▁▁▁▁▁▂▂▃▂▂▁▁▂▁▂▁▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_accuracy,▁▁▂▃▃▄▄▂▅▅▅▆▆▆▃▇▇▅▆▆▆▇▇▇█▅██▇█▇▇▆▆▇▇▇▇██

0,1
avg_ce_loss,2.38184
avg_energy_loss,4.32451
avg_gradient_loss,0.17724
avg_total_loss,2.99154
learning_rate,0.1
validation_accuracy,38.13


In [None]:
# Save model
torch.save(model.state_dict(), "/path/to/save/model")