In [2]:
import argparse
import json
import os

import utility, train, density_estimation, ood_detection, conf_calibration 
from utility import *
from train import *
from density_estimation import *
from ood_detection import * 
from conf_calibration import *

ID_dataset = "CIFAR-10"
batch_size = 64
val_size = 0.05
val_seed = 99
num_classes = 10
embedding_dim = 512
learning_rate = 1e-3
dropout_rate = 0.5
reg_param = 5e-2
num_epochs = 100
index = 0
device = "cuda:0"
ouput_dir = "saved_results"
pretrained = False

In [3]:
trainloader, validloader, testloader, ood_loader1, ood_loader2 = load_datasets(ID_dataset, batch_size, val_size)

Files already downloaded and verified
Using downloaded and verified file: ./data\test_32x32.mat
Files already downloaded and verified


In [7]:
resnet

<function utility.resnet()>

In [5]:
#model = load_model(ID_dataset, pretrained, index, dropout_rate, device) 
model = resnet(num_classes)
model

TypeError: resnet() takes 0 positional arguments but 1 was given

In [23]:
"""
Code for training & evaluation of the model
"""

import os
import math
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm  
import  warnings
warnings.filterwarnings('ignore')
from itertools import cycle

from scipy.stats import beta
from scipy.stats import dirichlet
from scipy.special import gammaln
from scipy.special import digamma
from scipy.stats import multivariate_normal as mvn

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.kl import kl_divergence as kl_div
from torch.nn.utils import spectral_norm
from torch.utils.data import DataLoader, Dataset


def train_daedl(model, learning_rate, reg_param, num_epochs, trainloader, validloader, num_classes, device):
    
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: 0.95 ** epoch)  
    
    VAL_ACC = []
    VAL_LOSS = []
    cnt = 0

    model.to(device)        
    model.train()
    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
          
        for i, (x,y) in enumerate(trainloader):
            optimizer.zero_grad()
            x,y = x.to(device), y.to(device)
          
            alpha = 1e-6 + torch.exp(model(x))                                                         
            alpha0 = alpha.sum(1).reshape(-1,1)
            y_oh = F.one_hot(y, num_classes).to(device)
            alpha_tilde = alpha * (1-y_oh) + y_oh
                
            expected_mse = torch.sum((y_oh - alpha / alpha0) ** 2 ) + torch.sum(((alpha * (alpha0 - alpha))) / ((alpha0 ** 2) * (alpha0 + 1)))                                                                           
            kl_regularizer = kl_div(Dirichlet(1e-6 + alpha_tilde), Dirichlet(torch.ones_like(alpha_tilde))).sum() 
            loss = expected_mse + reg_param * kl_regularizer
                
            loss.backward()      
            optimizer.step()
            running_loss += loss.item()    
            
        scheduler.step()
        
        if epoch % 20 == 0 and epoch > 0:
        
            total=0
            correct=0
            val_loss = 0
            
            with torch.no_grad():
                for i, (x_v,y_v) in enumerate(validloader):
                    x_v, y_v = x_v.to(device), y_v.to(device)

                    alpha_v= torch.exp(model(x_v))
                    alpha0_v = alpha_v.sum(1).reshape(-1,1)
                    y_oh_v = F.one_hot(y_v, num_classes).to(device)  
                    alpha_v_tilde = alpha_v * (1-y_oh_v) + y_oh_v
                    
                    expected_mse_v = torch.sum((y_oh_v - alpha_v/ alpha0_v) ** 2 ) + torch.sum(((alpha_v * (alpha0_v- alpha_v))) / ((alpha0_v ** 2) * (alpha0_v + 1)))
                    kl_regularizer_v = kl_div(Dirichlet(alpha_v_tilde), Dirichlet(torch.ones_like(alpha_v_tilde))).sum()
                    
                    val_loss += expected_mse_v + reg_param * kl_regularizer_v               
                    y_pred_v = alpha_v.argmax(1)
                    
                    total += y_v.size(0)
                    correct += (y_pred_v == y_v).sum().item()

            val_acc = 100*correct/total
            VAL_LOSS.append(val_loss)
            VAL_ACC.append(val_acc)
            
            if len(VAL_ACC) > 2 : 
                
                r_acc = (VAL_ACC[-1] - VAL_ACC[-2]) / VAL_ACC[-2]
                r_loss = (VAL_LOSS[-1] - VAL_LOSS[2]) / VAL_LOSS[-2]

                if r_loss > -0.0001 :
                    cnt = cnt + 1
                else : 
                    cnt = 0
                    
            if cnt > 3 :
                break
                
            print('Epoch {}, loss = {:.3f}'.format(epoch, val_loss)) 
            print('Validation Accuracy = {:.3f}'.format(val_acc))
                 
