Mount your Google Drive to use datasets (For colab users only)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Import Libraries

In [None]:
import numpy as np
from sklearn.cluster import KMeans
from torchvision import datasets, transforms
import torchvision.models as models
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import wandb
import random
import math

# Common Functions

In [None]:
class NPYAuxDataset(Dataset):
    def __init__(self, npy_file, transform=None):
        self.data = np.load(npy_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        if self.transform:
            img = self.transform(img)
        return img

def validate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Loss Terms

In [None]:
class energy_loss(nn.Module):

    def __init__(self, id_threshold, ood_threshold):
        super(energy_loss, self).__init__()
        self.id_threshold = id_threshold
        self.ood_threshold = ood_threshold

    def forward(self, id_scores, ood_scores):
        id_mask = (id_scores >= self.id_threshold).float()
        id_loss = torch.mean(((id_scores - self.id_threshold) * id_mask) ** 2)

        ood_mask = (ood_scores <= self.ood_threshold).float()
        ood_loss = torch.mean(((self.ood_threshold - ood_scores) * ood_mask) ** 2)

        return id_loss + ood_loss

class gradient_regularization(nn.Module):

    def __init__(self, id_threshold, ood_threshold):
        super(gradient_regularization, self).__init__()
        self.id_threshold = id_threshold
        self.ood_threshold = ood_threshold

    def forward(self, id_scores, ood_scores, id_outputs, ood_outputs):
        id_score_grads = torch.autograd.grad(outputs=id_scores, inputs = id_outputs, grad_outputs=torch.ones_like(id_scores),
                            retain_graph=True, create_graph=True)[0]
        ood_score_grads = torch.autograd.grad(outputs=ood_scores, inputs = ood_outputs, grad_outputs=torch.ones_like(ood_scores),
                            retain_graph=True, create_graph=True)[0]

        id_grad_norm = torch.norm(id_score_grads.view(id_score_grads.size(0), -1), dim=1)
        ood_grad_norm = torch.norm(ood_score_grads.view(ood_score_grads.size(0), -1), dim=1)

        id_mask = (id_scores <= self.id_threshold).float()
        ood_mask = (ood_scores <= self.ood_threshold).float()

        id_grad_loss = torch.mean(id_grad_norm * id_mask)
        ood_grad_loss = torch.mean(ood_grad_norm * ood_mask)

        return id_grad_loss + ood_grad_loss

# Model, Datasets and Loss Function

Download the model

In [None]:
device = torch.device("cuda")
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 10)
model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 142MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

Transformations for ID and auxiliary dataloaders

In [None]:
transform_aux = transforms.Compose([
    transforms.ToPILImage(),  # Convert numpy image to PIL format
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),  # ImageNet mean
                         (0.229, 0.224, 0.225))
])

transform_cifar = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),  # ImageNet mean
                         (0.229, 0.224, 0.225))
])

Create datasets

In [None]:
cifar10_root = '/path/to/cifar10'
randomimages300k_root = 'path/to/randomimages300k.npy'

In [None]:
cifar10_train = datasets.CIFAR10(root=cifar10_root, train=True, download=True, transform=transform_cifar)
cifar10_test = datasets.CIFAR10(root=cifar10_root, train=False, download=True, transform=transform_cifar)

randomimages300k_dataset = NPYAuxDataset(randomimages300k_root, transform=transform_aux)

Files already downloaded and verified
Files already downloaded and verified


Create dataloaders

In [None]:
batch_size = 64
id_dataloader = DataLoader(cifar10_train, batch_size=batch_size, shuffle=True)
id_test_dataloader = DataLoader(cifar10_test, batch_size=batch_size, shuffle=False)
aux_dataloader = DataLoader(randomimages300k_dataset, batch_size=batch_size, shuffle=True)

Define optimizer and scheduler

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0.001)

Weights & Biases (Optional) for logging training statistics

