In [1]:
!pip install torchmetrics
!pip install neptune-client
!pip install scikit-plot
!pip install -U "ray[tune]"

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
from torchmetrics import Accuracy
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet18,mobilenet_v2

import cv2
import os
import copy
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
import seaborn as sn
from functools import partial
from datetime import datetime
import scipy.ndimage as nd
import neptune.new as neptune
from sklearn.metrics import confusion_matrix ,classification_report,accuracy_score,f1_score,precision_score,recall_score
from scikitplot.metrics import plot_confusion_matrix
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler





In [2]:
#General
data_dir = os.path.abspath("./data")
classes  = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
class_len = len(classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
training_config  = 0
train_accuracy = Accuracy(task="multiclass", num_classes = class_len)
train_accuracy.to(device)
l1 = l2 = lr = batch_size = 0

tnsr_board_logger = SummaryWriter()
nep_logger = neptune.init(
project="mohan20325145/resnet",
api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJhZWQyMTU4OC02NmU4LTRiNjgtYWE5Zi1lNDg5MjdmZGJhNzYifQ==",)


#Tuning 
tune_hyperparams = False
num_samples = 4
max_num_epochs = 3
gpus_per_trial = 0 


#Training
num_workers = 8 #2
epochs = [100]
optimizer = ["Adam"]
criterion = [nn.CrossEntropyLoss(), "Evidential"]
model = ["ResNet"]
train_network = True
test_network = False
save_model_params = True

#num_workers = [2,3]
#criterion = ["Evidential", nn.CrossEntropyLoss(), nn.NLLLoss(), nn.GaussianNLLLoss(), nn.SoftMarginLoss()] 
#optimizer = ["Adam", "SGD", "ASGD", "Adamax"]
#epochs = [40, 100]
#model = ["ResNet", "MobileNet", "CustomNet"]

  nep_logger = neptune.init(


https://app.neptune.ai/mohan20325145/resnet/e/RESNETNEP-169
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [3]:
def load_data(data_dir):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
#     trainset = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
#     testset = torchvision.datasets.MNIST(root=data_dir, train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))    
    return trainset, testset

def logger():
    nep_logger.stop()
    tnsr_board_logger.close
    # !tensorboard --logdir=runs

class Param_Tuning_NN(nn.Module):
    def __init__(self, l1, l2):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)   
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x  
    
class Custom_Train_NN(nn.Module):
    def __init__(self, l1, l2):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
def tune_subroutine(config, checkpoint_dir=None, data_dir=None):
    
    net = Param_Tuning_NN(config["l1"], config["l2"])
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

    if checkpoint_dir:
        model_state, optimizer_state = torch.load(os.path.join(checkpoint_dir, "checkpoint"))
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
        
    trainset, testset = load_data(data_dir)
    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(trainset, [test_abs, len(trainset) - test_abs])
    trainloader = torch.utils.data.DataLoader(train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8)
    valloader = torch.utils.data.DataLoader(val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8)

    for epoch in range(10):
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1
        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((net.state_dict(), optimizer.state_dict()), path)
        tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
    
    
def tune_model():
    config = {"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
              "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
              "lr": tune.loguniform(1e-4, 1e-1),
              "batch_size": tune.choice([2, 4, 8, 16])}
    scheduler = ASHAScheduler(metric="loss",
                              mode="min",
                              max_t=max_num_epochs,
                              grace_period=1,
                              reduction_factor=2)
    reporter = CLIReporter(metric_columns=["loss", "accuracy", "training_iteration"])
    result = tune.run(partial(tune_subroutine, data_dir=data_dir),
             resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
             config=config,
             num_samples=num_samples,
             scheduler=scheduler,
             progress_reporter=reporter)

    best_trial = result.get_best_trial("loss", "min", "last")
    print("Best trial config: {}".format(best_trial.config))
    print("Best trial final validation loss: {}".format(best_trial.last_result["loss"]))
    print("Best trial final validation accuracy: {}".format(best_trial.last_result["accuracy"]))
    return best_trial.config['l1'], best_trial.config['l2'], best_trial.config['lr'], best_trial.config['batch_size']

In [5]:
def one_hot_embedding(labels, num_classes):
    y = torch.eye(num_classes)
    return y[labels]

def relu_evidence(y):
    return F.relu(y)

def kl_divergence(alpha, num_classes, device=None):
    beta = torch.ones([1, num_classes], dtype=torch.float32, device=device)
    S_alpha = torch.sum(alpha, dim=1, keepdim=True)
    S_beta = torch.sum(beta, dim=1, keepdim=True)
    lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
    lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta)
    dg0 = torch.digamma(S_alpha)
    dg1 = torch.digamma(alpha)
    kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni
    return kl

def loglikelihood_loss(y, alpha, device=None):
    y = y.to(device)
    alpha = alpha.to(device)
    S = torch.sum(alpha, dim=1, keepdim=True)
    loglikelihood_err = torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True)
    loglikelihood_var = torch.sum(alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True)
    loglikelihood = loglikelihood_err + loglikelihood_var
    return loglikelihood

