In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from train_tools.utils import directory_setter
import copy

class InspectionHandler():
    """
    Inspector for learned network. The forward function of model must return (outputs, mark).
    outputs: (tensor) network output / mark: (tensor) marks consists of path indices.
    
    This module performs following inspections for different adaptive paths:
    1) Computational Cost vs Accuracy Trade-off
    2) Risk-Coverage Trade-off
    3) Confidence(Softmax Response) & Entropy distribution
    """
    def __init__(self, Network, dataloaders, dataset_sizes, device='cuda:0', phase='test',
                 num_path=1, path_cost=(1,), base_setting=True, use_small=False, save_path='./results/inspection/'):
        """
        [args]     (int) num_path : the number of adaptive paths of inference 
                   (tuple) path_cost : relative cost of path flops w.r.t. total flops ex) (0.3, 0.7, 1.15)
                                       default is None for 'no adaptive computation option'
                   (bool) use_samll : True if using seperated small network as a single path 
        """

        self.Network = Network.to(device).eval()
        self.dataloaders = dataloaders
        self.dataset_sizes = dataset_sizes
        self.device = device
        self.phase = phase
        self.num_path = num_path
        self.path_cost = path_cost
        self.use_small = 1 if use_small else 0
        self.save_path = save_path
        self.name = 'test' # default experiment name is 'test'
        
        if num_path != 1:
            assert num_path == len(path_cost), 'number of paths should have corresponding cost!'
        
        # build base inspection results
        self._result_dict_builder(phase=phase)
        
        if base_setting:
            self._baseline_setter(phase=phase)
            self._sr_dist_builder(phase=phase)
            print('Baselines & SR distribution result has updated.\n')
        
        
    def inference(self, phase='test', exit_cond=None):
        """
        Inference for given exit condition and dataset.
        
        [args]      (str) phase : use test or valid set 'valid' or 'test'
                    (list or tuple) exit_cond : exiting threshold condition for the paths ex) (0.9, 0.97, 0.84)
        
        [returns]   (int) total_acc : total accuracy
                    (list) path_acc : list of accuracy for paths
                    (list) path_ratio : list of count ratio for paths
                    (float) score : flops score for carried out inference
        """
        self._condition_setter(exit_cond)
        
        path_correct, path_count = [0] * (self.num_path), [0] * (self.num_path)
        size = self.dataset_sizes[phase]
                 
        with torch.no_grad():
            for inputs, labels in self.dataloaders[phase]:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # Inference from Network
                outputs, mark = self._forward(inputs)
                pred = self._prediction(outputs)
                
                path_mark = []
                
                if self.use_small:
                    path_mark.append(mark == -1) # mark of smallnet is -1
                    
                for i in range(self.num_path-self.use_small):
                    path_mark.append(mark == i)
                
                for i, marker in enumerate(path_mark):
                    path_correct[i] += (pred[marker] == labels[marker]).sum().item()
                    path_count[i] += marker.sum().item()
                    
                    
        assert sum(path_count) == size, 'Total count must same with data count!'
        
        total_acc = round(sum(path_correct) / size, 4)
        
        path_acc, path_ratio = [], [] 
        
        for i in range(self.num_path):
            acc = -1 if path_count[i] == 0 else round((path_correct[i]/path_count[i]), 4) # assign -1 if no count
            ratio = round((path_count[i]/size), 4)
            path_acc.append(acc)
            path_ratio.append(ratio)

        flops_score = self._flops_checker(path_ratio)
        
        return total_acc, path_acc, path_ratio, flops_score

    
    def grid_inspector(self, start_cond, grid=0.01):
        """
        Inspection for grid condition.
        
        [args]      (list or tuple) start_cond: start cond of condition ragne for paths ex) [0.5, 0.55, 0.7]
        """ 
        assert len(start_cond) == self.num_path-1, 'condition should 1 less than num_path!'
        
        condition_range = []
        condition_elem_num = []
        
        for start in start_cond:
            # build condition range for paths
            cond_size = int((1.0-start)//grid)
            cond_list = [start + x*grid for x in range(cond_size)]

            # confirm final threshold as 1.0
            if cond_list[-1] != 1:
                cond_list.append(1.0)
            
            condition_range.append(cond_list)
            condition_elem_num.append(len(cond_list))
        
        # build grid condition set
        condition_set = self._grid_cond_builder(condition_set=[], condition_range=condition_range)
        
        print('Starting Grid Inspection...\n')

        for i, exit_cond in enumerate(condition_set):
            total_acc, path_acc, path_ratio, flops_score = self.inference(phase=self.phase, exit_cond=exit_cond)
            self._result_dict_updater(exit_cond, total_acc, path_acc, path_ratio, flops_score)
            
            if (i%100 == 0) and (i > 0):
                print('Grid Inspection (%d/%d)'%(i, len(condition_set)))
    
    
    def sr_dist_inspector(self, path_idx, phase='test'):
        """
        Inspects logit distribution and top-10 entropy distribution for given path
        
        [args]      (int) path_idx : 0 <= path index. < num_paths 
                    (str) phase : use test or valid set 'valid' or 'test'
                    
        [returns]   (tuple) (max_sr_co, max_sr_inco) : lists of maximum softmax response for correct/incorrect samples
                    (tuple) (entropy_co, entropy_inco) : lists of entropy of top-10 softmax response for correct/incorrect samples
        """
        # Set full path condition
        condition = self._fullcond_setter(path_idx)
        self._condition_setter(condition)
        
        max_sr_co, max_sr_inco = [], []
        entropy_co, entropy_inco = [], []
        
        size = self.dataset_sizes[phase]
        
        with torch.no_grad():
            for inputs, labels in self.dataloaders[phase]:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                
                # inference & prediction
                outputs, _ = self._forward(inputs)
                pred = self._prediction(outputs)
                
                # SR(softmax response)
                soft_out = F.softmax(outputs, dim=1)
                
                # maximum softmax response
                max_sr, _ = soft_out.max(dim=1) # maximum SR
                
                # top-10 entropy of softmax resnponse
                top10_out, _ = torch.topk(soft_out, 10) # top-10 SR
                entropy = (top10_out * -top10_out.log()).sum(dim=1) # top-10 SR entropy

                # Get correct/incorrect tensor
                co_tensor = pred == labels
                inco_tensor = pred != labels
                
                # get values for correct/incorrect samples
                co_sr, inco_sr = max_sr[co_tensor].tolist(), max_sr[inco_tensor].tolist()
                co_entropy, inco_entropy = entropy[co_tensor].tolist(), entropy[inco_tensor].tolist()

                # update max_sr list
                max_sr_co += co_sr
                max_sr_inco += inco_sr
                
                # update entropy list
                entropy_co += co_entropy
                entropy_inco += inco_entropy

        return (max_sr_co, max_sr_inco), (entropy_co, entropy_inco)
    
    
    def _sr_dist_builder(self, phase='test'):
        """
        Updates sr_dist for each paths
        
        [args]      (int) path_idx : 0 <= path index. < num_paths 
                    (str) phase : use test or valid set 'valid' or 'test'
        """
        for path_idx in range(self.num_path):
            (max_sr_co, max_sr_inco), (entropy_co, entropy_inco) = self.sr_dist_inspector(path_idx, phase)
            self.result_dict['max_sr_dist_'+str(path_idx)][0].append(max_sr_co)
            self.result_dict['max_sr_dist_'+str(path_idx)][1].append(max_sr_inco)
            self.result_dict['entropy_dist_'+str(path_idx)][0].append(entropy_co)
            self.result_dict['entropy_dist_'+str(path_idx)][1].append(entropy_inco)
            
            
    def _forward(self, x):
        """
        Inference for a single batch. 
        
        [returns]   (Tensor) outputs : inference result tensor
                    (Tensor) mark : mark tensor for paths. None if no adaptive path exits.
        """
        if self.path_cost is not None:
            outputs, mark = self.Network(x)
        else:
            outputs, mark = self.Network(x), None
        
        return outputs, mark
             
    
    def _prediction(self, outputs):
        """
        Prediction for a single batch inference result.
        
        [returns]   (Tensor) max_logits : maximum softmax output value tensor
                    (Tensor) pred : prediction result tensor
        """
        _, pred = torch.max(outputs, 1)
        return pred

    
    def _flops_checker(self, path_ratio):
        """
        Calculates total flops. Regards main network cost itself as 1.0
        
        [args]      (list) path_ratio: path ratio for paths
        
        [returns]   (float) score : flops score for given ratio list
        """
        score = 0
        for i, cost in enumerate(self.path_cost):
            score += (cost * path_ratio[i])

        return score
    
    
    def _condition_setter(self, exit_cond):
        """
        Sets exiting threshold for each paths in Network.
        
        [args]      (list or tuple) cond: exiting threshold condition for the paths ex) (0.9, 0.97, 0.84)
        """
        assert (self.num_path-1) == len(exit_cond), 'the number of exit_cond should 1 less than the number of paths!'
        self.Network.condition_updater(exit_cond)


    
    def _fullcond_setter(self, path_idx):
        """
        Returns full exiting condition for given path.
        
        [args]      (int) path_idx: path index (0 <= path_idx < num_paths)
        
        [returns]   (list) condition : condition list for corresponding path
        """
        condition = [1 * (self.num_path-1)] # coindition length should 1 less than path length
        if path_idx < self.num_path-1:
            condition[path_idx] = 0 # sets threshold as 0 for full exiting ex) [1,1,1,0,1]
        
        return condition
    

    def _grid_cond_builder(self, condition_set, condition_range):
        """
        Recursively builds list of exit_cond for given condition ranges.
        
        [args]      (list) condition_set : list of exit_cond. empty list at first
                    (list) condition_range : list of condition range tuples
                    
        [returns]   (list) condition_set : list of exit_cond (recursive result)
        """
        if len(condition_range) == 0: 
            # return if no more condition range exists.
            return condition_set

        else:
            if not condition_set: # if condition_set is empty (initial state)
                # append conditions for empty list
                for cond in condition_range[-1]:
                    condition_set.append([cond])
                condition_set = self._grid_cond_builder(condition_set, condition_range[:-1])
            
            else:
                new_condition_set = []
                
                # for condition range of a specific path
                for i in range(len(condition_range[-1])):
                    condition_subset = copy.deepcopy(condition_set) # make sure 1 to 1 mapping
                    # for conditions which already built
                    for j in range(len(condition_subset)):
                        condition_subset[j].insert(0, condition_range[-1][i]) # insert path condition
                    new_condition_set += condition_subset

                condition_set = new_condition_set
                condition_set = self._grid_cond_builder(condition_set, condition_range[:-1])

        return condition_set

    
    def _baseline_setter(self, phase='test'):
        """
        Updates baseline accuracy for each path.
        
        [args]      (str) phase : use test or valid set 'valid' or 'test '
        """
        baseline = []
        
        for i in range(self.num_path): 
            condition = self._fullcond_setter(i)
            accuracy, _, _, _ = self.inference(phase=phase, exit_cond=condition)
            baseline.append(accuracy)
        
        self.result_dict['baseline'] = baseline    
    
    
    def _result_dict_builder(self, phase):
        """
        Builds a dictionary to save results
        """
        result_dict = {
            'phase' : phase,
            'total_acc': [],
            'flops_score' : []
        }
        
        result_dict['path_cost'] = self.path_cost
        
        for i in range(self.num_path):
            result_dict['path_acc_' + str(i)] = []
            result_dict['path_ratio_' + str(i)] = []
            result_dict['path_cond_' + str(i)] = []
            result_dict['max_sr_dist_'+ str(i)] = [[], []]  # [co_dist, inco_dist]
            result_dict['entropy_dist_'+ str(i)] = [[], []] # [co_dist, inco_dist]

        self.result_dict = result_dict
            
    
    def _result_dict_updater(self, exit_cond, total_acc, path_acc, path_ratio, flops_score):
        """
        Updates inference result to result_dict
        """        
        assert len(path_acc) == len(path_ratio), 'path number should same for accuracy and ratio!'

        self.result_dict['total_acc'].append(total_acc)
        self.result_dict['flops_score'].append(flops_score)
        
        for i in range(self.num_path):
            self.result_dict['path_acc_'+str(i)].append(path_acc[i])
            self.result_dict['path_ratio_'+str(i)].append(path_ratio[i])
            
            if i < self.num_path-1:
                path_cond = exit_cond[i]
            else:
                path_cond = -1
            
            self.result_dict['path_cond_'+str(i)].append(path_cond)
    
    
    def plotter(self, make_dir=True):
        """
        explanation
        
        [args]      (type) name: 

        [returns]   (type) name : 
        """
        save_path = self.save_path + self.name
        directory_setter(save_path, make_dir)
        fig = None
        pass

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.adaptivecomp.utils import LogitCond, SL_Pair
from data_utils import cifar_100_setter
from models.ResNet import resnet101, resnet18

dataloaders, dataset_sizes = cifar_100_setter(root='./data/cifar100')

smallnet = resnet18(num_classes=100)
smallnet.load_state_dict(torch.load('./results/trained_models/ResNet18_ce_acc/trained_model.pth'))
largenet = resnet101(num_classes=100)
largenet.load_state_dict(torch.load('./results/trained_models/ResNet101_ce/trained_model.pth'))

model = SL_Pair(smallnet, largenet, exit_cond=1)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [12]:
inspector = InspectionHandler(model, dataloaders, dataset_sizes, num_path=2, path_cost=(0.5, 1.0), use_small=True)

Baselines & SR distribution result has updated.



In [13]:
inspector.grid_inspector([0.7], grid=0.025)

Starting Grid Inspection...

