In [7]:
import torch
import torch.utils.data
import numpy as np
import math
from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized, Callable, Self
from collections.abc import Mapping, Sequence as ABCSequence

In [None]:
def shuffle(arr):
    for i in range(len(arr) - 1, 0, -1):
        j = random.randint(0, i)  
        arr[i], arr[j] = arr[j], arr[i]  
    return arr

In [18]:
my_list = [1, 2, 3]
my_iterator = iter(my_list)

In [20]:
for i in my_iterator:
    print(i)

1
2
3


In [5]:
class CustomDataLoader: # (Generic[T_co])
    def __init__(
        self,
        dataset: torch.utils.data.Dataset[torch.Tensor], # torch.utils.data.Dataset[T_co]
        batch_size: int = 1,
        shuffle: bool = False,
        collate_fn: Optional[Callable] = None
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        # self.collate_fn = collate_fn if collate_fn is not None else default_collate
        self.dataset_len = len(dataset)
        self.num_batches = math.ceil(self.dataset_len / self.batch_size)
        self.batches = None
        
    def __iter__(self) -> Iterator[torch.Tensor]: # Iterator[List[T_co]]
        if self.shuffle:
            indices = torch.randperm(self.dataset_len).tolist()
        else:
            indices = list(range(self.dataset_len))
        # indices = list(self.sampler)
        batches = []
        for i in range(0, self.dataset_len, self.batch_size):
            batch_indices = indices[i:i + self.batch_size]
            batches.append(batch_indices)
        self.batches = batches
        return _DataLoaderIter(self)
    
    def __len__(self) -> int:
        return self.num_batches

class _DataLoaderIter:    
    def __init__(self, loader: CustomDataLoader):
        self.loader = loader
        self.current_batch = 0
    
    def __iter__(self) -> Self: # '_DataLoaderIter'
        return self
    
    def __next__(self) -> torch.Tensor: # List[T_co]
        if self.current_batch >= len(self.loader.batches):
            raise StopIteration
        batch_indices = self.loader.batches[self.current_batch]
        self.current_batch += 1
        batch_data = [self.loader.dataset[idx] for idx in batch_indices]
        batch = torch.stack(batch_data, 0) # Assume tensor
        # batch = self.loader.collate_fn(batch_data)
        return batch

class TensorDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)
        
data = torch.randn(100, 5)
dataset = TensorDataset(data)
dataloader = CustomDataLoader(dataset, batch_size=16, shuffle=True)
for batch in dataloader:
    print(f"Batch shape: {batch.shape}")

Batch shape: torch.Size([16, 5])
Batch shape: torch.Size([16, 5])
Batch shape: torch.Size([16, 5])
Batch shape: torch.Size([16, 5])
Batch shape: torch.Size([16, 5])
Batch shape: torch.Size([16, 5])
Batch shape: torch.Size([4, 5])


In [None]:
def default_collate(batch: List[T_co]):
    elem = batch[0]
    if isinstance(elem, torch.Tensor):
        return torch.stack(batch, 0)
    elif isinstance(elem, (str, bytes)):
        return batch
    elif isinstance(elem, Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, Sequence) and not isinstance(elem, (str, bytes)):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]
    else:
        return batch

In [None]:
class Sampler: # Generic[T_co]
    def __init__(self, data_source):
        self.data_source = data_source
    def __iter__(self) -> Iterator[int]:
        raise NotImplementedError
    def __len__(self) -> int:
        return len(self.data_source)

class SequentialSampler: # (Sampler[T_co])
    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

class RandomSampler: (Sampler[T_co])
    def __iter__(self) -> Iterator[int]:
        return iter(torch.randperm(len(self.data_source)).tolist())

In [11]:
generator = torch.Generator()
generator.manual_seed(self.seed)
indices = torch.randperm(self.dataset_len, generator=generator)

In [28]:
# import torch
# import math
# from typing import Iterator, Optional, List, TypeVar, Generic, Callable, Iterable
# from collections.abc import Mapping, Sequence

# T_co = TypeVar('T_co', covariant=True)

# class Sampler(Generic[T_co]):
#     """
#     Base class for all Samplers.
    
#     Every Sampler subclass has to provide an __iter__ method, providing a
#     way to iterate over indices of dataset elements, and a __len__ method
#     that returns the length of the returned iterators.
#     """
    
#     def __init__(self, data_source):
#         """
#         Initialize Sampler.
        
#         Args:
#             data_source: Dataset to sample from
#         """
#         self.data_source = data_source
    
#     def __iter__(self) -> Iterator[int]:
#         """
#         Return an iterator over the indices of the dataset.
        
#         Returns:
#             Iterator over indices
#         """
#         raise NotImplementedError
    
