# Data organization

In [None]:
import os
import shutil
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import timm
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

## 1. 下载数据

https://www.kaggle.com/datasets/cihan063/autism-image-data/data 直接下载 zip 文件  
或者运行

In [None]:
# import kagglehub

# # Download latest version
# path = kagglehub.dataset_download("cihan063/autism-image-data")

# print("Path to dataset files:", path)

原始数据目录组织为：
```
AutismDataset/consolidated/  
    ├── Autistic/  
    │   0001.jpg  
    │   0002.jpg  
    │   ...  
    └── Non_Autistic/  
        0001.jpg  
        0002.jpg  
        ...  
```

## 2. 将图像拆分为 train/val/test 集


In [None]:
data_dir = '/PathToYourData/AutismDataset/consolidated'
class_names = ['Autistic', 'Non_Autistic']

organized_dir = '/PathToYourData/AutismDataset/split_data'
os.makedirs(organized_dir, exist_ok=True)

train_dir = os.path.join(organized_dir, 'train')
val_dir = os.path.join(organized_dir, 'val')
test_dir = os.path.join(organized_dir, 'test')

for split_dir in [train_dir, val_dir, test_dir]:
    for class_name in class_names:
        os.makedirs(os.path.join(split_dir, class_name), exist_ok=True)

def organize_class_images(src_class_dir, dest_train_dir, dest_val_dir, dest_test_dir, test_size=0.2, val_size=0.2):
    image_files = [f for f in os.listdir(src_class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    train_files, test_files = train_test_split(image_files, test_size=test_size, random_state=42)
    train_files, val_files = train_test_split(train_files, test_size=val_size, random_state=42)
    
    for file in train_files:
        shutil.copy(os.path.join(src_class_dir, file), os.path.join(dest_train_dir, file))
    for file in val_files:
        shutil.copy(os.path.join(src_class_dir, file), os.path.join(dest_val_dir, file))
    for file in test_files:
        shutil.copy(os.path.join(src_class_dir, file), os.path.join(dest_test_dir, file))

for class_name in class_names:
    src_class_dir = os.path.join(data_dir, class_name)
    dest_train_dir = os.path.join(train_dir, class_name)
    dest_val_dir = os.path.join(val_dir, class_name)
    dest_test_dir = os.path.join(test_dir, class_name)
    
    organize_class_images(src_class_dir, dest_train_dir, dest_val_dir, dest_test_dir)