# 数据加载模块 Data Module

<!-- > 我们支持多种数据集（如 CIFAR100、VTAB 等），将 数据加载 的逻辑拆分到单独的笔记本中，这样可以保持模块化，数据处理和训练逻辑分离，增强可维护性和扩展性。 -->
> 处理数据加载和预处理的模块
> 
> Handles data loading and preprocessing for datasets

## 简介/Description:
数据模块主要负责数据集的加载与预处理。DatasetConfig 使用 Pydantic 进行配置管理，以保证数据集参数的正确性，并通过 ClassificationDataModule 实现 PyTorch Lightning 的数据模块封装。此模块支持自定义数据转换，并为不同的数据集（如 CIFAR100）提供灵活的加载方案。

The data module focuses on handling dataset loading and preprocessing. DatasetConfig is managed through Pydantic for configuration accuracy, and ClassificationDataModule encapsulates the PyTorch Lightning DataModule. This module supports custom data transforms and offers flexible loading schemes for various datasets such as CIFAR100.

## 主要符号/Main symbols:

- DatasetConfig: Pydantic 定义的配置类，用于数据模块的参数管理。
  
  DatasetConfig: A Pydantic configuration class for managing data module parameters.

- ClassificationDataModule: 用于 PyTorch Lightning 的数据模块封装，支持训练、验证、测试数据加载。
  
  ClassificationDataModule: A PyTorch Lightning DataModule wrapper supporting train, validation, and test data loading.
  
- CIFAR100DataModule: 基于 ClassificationDataModule 的具体数据模块实现，加载 CIFAR100 数据集。
  
  CIFAR100DataModule: A concrete implementation of ClassificationDataModule for loading the CIFAR100 dataset.

In [1]:
#| default_exp data.__init__

In [2]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [3]:
#| export
from namable_classify.utils import data_path, Path

In [4]:
#| export
from pydantic import BaseModel

class ClassificationDataConfig(BaseModel):
    # protocol: str = 'torch'
    # dataset_name: str = 'cifar100'
    dataset_root: Path = data_path
    dataset_name: str = 'CIFAR100'
    batch_size:int=1
# TODO 支持多个来源的数据集自动加载
# from torchvision.datasets import __all__, CIFAR100
# __all__

In [5]:
#| export
from typing import Callable
from torch.utils.data import random_split, DataLoader
import lightning as L
from torchvision import transforms
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS


class ClassificationDataModule(L.LightningDataModule):
    num_of_classes = None
    classes = None
    
    @classmethod
    def from_config(cls, config:ClassificationDataConfig) -> 'ClassificationDataModule':
        return cls(**config.model_dump())
    def __init__(self, **config:ClassificationDataConfig) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.save_hyperparameters(dict(num_of_classes=self.num_of_classes))
        self.workers = 31 # TODO 根据CPU自动设置
    #     self.__sub_init__()
    #     # self.config = config
    # def __sub_init__(self, a=1, b=2) -> None:
    #     print("Hello")
    #     self.save_hyperparameters(dict(a=a, b=b))
    # @property
    # def num_classes(self) -> int:
    #     return self.hparams.num_classes
    # @num_classes.setter
    # def num_classes(self, value:int) -> None:
    #     self.hparams.num_classes = value
        
    # @property
    # def transform(self) -> Callable: #TODO 类型标注不知道怎么写
    #     return self.hparams.transform
    
    # @transform.setter
    # def transform(self, value:Callable) -> None:
    #     self.hparams.transform = value
    
    
        
    def train_dataloader(self)->TRAIN_DATALOADERS:
        return DataLoader(self.train_ds, batch_size=self.hparams.batch_size, 
                          num_workers=self.workers, shuffle=True, pin_memory=True)

    def val_dataloader(self)->EVAL_DATALOADERS:
        return DataLoader(self.val_ds, batch_size=self.hparams.batch_size, num_workers=self.workers, pin_memory=True)

    def test_dataloader(self)->EVAL_DATALOADERS:
        return DataLoader(self.test_ds, batch_size=self.hparams.batch_size, num_workers=self.workers, pin_memory=True)

    def predict_dataloader(self)->EVAL_DATALOADERS:
        return DataLoader(self.predict_ds, batch_size=self.hparams.batch_size, num_workers=self.workers, pin_memory=True)

In [6]:
import lightning as L
L.seed_everything(42)
cdm = ClassificationDataModule.from_config(ClassificationDataConfig())
cdm.hparams

Seed set to 42