#     def __len__(self) -> int:
#         """
#         Return the length of the dataset.
        
#         Returns:
#             Number of samples in the dataset
#         """
#         return len(self.data_source)

# class SequentialSampler(Sampler[T_co]):
#     """
#     Samples elements sequentially, always in the same order.
#     """
    
#     def __iter__(self) -> Iterator[int]:
#         return iter(range(len(self.data_source)))

# class RandomSampler(Sampler[T_co]):
#     """
#     Samples elements randomly, without replacement.
#     """
    
#     def __iter__(self) -> Iterator[int]:
#         return iter(torch.randperm(len(self.data_source)).tolist())

# class CustomDataLoader(Generic[T_co]):
#     """
#     Simplified Custom DataLoader implementation for PyTorch with sampler support.
    
#     This DataLoader works with map-style datasets that implement __getitem__ and __len__ methods.
#     It supports sampling strategies via a sampler argument.
#     """
    
#     def __init__(
#         self,
#         dataset: torch.utils.data.Dataset[T_co],
#         batch_size: int = 1,
#         shuffle: bool = False,
#         sampler: Optional[Sampler] = None,
#         collate_fn: Optional[Callable] = None
#     ):
#         """
#         Initialize CustomDataLoader.
        
#         Args:
#             dataset: Dataset from which to load data
#             batch_size: Number of samples per batch
#             shuffle: Whether to shuffle the data (ignored if sampler is specified)
#             sampler: Strategy to draw samples from the dataset
#             collate_fn: Function to merge samples into batches
#         """
#         self.dataset = dataset
#         self.batch_size = batch_size
#         self.collate_fn = collate_fn if collate_fn is not None else default_collate
        
#         # Make sure sampler and shuffle are mutually exclusive
#         if sampler is not None and shuffle:
#             raise ValueError("sampler and shuffle cannot be used together")
        
#         if sampler is None:
#             if shuffle:
#                 self.sampler = RandomSampler(dataset)
#             else:
#                 self.sampler = SequentialSampler(dataset)
#         else:
#             self.sampler = sampler
        
#         # Calculate the number of samples and batches
#         self.dataset_len = len(dataset)
#         self.num_batches = math.ceil(self.dataset_len / self.batch_size)
    
#     def __iter__(self) -> Iterator[List[T_co]]:
#         """
#         Create a new iterator over the dataset.
        
#         Returns:
#             Iterator over batches of data
#         """
#         # Get indices from sampler
#         indices = list(self.sampler)
        
#         # Create batches
#         batches = []
#         for i in range(0, len(indices), self.batch_size):
#             batch_indices = indices[i:i + self.batch_size]
#             batches.append(batch_indices)
        
#         return _DataLoaderIter(self, batches)
    
#     def __len__(self) -> int:
#         """
#         Return the number of batches in the dataset.
        
#         Returns:
#             Number of batches
#         """
#         return self.num_batches

# class _DataLoaderIter:
#     """
#     Iterator class for the DataLoader.
#     """
    
#     def __init__(self, loader: CustomDataLoader, batches: List[List[int]]):
#         """
#         Initialize the iterator.
        
#         Args:
#             loader: The DataLoader instance
#             batches: List of batch indices
#         """
#         self.loader = loader
#         self.batches = batches
#         self.current_batch = 0
    
#     def __iter__(self) -> '_DataLoaderIter':
#         return self
    
#     def __next__(self) -> List[T_co]:
#         if self.current_batch >= len(self.batches):
#             raise StopIteration
        
#         # Get current batch indices
#         batch_indices = self.batches[self.current_batch]
#         self.current_batch += 1
        
#         # Load data for this batch
#         batch_data = [self.loader.dataset[idx] for idx in batch_indices]
        
#         # Apply collation function
#         batch = self.loader.collate_fn(batch_data)
        
#         return batch

# def default_collate(batch):
#     """
#     Default collation function that converts a list of samples to a batch.
    
#     Args:
#         batch: List of samples

#     Returns:
#         Collated batch
#     """
#     elem = batch[0]
#     if isinstance(elem, torch.Tensor):
#         return torch.stack(batch, 0)
#     elif isinstance(elem, (str, bytes)):
#         return batch
#     elif isinstance(elem, Mapping):
#         return {key: default_collate([d[key] for d in batch]) for key in elem}
#     elif isinstance(elem, Sequence) and not isinstance(elem, (str, bytes)):
#         transposed = zip(*batch)
#         return [default_collate(samples) for samples in transposed]
#     else:
#         return batch

# # Example of a custom sampler: weighted sampling
# class WeightedRandomSampler(Sampler):
#     """
#     Samples elements randomly according to given weights.
    
