In [None]:
import os
import shutil
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import timm
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

# data

In [None]:
# 下载数据，或 https://www.kaggle.com/datasets/cihan063/autism-image-data/data 直接下载 zip 文件
# import kagglehub

# # Download latest version
# path = kagglehub.dataset_download("cihan063/autism-image-data")

# print("Path to dataset files:", path)

In [None]:
data_dir = '/PathToYourData/AutismDataset/consolidated'
class_names = ['Autistic', 'Non_Autistic']

organized_dir = '/PathToYourData/AutismDataset/split_data'
os.makedirs(organized_dir, exist_ok=True)

train_dir = os.path.join(organized_dir, 'train')
val_dir = os.path.join(organized_dir, 'val')
test_dir = os.path.join(organized_dir, 'test')

for split_dir in [train_dir, val_dir, test_dir]:
    for class_name in class_names:
        os.makedirs(os.path.join(split_dir, class_name), exist_ok=True)

def organize_class_images(src_class_dir, dest_train_dir, dest_val_dir, dest_test_dir, test_size=0.2, val_size=0.2):
    image_files = [f for f in os.listdir(src_class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    train_files, test_files = train_test_split(image_files, test_size=test_size, random_state=42)
    train_files, val_files = train_test_split(train_files, test_size=val_size, random_state=42)
    
    for file in train_files:
        shutil.copy(os.path.join(src_class_dir, file), os.path.join(dest_train_dir, file))
    for file in val_files:
        shutil.copy(os.path.join(src_class_dir, file), os.path.join(dest_val_dir, file))
    for file in test_files:
        shutil.copy(os.path.join(src_class_dir, file), os.path.join(dest_test_dir, file))

for class_name in class_names:
    src_class_dir = os.path.join(data_dir, class_name)
    dest_train_dir = os.path.join(train_dir, class_name)
    dest_val_dir = os.path.join(val_dir, class_name)
    dest_test_dir = os.path.join(test_dir, class_name)
    
    organize_class_images(src_class_dir, dest_train_dir, dest_val_dir, dest_test_dir)

In [None]:
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')


data_dir = '/data7/cyd/files/data/AutismDataset/split_data'

vit_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = ImageFolder(os.path.join(data_dir, 'train'), transform=vit_transform)
val_dataset = ImageFolder(os.path.join(data_dir, 'val'), transform=vit_transform)
test_dataset = ImageFolder(os.path.join(data_dir, 'test'), transform=vit_transform)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# model

In [None]:
from safetensors.torch import load_file

class ViTFeatureExtractor(nn.Module):
    def __init__(self, model_name='vit_large_patch16_224', ckpt_path=None):
        super(ViTFeatureExtractor, self).__init__()
        # self.vit = timm.create_model(model_name, pretrained=False, num_classes=0)  # 不自动下载
        self.vit = timm.create_model(model_name, pretrained=True)
        if ckpt_path is not None:
            state_dict = load_file(ckpt_path)  # 用 safetensors 读取
            self.vit.load_state_dict(state_dict, strict=False)
        
        for param in self.vit.parameters():
            param.requires_grad = False     # 冻结所有参数
            
    def forward(self, x):
        return self.vit(x)


model = ViTFeatureExtractor(
    model_name="vit_large_patch16_224",
    # model_name="vit_large_patch16_224.augreg_in21k_ft_in1k",
    # ckpt_path = "/data7/cyd/.cache/huggingface/hub/models--timm--vit_large_patch16_224.augreg_in21k_ft_in1k/snapshots/0930ab3308b84cb2ae091a4a80703c459412a4c7/model.safetensors"
).to(device)
model.eval()

def extract_and_save_features(data_loader, dataset, save_path):
    os.makedirs(save_path, exist_ok=True)
    
    all_features = []
    all_labels = []
    all_filenames = []
    
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc='提取特征'):
            images = images.to(device)
            features = model(images)
            
            all_features.append(features.cpu().numpy())
            all_labels.append(labels.numpy())
            
            batch_filenames = [dataset.samples[i][0] for i in range(len(labels))]
            all_filenames.extend(batch_filenames)
    
    all_features = np.concatenate(all_features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    np.save(os.path.join(save_path, 'features.npy'), all_features)
    np.save(os.path.join(save_path, 'labels.npy'), all_labels)
    
    with open(os.path.join(save_path, 'filenames.txt'), 'w') as f:
        for filename in all_filenames:
            f.write(f"{filename}\n")
    
    print(f"特征已保存到 {save_path}")
    print(f"特征形状: {all_features.shape}, 标签形状: {all_labels.shape}")

print("为训练集提取特征...")
extract_and_save_features(train_loader, train_dataset, '/PathToYourData/AutismDataset/vit_large_features/train')

print("\n为验证集提取特征...")
extract_and_save_features(val_loader, val_dataset, '/PathToYourData/AutismDataset/vit_large_features/val')

print("\n为测试集提取特征...")
extract_and_save_features(test_loader, test_dataset, '/PathToYourData/AutismDataset/vit_large_features/test')

print("\n特征提取完成!")

为训练集提取特征...


提取特征: 100%|██████████| 59/59 [01:51<00:00,  1.89s/it]


特征已保存到 /data7/cyd/files/data/AutismDataset/vit_large_features/train
特征形状: (1880, 1024), 标签形状: (1880,)

为验证集提取特征...


提取特征: 100%|██████████| 15/15 [00:03<00:00,  4.26it/s]


特征已保存到 /data7/cyd/files/data/AutismDataset/vit_large_features/val
特征形状: (472, 1024), 标签形状: (472,)

为测试集提取特征...


提取特征: 100%|██████████| 19/19 [00:04<00:00,  4.00it/s]


特征已保存到 /data7/cyd/files/data/AutismDataset/vit_large_features/test
特征形状: (588, 1024), 标签形状: (588,)

特征提取完成!


In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
import os

def load_features(feature_dir):
    features = np.load(os.path.join(feature_dir, 'features.npy'))
    labels = np.load(os.path.join(feature_dir, 'labels.npy'))
    return features, labels

train_features, train_labels = load_features('/data7/cyd/files/data/AutismDataset/vit_large_features/train')
val_features, val_labels = load_features('/data7/cyd/files/data/AutismDataset/vit_large_features/val')
test_features, test_labels = load_features('/data7/cyd/files/data/AutismDataset/vit_large_features/test')

# 合并训练和验证集
X = np.concatenate([train_features, val_features])
y = np.concatenate([train_labels, val_labels])

print(f"训练特征形状: {train_features.shape}, 训练标签形状: {train_labels.shape}")
print(f"测试特征形状: {test_features.shape}, 测试标签形状: {test_labels.shape}")

训练特征形状: (1880, 1024), 训练标签形状: (1880,)
测试特征形状: (588, 1024), 测试标签形状: (588,)


# train

## vit_large

In [None]:
def create_data_loaders(train_features, train_labels, val_features, val_labels, test_features, test_labels, batch_size=64):
    train_dataset = TensorDataset(
        torch.FloatTensor(train_features), 
        torch.LongTensor(train_labels)
    )
    val_dataset = TensorDataset(
        torch.FloatTensor(val_features), 
        torch.LongTensor(val_labels)
    )
    test_dataset = TensorDataset(
        torch.FloatTensor(test_features), 
        torch.LongTensor(test_labels)
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        drop_last=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size,
        num_workers=2,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=batch_size,
        num_workers=2,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

batch_size = 64
train_loader, val_loader, test_loader = create_data_loaders(
    train_features, train_labels,
    val_features, val_labels,
    test_features, test_labels,
    batch_size=batch_size
)

# 分类器
class ViTLargeClassifier(nn.Module):
    def __init__(self, input_dim=1024, num_classes=2):
        super(ViTLargeClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 768),
            nn.BatchNorm1d(768),
            nn.GELU(), 
            nn.Dropout(0.5),
            
            nn.Linear(768, 384),
            nn.BatchNorm1d(384),
            nn.GELU(),
            nn.Dropout(0.4),
            
            nn.Linear(384, 192),
            nn.BatchNorm1d(192),
            nn.GELU(),
            nn.Dropout(0.3),
            
            nn.Linear(192, num_classes)
        )
        
        # 权重初始化
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
    
    def forward(self, x):
        return self.classifier(x)

class Trainer:
    def __init__(self, model, device, optimizer, criterion, scheduler=None):
        self.model = model.to(device)
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler
        self.best_metric = 0.0
    
    def train_epoch(self, train_loader):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        
        return epoch_loss, epoch_acc
    
    def evaluate(self, data_loader, return_predictions=False):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []
        all_probs = []
        
        with torch.no_grad():
            for inputs, labels in data_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                probs = torch.softmax(outputs, dim=1)
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
        
        epoch_loss = running_loss / len(data_loader)
        epoch_acc = correct / total
        
        if return_predictions:
            return epoch_loss, epoch_acc, all_labels, all_preds, all_probs
        return epoch_loss, epoch_acc
    
    def train(self, train_loader, val_loader, num_epochs=30, early_stop_patience=5):
        best_model_wts = None
        no_improve = 0
        
        for epoch in range(num_epochs):
            train_loss, train_acc = self.train_epoch(train_loader)
            val_loss, val_acc = self.evaluate(val_loader)
            
            if self.scheduler is not None:
                if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_acc)
                else:
                    self.scheduler.step()
            
            lr = self.optimizer.param_groups[0]['lr']
            print(f'Epoch {epoch+1}/{num_epochs}: LR={lr:.2e}')
            print(f'Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}')
            print(f'Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f}')
            
            # if val_acc > self.best_metric:
            self.best_metric = val_acc
            best_model_wts = self.model.state_dict()
            torch.save(best_model_wts, 'best_classifier.pth')
            print('↳ 保存最佳模型')
            no_improve = 0
            # else:
            #     no_improve += 1
            #     if no_improve >= early_stop_patience:
            #         print(f'↳ 早停触发，在 {epoch+1} 个epoch后停止训练')
            #         break
        
        if best_model_wts is not None:
            self.model.load_state_dict(best_model_wts)
        
        return self.model

model = ViTLargeClassifier(input_dim=train_features.shape[1]).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)