def mse_loss(y, alpha, epoch_num, num_classes, annealing_step, device=None):
    y = y.to(device)
    alpha = alpha.to(device)
    loglikelihood = loglikelihood_loss(y, alpha, device=device)
    annealing_coef = torch.min(torch.tensor(1.0, dtype=torch.float32), torch.tensor(epoch_num / annealing_step, dtype=torch.float32))
    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
    return loglikelihood + kl_div

def edl_mse_loss(output, target, epoch_num, num_classes, annealing_step, device=None):
    evidence = relu_evidence(output)
    alpha = evidence + 1
    loss = torch.mean(mse_loss(target, alpha, epoch_num, num_classes, annealing_step, device=device))
    return loss

In [6]:
def train_model_subroutine(criterion_, optimizer_, epochs_, model_):
    since = time.time()
    nep_logger['params/training/model'+ str(training_config)].log(model_)
    nep_logger['params/training/model'+ str(training_config)].log(criterion_)
    nep_logger['params/training/model'+ str(training_config)].log(epochs_)
    nep_logger['params/training/model'+ str(training_config)].log(optimizer_)
    
    if (model_ == "ResNet"):
        net = resnet18()
        net.fc = nn.Linear(in_features=512,out_features=class_len)
    elif (model_ == "MobileNet"):
        net = mobilenet_v2()
        net.fc = nn.Linear(in_features=512,out_features=class_len)
    elif (model_ == "CustomNet"):
        net = Custom_Train_NN(l1, l2)   

    if (optimizer_ == "Adam"):
        optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=0.005)
    elif (optimizer_ == "SGD"):
        optimizer = optim.SGD(net.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
   
    net = net.to(device)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    valloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, num_workers=num_workers)
    dataloaders = {"train": trainloader, "val": valloader,}
    
    best_model_wts = copy.deepcopy(net.state_dict())
    best_acc = 0.0
    
    losses = {"loss": [], "phase": [], "epoch": []}
    accuracy = {"accuracy": [], "phase": [], "epoch": []}
    
    for epoch in range(epochs_):
        print("Epoch {}/{}".format(epoch, epochs_ - 1))
        print("-" * 10)
        
        for phase in ["train", "val"]:
            if phase == "train":
                print("Training...")
                net.train()  
            else:
                print("Validating...")
                net.eval()
                
            running_loss = 0.0
            running_corrects = 0.0
            
            for i, data in enumerate(dataloaders[phase]):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = net(inputs)
                _,predicted = torch.max(outputs.data, 1)

                with torch.set_grad_enabled(phase == "train"):
                    if criterion_ == "Evidential":
                        y = one_hot_embedding(labels=labels,num_classes=class_len)
                        y.to(device)
                        loss = edl_mse_loss(outputs, y.float(), epoch, class_len, 10, device)
                        match = torch.reshape(torch.eq(predicted, labels).float(), (-1, 1))
                        acc = torch.mean(match)
                        evidence = relu_evidence(outputs)    
                        alpha = evidence + 1
                        u = class_len / torch.sum(alpha, dim=1, keepdim=True)
                        total_evidence = torch.sum(evidence, 1, keepdim=True)
                        mean_evidence = torch.mean(total_evidence)
                        mean_evidence_succ = torch.sum(torch.sum(evidence, 1, keepdim=True) * match) / torch.sum(match + 1e-20)
                        mean_evidence_fail = torch.sum(torch.sum(evidence, 1, keepdim=True) * (1 - match)) / (torch.sum(torch.abs(1 - match)) + 1e-20)
                    else:
                        loss = criterion_(outputs, labels)
                    
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        
                running_loss += loss.item()* inputs.size(0)
                running_corrects += torch.sum(predicted == labels.data)
            
            epoch_loss = running_loss / len(trainloader)
            epoch_acc = running_corrects.double() / len(trainloader)
            
            if phase == "train":
                nep_logger['plots/training/train/loss'+ str(training_config)].log(epoch_loss)
                nep_logger['plots/training/train/accuracy'+ str(training_config)].log(epoch_acc.item())
            else:
                nep_logger['plots/training/val/loss'+ str(training_config)].log(epoch_loss)
                nep_logger['plots/training/val/accuracy'+ str(training_config)].log(epoch_acc.item())
                
            losses["loss"].append(epoch_loss)
            losses["phase"].append(phase)
            losses["epoch"].append(epoch)
            accuracy["accuracy"].append(epoch_acc.item())
            accuracy["epoch"].append(epoch)
            accuracy["phase"].append(phase)
            
            print("{} loss: {:.4f} acc: {:.4f}".format(phase.capitalize(), epoch_loss, epoch_acc))

            if phase == "val" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(net.state_dict())
                
                
    time_elapsed = time.time() - since
    print("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
    print("Best val Acc: {:4f}".format(best_acc))
    
    net.load_state_dict(best_model_wts)

    if save_model_params:
        state = {"epoch": epochs_,
                 "model_state_dict": net.state_dict(),
                 "optimizer_state_dict": optimizer.state_dict()}
        
        if criterion_ == "Evidential":
            saved_models_count = len('./results/EDL/*')
            torch.save(state, "./results/EDL/"+ str(saved_models_count+1)+".pt")
        else:
            saved_models_count = len('./results/CEL/*')
            torch.save(state, "./results/CEL/"+ str(saved_models_count+1)+".pt")
    

def train_model():
    global training_config
    for loss, opti, epo, mod in [(loss, opti, epo, mod) for loss in criterion for opti in optimizer for epo in epochs for mod in model]:
        training_config += 1
        train_model_subroutine(loss, opti, epo, mod)

In [7]:
def test_model():
    for loss, mod in [(loss, mod) for loss in criterion for mod in model]:
        if loss == "Evidential":
            saved_models_count = len('./results/EDL/*')
            checkpoint = torch.load("./results/EDL/"+str(saved_models_count)+".pt")
            
        else:
            saved_models_count = len('./results/CEL/*')
            checkpoint = torch.load("./results/CEL/"+str(saved_models_count)+".pt")
            
        if (model_ == "ResNet"):
            net = resnet18()
            net.fc = nn.Linear(in_features=512,out_features=class_len)
            
        optimizer = optim.Adam(net.parameters())
        

In [8]:
#Data loader
trainset, testset = load_data(data_dir)

#Tune model params
if tune_hyperparams == True:
    l1, l2, lr, batch_size = tune_model() 
else:
    l1 = 16
    l2 = 8
    lr = 1e-3
    batch_size = 16 #16

#Train model
if train_network == True:
    train_model()

#Test model
if test_network == True:
    test_model()

#Logging
logger()

Files already downloaded and verified
Files already downloaded and verified
Epoch 0/99
----------
Training...
Train loss: 25.9319 acc: 6.6458
Validating...
Val loss: 4.6156 acc: 1.5434
Epoch 1/99
----------
Training...
Train loss: 21.1859 acc: 8.5642
Validating...
Val loss: 3.8160 acc: 1.8726
Epoch 2/99
----------
Training...
Train loss: 18.9580 acc: 9.4230
Validating...
Val loss: 3.6290 acc: 1.9146
Epoch 3/99
----------
Training...
Train loss: 17.8179 acc: 9.8064
Validating...
Val loss: 3.2977 acc: 2.0573
Epoch 4/99
----------
Training...
Train loss: 17.3261 acc: 10.0138
Validating...
Val loss: 3.2929 acc: 2.0490
Epoch 5/99
----------
Training...
Train loss: 16.8092 acc: 10.2765
Validating...
Val loss: 3.1625 acc: 2.0982
Epoch 6/99
----------
Training...
Train loss: 16.6233 acc: 10.2774
Validating...
Val loss: 3.1175 acc: 2.1261
Epoch 7/99
----------
Training...
Train loss: 16.2880 acc: 10.4586
Validating...
Val loss: 3.2162 acc: 2.1040
Epoch 8/99
----------
Training...
Train loss: 16

Train loss: 14.8203 acc: 10.9728
Validating...
Val loss: 2.9502 acc: 2.1552
Epoch 74/99
----------
Training...
Train loss: 14.8232 acc: 10.9760
Validating...
Val loss: 2.9230 acc: 2.1914
Epoch 75/99
----------
Training...
Train loss: 14.8442 acc: 10.9514
Validating...
Val loss: 2.9324 acc: 2.1990
Epoch 76/99
----------
Training...
Train loss: 14.8559 acc: 10.9731
Validating...
Val loss: 2.9308 acc: 2.1971
Epoch 77/99
----------
Training...
Train loss: 14.8240 acc: 10.9453
Validating...
Val loss: 2.9783 acc: 2.1757
Epoch 78/99
----------
Training...
Train loss: 14.8982 acc: 10.9338
Validating...
Val loss: 3.0288 acc: 2.1626
Epoch 79/99
----------
Training...
Train loss: 14.8127 acc: 10.9917
Validating...
Val loss: 2.9519 acc: 2.1610
Epoch 80/99
----------
Training...
Train loss: 14.8537 acc: 10.9680
Validating...
Val loss: 2.9389 acc: 2.1987
Epoch 81/99
----------
Training...
Train loss: 14.8711 acc: 10.9552
Validating...
Val loss: 2.9130 acc: 2.1907
Epoch 82/99
----------
Training...
T

  lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)


