In [1]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from datetime import datetime
from torch.cuda.amp import autocast, GradScaler

class ImageClassifier:
    def __init__(self, device='cuda:0'):
        self.device = torch.device(device)
        self._setup_device()
        
    def _setup_device(self):
        torch.cuda.set_device(self.device)
        torch.backends.cudnn.benchmark = True
        print(f"Using device: {self.device}")

    class ImageDataset(Dataset):
        def __init__(self, file_paths, labels, transform=None):
            self.file_paths = file_paths
            self.labels = labels
            self.transform = transform

        def __len__(self):
            return len(self.file_paths)

        def __getitem__(self, idx):
            image = Image.open(self.file_paths[idx]).convert("RGB")
            return self.transform(image) if self.transform else image, self.labels[idx]

    def load_data(self, root_dir, target_folder, limit_per_class=None):
        """
        加载指定文件夹结构的数据
        :param root_dir: 根目录（包含各分类文件夹的父目录）
        :param target_folder: 要训练的特征文件夹名称（如 'trajectory_pos'）
        :param limit_per_class: 每个类别最大样本数（None表示不限制）
        :return: (file_paths, labels, class_mapping)
        """
        file_paths = []
        labels = []
        class_mapping = {}
        
        # 获取所有类别文件夹
        class_folders = sorted([d.name for d in os.scandir(root_dir) if d.is_dir()])
        class_mapping = {cls: idx for idx, cls in enumerate(class_folders)}
        
        for cls_name, cls_idx in class_mapping.items():
            target_path = os.path.join(root_dir, cls_name, target_folder)
            
            if not os.path.exists(target_path):
                print(f"Warning: Missing {target_folder} in {cls_name}")
                continue
                
            images = [
                os.path.join(target_path, f) 
                for f in os.listdir(target_path) 
                if f.lower().endswith(('.png', '.jpg', '.jpeg'))
            ]
            
            if limit_per_class:
                images = images[:limit_per_class]
                
            file_paths.extend(images)
            labels.extend([cls_idx] * len(images))
        
        print(f"\nLoaded {len(class_mapping)} classes from {target_folder}")
        self._print_class_stats(labels, class_mapping)
        return file_paths, labels, class_mapping

    def _print_class_stats(self, labels, class_mapping):
        unique, counts = np.unique(labels, return_counts=True)
        for cls_idx, count in zip(unique, counts):
            cls_name = list(class_mapping.keys())[cls_idx]
            print(f"Class {cls_idx} ({cls_name}): {count} samples")

    def create_model(self, num_classes, model_name='resnet18'):
        """创建可配置的模型"""
        model_map = {
            'resnet18': models.resnet18,
            'resnet50': models.resnet50,
            'efficientnet_b0': models.efficientnet_b0
        }
        
        model = model_map[model_name](weights='DEFAULT')
        if 'resnet' in model_name:
            model.fc = nn.Linear(model.fc.in_features, num_classes)
        elif 'efficientnet' in model_name:
            model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
            
        return model.to(self.device)

    def train(self, config):
        """
        完整的训练流程
        :param config: 包含以下键的配置字典:
            - root_dir: 数据根目录
            - target_folder: 目标特征文件夹名称
            - save_dir: 结果保存目录
            - batch_size: 批大小
            - epochs: 训练轮数
            - lr: 学习率
            - input_size: 输入尺寸
            - limit_per_class: 每类最大样本数
            - model_name: 模型名称
            - patience: 早停耐心值
        """
        # 初始化配置
        config.setdefault('save_dir', 'model_results')
        config.setdefault('model_name', 'resnet18')
        config.setdefault('input_size', 224)
        config.setdefault('patience', 5)
        
        # 数据加载
        file_paths, labels, class_mapping = self.load_data(
            config['root_dir'],
            config['target_folder'],
            config.get('limit_per_class')
        )
        
        # 数据预处理
        train_transform, test_transform = self._get_transforms(config['input_size'])
        
        # 数据集划分
        train_loader, val_loader, test_loader = self._create_data_loaders(
            file_paths, labels, 
            train_transform, test_transform,
            config['batch_size']
        )
        
        # 模型初始化
        model = self.create_model(len(class_mapping), config['model_name'])
        optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-4)
        criterion = nn.CrossEntropyLoss()
        scaler = torch.amp.GradScaler()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)
        
        # 训练循环
        history = self._train_loop(
            model, optimizer, criterion, scheduler, scaler,
            train_loader, val_loader,
            config['epochs'], config['patience']
        )
        
        # 最终评估
        test_acc, cm = self.evaluate(model, test_loader)
        
        # 保存结果
        self._save_results(
            model, history, cm, class_mapping,
            config, test_acc
        )

    def _get_transforms(self, input_size):
        train_transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        test_transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        return train_transform, test_transform

    def _create_data_loaders(self, file_paths, labels, train_trans, test_trans, batch_size):
        # 数据集划分（60-20-20）
        train_p, test_p, train_l, test_l = train_test_split(
            file_paths, labels, test_size=0.2, stratify=labels, random_state=42
        )
        train_p, val_p, train_l, val_l = train_test_split(
            train_p, train_l, test_size=0.25, stratify=train_l, random_state=42
        )
        
        return (
            DataLoader(self.ImageDataset(train_p, train_l, train_trans), 
                      batch_size, shuffle=True, pin_memory=True),
            DataLoader(self.ImageDataset(val_p, val_l, test_trans), 
                      batch_size, shuffle=False, pin_memory=True),
            DataLoader(self.ImageDataset(test_p, test_l, test_trans), 
                      batch_size, shuffle=False, pin_memory=True)
        )

    def _train_loop(self, model, optimizer, criterion, scheduler, scaler, 
                   train_loader, val_loader, epochs, patience):
        history = {
            'train_loss': [], 'val_loss': [],
            'train_acc': [], 'val_acc': [],
            'lr': []
        }
        best_acc = 0.0
        early_stop_counter = 0
        
        for epoch in range(epochs):
            # 训练阶段
            model.train()
            train_loss, correct, total = 0.0, 0, 0
            
            for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                with torch.amp.autocast(device_type='cuda'):
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
            
            # 验证阶段
            val_acc, val_loss = self._validate(model, criterion, val_loader)
            
            # 记录指标
            train_acc = correct / total
            history['train_loss'].append(train_loss/len(train_loader))
            history['val_loss'].append(val_loss)
            history['train_acc'].append(train_acc)
            history['val_acc'].append(val_acc)
            history['lr'].append(optimizer.param_groups[0]['lr'])
            
            # 学习率调整
            scheduler.step(val_acc)
            
            # 早停机制
            if val_acc > best_acc:
                best_acc = val_acc
                early_stop_counter = 0
                torch.save(model.state_dict(), "best_model.pth")
            else:
                early_stop_counter += 1
                if early_stop_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
            
            print(f"Epoch {epoch+1}: "
                  f"Train Loss: {history['train_loss'][-1]:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Train Acc: {train_acc:.4f} | "
                  f"Val Acc: {val_acc:.4f} | "
                  f"LR: {history['lr'][-1]:.2e}")
        
        model.load_state_dict(torch.load("best_model.pth"))
        return history

    def _validate(self, model, criterion, val_loader):
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = model(images)
                
                val_loss += criterion(outputs, labels).item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
        return correct / total, val_loss / len(val_loader)

    def evaluate(self, model, test_loader):
        model.eval()
        all_preds = []
        all_labels = []
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = model(images)
                
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        test_acc = correct / total
        cm = confusion_matrix(all_labels, all_preds)
        return test_acc, cm

    def _save_results(self, model, history, cm, class_mapping, config, test_acc):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        save_path = os.path.join(
            config['save_dir'],
            f"{config['target_folder']}_{timestamp}"
        )
        os.makedirs(save_path, exist_ok=True)
        
        # 保存模型
        torch.save({
            'model_state': model.state_dict(),
            'class_mapping': class_mapping,
            'config': config,
            'test_acc': test_acc
        }, os.path.join(save_path, "model.pth"))
        
        # 可视化结果
        self._plot_training_curves(history, save_path)
        self._plot_confusion_matrix(cm, class_mapping, save_path)
        
        # 保存报告
        self._save_report(config, history, test_acc, class_mapping, save_path)
        
        print(f"\nResults saved in: {save_path}")

    def _plot_training_curves(self, history, save_path):
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(history['train_loss'], label='Train')
        plt.plot(history['val_loss'], label='Validation')
        plt.title('Loss Curves')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history['train_acc'], label='Train')
        plt.plot(history['val_acc'], label='Validation')
        plt.title('Accuracy Curves')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, "training_curves.png"))
        plt.close()

    def _plot_confusion_matrix(self, cm, class_mapping, save_path):
        plt.figure(figsize=(15, 12))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_mapping.keys(),
                    yticklabels=class_mapping.keys())
        plt.title("Confusion Matrix")
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, "confusion_matrix.png"))
        plt.close()

    def _save_report(self, config, history, test_acc, class_mapping, save_path):
        with open(os.path.join(save_path, "report.txt"), "w") as f:
            f.write("=== Training Configuration ===\n")
            for k, v in config.items():
                f.write(f"{k}: {v}\n")
            
            f.write("\n=== Performance Summary ===\n")
            f.write(f"Best Validation Accuracy: {max(history['val_acc']):.4f}\n")
            f.write(f"Final Test Accuracy: {test_acc:.4f}\n")
            
            f.write("\n=== Class Distribution ===\n")
            for cls_name, idx in class_mapping.items():
                count = sum(1 for lbl in history['all_labels'] if lbl == idx)
                f.write(f"Class {idx} ({cls_name}): {count} samples\n")

