Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wiki] Tutorial #80

Open
Jingkang50 opened this issue May 28, 2022 · 2 comments
Open

[wiki] Tutorial #80

Jingkang50 opened this issue May 28, 2022 · 2 comments
Assignees

Comments

@Jingkang50
Copy link
Owner

No description provided.

@JediWarriorZou
Copy link
Collaborator

JediWarriorZou commented May 29, 2022

1. How to implement a method with train (Classification) and test (OOD) stage ?

Train stage:

  • Config: You should write a few of configuration files with format ".yml". A configuration file is in "key-value" form. Generally, these configuration files are associated with five aspects.
    • datasets: define details of your dataset including name, batch size, image size, number of class etc. In some case, multiple datasets are used and splitted so that the configuration is rather complex.
      dataset:
        name: cifar10
        split_names: [train, val, test]
        num_classes: 10
        pre_size: 32
        image_size: 32
        num_workers: '@{num_workers}'
        num_gpus: '@{num_gpus}'
        num_machines: '@{num_machines}'
        train:
          dataset_class: ImglistDataset
          data_dir: ./data/images_classic/
          imglist_pth: ./data/benchmark_imglist/cifar10/train_cifar10.txt
          batch_size: 128
          shuffle: True
          interpolation: bilinear
        val:
          dataset_class: ImglistDataset
          data_dir: ./data/images_classic/
          imglist_pth: ./data/benchmark_imglist/cifar10/val_cifar10.txt
          batch_size: 200
          shuffle: False
          interpolation: bilinear
        test:
          dataset_class: ImglistDataset
          data_dir: ./data/images_classic/
          imglist_pth: ./data/benchmark_imglist/cifar10/test_cifar10.txt
          batch_size: 200
          shuffle: False
          interpolation: bilinear
    • networks: define details of your network including pretrain requirement, checkpoint of network etc.
      network:
        name: resnet18_32x32
        num_classes: '@{dataset.num_classes}'
        pretrained: True          # set 'True' to load pretrained model
        checkpoint: ./results/cifar10_resnet18_32x32_base_e200_lr_0.1/best.ckpt            # ignore if pretrained is false
        num_gpus: '@{num_gpus}'
    • pipelines: Certainly you will use a train pipeline. In this file, you can set the configuraion of your device such as number of worker and number of GPU. Moreover, you can set parameters of the pipeline, trainer, evaluator, optimizer, scheduler, recorder etc.
      exp_name: "'@{dataset.name}'_'@{network.name}'_'@{trainer.name}'"
      output_dir: ./results/
      save_output: True
      force_merge: False # disabled if 'save_output' is False
      num_classes: '@{dataset.num_classes}'
      
      num_gpus: 1
      num_workers: 2
      num_machines: 1
      machine_rank: 0
      
      baseline: False
      
      pipeline:
        name: train
      
      trainer:
        name: conf_esti
        budget: 0.3
        lmbda: 0.1
        eps: 1.0e-12
      
      evaluator:
        name: conf_esti
      
      optimizer:
        num_epochs: 200
        learning_rate: 0.1
        momentum: 0.9
        nesterov: True
        weight_decay: 5.0e-4
      
      recorder:
        name: conf_esti
        save_all_models: False
      
      scheduler:
        milestones: [80, 120] #[60,120,160] for cifar10, [80, 120] for svhn
        gamma: 0.1 #0.2 for cifar10, 0.1 for svhn
      
      preprocessor:
        name: base
    • preprocessors: Some method need a special preprocessor. The base preprocessor is chosen if you do not need special preprocessor.
    • proprocessors: Some method need a special postprocessor. The base postprocessor is chosen if you do not need special postprocessor.
  • Main code: It is challenging for you to implement the train stage in one python file. In the train stage, the code can be divided into 10 main sections you will implement. Some sections are optional which means you may not implement or amend these sections.
    • utils: It provides some functions including reading configuration files, generating logging files etc. You rarely need to amend this section.

    • losses: It defines a few of special and complicated loss functions.

    • datasets: It defines some dataset forms and provides two dataloader generators include "get_dataloader" and "get_ood_dataloader" in utils.py. The most common dataset form is imglist. It generates datasets with respective txt files containing pathes of images and labels.You rarely need to amend this section.

       # part of ImglistDataset
      class ImglistDataset(BaseDataset):
          def __init__(self,  name,split,  interpolation, image_size, imglist_pth,data_dir, 
          num_classes,preprocessor,data_aux_preprocessor,
          maxlen=None,dummy_read=False,dummy_size=None, **kwargs):
      
          super(ImglistDataset, self).__init__(**kwargs)
      
          self.name = name
          self.image_size = image_size
          with open(imglist_pth) as imgfile:
              self.imglist = imgfile.readlines()
          self.data_dir = data_dir
          self.num_classes = num_classes
          self.preprocessor = preprocessor
          self.transform_image = preprocessor
          self.transform_aux_image = data_aux_preprocessor
          self.maxlen = maxlen
          self.dummy_read = dummy_read
          self.dummy_size = dummy_size
          if dummy_read and dummy_size is None:
              raise ValueError(
                  'if dummy_read is True, should provide dummy_size')
    • network: It defines plenty of networks and provide "get_network" function to choose the network according to config in utils.py. Generally, a network class consists of two main methods.

      • init(): It defines the network structure and details of each layer.
      # example Lenet
         class LeNet(nn.Module):
         def __init__(self, num_classes, num_channel=3):
             super(LeNet, self).__init__()
             self.num_classes = num_classes
      
             self.block1 = nn.Sequential(
                 nn.Conv2d(in_channels=num_channel,
                           out_channels=6,
                           kernel_size=5,
                           stride=1,
                           padding=2), nn.ReLU(), nn.MaxPool2d(kernel_size=2))
      
             self.block2 = nn.Sequential(
                 nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
                 nn.ReLU(), nn.MaxPool2d(kernel_size=2))
      
             self.block3 = nn.Sequential(
                 nn.Conv2d(in_channels=16,
                           out_channels=120,
                           kernel_size=5,
                           stride=1), nn.ReLU())
      
             self.classifier = nn.Sequential(
                 nn.Linear(in_features=120, out_features=84),
                 nn.ReLU(),
                 nn.Linear(in_features=84, out_features=num_classes),
             )
      • forward(): It defines how a input image is processed by the network and what is returned.
       def forward(self, x, return_feature=False, return_feature_list=False):
           feature1 = self.block1(x)
           feature2 = self.block2(feature1)
           feature3 = self.block3(feature2)
           feature = feature3.view(feature3.shape[0], -1)
           logits_cls = self.classifier(feature)
           feature_list = [feature1, feature2, feature3]
           if return_feature:
               return logits_cls, feature
           elif return_feature_list:
               return logits_cls, feature_list
           else:
               return logits_cls
    • trainers: In it you can design and implement your own trainer. Generally, a trainer class consists of two methods.

      • init(): It defines some important attributes including config, model, dataloader, optimizer, scheduler etc.
        #example ConfBranchTrainer, it is a method of which the loss includes classification loss and confidence loss.
        def __init__(self, net, train_loader, config: Config) -> None:
           self.train_loader = train_loader
           self.config = config
           self.net = net
           self.prediction_criterion = nn.NLLLoss().cuda()
           self.optimizer = torch.optim.SGD(
               net.parameters(),
               lr=config.optimizer['learning_rate'],
               momentum=config.optimizer['momentum'],
               nesterov=config.optimizer['nesterov'],
               weight_decay=config.optimizer['weight_decay'])
           self.scheduler = MultiStepLR(self.optimizer,
                                        milestones=config.scheduler['milestones'],
                                        gamma=config.scheduler['gamma'])
           self.lmbda = self.config.trainer['lmbda']
      • train_epoch(): It is used to train a net for one epoch. During the epoch:
        1. We use "self.net.train()" to begin train mode and initiate some variables.
          def train_epoch(self, epoch_idx):
                 self.net.train()
                 correct_count = 0.
                 total = 0.
                 accuracy = 0.
                 train_dataiter = iter(self.train_loader)
        1. We use iterator (tqdm is used) to input each batch of data into the model and obtain output (score, feature ...).
         for train_step in tqdm(range(1,
                                       len(train_dataiter) + 1),
                                 desc='Epoch {:03d}'.format(epoch_idx),
                                 position=0,
                                 leave=True):
              batch = next(train_dataiter)
              images = Variable(batch['data']).cuda()
              labels = Variable(batch['label']).cuda()
              labels_onehot = Variable(
                  encode_onehot(labels, self.config.num_classes))
              self.net.zero_grad()
        
              pred_original, confidence = self.net(images)
              pred_original = F.softmax(pred_original, dim=-1)
              confidence = torch.sigmoid(confidence)
              eps = self.config.trainer['eps']
              pred_original = torch.clamp(pred_original, 0. + eps, 1. - eps)
              confidence = torch.clamp(confidence, 0. + eps, 1. - eps)
        1. We define the loss function (may be very complex) and compute loss with model output.
              b = Variable(
                  torch.bernoulli(
                      torch.Tensor(confidence.size()).uniform_(0,
                                                               1))).cuda()
              conf = confidence * b + (1 - b)
              pred_new = pred_original * conf.expand_as(
                  pred_original) + labels_onehot * (
                      1 - conf.expand_as(labels_onehot))
              pred_new = torch.log(pred_new)
              xentropy_loss = self.prediction_criterion(pred_new, labels)
              confidence_loss = torch.mean(-torch.log(confidence))
              total_loss = xentropy_loss + (self.lmbda * confidence_loss)
        1. We use backpropagation by loss.backward() with optimizer and scheduler.
             total_loss.backward()
             self.optimizer.step()
             self.scheduler.step()
        1. We need to return metrics (loss) and net.
            metrics = {}
            metrics['train_acc'] = accuracy
            metrics['epoch_idx'] = epoch_idx
            return self.net, metrics
    • evaluators: It defines a lot of evaluators to assess a network's performance and provides "get_evaluator" function to choose the evaluator according to config in utils.py. Two common and important functions are eval_acc() and eval_ood()

      • eval_acc() : It is used to calculate classification accuracy of the model. During the evaluation:
        1. Dataloader of test dataset is used and each batch of data is input into the model to obtain scores.
        2. By scores we establish a inference method to gain predicted labels of input images
        3. Comparing actual labels with predicted labels, we computs accuracy of classification and return the value .
       #example ood_evaluator
       def eval_acc(self,
                       net: nn.Module,
                       data_loader: DataLoader,
                       postprocessor: BasePostprocessor = None,
                       epoch_idx: int = -1):
              """Returns the accuracy score of the labels and predictions.
              :return: float
              """
              if type(net) is dict:
                  net['backbone'].eval()
              else:
                  net.eval()
              id_pred, _, id_gt = postprocessor.inference(net, data_loader)
              metrics = {}
              metrics['acc'] = sum(id_pred == id_gt) / len(id_pred)
              metrics['epoch_idx'] = epoch_idx
              return metrics 
      • eval_ood(): It is used to assess ood detection performance of the model and compute a series of metrics. During the evaluation:
        1. Dataloders of id dataset and ood dataset are used and each batch of data is input into the model to obtain ood scores.
        2. By scores we establish a inference method to gain predicted labels and confidence of both id images and ood images since ood detection is similar to binary classification. -1 represents ood ,others represents id.
        3. Three lists including predicted labels list, confidence list, actual labels list are filled and then we compute metrics including FPR@95,AUROC,AUPR_IN,AUPR_OUT,CCR_4,CCR_3,CCR_2,CCR_1,ACC.
       def _eval_ood(self,
                  net: nn.Module,
                  id_list: List[np.ndarray],
                  ood_data_loaders: Dict[str, Dict[str, DataLoader]],
                  postprocessor: BasePostprocessor,
                  ood_split: str = 'nearood'):
        print(f'Processing {ood_split}...', flush=True)
        [id_pred, id_conf, id_gt] = id_list
        metrics_list = []
        for dataset_name, ood_dl in ood_data_loaders[ood_split].items():
            print(f'Performing inference on {dataset_name} dataset...',
                  flush=True)
            ood_pred, ood_conf, ood_gt = postprocessor.inference(net, ood_dl)
            ood_gt = -1 * np.ones_like(ood_gt)  # hard set to -1 as ood
            if self.config.recorder.save_scores:
                self._save_scores(ood_pred, ood_conf, ood_gt, dataset_name)
      
            pred = np.concatenate([id_pred, ood_pred])
            conf = np.concatenate([id_conf, ood_conf])
            label = np.concatenate([id_gt, ood_gt])
      
            print(f'Computing metrics on {dataset_name} dataset...')
      
            ood_metrics = compute_all_metrics(conf, label, pred)
            if self.config.recorder.save_csv:
                self._save_csv(ood_metrics, dataset_name=dataset_name)
            metrics_list.append(ood_metrics)
      
        print('Computing mean metrics...', flush=True)
        metrics_list = np.array(metrics_list)
        metrics_mean = np.mean(metrics_list, axis=0)
        if self.config.recorder.save_csv:
            self._save_csv(metrics_mean, dataset_name=ood_split)
    • preprocessors: It is used for image preprocessing and used in dataset classes. If you do not need special preprocessor, you can just use base preprocessor which only transforms images.

    class BasePreprocessor():
       """For train dataset standard transformation."""
       def __init__(self, config: Config, split):
           dataset_name = config.dataset.name.split('_')[0]
           image_size = config.dataset.image_size
           pre_size = center_crop_dict[image_size]
           if dataset_name in normalization_dict.keys():
               mean = normalization_dict[dataset_name][0]
               std = normalization_dict[dataset_name][1]
           else:
               mean = [0.5, 0.5, 0.5]
               std = [0.5, 0.5, 0.5]
    
           interpolation = interpolation_modes[
               config.dataset['train'].interpolation]
    
           self.transform = tvs_trans.Compose([
               Convert('RGB'),
               tvs_trans.Resize(pre_size, interpolation=interpolation),
               tvs_trans.CenterCrop(image_size),
               tvs_trans.RandomHorizontalFlip(),
               tvs_trans.RandomCrop(image_size, padding=4),
               tvs_trans.ToTensor(),
               tvs_trans.Normalize(mean=mean, std=std),
           ])
    
       def setup(self, **kwargs):
           pass
    
       def __call__(self, image):
           return self.transform(image)
    • postprocessors: It is used to complement the framework of a model to gain predicted labels and confidence and it isused in evaluator classes. If you do not need special postprocessor, you can just use base postprocessor which defines postprocess() and inference() functions.
      • postprocess(): It inputs a batch of data into the network to gain scores. Then it uses softmax method to gain predicted labels and confidence. It assists inference() function.
      • inference(): It uses postprocess() to gain predicted labels, confidence, actual labels of all batches in the input dataloade Finally it returns filled labels list, confidence list, actual labels list.
    class BasePostprocessor:
    def __init__(self, config):
        self.config = config
    
    def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
        pass
    
    @torch.no_grad()
    def postprocess(self, net: nn.Module, data: Any):
        output = net(data)
        score = torch.softmax(output, dim=1)
        conf, pred = torch.max(score, dim=1)
        return pred, conf
    
    def inference(self, net: nn.Module, data_loader: DataLoader):
        pred_list, conf_list, label_list = [], [], []
        for batch in data_loader:
            data = batch['data'].cuda()
            label = batch['label'].cuda()
            pred, conf = self.postprocess(net, data)
            for idx in range(len(data)):
                pred_list.append(pred[idx].cpu().tolist())
                conf_list.append(conf[idx].cpu().tolist())
                label_list.append(label[idx].cpu().tolist())
    
        # convert values into numpy array
        pred_list = np.array(pred_list, dtype=int)
        conf_list = np.array(conf_list)
        label_list = np.array(label_list, dtype=int)
    
        return pred_list, conf_list, label_list
    • recorders: It defines plenty of recorders and provides "get_recorder" function to choose the recorder according to config in utils.py. Generally, a recorder class defines 4 functions.
      • init() : It initiates system time, best metrics and best epoch.
      • report(): It reports results of each epoch.
      • save_model(): It saves model and updates the saved one so that the best model is saved
      • summary(): It reports the best metrics and best epoch when the train stage ends.
     class BaseRecorder:
      def __init__(self, config) -> None:
          self.config = config
    
          self.best_acc = 0.0
          self.best_epoch_idx = 0
    
          self.begin_time = time.time()
          self.output_dir = config.output_dir
    
      def report(self, train_metrics, val_metrics):
          print('\nEpoch {:03d} | Time {:5d}s | Train Loss {:.4f} | '
                'Val Loss {:.3f} | Val Acc {:.2f}'.format(
                    (train_metrics['epoch_idx']),
                    int(time.time() - self.begin_time), train_metrics['loss'],
                    val_metrics['loss'], 100.0 * val_metrics['acc']),
                flush=True)
    
      def save_model(self, net, val_metrics):
          if self.config.recorder.save_all_models:
              torch.save(
                  net.state_dict(),
                  os.path.join(
                      self.output_dir,
                      'model_epoch{}.ckpt'.format(val_metrics['epoch_idx'])))
    
          # enter only if better accuracy occurs
          if val_metrics['acc'] >= self.best_acc:
              # delete the depreciated best model
              old_fname = 'best_epoch{}_acc{:.4f}.ckpt'.format(
                  self.best_epoch_idx, self.best_acc)
              old_pth = os.path.join(self.output_dir, old_fname)
              Path(old_pth).unlink(missing_ok=True)
    
              # update the best model
              self.best_epoch_idx = val_metrics['epoch_idx']
              self.best_acc = val_metrics['acc']
              torch.save(net.state_dict(),
                         os.path.join(self.output_dir, 'best.ckpt'))
    
              save_fname = 'best_epoch{}_acc{:.4f}.ckpt'.format(
                  self.best_epoch_idx, self.best_acc)
              save_pth = os.path.join(self.output_dir, save_fname)
              torch.save(net.state_dict(), save_pth)
    
          # save last path
          if val_metrics['epoch_idx'] == self.config.optimizer.num_epochs:
              save_fname = 'last_epoch{}_acc{:.4f}.ckpt'.format(
                  val_metrics['epoch_idx'], val_metrics['acc'])
              save_pth = os.path.join(self.output_dir, save_fname)
              torch.save(net.state_dict(), save_pth)
    
      def summary(self):
          print('Training Completed! '
                'Best accuracy: {:.2f} '
                'at epoch {:d}'.format(100 * self.best_acc, self.best_epoch_idx),
                flush=True)
    • pipelines: It defines plenty of pipelines including train and test pipeline and provides "get_pipeline" function to choose the pipeline according to config in utils.py. For a train pipeline, you are supposed to finish these steps.
      1 . Initiate logger 2 . Initiate dataloader; 3 . Initiate network; 4 .Initiate postprocessor (optional) 5 . Initiate trainer; 6 . Initiate evaluator; 7 . Initiate recorder; 8 . Start training; 9 . Evaluation.
      To attain more details of how to build a pipeline, please refer to Tutorials 1.
  • Script: We use a script to start the train process, You should load the yml files to input configuration of train stage. Generally, what you need to load includes dataset, network, pipeline, preprocessor and postprocessor config.

