In [3]:
import os
os.chdir("../")
os.getcwd()

'e:\\github_clone\\siamese-network\\src'

In [1]:
import os
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset

In [2]:
class SiameseDataset(Dataset):
    def __init__(self, image_paths, transform=None, train=True):
        """
        Args:
            image_paths (list of str): List of image file paths.
            transform (callable, optional): Optional transform to be applied on a sample.
            train (bool): Flag to indicate if the dataset is for training or testing.
        """
        self.image_paths = image_paths
        self.transform = transform
        self.train = train
        
        # Extract labels from paths
        self.labels = [self._get_label_from_path(path) for path in self.image_paths]
        
        # Create a mapping from labels to image paths
        self.labels_set = set(self.labels)
        self.label_to_indices = {label: np.where(np.array(self.labels) == label)[0] for label in self.labels_set}
        
        if not self.train:
            # Generate fixed pairs for testing
            random_state = np.random.RandomState(29)
            
            positive_pairs = [[i, random_state.choice(self.label_to_indices[self.labels[i]]), 1]
                              for i in range(0, len(self.image_paths), 2)]
            
            negative_pairs = [[i, random_state.choice(self.label_to_indices[
                                                       random_state.choice(list(self.labels_set - {self.labels[i]}))
                                                     ]), 0]
                              for i in range(1, len(self.image_paths), 2)]
            
            self.test_pairs = positive_pairs + negative_pairs

    def _get_label_from_path(self, path):
        return path.split(os.sep)[-2]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        if self.train:
            target = np.random.randint(0, 2)
            img1_path, label1 = self.image_paths[index], self.labels[index]
            
            if target == 1:
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                siamese_label = np.random.choice(list(self.labels_set - {label1}))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])
            
            img2_path = self.image_paths[siamese_index]
        else:
            img1_path = self.image_paths[self.test_pairs[index][0]]
            img2_path = self.image_paths[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.open(img1_path)
        img2 = Image.open(img2_path)
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        
        return (img1, img2), target

In [42]:
class SiameseDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        """
        Args:
            image_paths (list of str): List of image file paths.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_paths = image_paths
        self.transform = transform

        # Extract labels from paths
        self.labels = [self._get_label_from_path(path) for path in self.image_paths]

        # Create a mapping from labels to image paths
        self.labels_set = set(self.labels)
        self.label_to_indices = {label: np.where(np.array(self.labels) == label)[0] for label in self.labels_set}

    def _get_label_from_path(self, path):
        return path.split(os.sep)[-2]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img1_path, label1 = self.image_paths[index], self.labels[index]
        
        # Randomly decide if this pair is positive or negative
        target = np.random.randint(0, 2)
        
        if target == 1:
            # Positive pair: choose a different image with the same label
            siamese_index = index
            while siamese_index == index:
                siamese_index = np.random.choice(self.label_to_indices[label1])
        else:
            # Negative pair: choose an image with a different label
            siamese_label = np.random.choice(list(self.labels_set - {label1}))
            siamese_index = np.random.choice(self.label_to_indices[siamese_label])
        
        img2_path = self.image_paths[siamese_index]

        # Load images
        img1 = Image.open(img1_path).convert('RGB')
        img2 = Image.open(img2_path).convert('RGB')
        
        # Apply transformations if any
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        
        return (img1, img2), target

In [43]:
from glob import glob
imges = sorted(glob("E:\\github_clone\\siamese-network\\src\\data\\*\\*"))

In [44]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

In [45]:
dataset = SiameseDataset(image_paths=imges, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)


In [47]:
for batch_idx, (images, targets) in enumerate(dataloader):
    img1, img2 = images
    print(f"Batch {batch_idx}:")
    print(f"Image 1: {img1.shape}, Image 2: {img2.shape}, Target: {targets}")

KeyboardInterrupt: 