def eval_daedl(model, testloader, device):    
    model.eval()
    total = 0
    correct = 0
    
    with torch.no_grad():
        for i, (x,y) in enumerate(testloader):
            x,y = x.to(device), y.to(device)
            alpha_pred = torch.exp(model(x))
            y_pred = alpha_pred.argmax(1)
            
            total += y.size(0)
            correct += (y_pred == y).sum().item()
            
        test_acc = 100*correct/total
        print("Test Accuracy:",test_acc)
    
    return test_acc

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import numpy as np

# Step 1: Prepare the dataset
X, y = make_moons(n_samples=1000, noise=0.1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

# Step 2: Define the Model
class TwoMoonsClassifier(nn.Module):
    def __init__(self):
        super(TwoMoonsClassifier, self).__init__()
        self.fc1 = nn.Linear(2, 64)
        self.fc2 = nn.Linear(64, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = TwoMoonsClassifier()

# Step 3: Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Step 4: Training Loop
num_epochs = 20
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss}")

# Step 5: Evaluate Model
model.eval()
with torch.no_grad():
    outputs = model(X_test_tensor)
    _, predicted = torch.max(outputs, 1)
    accuracy = (predicted == y_test_tensor).sum().item() / len(y_test_tensor)
    print(f"Accuracy on test set: {accuracy}")


In [18]:
learning_rate = 1e-4

In [None]:
train_daedl(model, learning_rate, reg_param, num_epochs, trainloader, validloader, num_classes, device)   
test_acc = eval_daedl(model, testloader, device)

In [25]:
gda, p_z_train = fit_gda(model, trainloader, num_classes, embedding_dim, device)
ood_auroc, ood_aupr = ood_detection_daedl(model, gda, p_z_train, testloader, ood_loader1, ood_loader2, num_classes,                                              device)
result = {"Test Acc": test_acc,"OOD AUROC": ood_auroc,"OOD AUPR": ood_aupr}

100%|████████████████████████████████████████████████████████████████████████████████| 743/743 [00:20<00:00, 36.69it/s]
  0%|                                                                                          | 0/743 [00:10<?, ?it/s]


RuntimeError: The size of tensor a (10) must match the size of tensor b (100) at non-singleton dimension 1

In [None]:
if ID_dataset == "CIFAR-10" or "CIFAR-100":
    brier, conf_aupr, conf_auroc = conf_calibration_daedl(model, gda, p_z_train, testloader, device)        
    result["Conf AUROC"] = conf_auroc
    result["Conf AUPR"] = conf_aupr

os.makedirs(output_dir, exist_ok=True)
result_filepath = os.path.join(output_dir, 'results.json')

with open(result_filepath, 'w') as result_file:
    json.dump(result, result_file, indent=4)

In [6]:
trainloader, validloader, testloader, ood_loader1, ood_loader2 = load_datasets(ID_dataset, batch_size, val_size)
model = load_model(ID_dataset, pretrained, index, dropout_rate, device)   
train_daedl(model, learning_rate, reg_param, num_epochs, trainloader, validloader, num_classes, device)   
test_acc = eval_daedl(model, testloader, device)  
gda, p_z_train = fit_gda(model, trainloader, num_classes, embedding_dim, device)
ood_auroc, ood_aupr = ood_detection_daedl(model, gda, p_z_train, testloader, ood_loader1, ood_loader2, num_classes,                                              device)
result = {"Test Acc": test_acc,"OOD AUROC": ood_auroc,"OOD AUPR": ood_aupr}
    
if ID_dataset == "CIFAR-10" or "CIFAR-100":
    brier, conf_aupr, conf_auroc = conf_calibration_daedl(model, gda, p_z_train, testloader, device)        
    result["Conf AUROC"] = conf_auroc
    result["Conf AUPR"] = conf_aupr

os.makedirs(output_dir, exist_ok=True)
result_filepath = os.path.join(output_dir, 'results.json')

with open(result_filepath, 'w') as result_file:
    json.dump(result, result_file, indent=4)

Files already downloaded and verified
Using downloaded and verified file: ./data\test_32x32.mat


 11%|████████▋                                                                      | 11/100 [09:01<1:16:59, 51.90s/it]

Epoch 10, loss = 2500.173
Validation Accuracy = 0.000


 11%|████████▋                                                                      | 11/100 [09:33<1:17:19, 52.13s/it]


KeyboardInterrupt: 