In [20]:
import pandas as pd
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset
import pandas as pd
import numpy as np

class DataPreprocessor:
    def __init__(self, categorical_features, numerical_features, label_encoder=None, scaler=None):
        self.categorical_features = categorical_features
        self.numerical_features = numerical_features
        self.label_encoder = label_encoder if label_encoder else LabelEncoder()
        self.scaler = scaler if scaler else StandardScaler()

    def process(self, df):
        """
        对单个数据块进行预处理，返回处理后的数据
        """
        # 处理类别型特征
        for col in self.categorical_features:
            df[col] = self.label_encoder.fit_transform(df[col])

        # 处理数值型特征
        df[self.numerical_features] = self.scaler.fit_transform(df[self.numerical_features])

        return df

    def reset(self):
        """
        重置数据预处理器，适用于每个epoch开始时的重新训练
        """
        self.label_encoder = LabelEncoder()
        self.scaler = StandardScaler()


In [24]:
class StreamingDataset(Dataset):
    def __init__(self, file_path, batch_size, chunk_size=1000, shuffle=False, preprocessor=None):
        """
        :param file_path: 数据文件路径（CSV）
        :param batch_size: 每个批次的大小
        :param chunk_size: 每次从CSV中读取的块大小（避免将整个数据集一次性加载到内存）
        :param shuffle: 是否在每个 epoch 开始时打乱数据
        """
        self.file_path = file_path
        self.batch_size = batch_size
        self.chunk_size = chunk_size
        self.shuffle = shuffle
        self.preprocessor = preprocessor
        self.batch_idx = 0
        self.current_chunk = None
        self.chunk_iterator = None
        self._load_chunk()

    def _load_chunk(self):
        """
        加载一个数据块（chunk）并将其转换为 PyTorch Dataset。
        这里使用 pandas 分块读取 CSV 数据。
        """
        # 使用 pandas 逐块读取 CSV 文件
        self.chunk_iterator = pd.read_csv(self.file_path, chunksize=self.chunk_size)
        self.current_chunk = next(self.chunk_iterator, None)

        # 如果设置了 shuffle，我们就打乱数据
        if self.shuffle:
            self.current_chunk = self.current_chunk.sample(frac=1).reset_index(drop=True)

        # 数据预处理
        if self.preprocessor:
            self.current_chunk = self.preprocessor.process(self.current_chunk)

    def __len__(self):
        """返回数据集的长度，按块大小分割"""
        if self.current_chunk is None:
            return 0
        return len(self.current_chunk)

    def __getitem__(self, idx):
        """
        获取当前块的数据。由于数据会逐步加载，不会一次性全部加载到内存。
        """
        row = self.current_chunk.iloc[idx]

        # 提取特征和标签
        features = row.drop('target').values  # 假设 'target' 是标签列
        label = row['target']

        # 转换为 tensor
        features = torch.tensor(features, dtype=torch.float32)
        label = torch.tensor(int(label), dtype=torch.long)

        return features, label

    def next_batch(self):
        """
        返回当前批次的数据。批次跨多个块时，自动继续加载下一个块。
        """
        batch_data = []
        batch_labels = []

        while len(batch_data) < self.batch_size:
            if self.batch_idx >= len(self.current_chunk):
                self.next_chunk()  # 加载下一个数据块
                self.batch_idx = 0  # 重置批次索引
                if self.current_chunk is None:
                    raise StopIteration

            feature, label = self[self.batch_idx]
            batch_data.append(feature)
            batch_labels.append(label)
            self.batch_idx += 1

        # 将数据转换为Tensor
        batch_data = torch.stack(batch_data)
        batch_labels = torch.stack(batch_labels)

        return batch_data, batch_labels

    def next_chunk(self):
        """
        加载下一个数据块。
        """
        self.current_chunk = next(self.chunk_iterator, None)
        if self.current_chunk is None:
            return None  # 数据集已读完

    def reset(self):
        """重置迭代器，从头开始加载数据"""
        self.chunk_iterator = pd.read_csv(self.file_path, chunksize=self.chunk_size)
        self._load_chunk()


In [25]:
class StreamingDataLoader:
    def __init__(self, dataset):
        self.dataset = dataset

    def __iter__(self):
        return self

    def __next__(self):
        """
        返回下一个批次的数据
        """
        return self.dataset.next_batch()

    def reset(self):
        """
        重置批次索引，从头开始加载数据
        """
        self.dataset.reset()


In [27]:
# 数据预处理器实例
categorical_features = ['feature_1', 'feature_2']  # 假设这些是类别特征
numerical_features = ['feature_3', 'feature_4']  # 数值型特征
preprocessor = DataPreprocessor(categorical_features, numerical_features)

# 创建数据集实例
dataset = StreamingDataset(file_path='large_dataset.csv', batch_size=32, preprocessor=preprocessor)

# 创建数据加载器
streaming_loader = StreamingDataLoader(dataset)

# 模拟训练过程
for epoch in range(2):  # 假设训练 2 个 epoch
    print(f"Epoch {epoch + 1}")
    streaming_loader.reset()  # 每个 epoch 重新初始化 DataLoader，打乱数据

    for batch_idx, (features, labels) in enumerate(streaming_loader):
        # 模拟训练过程中的一些操作
        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}, Feature Shape: {features.shape}, Label Shape: {labels.shape}")

        if batch_idx == 100:break


Epoch 1
Batch 0, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 10, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 20, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 30, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 40, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 50, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 60, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 70, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 80, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 90, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 100, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Epoch 2
Batch 0, Feature Shape: torch.Size([32, 10]), Label Shape: torch.Size([32])
Batch 10, Feature Shape: torch.Size([32, 10]), Label Shape: t