Train loss: 13.8136 acc: 3.9840
Validating...
Val loss: 2.6865 acc: 0.9158
Epoch 1/99
----------
Training...
Train loss: 15.2578 acc: 4.7062
Validating...
Val loss: 3.0478 acc: 0.9411
Epoch 2/99
----------
Training...
Train loss: 15.4127 acc: 4.8752
Validating...
Val loss: 3.0471 acc: 1.1283
Epoch 3/99
----------
Training...
Train loss: 15.4771 acc: 5.4256
Validating...
Val loss: 3.0835 acc: 1.1248
Epoch 4/99
----------
Training...
Train loss: 15.5278 acc: 5.5456
Validating...
Val loss: 3.0969 acc: 1.1747
Epoch 5/99
----------
Training...
Train loss: 15.5664 acc: 5.5667
Validating...
Val loss: 3.1379 acc: 1.0019
Epoch 6/99
----------
Training...
Train loss: 15.6009 acc: 5.4666
Validating...
Val loss: 3.1156 acc: 1.0934
Epoch 7/99
----------
Training...
Train loss: 15.6187 acc: 5.4822
Validating...
Val loss: 3.1274 acc: 0.9965
Epoch 8/99
----------
Training...
Train loss: 15.6364 acc: 5.3149
Validating...
Val loss: 3.1269 acc: 1.0150
Epoch 9/99
----------
Training...
Train loss: 15.6484