Test stage:
Test stage is very similar to train stage. As a result, it is simply described.

  • Config: You need to change the network config to fill the checkpoint of trained model and change pretrained requirement to true. You also need to write a config for test pipeline. In short, you should amend the config for test stage.
  • Main code: You need to build a test pipeline which is relatively simple. For a test pipeline, you are supposed to finish these steps.
    1 . Initiate logger 2. Initiate id and ood dataloader 3. Initiate network; 4 . Initiate postprocessor 5 . Initiate evaluator 6. eval_acc() 7. eval_ood()
    To attain more details of how to build a pipeline, please refer to Tutorials 1.
  • Script: We use a script to start the test process, You should load the yml files to input configuration of test stage. Generally, what you need to load includes dataset, network, pipeline, preprocessor and postprocessor config.

@Zzitang
Copy link
Collaborator

Zzitang commented Jun 4, 2022

2. Implement a method with test stage: take Gram as an example

  1. For each approach, we suggest starting from the pipelines. In gram, we detect out-of-distribution samples by comparing each value in the gram matrix with its respective range observed over the training data. No training is required and we can directly employ the TestOODPipeline.
from openood.datasets import get_dataloader, get_ood_dataloader
from openood.evaluators import get_evaluator
from openood.networks import get_network
from openood.postprocessors import get_postprocessor
from openood.utils import setup_logger


