代码功能：训练分类器

In [1]:
import numpy as np
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

# load data
需要先运行 data_organize.ipynb 组织数据结构

In [None]:
data_dir = '/data7/cyd/files/data/AutismDataset/split_data'

vit_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = ImageFolder(os.path.join(data_dir, 'train'), transform=vit_transform)
val_dataset = ImageFolder(os.path.join(data_dir, 'val'), transform=vit_transform)
test_dataset = ImageFolder(os.path.join(data_dir, 'test'), transform=vit_transform)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# extract feature
使用 ViT 提取图像特征并保存

In [None]:
from models.feature_extractor import ViTFeatureExtractor
from utils import extract_and_save_features

# load model
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')

model = ViTFeatureExtractor(
    model_name="vit_large_patch16_224",
    # model_name="vit_large_patch16_224.augreg_in21k_ft_in1k",
    # ckpt_path = "/data7/cyd/.cache/huggingface/hub/models--timm--vit_large_patch16_224.augreg_in21k_ft_in1k/snapshots/0930ab3308b84cb2ae091a4a80703c459412a4c7/model.safetensors"
).to(device)
model.eval()

In [15]:
print("为训练集提取特征...")
extract_and_save_features(train_loader, model, train_dataset, '/data7/cyd/files/data/AutismDataset/vit_large_features/train', device)

print("\n为验证集提取特征...")
extract_and_save_features(val_loader, model, val_dataset, '/data7/cyd/files/data/AutismDataset/vit_large_features/val', device)

print("\n为测试集提取特征...")
extract_and_save_features(test_loader, model, test_dataset, '/data7/cyd/files/data/AutismDataset/vit_large_features/test', device)

为训练集提取特征...


提取特征: 100%|██████████| 59/59 [00:20<00:00,  2.92it/s]


特征已保存到 /data7/cyd/files/data/AutismDataset/vit_large_features/train
特征形状: (1880, 1000), 标签形状: (1880,)

为验证集提取特征...


提取特征: 100%|██████████| 15/15 [00:03<00:00,  4.71it/s]


特征已保存到 /data7/cyd/files/data/AutismDataset/vit_large_features/val
特征形状: (472, 1000), 标签形状: (472,)

为测试集提取特征...


提取特征: 100%|██████████| 19/19 [00:03<00:00,  5.61it/s]

特征已保存到 /data7/cyd/files/data/AutismDataset/vit_large_features/test
特征形状: (588, 1000), 标签形状: (588,)





# load features  
加载保存的特征

In [2]:
from utils import load_features

train_features, train_labels = load_features('/data7/cyd/files/data/AutismDataset/vit_large_features/train')
val_features, val_labels = load_features('/data7/cyd/files/data/AutismDataset/vit_large_features/val')
test_features, test_labels = load_features('/data7/cyd/files/data/AutismDataset/vit_large_features/test')

# 合并训练和验证集
X = np.concatenate([train_features, val_features])
y = np.concatenate([train_labels, val_labels])

print(f"训练特征形状: {train_features.shape}, 训练标签形状: {train_labels.shape}")
print(f"测试特征形状: {test_features.shape}, 测试标签形状: {test_labels.shape}")

训练特征形状: (1880, 1000), 训练标签形状: (1880,)
测试特征形状: (588, 1000), 测试标签形状: (588,)


# train  
用 ViT 提取出的特征作为数据训练分类器

In [3]:
from utils import create_data_loaders

batch_size = 64
train_loader, val_loader, test_loader = create_data_loaders(
    train_features, train_labels,
    val_features, val_labels,
    test_features, test_labels,
    batch_size=batch_size
)

In [11]:
from models.classifier import ViTLargeClassifier
from tqdm.auto import tqdm
import torch
import torch.nn as nn

device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')

class Trainer:
    def __init__(self, model, device, optimizer, criterion, scheduler=None):
        self.model = model.to(device)
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler
        self.best_metric = 0.0
    
    def train_epoch(self, train_loader):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        
        return epoch_loss, epoch_acc
    
    def evaluate(self, data_loader, return_predictions=False):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []
        all_probs = []
        
        with torch.no_grad():
            for inputs, labels in data_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
                probs = torch.softmax(outputs, dim=1)
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
        
        epoch_loss = running_loss / len(data_loader)
        epoch_acc = correct / total
        
        if return_predictions:
            return epoch_loss, epoch_acc, all_labels, all_preds, all_probs
        return epoch_loss, epoch_acc
    
    def train(self, train_loader, val_loader, num_epochs=30, early_stop_patience=5):
        best_model_wts = None
        no_improve = 0
        
        epoch_iter = tqdm(range(num_epochs), desc='Epochs', bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]')
        for epoch in epoch_iter:
            train_loss, train_acc = self.train_epoch(train_loader)
            val_loss, val_acc = self.evaluate(val_loader)
            
            if self.scheduler is not None:
                if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_acc)
                else:
                    self.scheduler.step()
            
            lr = self.optimizer.param_groups[0]['lr']
            
            epoch_iter.set_postfix({
                'lr': f'{lr:.2e}',
                'train_loss': f'{train_loss:.4f}',
                'train_acc': f'{train_acc:.4f}',
                'val_loss': f'{val_loss:.4f}',
                'val_acc': f'{val_acc:.4f}'
            })
            
            self.best_metric = val_acc
            best_model_wts = self.model.state_dict()
            torch.save(best_model_wts, 'best_classifier.pth')
            no_improve = 0
        
        if best_model_wts is not None:
            self.model.load_state_dict(best_model_wts)
        
        return self.model

model = ViTLargeClassifier(input_dim=train_features.shape[1]).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)

def lr_lambda(epoch):
    if epoch < 5:
        return (epoch + 1) / 5
    elif epoch < 20:
        return 1.0
    else:
        return max(0.0, (25 - epoch) / 5)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# train
trainer = Trainer(model, device, optimizer, criterion, scheduler)
trainer.train(train_loader, val_loader, num_epochs=50)

# test
test_loss, test_acc, y_true, y_pred, y_probs = trainer.evaluate(test_loader, return_predictions=True)

print(f"\nTest Performance:")
print(f"- Accuracy: {test_acc:.4f}")
print(f"- Loss: {test_loss:.4f}")

torch.save({
    'model_state_dict': model.state_dict(),
    'input_dim': train_features.shape[1],
    'num_classes': 2
}, 'final_classifier.pth')

Epochs:   0%|          | 0/50 [00:00<?, ?it/s]


Test Performance:
- Accuracy: 0.8503
- Loss: 0.4583
