# ModelNet40-Few Shot

> A dataset loader for the modelnet40-few shot classification. We are using the same dataset as [Point-BERT](https://github.com/lulutang0608/Point-BERT.git) and [Point-MAE](https://github.com/Pang-Yatian/Point-MAE.git)

In [None]:
#| default_exp datasets/modelnet_fewshot

The code to load the ModelNet40 dataset comes from the DGCNN repo. 
There are some alternations, so that it can store and load the data from a custom path. 
Also there is an option to load only a specific class. 

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
import numpy as np
import pickle

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from pclab.transforms import *

In [None]:
#| export

class ModelNetFewShot(Dataset):
    'Dataset to access the few shot classification data for the ModelNet40 dataset.'
    def __init__(self, 
                 path,
                 split='train',
                 way =5, 
                 shot=10, 
                 fold=0,
                 transforms=[]
                ):
        
        assert split in ['train', 'test'], 'Split should either be `train` or `test`'
        assert way in [5, 10]
        assert shot in [10, 20]
        assert fold in list(range(10))
        
        # just a reminder that the data also contain normal information
        self.use_normals=False

        self.split=split
        self.way=way
        self.shot=shot
        self.fold=fold
        self.transforms=transforms
        
        self.pickle_path = os.path.join(path, 'ModelNetFewshot', f'{self.way}way_{self.shot}shot', f'{self.fold}.pkl')
        
        with open(self.pickle_path, 'rb') as f:
            self.data = pickle.load(f)[self.split]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
                
        points, label, _ = self.data[idx]
        
        # points has normal information, but we only use the point coordinates
        points = points[:, :3]
        
        for t in self.transforms:
            points = t(points)
            
        return points, label

In [None]:
#| hide
path = "/home/ioannis/Desktop/programming/data"

In [None]:
#|eval: false
transforms=[RandomPointKeep(1024), RandomPointDropout(), UnitSphereNormalization(), AnisotropicScale(), ToTensor()]
dataset = ModelNetFewShot(path, 'train', way=5, shot=10, fold=5, transforms=transforms)
len(dataset)

50