class TestOODPipeline:
    def __init__(self, config) -> None:
        self.config = config

    def run(self):
        # generate output directory and save the full config file
        setup_logger(self.config)

        # get dataloader
        id_loader_dict = get_dataloader(self.config)
        ood_loader_dict = get_ood_dataloader(self.config)

        # init network
        net = get_network(self.config.network)

        # init ood evaluator
        evaluator = get_evaluator(self.config)

        # init ood postprocessor
        postprocessor = get_postprocessor(self.config)
        # setup for distance-based methods
        postprocessor.setup(net, id_loader_dict, ood_loader_dict)
        print('\n', flush=True)
        print(u'\u2500' * 70, flush=True)

        # start calculating accuracy
        print('\nStart evaluation...', flush=True)
        acc_metrics = evaluator.eval_acc(net, id_loader_dict['test'],
                                         postprocessor)
        print('\nAccuracy {:.2f}%'.format(100 * acc_metrics['acc']),
              flush=True)
        print(u'\u2500' * 70, flush=True)

        # start evaluating ood detection methods
        evaluator.eval_ood(net, id_loader_dict, ood_loader_dict, postprocessor)
        print('Completed!', flush=True)
  1. In TestOODPipeline, necessary components consist of dataloader, network, evaluator, and postprocessor. OPENOOD codebase provides the OODEvaluator suitable for most OOD Detection methods. Therefore, all we need is a postprocessor that provides methods to calculate necessary results like predictions and confidence.
  2. Each postprocessor extends the BasePostprocessor class, which provides three basic method including setup(), inference() and postprocess().