if __name__ == "__main__":
    # 使用示例
    classifier = ImageClassifier()
    
    config = {
        'root_dir': "../../IQ_signal_plots",
        'target_folder': "trajectory_pos",  # 可修改为任意文件夹名称
        'save_dir': "training_results",
        'batch_size': 256,
        'epochs': 30,
        'lr': 0.0005,
        'limit_per_class': None,
        'input_size': 224,
        'model_name': "resnet18",
        'patience': 5
    }
    
    classifier.train(config)

Using device: cuda:0

Loaded 150 classes from trajectory_pos
Class 0 (1-1): 255 samples
Class 1 (1-10): 255 samples
Class 2 (1-11): 255 samples
Class 3 (1-12): 255 samples
Class 4 (1-14): 255 samples
Class 5 (1-15): 255 samples
Class 6 (1-16): 255 samples
Class 7 (1-18): 255 samples
Class 8 (1-19): 255 samples
Class 9 (1-2): 255 samples
Class 10 (1-8): 255 samples
Class 11 (10-1): 255 samples
Class 12 (10-10): 255 samples
Class 13 (10-11): 255 samples
Class 14 (10-17): 255 samples
Class 15 (10-4): 255 samples
Class 16 (10-7): 255 samples
Class 17 (11-1): 255 samples
Class 18 (11-10): 255 samples
Class 19 (11-17): 255 samples
Class 20 (11-19): 255 samples
Class 21 (11-20): 255 samples
Class 22 (11-4): 255 samples
Class 23 (11-7): 255 samples
Class 24 (12-1): 255 samples
Class 25 (12-19): 255 samples
Class 26 (12-20): 255 samples
Class 27 (12-7): 255 samples
Class 28 (13-14): 255 samples
Class 29 (13-18): 255 samples
Class 30 (13-19): 255 samples
Class 31 (13-20): 255 samples
Class 32 (1

Epoch 1/30: 100%|██████████| 90/90 [04:10<00:00,  2.79s/it]


Epoch 1: Train Loss: 4.3316 | Val Loss: 3.9911 | Train Acc: 0.0651 | Val Acc: 0.0971 | LR: 5.00e-04


Epoch 2/30: 100%|██████████| 90/90 [04:52<00:00,  3.25s/it]


Epoch 2: Train Loss: 3.5680 | Val Loss: 3.5473 | Train Acc: 0.1554 | Val Acc: 0.1561 | LR: 5.00e-04


Epoch 3/30: 100%|██████████| 90/90 [07:36<00:00,  5.08s/it]


Epoch 3: Train Loss: 3.1402 | Val Loss: 3.3915 | Train Acc: 0.2265 | Val Acc: 0.1788 | LR: 5.00e-04


Epoch 4/30: 100%|██████████| 90/90 [05:35<00:00,  3.73s/it]


Epoch 4: Train Loss: 2.7979 | Val Loss: 3.0769 | Train Acc: 0.2880 | Val Acc: 0.2361 | LR: 5.00e-04


Epoch 5/30: 100%|██████████| 90/90 [04:29<00:00,  2.99s/it]


Epoch 5: Train Loss: 2.5251 | Val Loss: 3.3487 | Train Acc: 0.3458 | Val Acc: 0.2132 | LR: 5.00e-04


Epoch 6/30: 100%|██████████| 90/90 [04:26<00:00,  2.96s/it]


Epoch 6: Train Loss: 2.2748 | Val Loss: 2.8921 | Train Acc: 0.4013 | Val Acc: 0.2753 | LR: 5.00e-04


Epoch 7/30: 100%|██████████| 90/90 [04:21<00:00,  2.90s/it]


Epoch 7: Train Loss: 2.0389 | Val Loss: 2.9790 | Train Acc: 0.4585 | Val Acc: 0.2650 | LR: 5.00e-04


Epoch 8/30: 100%|██████████| 90/90 [04:49<00:00,  3.22s/it]


Epoch 8: Train Loss: 1.7752 | Val Loss: 2.8482 | Train Acc: 0.5295 | Val Acc: 0.2941 | LR: 5.00e-04


Epoch 9/30: 100%|██████████| 90/90 [05:54<00:00,  3.94s/it]


Epoch 9: Train Loss: 1.5035 | Val Loss: 2.9288 | Train Acc: 0.5993 | Val Acc: 0.2872 | LR: 5.00e-04


Epoch 10/30:  40%|████      | 36/90 [02:44<04:06,  4.57s/it]


FileNotFoundError: [Errno 2] No such file or directory: 'E:\\program\\IQ_signal_plots\\1-10\\trajectory_pos\\trajectory_pos_154.png'