Train loss: 15.6730 acc: 4.7206
Validating...
Val loss: 3.1361 acc: 0.8173
Epoch 76/99
----------
Training...
Train loss: 15.6726 acc: 4.7453
Validating...
Val loss: 3.1339 acc: 0.9712
Epoch 77/99
----------
Training...
Train loss: 15.6732 acc: 4.6982
Validating...
Val loss: 3.1333 acc: 0.9494
Epoch 78/99
----------
Training...
Train loss: 15.6730 acc: 4.7229
Validating...
Val loss: 3.1371 acc: 0.7830
Epoch 79/99
----------
Training...
Train loss: 15.6736 acc: 4.6589
Validating...
Val loss: 3.1349 acc: 0.9517
Epoch 80/99
----------
Training...
Train loss: 15.6743 acc: 4.6166
Validating...
Val loss: 3.1340 acc: 0.9392
Epoch 81/99
----------
Training...
Train loss: 15.6748 acc: 4.6368
Validating...
Val loss: 3.1344 acc: 0.9472
Epoch 82/99
----------
Training...
Train loss: 15.6738 acc: 4.6797
Validating...
Val loss: 3.1337 acc: 1.0163
Epoch 83/99
----------
Training...
Train loss: 15.6734 acc: 4.6397
Validating...
Val loss: 3.1337 acc: 0.9795
Epoch 84/99
----------
Training...
Train loss