class GRAMPostprocessor(BasePostprocessor):
    def __init__(self, config):
        self.config = config
        self.postprocessor_args = config.postprocessor.postprocessor_args
        self.num_classes = self.config.dataset.num_classes
        self.powers = self.postprocessor_args.powers

        self.feature_min, self.feature_max = None, None
  • setup(): The setup() method takes the input of both id and ood dataloaders to provide the necessary preparation for the upcoming process.
    Here, as we know from the Gram method, we require the range of gram matrices of the training data, in particular, the maximum and minimum for gram matrices of each class, each layer of the network, and each order of gram matrices. In GRAMPostprocessor, we use sample_estimator() to calculate the required ranges of the feature using the model and train data.
    def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):

        self.feature_min, self.feature_max = sample_estimator(
            net, id_loader_dict['train'], self.num_classes, self.powers)
@torch.no_grad()
def sample_estimator(model, train_loader, num_classes, powers):

    model.eval()

    num_layer = 5
    num_poles_list = powers
    num_poles = len(num_poles_list)
    feature_class = [[[None for x in range(num_poles)]
                      for y in range(num_layer)] for z in range(num_classes)]
    label_list = []
    mins = [[[None for x in range(num_poles)] for y in range(num_layer)]
            for z in range(num_classes)]
    maxs = [[[None for x in range(num_poles)] for y in range(num_layer)]
            for z in range(num_classes)]

    # collect features and compute gram metrix
    for batch in tqdm(train_loader, desc='Compute min/max'):
        data = batch['data'].cuda()
        label = batch['label']
        _, feature_list = model(data, return_feature_list=True)
        label_list = tensor2list(label)
        for layer_idx in range(num_layer):

            for pole_idx, p in enumerate(num_poles_list):
                temp = feature_list[layer_idx].detach()

                temp = temp**p
                temp = temp.reshape(temp.shape[0], temp.shape[1], -1)
                temp = ((torch.matmul(temp,
                                      temp.transpose(dim0=2,
                                                     dim1=1)))).sum(dim=2)
                temp = (temp.sign() * torch.abs(temp)**(1 / p)).reshape(
                    temp.shape[0], -1)

                temp = tensor2list(temp)
                for feature, label in zip(temp, label_list):
                    if isinstance(feature_class[label][layer_idx][pole_idx],
                                  type(None)):
                        feature_class[label][layer_idx][pole_idx] = feature
                    else:
                        feature_class[label][layer_idx][pole_idx].extend(
                            feature)
    # compute mins/maxs
    for label in range(num_classes):
        for layer_idx in range(num_layer):
            for poles_idx in range(num_poles):
                feature = torch.tensor(
                    np.array(feature_class[label][layer_idx][poles_idx]))
                current_min = feature.min(dim=0, keepdim=True)[0]
                current_max = feature.max(dim=0, keepdim=True)[0]

                if mins[label][layer_idx][poles_idx] is None:
                    mins[label][layer_idx][poles_idx] = current_min
                    maxs[label][layer_idx][poles_idx] = current_max
                else:
                    mins[label][layer_idx][poles_idx] = torch.min(
                        current_min, mins[label][layer_idx][poles_idx])
                    maxs[label][layer_idx][poles_idx] = torch.max(
                        current_min, maxs[label][layer_idx][poles_idx])

    return mins, maxs
  • inference(): The inference() method provides the predictions, confidence, and original labels for the input dataloader, which will be utilized in evaluators. This method employs the postproces() method to calculate the prediction and confidence for each data. Usually, there’s no need for custom postprocessors to rewrite the inference() method. Therefore, in GRAMPostprocessor, we only rewrite the setup() and postprocess() methods.
  • postprocess(): The postprocess() method calculates the prediction and confidence for each input data.
    For gram, we use get_deviations() which calculates the deviation of the gram metric of each data using the mins and maxs prepared in setup(). The deviations are directly employed as the confidence.
    def postprocess(self, net: nn.Module, data: Any):
        preds, deviations = get_deviations(net, data, self.feature_min,
                                           self.feature_max, self.num_classes,
                                           self.powers)
        return preds, deviations