"batch_size":     1
"dataset_name":   CIFAR100
"dataset_root":   /home/ycm/repos/research/cv/cls/NamableClassify/data
"num_of_classes": None

In [14]:
#| export
from torch.utils.data import random_split
from sklearn.model_selection import train_test_split
import torch
from torchvision.datasets import CIFAR100

def sksplit_for_torch(ds_full:torch.utils.data.Dataset, test_size:float=0.2, stratify_targets=None, random_state=None):
    indexes = list(range(len(ds_full)))
    stratify_targets = stratify_targets or (ds_full.targets if hasattr(ds_full, 'targets') else [int(ds_full[i][1]) for i in indexes])
    train_indexes, val_indexes = train_test_split(indexes, test_size=test_size,
                                                    stratify=stratify_targets, random_state=random_state)
    return torch.utils.data.Subset(ds_full, train_indexes), torch.utils.data.Subset(ds_full, val_indexes)
    # return random_split(
                #     ds_full, [, ], 
                #     generator=torch.Generator().manual_seed(L.seed_everything()), 
                # )

In [16]:
ds_full = CIFAR100(data_path, train=True)
train_ds, val_ds = sksplit_for_torch(ds_full, test_size=0.1)
len(train_ds), len(val_ds)

(45000, 5000)

In [17]:
#| export
import lightning as L

from torchvision.datasets import MNIST
from torchvision import transforms
import torch
from torchvision.datasets import CIFAR100, CIFAR10
# CIFAR100.url = # Tsinghua mirrorURL
class TorchVisionDataModule(ClassificationDataModule):
    torchvision_cls = CIFAR100
    num_of_classes = 100
    def __init__(self, 
                 train_transform=None, # 需要后续设置
                 test_transform=None, # 需要后续设置
                 train_val_split=0.9, # 训练集和验证集的比例
                 **config:ClassificationDataConfig) -> None:
        super().__init__(**config)
        self.save_hyperparameters()
        

    def prepare_data(self):
        # download
        train_ds = self.torchvision_cls(self.hparams.dataset_root, train=True, download=True)
        test_ds = self.torchvision_cls(self.hparams.dataset_root, train=False, download=True)
        self.classes = train_ds.classes
        assert len(self.classes)==self.num_of_classes, f"Number of classes in dataset is {len(self.classes)}, but {self.num_of_classes} is expected."

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        match (stage):
            case ("fit"):
                ds_full = self.torchvision_cls(self.hparams.dataset_root, train=True, transform=self.hparams.train_transform)
                self.train_ds, self.val_ds = sksplit_for_torch(ds_full, 1-self.hparams.train_val_split, random_state=0) # 不和Lighning seed everything一起
                # 还有个 validate 但是fit的时候我们就设置好了，所以直接跳过
            case ("validate"):
                print("Validation loader has been setup before. ")
                pass
            case ("test"):
                self.test_ds = self.torchvision_cls(self.hparams.dataset_root, train=False, transform=self.hparams.test_transform)
            case ("predict"):
                self.predict_ds = self.torchvision_cls(self.hparams.dataset_root, train=False, transform=self.hparams.test_transform)

class CIFAR100DataModule(TorchVisionDataModule):
    torchvision_cls = CIFAR100
    num_of_classes = 100
    
class MNISTDataModule(TorchVisionDataModule):
    torchvision_cls = MNIST
    num_of_classes = 10
    
class CIFAR10DataModule(TorchVisionDataModule):
    torchvision_cls = CIFAR10
    num_of_classes = 10

In [18]:
mnist_data = MNISTDataModule.from_config(ClassificationDataConfig(dataset_name='MNIST'))
mnist_data.prepare_data()
mnist_data.hparams

"batch_size":      1
"dataset_name":    MNIST
"dataset_root":    /home/ycm/repos/research/cv/cls/NamableClassify/data
"num_of_classes":  10
"test_transform":  None
"train_transform": None
"train_val_split": 0.9

In [10]:
cifar100_data = CIFAR100DataModule.from_config(ClassificationDataConfig(dataset_name="CIFAR100"))
cifar100_data.prepare_data()
cifar100_data.hparams

Files already downloaded and verified
Files already downloaded and verified


"batch_size":      1
"dataset_name":    CIFAR100
"dataset_root":    /home/ycm/repos/research/cv/cls/NamableClassify/data
"num_of_classes":  100
"test_transform":  None
"train_transform": None
"train_val_split": 0.9

In [11]:
import inspect
inspect.getfile(CIFAR100)