def lr_lambda(epoch):
    if epoch < 5:  # 前5个epoch线性warmup
        return (epoch + 1) / 5
    elif epoch < 20:  # 然后保持
        return 1.0
    else:  # 最后线性衰减
        return max(0.0, (25 - epoch) / 5)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# train
trainer = Trainer(model, device, optimizer, criterion, scheduler)
trainer.train(train_loader, val_loader, num_epochs=50)

# test
print("\n在测试集上评估最佳模型...")
test_loss, test_acc, y_true, y_pred, y_probs = trainer.evaluate(test_loader, return_predictions=True)

print(f"\n测试集性能:")
print(f"- 准确率: {test_acc:.4f}")
print(f"- 交叉熵损失: {test_loss:.4f}")

torch.save({
    'model_state_dict': model.state_dict(),
    'input_dim': train_features.shape[1],
    'num_classes': 2
}, 'final_classifier.pth')

Epoch 1/50: LR=8.00e-05
Train Loss: 1.1465 | Acc: 0.5307
Val Loss: 0.6296 | Acc: 0.6208
↳ 保存最佳模型
Epoch 2/50: LR=1.20e-04
Train Loss: 0.9270 | Acc: 0.5884
Val Loss: 0.5391 | Acc: 0.7352
↳ 保存最佳模型
Epoch 3/50: LR=1.60e-04
Train Loss: 0.7448 | Acc: 0.6853
Val Loss: 0.5028 | Acc: 0.7754
↳ 保存最佳模型
Epoch 4/50: LR=2.00e-04
Train Loss: 0.5955 | Acc: 0.7645
Val Loss: 0.4839 | Acc: 0.8008
↳ 保存最佳模型
Epoch 5/50: LR=2.00e-04
Train Loss: 0.5332 | Acc: 0.8017
Val Loss: 0.4622 | Acc: 0.8284
↳ 保存最佳模型
Epoch 6/50: LR=2.00e-04
Train Loss: 0.4819 | Acc: 0.8244
Val Loss: 0.4479 | Acc: 0.8475
↳ 保存最佳模型
Epoch 7/50: LR=2.00e-04
Train Loss: 0.4392 | Acc: 0.8561
Val Loss: 0.4498 | Acc: 0.8411
↳ 保存最佳模型
Epoch 8/50: LR=2.00e-04
Train Loss: 0.4234 | Acc: 0.8745
Val Loss: 0.4445 | Acc: 0.8475
↳ 保存最佳模型
Epoch 9/50: LR=2.00e-04
Train Loss: 0.4121 | Acc: 0.8879
Val Loss: 0.4496 | Acc: 0.8390
↳ 保存最佳模型
Epoch 10/50: LR=2.00e-04
Train Loss: 0.3638 | Acc: 0.9116
Val Loss: 0.4447 | Acc: 0.8496
↳ 保存最佳模型
Epoch 11/50: LR=2.00e-04
Trai