def get_deviations(model, data, mins, maxs, num_classes, powers):
    model.eval()

    num_layer = 5
    num_poles_list = powers
    exist = 1
    pred_list = []
    dev = [0 for x in range(200)]

    # get predictions
    logits, feature_list = model(data, return_feature_list=True)
    confs = F.softmax(logits, dim=1).cpu().detach().numpy()
    preds = np.argmax(confs, axis=1)
    predsList = preds.tolist()
    preds = torch.tensor(preds)

    for pred in predsList:
        exist = 1
        if len(pred_list) == 0:
            pred_list.extend([pred])
        else:
            for pred_now in pred_list:
                if pred_now == pred:
                    exist = 0
            if exist == 1:
                pred_list.extend([pred])

    # compute sample level deviation
    for layer_idx in range(num_layer):
        for pole_idx, p in enumerate(num_poles_list):
            # get gram metirx
            temp = feature_list[layer_idx].detach()
            temp = temp**p
            temp = temp.reshape(temp.shape[0], temp.shape[1], -1)
            temp = ((torch.matmul(temp, temp.transpose(dim0=2,
                                                       dim1=1)))).sum(dim=2)
            temp = (temp.sign() * torch.abs(temp)**(1 / p)).reshape(
                temp.shape[0], -1)
            temp = tensor2list(temp)

            # compute the deviations with train data
            for idx in range(len(temp)):
                dev[idx] += (F.relu(mins[preds[idx]][layer_idx][pole_idx] -
                                    sum(temp[idx])) /
                             torch.abs(mins[preds[idx]][layer_idx][pole_idx] +
                                       10**-6)).sum()
                dev[idx] += (F.relu(
                    sum(temp[idx]) - maxs[preds[idx]][layer_idx][pole_idx]) /
                             torch.abs(maxs[preds[idx]][layer_idx][pole_idx] +
                                       10**-6)).sum()
    conf = [i / 50 for i in dev]

    return preds, torch.tensor(conf)
  1. Now you can test the method with the correct config.
#!/bin/bash
# sh scripts/ood/gram/7_cifar_test_ood_gram.sh

GPU=1
CPU=1
node=36
jobname=openood

PYTHONPATH='.':$PYTHONPATH \
#srun -p dsta --mpi=pmi2 --gres=gpu:${GPU} -n1 \
#--cpus-per-task=${CPU} --ntasks-per-node=${GPU} \
#--kill-on-bad-exit=1 --job-name=${jobname} -w SG-IDC1-10-51-2-${node} \

python main.py \
--config configs/datasets/objects/cifar10.yml \
configs/datasets/objects/cifar10_ood.yml \
configs/networks/resnet18_32x32.yml \
configs/pipelines/test/test_gram.yml \
configs/postprocessors/gram.yml \
--dataset.image_size 32 \
--network.name resnet18_32x32 \
--num_workers 8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants