In [14]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, models
import torchvision.transforms as TF
from collections import Counter
from tqdm import tqdm 



from typing import Any, Callable, Dict, List, Optional, Union, Tuple


def exclude_inds_2021(labels):
    exclude_lbls = [0,9,10,11]

    indices_all = np.arange(labels.shape[0])
    exclude_inds = []
    for lb in exclude_lbls:
        inds = np.where(labels==lb)[0].astype(int)
        exclude_inds.extend(inds.tolist())
    
    indices_keep = np.setdiff1d(indices_all, np.array(exclude_inds))
    
    return indices_keep




from torchvision import datasets, models, transforms
  

class WrapNaturalist(datasets.INaturalist):
    
    def __init__(self, root: str, version: str = "2021_train", target_type: Union[List[str], str] = "full", \
        transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False) -> None:
        super().__init__(root, version, target_type, transform, target_transform, download)
        
        self.target_transform = target_transform
        self.version = version
        
        if '19' in self.version:
            indices_path = root+'/'+'2019_valid_indices.npy'
        elif '21' in self.version:
            self.map_target_2021 = {1:0,2:1,3:2,4:3,5:4,6:5,7:6,8:7,12:8}
            self.target_transform =  transforms.Lambda(lambda x: self.map_target_2021[x])
            indices_path = root+'/'+'2021_valid_indices.npy'
            
        self.indices_path = indices_path
        self.indices_all = np.load(indices_path)
            
    def __len__(self):
        return self.indices_all.shape[0]
        
    def __getindex__(self, idx):
        
        idx = self.indices_all[idx]
        
        cat_id, _ = self.index[idx]

        target = []
        # target_name = []
        for t in self.target_type:
            if t == "full":
                target.append(cat_id)
            else:
                target.append(self.categories_map[cat_id][t])
                # if '21' in self.version:
                #     target_name.append(self.all_categories[cat_id].split('_')[2])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.target_transform is not None:
            target = self.target_transform(target)

        
        return target
    
    
    def __getitem_wrap__(self, idx):
        
        idx = self.indices_all[idx]
        
        im, target = self.__getitem__(idx)
        
        # if self.target_transform is not None:
        #     target = self.target_transform(target)
            
        return im, target
    
    
    def __getlabels_2019__(self):
        
        self.labels = np.zeros((self.__len__(),)).astype(int)
        for i in range(self.__len__()):
            target = self.__getindex__(i)            
            self.labels[i]=target
            
        return self.labels
            
            
    def __getlabels_2021__(self):
        self.labels = np.zeros((self.__len__(),)).astype(int)
        for i in range(self.__len__()):
            target = self.__getindex__(i)            
            self.labels[i]=target
            # self.labels_map.append((target, name))
                    
        return self.labels
    
    

    
class inaturalist_normalize():
    def __init__(self):
        self.tf = TF.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

    def __call__(self, img):
        return self.tf(img)


class inaturalist_train():
    """Pre-processing for inaturalist training.
    """
    # TODO maybe the tranform topilimage is messing up the normalization!
    def __init__(self):
        self.tf = TF.Compose([TF.Resize(256), TF.CenterCrop(224), TF.RandomHorizontalFlip(), TF.ToTensor(), 
                            #   inaturalist_normalize(),
                              ])
        # self.tf = TF.Compose([core50_normalize()])

    def __call__(self, img):
        return self.tf(img)
    
tapply = inaturalist_train()

    
dataroot = '/lab/arios/ProjIntel/incDFM/data/inaturalist21/'


dset_index = WrapNaturalist(dataroot, version='2021_train_mini', target_type=['phylum'], \
    transform=tapply, target_transform=None, download=False)
labels = dset_index.__getlabels_2021__()
    

# dset_index = WrapNaturalist(dataroot, version='2019', target_type=['super'], \
#     transform=tapply, target_transform=None, download=False)
# labels = dset_index.__getlabels_2019__()

Counter(labels), labels.shape

(Counter({0: 137600,
          1: 120800,
          2: 1300,
          3: 1050,
          4: 8450,
          5: 4200,
          6: 12850,
          7: 1750,
          8: 210900}),
 (498900,))

In [15]:


class inaturalistTask():
    def __init__(self, dataroot, dset_name='inaturalist19', tasklist='task_indices.npy', transform=None, \
        returnIDX=False, train=True, preload=False):
        """ 
        dataset for each individual task
        """
        self.preload = preload
        self.transform = transform
        

        self.transform_preload = transform
        self.dset_name = dset_name
        
        if self.dset_name == 'inaturalist19':
            self.dataset = WrapNaturalist(dataroot, version='2019', target_type=['super'], \
                                    transform=transform, download=False)
            _ = self.dataset.__getlabels_2019__()
        elif self.dset_name == 'inaturalist21':
            self.dataset = WrapNaturalist(dataroot, version='2021_train_mini', target_type=['phylum'], \
                                    transform=transform, download=False)
            _ = self.dataset.__getlabels_2021__()

        # have option of loading multiple tasklists if the case
        if isinstance(tasklist, list):
            self.indices_task_init=[]
            for l in tasklist:
                self.indices_task_init.append(np.load(l))
            self.indices_task_init = np.concatenate(self.indices_task_init)
        else:
            self.indices_task_init = np.load(tasklist)
            

        self.returnIDX = returnIDX
        
        self.indices_task = np.copy(self.indices_task_init)

            
    def __len__(self):
        return self.indices_task.shape[0]

    def select_random_subset(self, random_num):

        inds_keep = np.random.permutation(np.arange(self.indices_task_init.shape[0]))[:random_num]

        self.indices_task = self.indices_task_init[inds_keep]
        
    def select_specific_subset(self, indices_select):
        
        self.indices_task = self.indices_task_init[indices_select]
        
    def __getitem__(self, idx):
        
        idx = self.indices_task[idx]
        print('idx', idx)
        im, class_lbl = self.dataset.__getitem_wrap__(idx)

        # assert self.dataset.labels[idx] == class_lbl
        
        if self.returnIDX:
            return im, class_lbl, class_lbl, idx
            
        return im, class_lbl, class_lbl



indices_task = '/lab/arios/ProjIntel/incDFM/src/novelty_dfm_CL/Experiments_DFM_CL/inaturalist21/holdout_0.20_val_0.10/train/nc_train_task_0.npy'

dset_t = inaturalistTask(dataroot, dset_name='inaturalist21', tasklist=indices_task, transform=tapply)

In [16]:
dset_t.__getitem__(0)

idx 40552


(tensor([[[0.8588, 0.8745, 0.8745,  ..., 0.5137, 0.5412, 0.5608],
          [0.8588, 0.8627, 0.8588,  ..., 0.4824, 0.5176, 0.5294],
          [0.8667, 0.8549, 0.8392,  ..., 0.4667, 0.4784, 0.4863],
          ...,
          [0.7176, 0.6980, 0.6784,  ..., 0.3294, 0.3020, 0.3059],
          [0.7059, 0.6627, 0.6353,  ..., 0.2627, 0.2627, 0.2510],
          [0.6706, 0.6275, 0.6039,  ..., 0.2863, 0.2863, 0.2784]],
 
         [[0.8431, 0.8510, 0.8471,  ..., 0.4549, 0.4824, 0.4980],
          [0.8392, 0.8353, 0.8275,  ..., 0.4275, 0.4627, 0.4706],
          [0.8392, 0.8275, 0.8078,  ..., 0.4157, 0.4235, 0.4353],
          ...,
          [0.6784, 0.6588, 0.6471,  ..., 0.2549, 0.2392, 0.2431],
          [0.6667, 0.6235, 0.6000,  ..., 0.2118, 0.2196, 0.2118],
          [0.6275, 0.5882, 0.5647,  ..., 0.2510, 0.2549, 0.2471]],
 
         [[0.8000, 0.8039, 0.8000,  ..., 0.3922, 0.4235, 0.4392],
          [0.7961, 0.7882, 0.7765,  ..., 0.3686, 0.4039, 0.4118],
          [0.7961, 0.7843, 0.7608,  ...,

In [17]:
dset_index.__getitem_wrap__(40552)

(tensor([[[0.8588, 0.8745, 0.8745,  ..., 0.5137, 0.5412, 0.5608],
          [0.8588, 0.8627, 0.8588,  ..., 0.4824, 0.5176, 0.5294],
          [0.8667, 0.8549, 0.8392,  ..., 0.4667, 0.4784, 0.4863],
          ...,
          [0.7176, 0.6980, 0.6784,  ..., 0.3294, 0.3020, 0.3059],
          [0.7059, 0.6627, 0.6353,  ..., 0.2627, 0.2627, 0.2510],
          [0.6706, 0.6275, 0.6039,  ..., 0.2863, 0.2863, 0.2784]],
 
         [[0.8431, 0.8510, 0.8471,  ..., 0.4549, 0.4824, 0.4980],
          [0.8392, 0.8353, 0.8275,  ..., 0.4275, 0.4627, 0.4706],
          [0.8392, 0.8275, 0.8078,  ..., 0.4157, 0.4235, 0.4353],
          ...,
          [0.6784, 0.6588, 0.6471,  ..., 0.2549, 0.2392, 0.2431],
          [0.6667, 0.6235, 0.6000,  ..., 0.2118, 0.2196, 0.2118],
          [0.6275, 0.5882, 0.5647,  ..., 0.2510, 0.2549, 0.2471]],
 
         [[0.8000, 0.8039, 0.8000,  ..., 0.3922, 0.4235, 0.4392],
          [0.7961, 0.7882, 0.7765,  ..., 0.3686, 0.4039, 0.4118],
          [0.7961, 0.7843, 0.7608,  ...,