In [None]:
wandb.login()
wandb.init(
      project="Project Name",
      name=f"Name of the Run",
      config={
      "learning_rate": 0.01,
      "architecture": "Resnet18",
      "dataset": "CIFAR-10",
      "epochs": 20,
      })

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mcengizcenkkerem[0m ([33mcengizcenkkerem-metu-middle-east-technical-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


Create loss instances

In [None]:
energy_l = energy_loss(id_threshold=-27, ood_threshold=-5) # Energy Loss
gradient_l = gradient_regularization(id_threshold=-27, ood_threshold=-5) # Gradient Regularization Term
criterion_ce = nn.CrossEntropyLoss() # Cross-Entropy Loss

# Training

In [None]:
num_epochs = 20

# Training Loop
for epoch in range(num_epochs):

    # Initialize accumulators for losses
    total_ce_loss = 0.0
    total_energy_loss = 0.0
    total_gradient_loss = 0.0
    num_batches = 0
    epoch_loss = 0.0

    model.train()
    aux_iter = iter(aux_dataloader)  # Iterator for auxiliary OOD data

    # Wrap the batch loop with tqdm for progress tracking
    batch_loop = tqdm(id_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

    for id_inputs, id_labels in batch_loop:
        try:
            aux_inputs = next(aux_iter)
        except StopIteration:
            aux_iter = iter(aux_dataloader)
            aux_inputs = next(aux_iter)

        id_inputs, id_labels, aux_inputs = id_inputs.to(device), id_labels.to(device), aux_inputs.to(device)
        id_inputs.requires_grad = True
        aux_inputs.requires_grad = True

        # Forward Pass for ID Data
        id_outputs = model(id_inputs)
        id_energy_scores = -torch.logsumexp(id_outputs, dim=1)

        # Forward Pass for Auxiliary OOD Data
        aux_outputs = model(aux_inputs)
        aux_energy_scores = -torch.logsumexp(aux_outputs, dim=1)

        # Compute Loss
        ce_loss_value = criterion_ce(id_outputs, id_labels)
        energy_loss_value = energy_l(id_energy_scores, aux_energy_scores)
        gradient_loss_value = gradient_l(id_energy_scores, aux_energy_scores, id_inputs, aux_inputs)

        # Total Loss with Weighted Contributions
        total_loss = ce_loss_value + 0.1 * energy_loss_value + 1.0 * gradient_loss_value

        # Accumulate Losses
        total_ce_loss += ce_loss_value.item()
        total_energy_loss += energy_loss_value.item()
        total_gradient_loss += gradient_loss_value.item()
        epoch_loss += total_loss.item()
        num_batches += 1

        # Backward Pass and Optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Update tqdm description with current batch loss
        batch_loop.set_postfix({
            "CE Loss": ce_loss_value.item(),
            "Energy Loss": energy_loss_value.item(),
            "Grad Loss": gradient_loss_value.item(),
            "Total Loss": total_loss.item()
        })

    # Step the Scheduler
    scheduler.step()

    # Calculate average losses for the epoch
    avg_ce_loss = total_ce_loss / num_batches
    avg_energy_loss = total_energy_loss / num_batches
    avg_gradient_loss = total_gradient_loss / num_batches
    avg_total_loss = epoch_loss / num_batches

    # Validate the model
    val_accuracy = validate(model, id_test_dataloader, device)

    # Log Metrics to Weights & Biases
    wandb.log({
        "avg_ce_loss": avg_ce_loss,
        "avg_energy_loss": avg_energy_loss,
        "avg_gradient_loss": avg_gradient_loss,
        "avg_total_loss": avg_total_loss,
        "validation_accuracy": val_accuracy,
        "learning_rate": scheduler.get_last_lr()[0]
    })

wandb.finish()



0,1
avg_ce_loss,█▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
avg_energy_loss,█▃▃▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁
avg_gradient_loss,▇▇█▆▄▄▃▃▂▂▃▂▁▁▂▁▁▁▁▁
avg_total_loss,█▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
learning_rate,███▇▇▇▆▆▅▅▄▃▃▂▂▂▁▁▁▁
validation_accuracy,▁▃▄▅▄▄▆▆▆▆▇▇▇▇▇█████

0,1
avg_ce_loss,1.19443
avg_energy_loss,2.24612
avg_gradient_loss,0.14926
avg_total_loss,1.5683
learning_rate,0.001
validation_accuracy,53.71


Save the model

In [None]:
torch.save(model.state_dict(), "/path/to/save/model/weights.pt")