## vit_base

In [13]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

# 转换为PyTorch张量
train_features_tensor = torch.FloatTensor(train_features)
train_labels_tensor = torch.LongTensor(train_labels)
val_features_tensor = torch.FloatTensor(val_features)
val_labels_tensor = torch.LongTensor(val_labels)
test_features_tensor = torch.FloatTensor(test_features)
test_labels_tensor = torch.LongTensor(test_labels)

# 创建数据集和数据加载器
train_dataset = TensorDataset(train_features_tensor, train_labels_tensor)
val_dataset = TensorDataset(val_features_tensor, val_labels_tensor)
test_dataset = TensorDataset(test_features_tensor, test_labels_tensor)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# 定义简单分类器
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim=768, num_classes=2):
        super(SimpleClassifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.fc(x)

model = SimpleClassifier(input_dim=train_features.shape[1]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.5)

# 训练函数
def train_model(model, train_loader, val_loader, num_epochs=20):
    best_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        train_loss, train_correct = 0, 0
        
        for features, labels in train_loader:
            features, labels = features.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_correct += (preds == labels).sum().item()
        
        # 验证
        val_loss, val_acc = evaluate(model, val_loader)
        scheduler.step(val_acc)
        
        train_loss /= len(train_loader)
        train_acc = train_correct / len(train_dataset)
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f}')
        
        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_simple_classifier.pth')
            print('↳ 保存最佳模型')