#     Args:
#         weights: a sequence of weights, not necessary summing up to 1
#         num_samples: number of samples to draw
#         replacement: if True, sampling is done with replacement
#     """
    
#     def __init__(self, weights, num_samples, replacement=True):
#         super().__init__(None)  # We don't actually need the data source
#         self.weights = torch.as_tensor(weights, dtype=torch.double)
#         self.num_samples = num_samples
#         self.replacement = replacement
    
#     def __iter__(self):
#         return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
    
#     def __len__(self):
#         return self.num_samples

# # Example usage:
# if __name__ == "__main__":
#     # Create a simple tensor dataset
#     class TensorDataset(torch.utils.data.Dataset):
#         def __init__(self, data):
#             self.data = data
        
#         def __getitem__(self, index):
#             return self.data[index]
        
#         def __len__(self):
#             return len(self.data)
    
#     # Create a dataset with 100 samples
#     data = torch.randn(100, 5)
#     dataset = TensorDataset(data)
    
#     # Example 1: Using shuffle parameter
#     dataloader1 = CustomDataLoader(dataset, batch_size=16, shuffle=True)
    
#     # Example 2: Using a RandomSampler explicitly
#     sampler = RandomSampler(dataset)
#     dataloader2 = CustomDataLoader(dataset, batch_size=16, sampler=sampler)
    
#     # Example 3: Using a WeightedRandomSampler to sample with weights
#     # Create weights that favor sampling from the first half of the dataset
#     weights = torch.ones(len(dataset))
#     weights[:len(dataset)//2] *= 2  # First half has double the weight
#     weighted_sampler = WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True)
#     dataloader3 = CustomDataLoader(dataset, batch_size=16, sampler=weighted_sampler)
    
#     # Iterate through the DataLoader with weighted sampling
#     print("Using weighted sampling:")
#     for i, batch in enumerate(dataloader3):
#         if i < 3:  # Just print the first 3 batches
#             print(f"Batch {i} shape: {batch.shape}")

In [None]:
# T_co = TypeVar('T_co', covariant=True)
# T = TypeVar('T')

# class CustomDataLoader(Generic[T_co]):
#     """
#     Custom DataLoader implementation for PyTorch.
    
#     This DataLoader works with map-style datasets that implement __getitem__ and __len__ methods.
#     It supports shuffling, batching, collation, and worker processes for parallel data loading.
#     """
    
#     def __init__(
#         self,
#         dataset: torch.utils.data.Dataset[T_co],
#         batch_size: Optional[int] = 1,
#         shuffle: bool = False,
#         num_workers: int = 0,
#         collate_fn: Optional[Callable] = None,
#         drop_last: bool = False,
#         pin_memory: bool = False,
#         timeout: float = 0,
#         worker_init_fn: Optional[Callable] = None,
#         prefetch_factor: int = 2,
#         persistent_workers: bool = False,
#         seed: Optional[int] = None
#     ):
#         """
#         Initialize CustomDataLoader.
        
#         Args:
#             dataset: Dataset from which to load data
#             batch_size: Number of samples per batch
#             shuffle: Whether to shuffle the data
#             num_workers: Number of subprocesses for data loading
#             collate_fn: Function to merge samples into batches
#             drop_last: Whether to drop the last incomplete batch
#             pin_memory: Whether to copy tensors to CUDA pinned memory
#             timeout: Timeout for collecting a batch
#             worker_init_fn: Function to initialize worker processes
#             prefetch_factor: Number of batches loaded in advance by each worker
#             persistent_workers: Whether to keep worker processes alive after iteration
#             seed: Random seed for shuffling
#         """
#         self.dataset = dataset
#         self.batch_size = batch_size if batch_size is not None else 1
#         self.shuffle = shuffle
#         self.num_workers = num_workers
#         self.collate_fn = collate_fn if collate_fn is not None else default_collate
#         self.drop_last = drop_last
#         self.pin_memory = pin_memory
#         self.timeout = timeout
#         self.worker_init_fn = worker_init_fn
#         self.prefetch_factor = prefetch_factor
#         self.persistent_workers = persistent_workers
        
#         # Set random seed for reproducibility
#         self.seed = seed if seed is not None else torch.initial_seed()
        
#         # Calculate the number of samples and batches
#         self.dataset_len = len(dataset)
#         if self.drop_last:
#             self.num_batches = self.dataset_len // self.batch_size
#         else:
#             self.num_batches = math.ceil(self.dataset_len / self.batch_size)
        
#         # Set up multiprocessing if using workers (simplified - would need actual implementation)
#         self._setup_multiprocessing()
    
