In [1]:

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
import random

In [None]:

#loading the dataset 
transform = transforms.Compose([transforms.ToTensor()])
cifar10 = CIFAR10(root='./data', train=True, download=True, transform=transform)

class PatchDataset(Dataset):
    def __init__(self, cifar_data, patch_size=10, stride=8):
        self.cifar_data = cifar_data
        self.patch_size = patch_size
        self.stride = stride
        self.position_labels = {
            (0, 0): 0, (0, 1): 1, (0, 2): 2,
            (1, 0): 3,           (1, 2): 4,
            (2, 0): 5, (2, 1): 6, (2, 2): 7
        }

    def __len__(self):
        return len(self.cifar_data)
    # this function allows to extract aone patch from the image 
    def extract_patch(self, img, top, left):
        """Extract a single patch from the image."""
        return img[:, top:top + self.patch_size, left:left + self.patch_size]
    

    def __getitem__(self, idx):
        # loading the image 
        img, _ = self.cifar_data[idx]
        img_size = img.size(1) # get the image size in our case is 32 by 32 

        # Define and extract the reference center patch 
        center_top = (img_size - self.patch_size) // 2
        center_left = (img_size - self.patch_size) // 2
        reference_patch = self.extract_patch(img, center_top, center_left)

        # Randomly select one of the eight possible neighboring patches
        rel_pos = random.choice(list(self.position_labels.keys()))
        offset_y, offset_x = rel_pos[0] - 1, rel_pos[1] - 1
        neighbor_top = center_top + offset_y * self.stride
        neighbor_left = center_left + offset_x * self.stride
        neighbor_patch = self.extract_patch(img, neighbor_top, neighbor_left)

        # Get the label based on the relative position
        label = self.position_labels[rel_pos]

        return reference_patch, neighbor_patch, label

# DAtaloader 
patch_dataset = PatchDataset(cifar10)
patch_loader = torch.utils.data.DataLoader(patch_dataset, batch_size=128, shuffle=True)



image , label =cifar10[0]
print(image.shape)
patch_dataset.__getitem__(0)[0].shape




