### Custom Dataset for Grid world

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class GridWorldDataset(Dataset):
    def __init__(self, num_grids, grid_size=(10,10)):
        self.num_grids = num_grids # number of grids 
        self.grid_size = grid_size # by default size is set to 10x10

    def __len__(self):
        return self.num_grids 

    def __getitem__(self, index):
        # 0 is empty space, 1 is obstacle
        grid = np.random.choice([0, 1], size=self.grid_size, p=[0.8, 0.2])  # 80% empty, 20% obstacles
        return torch.tensor(grid,dtype=int) # transform not added, can add transform argument separately

Adding complexity by including "goal" in each grid

In [4]:
class GridWorldDataset(Dataset):
    def __init__(self, num_grids, grid_size=(10,10)):
        self.num_grids = num_grids # number of grids
        self.grid_size = grid_size # by default size is set to 10x10

    def __len__(self):
        return self.num_grids 

    def __getitem__(self, index):
        # 0 is empty space, 1 is obstacle
        grid = np.random.choice([0, 1], size=self.grid_size, p=[0.8, 0.2])  # 80% empty, 20% obstacles
        # select a goal randomly and set it to 2
        index = np.random.choice(grid.shape[0],2) # returns index like [2,3]
        grid[index] = 2 # set goal to 
        return {
            'grid': torch.tensor(grid,dtype=int),
            'goal': torch.tensor(index,dtype=int)
        }

Creating dataloader for above

In [5]:
dataset = GridWorldDataset(1000,(10,10)) # let's say 1000 grid samples, this works for both the above cases
dataloader = DataLoader(dataset, batch_size=100, num_workers=4)

`torch.utils.data.DataLoader` is an iterator which provides all these features. Parameters used should be clear. One parameter of interest is `collate_fn`. You can specify how exactly the samples need to be batched using `collate_fn`. However, default collate should work fine for most use cases.<br/>
<br/>
`collate_fn` - *Try to see it as a glue that you specify the way examples stick together in a batch. If you don’t use it, PyTorch only put batch_size examples together as you would using torch.stack (not exactly it, but it is simple like that)*


In [6]:
# custom collate function
def collate_fn(data): # data is the return type of __getitem__ from dataset
    """
     data (list): List of dictionaries with 'grid' and 'goal' keys,
                     where each dictionary represents a single sample.
    """
    grids = torch.stack(item['grid'] for item in data)
    goals = torch.stack(item['goal'] for item in data)
    return {
        'grids': grids,
        'goals': goals
    }

In [7]:
# calling with custom collate_fn
dataloader = DataLoader(dataset, batch_size=100, num_workers=4, collate_fn=collate_fn)