#     def _setup_multiprocessing(self):
#         """
#         Set up multiprocessing for data loading.
        
#         Note: This is a simplified placeholder. A real implementation would need to use 
#         Python's multiprocessing library to spawn worker processes.
#         """
#         # In a full implementation, this would set up worker processes using Python's multiprocessing
#         self.workers = None
#         if self.num_workers > 0:
#             print(f"Note: Using {self.num_workers} workers would require implementing multiprocessing")
#             # Setup worker processes, queues, etc.
    
    # def __iter__(self) -> Iterator[List[T_co]]:
    #     """
    #     Create a new iterator over the dataset.
        
    #     Returns:
    #         Iterator over batches of data
    #     """
    #     # Create index sampler
    #     if self.shuffle:
    #         # Deterministically shuffle based on current epoch and seed
    #         generator = torch.Generator()
    #         generator.manual_seed(self.seed)
    #         indices = torch.randperm(self.dataset_len, generator=generator).tolist()
#         else:
#             indices = list(range(self.dataset_len))
        
#         # Create batches
#         batches = []
#         for i in range(0, self.dataset_len, self.batch_size):
#             batch_indices = indices[i:i + self.batch_size]
#             if len(batch_indices) < self.batch_size and self.drop_last:
#                 continue
#             batches.append(batch_indices)
        
#         # For simplicity, this implementation loads data in the main process
#         # A full implementation would distribute this to worker processes
#         return _DataLoaderIter(self, batches)
    
#     def __len__(self) -> int:
#         """
#         Return the number of batches in the dataset.
        
#         Returns:
#             Number of batches
#         """
#         return self.num_batches

# class _DataLoaderIter:
#     """
#     Iterator class for the DataLoader.
#     """
    
#     def __init__(self, loader: CustomDataLoader, batches: List[List[int]]):
#         """
#         Initialize the iterator.
        
#         Args:
#             loader: The DataLoader instance
#             batches: List of batch indices
#         """
#         self.loader = loader
#         self.batches = batches
#         self.current_batch = 0
    
#     def __iter__(self) -> '_DataLoaderIter':
#         return self
    
#     def __next__(self) -> List[T_co]:
#         if self.current_batch >= len(self.batches):
#             raise StopIteration
        
#         # Get current batch indices
#         batch_indices = self.batches[self.current_batch]
#         self.current_batch += 1
        
#         # Load data for this batch
#         batch_data = [self.loader.dataset[idx] for idx in batch_indices]
        
#         # Apply collation function
#         batch = self.loader.collate_fn(batch_data)
        
#         # Pin memory if requested
#         if self.loader.pin_memory:
#             batch = _pin_memory(batch)
        
#         return batch
    
#     def __len__(self) -> int:
#         return len(self.batches)

# def _pin_memory(data):
#     """
#     Pin memory for faster data transfer to GPU.
    
#     Args:
#         data: Data to pin

#     Returns:
#         Pinned data
#     """
#     if isinstance(data, torch.Tensor):
#         return data.pin_memory()
#     elif isinstance(data, Mapping):
#         return {k: _pin_memory(v) for k, v in data.items()}
#     elif isinstance(data, (tuple, list)):
#         return [_pin_memory(x) for x in data]
#     else:
#         return data

# def default_collate(batch):
#     """
#     Default collation function that converts a list of samples to a batch.
    
#     Args:
#         batch: List of samples

#     Returns:
#         Collated batch
#     """
#     elem = batch[0]
#     if isinstance(elem, torch.Tensor):
#         return torch.stack(batch, 0)
#     elif isinstance(elem, (str, bytes)):
#         return batch
#     elif isinstance(elem, Mapping):
#         return {key: default_collate([d[key] for d in batch]) for key in elem}
#     elif isinstance(elem, ABCSequence) and not isinstance(elem, (str, bytes)):
#         transposed = zip(*batch)
#         return [default_collate(samples) for samples in transposed]
#     else:
#         return batch

# # Example usage:
# if __name__ == "__main__":
#     # Create a simple tensor dataset
#     class TensorDataset(torch.utils.data.Dataset):
#         def __init__(self, data):
#             self.data = data
        
#         def __getitem__(self, index):
#             return self.data[index]
        
#         def __len__(self):
#             return len(self.data)
    
#     # Create a dataset with 100 samples
#     data = torch.randn(100, 5)
#     dataset = TensorDataset(data)
    
#     # Create a DataLoader with batch_size=16 and shuffle=True
#     dataloader = CustomDataLoader(dataset, batch_size=16, shuffle=True)
    
#     # Iterate through the DataLoader
#     for batch in dataloader:
#         print(f"Batch shape: {batch.shape}")