In [9]:
#         train_accuracy.update(predicted, labels)
#         epoch_accuracy = train_accuracy.compute()
#         train_accuracy.reset() 
#         predicted = predicted.cpu().detach().numpy()
#         labels = labels.cpu().detach().numpy()
#         epoch_accuracy_score = accuracy_score(labels, predicted)
#         epoch_precision_score = precision_score(labels, predicted, average='weighted')
#         epoch_f1_score = f1_score(labels, predicted, average='weighted')
#         epoch_recall_score = recall_score(labels, predicted, average='weighted')
#         nep_logger['plots/training/accuracy_sklearn'+ str(training_config)].log(epoch_accuracy_score)
#         nep_logger['plots/training/precision_score'+ str(training_config)].log(epoch_precision_score)
#         nep_logger['plots/training/f1_score'+ str(training_config)].log(epoch_f1_score)
#         nep_logger['plots/training/recall_score'+ str(training_config)].log(epoch_recall_score)
#         tnsr_board_logger.add_scalar('plots/training/loss'+ str(training_config), epoch_loss, epoch)
#             fig, axis = plt.subplots(figsize = (20,20))
#             plot_confusion_matrix(labels, predicted, ax=axis, normalize=True)
#             ticks = np.arange(class_len)
#             plt.xticks(ticks, classes, rotation=0, fontsize=14)
#             plt.yticks(ticks, classes, fontsize=14)
#             plt.title('Confusion Matrix', fontsize=20)
#             nep_logger["images/training/conf_matrix"].upload(fig)       
#         PATH = './cifar_net_' + str(training_count) + '.pth'
#         torch.save(net.state_dict(), PATH)
    
    
#                 match = torch.reshape(torch.eq(predicted, labels).float(), (-1, 1))
#                 acc = torch.mean(match)
#                 evidence = relu_evidence(outputs)    
#                 alpha = evidence + 1
#                 u = class_len / torch.sum(alpha, dim=1, keepdim=True)
#                 total_evidence = torch.sum(evidence, 1, keepdim=True)
#                 mean_evidence = torch.mean(total_evidence)
#                 mean_evidence_succ = torch.sum(torch.sum(evidence, 1, keepdim=True) * match) / torch.sum(match + 1e-20)
#                 mean_evidence_fail = torch.sum(torch.sum(evidence, 1, keepdim=True) * (1 - match)) / (torch.sum(torch.abs(1 - match)) + 1e-20)