'/home/ycm/program_files/managers/conda/envs/hf_ai/lib/python3.10/site-packages/torchvision/datasets/cifar.py'

In [12]:
# TODO VTAB dataset
# vtab_dir: str = "/home/ai_pitch_perfector/datasets/vtab-1k/"
# subset_name: str = "cifar"

In [13]:
#| export
from transformers import AutoImageProcessor, BitImageProcessor, ViTImageProcessor
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
    RandomRotation,
    RandomGrayscale,
    Grayscale,
    AutoAugment,
    RandAugment,
)
from namable_classify.data.transforms import CutoutPIL
from fastcore.basics import patch
@patch
def set_transform_from_hf_image_preprocessor(self:ClassificationDataModule, hf_image_preprocessor:AutoImageProcessor, model_image_size=None):
    if model_image_size is None:
        if isinstance(hf_image_preprocessor, ViTImageProcessor):
            model_image_size:tuple[int, int] = (hf_image_preprocessor.size['height'], hf_image_preprocessor.size['width'])
        elif isinstance(hf_image_preprocessor, BitImageProcessor):
            model_image_size:tuple[int, int] = (hf_image_preprocessor.crop_size['height'], hf_image_preprocessor.crop_size['width'])
    normalize = Normalize(mean=hf_image_preprocessor.image_mean, std=hf_image_preprocessor.image_std)
    self.hparams.train_transform = Compose(
        [
            # # RandomResizedCrop(image_processor.size["height"]),
            # RandomResizedCrop(image_processor.crop_size["height"]),
            # RandomHorizontalFlip(),
            # # RandomRotation((-30, 30)),
            # # RandomGrayscale(),
            # # AddPepperNoise(0.5, p=0.1),
            # Grayscale(num_output_channels=3),

            Resize(model_image_size),
            CutoutPIL(cutout_factor=1/4), # cifar 32x32  随机把中间8x8正方形变成空白 
            # CutoutPIL(cutout_factor=0.5),
            RandAugment(),
            
            # resize
            # center_crop
            
            # rescale
            # normalize
            
            ToTensor(),
            normalize,
        ]
    )

    self.hparams.test_transform = Compose(
        [
            # Resize(image_processor.size["height"]),
            # Resize(image_processor.crop_size["height"]),
            # # CenterCrop(image_processor.size["height"]),
            # CenterCrop(image_processor.crop_size["height"]),
            # Grayscale(num_output_channels=3),
            
            Resize(model_image_size),
            ToTensor(),
            normalize,
        ]
    )

In [14]:
from transformers import AutoImageProcessor, BitImageProcessor
model_checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = BitImageProcessor.from_pretrained(model_checkpoint, use_fast=True)
image_processor
image_processor.crop_size.values() # height, width
# image_processor.size



dict_values([224, 224])

In [15]:
#| export
from fastcore.basics import patch
@patch
def get_lightning_data_module(self:ClassificationDataConfig):
    if self.dataset_name == 'MNIST':
        return MNISTDataModule.from_config(self)
    elif self.dataset_name == 'CIFAR100':
        return CIFAR100DataModule.from_config(self)
    else:
        raise ValueError(f"Unsupported dataset: {self.dataset_name}")

In [16]:
lit_data = ClassificationDataConfig(dataset_name="CIFAR100").get_lightning_data_module()
lit_data.hparams

"batch_size":      1
"dataset_name":    CIFAR100
"dataset_root":    /home/ycm/repos/research/cv/cls/NamableClassify/data
"num_of_classes":  100
"test_transform":  None
"train_transform": None
"train_val_split": 0.9

In [17]:
lit_data.set_transform_from_hf_image_preprocessor(image_processor)
lit_data.hparams

"batch_size":      1
"dataset_name":    CIFAR100
"dataset_root":    /home/ycm/repos/research/cv/cls/NamableClassify/data
"num_of_classes":  100
"test_transform":  Compose(
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
    ToTensor()
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
)
"train_transform": Compose(
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
    <namable_classify.data.transforms.CutoutPIL object at 0x7d0ebf3047f0>
    RandAugment(num_ops=2, magnitude=9, num_magnitude_bins=31, interpolation=InterpolationMode.NEAREST, fill=None)
    ToTensor()
    Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
)
"train_val_split": 0.9

In [18]:
# TODO 
# https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/11-vision-transformer.html
# 这里的可视化不错。

In [1]:
#| hide
import nbdev; nbdev.nbdev_export()