def evaluate(model, data_loader):
    model.eval()
    loss, correct = 0, 0
    
    with torch.no_grad():
        for features, labels in data_loader:
            features, labels = features.to(device), labels.to(device)
            outputs = model(features)
            loss += criterion(outputs, labels).item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
    
    loss /= len(data_loader)
    acc = correct / len(data_loader.dataset)
    return loss, acc

# 训练和评估
print("\n训练简单分类器...")
train_model(model, train_loader, val_loader, num_epochs=20)

# 加载最佳模型并测试
model.load_state_dict(torch.load('best_simple_classifier.pth'))
test_loss, test_acc = evaluate(model, test_loader)
print(f"\n测试集性能 - 准确率: {test_acc:.4f}")


训练简单分类器...
Epoch 1/20:
Train Loss: 0.4459 | Acc: 0.7851
Val Loss: 0.4014 | Acc: 0.8030
↳ 保存最佳模型
Epoch 2/20:
Train Loss: 0.2865 | Acc: 0.8824
Val Loss: 0.3718 | Acc: 0.8199
↳ 保存最佳模型
Epoch 3/20:
Train Loss: 0.2267 | Acc: 0.9043
Val Loss: 0.4453 | Acc: 0.8008
Epoch 4/20:
Train Loss: 0.1794 | Acc: 0.9378
Val Loss: 0.3919 | Acc: 0.8326
↳ 保存最佳模型
Epoch 5/20:
Train Loss: 0.1409 | Acc: 0.9590
Val Loss: 0.4108 | Acc: 0.8242
Epoch 6/20:
Train Loss: 0.1156 | Acc: 0.9633
Val Loss: 0.4343 | Acc: 0.8157
Epoch 7/20:
Train Loss: 0.0838 | Acc: 0.9787
Val Loss: 0.4776 | Acc: 0.8284
Epoch 8/20:
Train Loss: 0.0935 | Acc: 0.9691
Val Loss: 0.4149 | Acc: 0.8114
Epoch 9/20:
Train Loss: 0.0488 | Acc: 0.9910
Val Loss: 0.4434 | Acc: 0.8220
Epoch 10/20:
Train Loss: 0.0427 | Acc: 0.9926
Val Loss: 0.4326 | Acc: 0.8263
Epoch 11/20:
Train Loss: 0.0364 | Acc: 0.9936
Val Loss: 0.4534 | Acc: 0.8326
Epoch 12/20:
Train Loss: 0.0321 | Acc: 0.9957
Val Loss: 0.4623 | Acc: 0.8347
↳ 保存最佳模型
Epoch 13/20:
Train Loss: 0.0293 | Acc