In [None]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import numpy as np
import random
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from datetime import datetime
from torch.cuda.amp import autocast, GradScaler

class ImageClassifier:
    def __init__(self, device='cuda:0'):
        self.device = torch.device(device)
        self._setup_device()
    
    def _setup_device(self):
        """初始化设备配置"""
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA is not available. Please enable a GPU.")
        torch.cuda.set_device(self.device)
        torch.backends.cudnn.benchmark = True
        print(f"Using device: {self.device}")
        
    # 将 ImageDataset 定义在类作用域内
    class ImageDataset(Dataset):
        def __init__(self, file_paths, labels, transform=None):
            self.file_paths = file_paths
            self.labels = labels
            self.transform = transform

        def __len__(self):
            return len(self.file_paths)

        def __getitem__(self, idx):
            image = Image.open(self.file_paths[idx]).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, self.labels[idx]
        
    def load_data(self, root_dir, target_folder, num_classes_to_select=None, 
             limit_per_class=None, use_all_images=False):
        """
        修正后的参数逻辑:
        :param num_classes_to_select: 选择训练的类别数量
        :param limit_per_class: 每个类别的最大图片数
        """
        # 获取所有有效类别
        all_classes = [d.name for d in os.scandir(root_dir) if d.is_dir()]
        if not all_classes:
            raise ValueError(f"No valid classes found in {root_dir}")

        # 类别选择验证逻辑
        if num_classes_to_select is not None:
            if not isinstance(num_classes_to_select, int) or num_classes_to_select <= 0:
                raise ValueError("num_classes_to_select must be a positive integer")
            
            # 当请求类别数超过实际数量时自动修正
            if num_classes_to_select > len(all_classes):
                print(f"Warning: Requested {num_classes_to_select} classes but only {len(all_classes)} available. Using all classes.")
                num_classes_to_select = len(all_classes)
            
            selected_classes = random.sample(all_classes, num_classes_to_select)
            print(f"Randomly selected {num_classes_to_select} classes from {len(all_classes)} total classes")
        else:
            selected_classes = all_classes
            num_classes_to_select = len(all_classes)  # 保持参数记录准确

        # 处理每个类别的图片数量
        class_mapping = {cls: idx for idx, cls in enumerate(selected_classes)}
        file_paths = []
        labels = []
        
        for cls_name in selected_classes:
            target_path = os.path.join(root_dir, cls_name, target_folder)
            if not os.path.exists(target_path):
                print(f"Warning: Missing {target_folder} in {cls_name}")
                continue
                
            # 获取所有图片文件
            images = [
                os.path.join(target_path, f) 
                for f in os.listdir(target_path) 
                if f.lower().endswith(('.png', '.jpg', '.jpeg'))
            ]
            
            # 图片数量控制逻辑
            if limit_per_class and not use_all_images:
                if len(images) < limit_per_class:
                    print(f"Warning: Class {cls_name} only has {len(images)} images (requested {limit_per_class})")
                images = images[:limit_per_class]  # 安全截断
                
            file_paths.extend(images)
            labels.extend([class_mapping[cls_name]] * len(images))

        print(f"\nLoaded {len(class_mapping)} classes from {target_folder}")
        self._print_class_stats(labels, class_mapping)
        return file_paths, labels, class_mapping

    def _print_class_stats(self, labels, class_mapping):
        unique, counts = np.unique(labels, return_counts=True)
        for cls_idx, count in zip(unique, counts):
            cls_name = list(class_mapping.keys())[cls_idx]
            print(f"Class {cls_idx} ({cls_name}): {count} samples")

    def create_model(self, num_classes, model_name='resnet18'):
        """创建可配置的模型"""
        model_map = {
            'resnet18': models.resnet18,
            'resnet50': models.resnet50,
            'efficientnet_b0': models.efficientnet_b0
        }
        
        model = model_map[model_name](weights='DEFAULT')
        if 'resnet' in model_name:
            model.fc = nn.Linear(model.fc.in_features, num_classes)
        elif 'efficientnet' in model_name:
            model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
            
        return model.to(self.device)

    def train(self, config):
        """完整的训练流程"""
        # 初始化配置
        config.setdefault('save_dir', 'model_results')
        config.setdefault('model_name', 'resnet18')
        config.setdefault('input_size', 224)
        config.setdefault('patience', 5)

        # 必须先加载数据以获取 class_mapping
        file_paths, labels, class_mapping = self.load_data(
            root_dir=config['root_dir'],
            target_folder=config['target_folder'],
            num_classes_to_select=config.get('num_classes_to_select'),
            limit_per_class=config.get('limit_per_class'),
            use_all_images=config.get('use_all_images', False)
        )

        # 创建唯一保存目录
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        folder_name = (
            f"{config['target_folder']}_"
            f"{timestamp}-"
            f"{len(class_mapping)}_classes"
        )
        if config.get('num_classes_to_select'):
            folder_name += f"-selected_{config['num_classes_to_select']}"
        if config.get('limit_per_class'):
            folder_name += f"-limit_{config['limit_per_class']}"
        save_path = os.path.join(config['save_dir'], folder_name)
        os.makedirs(save_path, exist_ok=True)

        # 数据加载
        file_paths, labels, class_mapping = self.load_data(
            root_dir=config['root_dir'],
            target_folder=config['target_folder'],
            num_classes_to_select=config.get('num_classes_to_select'),
            limit_per_class=config.get('limit_per_class'),
            use_all_images=config.get('use_all_images', False)
        )

        # 数据预处理
        train_transform, test_transform = self._get_transforms(config['input_size'])
        
        # 数据集划分
        train_loader, val_loader, test_loader = self._create_data_loaders(
            file_paths, labels, 
            train_transform, test_transform,
            config['batch_size']
        )
        
        # 模型初始化
        model = self.create_model(len(class_mapping), config['model_name'])
        optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-4)
        criterion = nn.CrossEntropyLoss()
        scaler = torch.amp.GradScaler()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)
        
        # 训练循环
        history = self._train_loop(
            model, optimizer, criterion, scheduler, scaler,
            train_loader, val_loader,
            config['epochs'], config['patience'],
            save_path  # 传递保存路径
        )
        
        # 最终评估
        test_acc, cm = self.evaluate(model, test_loader)
        
        # 保存结果
        self._save_results(
            model, history, cm, class_mapping,
            config, test_acc, save_path, labels
        )

    def _get_transforms(self, input_size):
        train_transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        test_transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        return train_transform, test_transform

    def _create_data_loaders(self, file_paths, labels, train_trans, test_trans, batch_size):
        # 数据集划分（60-20-20）
        train_p, test_p, train_l, test_l = train_test_split(
            file_paths, labels, test_size=0.2, stratify=labels, random_state=42
        )
        train_p, val_p, train_l, val_l = train_test_split(
            train_p, train_l, test_size=0.25, stratify=train_l, random_state=42
        )
        
        return (
            DataLoader(self.ImageDataset(train_p, train_l, train_trans), 
                      batch_size, shuffle=True, pin_memory=True),
            DataLoader(self.ImageDataset(val_p, val_l, test_trans), 
                      batch_size, shuffle=False, pin_memory=True),
            DataLoader(self.ImageDataset(test_p, test_l, test_trans), 
                      batch_size, shuffle=False, pin_memory=True)
        )

    def _train_loop(self, model, optimizer, criterion, scheduler, scaler, 
                   train_loader, val_loader, epochs, patience, save_path):
        history = {
            'train_loss': [], 'val_loss': [],
            'train_acc': [], 'val_acc': [],
            'lr': []
        }
        best_acc = 0.0
        early_stop_counter = 0
        
        for epoch in range(epochs):
            # 训练阶段
            model.train()
            train_loss, correct, total = 0.0, 0, 0
            
            for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                with torch.amp.autocast(device_type='cuda'):
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
            
            # 验证阶段
            val_acc, val_loss = self._validate(model, criterion, val_loader)
            
            # 记录指标
            train_acc = correct / total
            history['train_loss'].append(train_loss/len(train_loader))
            history['val_loss'].append(val_loss)
            history['train_acc'].append(train_acc)
            history['val_acc'].append(val_acc)
            history['lr'].append(optimizer.param_groups[0]['lr'])
            
            # 学习率调整
            scheduler.step(val_acc)
            
            # 早停机制和模型保存
            if val_acc > best_acc:
                best_acc = val_acc
                early_stop_counter = 0
                # 保存到指定路径
                model_save_path = os.path.join(save_path, "best_model.pth")
                torch.save(model.state_dict(), model_save_path)
            else:
                early_stop_counter += 1
                if early_stop_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
            
            print(f"Epoch {epoch+1}: "
                  f"Train Loss: {history['train_loss'][-1]:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Train Acc: {train_acc:.4f} | "
                  f"Val Acc: {val_acc:.4f} | "
                  f"LR: {history['lr'][-1]:.2e}")
        
         # 加载指定路径的最佳模型
        model.load_state_dict(torch.load(os.path.join(save_path, "best_model.pth")))
        return history

    def _validate(self, model, criterion, val_loader):
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = model(images)
                
                val_loss += criterion(outputs, labels).item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                
        return correct / total, val_loss / len(val_loader)

    def evaluate(self, model, test_loader):
        model.eval()
        all_preds = []
        all_labels = []
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = model(images)
                
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        test_acc = correct / total
        cm = confusion_matrix(all_labels, all_preds)
        return test_acc, cm

    def _save_results(self, model, history, cm, class_mapping, config, test_acc, save_path, labels):
        # 生成带特征文件夹信息的目录名
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        folder_name = (
            f"{config['target_folder']}_"
            f"{timestamp}-"
            f"{len(class_mapping)}_classes"
        )
        
        # 添加参数标记
        if config.get('num_classes_to_select'):
            folder_name += f"-selected_{config['num_classes_to_select']}"
        if config.get('limit_per_class'):
            folder_name += f"-limit_{config['limit_per_class']}"
        elif not config.get('use_all_images'):
            folder_name += "-default_limit"
        
        save_path = os.path.join(config['save_dir'], folder_name)
        os.makedirs(save_path, exist_ok=True)

        # 保存最佳模型到指定路径
        model_save_path = os.path.join(save_path, "best_model.pth")
        torch.save(model.state_dict(), model_save_path)

        # 保存训练曲线
        self._plot_training_curves(history, save_path)
        
        # 保存混淆矩阵
        self._plot_confusion_matrix(cm, class_mapping, save_path)
        
        # 保存报告
        self._save_report(config, history, test_acc, class_mapping, save_path, labels)

    def _plot_training_curves(self, history, save_path):
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(history['train_loss'], label='Train')
        plt.plot(history['val_loss'], label='Validation')
        plt.title('Loss Curves')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history['train_acc'], label='Train')
        plt.plot(history['val_acc'], label='Validation')
        plt.title('Accuracy Curves')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
         # 保存并关闭
        plt.savefig(os.path.join(save_path, "training_curves.png"), bbox_inches='tight')
        plt.close()  # 关键：防止内存泄漏

    def _plot_confusion_matrix(self, cm, class_mapping, save_path):
        plt.figure(figsize=(15, 12))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_mapping.keys(),
                    yticklabels=class_mapping.keys())
        plt.title("Confusion Matrix")
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)

        # 保存并关闭
        plt.savefig(os.path.join(save_path, "confusion_matrix.png"), bbox_inches='tight')
        plt.close()  # 关键：确保保存完成

    def _save_report(self, config, history, test_acc, class_mapping, save_path, labels):
        with open(os.path.join(save_path, "results.txt"), "w") as f:
            f.write("=== Experiment Summary ===\n")
            f.write(f"Feature Folder: {config['target_folder']}\n")  # 新增特征目录信息
            f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Total Classes: {len(class_mapping)}\n")
            f.write(f"Best Train Accuracy: {max(history['train_acc']):.4f}\n")
            f.write(f"Best Val Accuracy: {max(history['val_acc']):.4f}\n")
            f.write(f"Final Test Accuracy: {test_acc:.4f}\n\n")
            
            f.write("=== Training Parameters ===\n")
            f.write(f"num_classes_to_select: {config.get('num_classes_to_select', 'All')}\n")
            f.write(f"limit_per_class: {config.get('limit_per_class', 'No limit')}\n")
            f.write(f"use_all_images: {config.get('use_all_images', False)}\n")
            f.write(f"batch_size: {config['batch_size']}\n")
            f.write(f"epochs: {config['epochs']}\n")
            f.write(f"learning_rate: {config['lr']}\n")
            f.write(f"input_size: {config['input_size']}\n")
            f.write(f"model_name: {config['model_name']}\n\n")
            
            # 按epoch记录详细数据
            f.write("=== Epoch-wise Results ===\n")
            f.write("Epoch | Train Acc | Val Acc | Learning Rate\n")
            f.write("--------------------------------------------\n")
            for epoch, (train_acc, val_acc, lr) in enumerate(zip(
                history['train_acc'], 
                history['val_acc'],
                history['lr']
            )):
                f.write(f"{epoch+1:5d} | {train_acc:.4f}   | {val_acc:.4f}  | {lr:.2e}\n")
            f.write("\n")

            f.write("=== Class Distribution ===\n")
            # 使用传入的labels直接统计
            unique, counts = np.unique(labels, return_counts=True)
            for cls_idx, count in zip(unique, counts):
                cls_name = [k for k, v in class_mapping.items() if v == cls_idx][0]
                f.write(f"Class {cls_idx} ({cls_name}): {count} samples\n")

if __name__ == "__main__":
    classifier = ImageClassifier()
    
    config = {
        'root_dir': "../../IQ_signal_plots",
        'target_folder': "trajectory_pos",  # 可替换为其他特征文件夹
        'save_dir': "training_results",
        'batch_size': 256,
        'epochs': 30,
        'lr': 0.01,
        'num_classes_to_select': 150,      # 选择训练多少个类
        'limit_per_class': 1000,           # 每类最大样本数
        'use_all_images': False,          # 是否忽略limit_per_class
        'input_size': 224,
        'model_name': "resnet18",
        'patience': 5
    }
    
    classifier.train(config)

Using device: cuda:0
Randomly selected 150 classes from 150 total classes

Loaded 150 classes from trajectory_pos
Class 0 (19-9): 1000 samples
Class 1 (20-4): 1000 samples
Class 2 (19-3): 1000 samples
Class 3 (3-19): 1000 samples
Class 4 (10-7): 1000 samples
Class 5 (14-10): 1000 samples
Class 6 (9-1): 1000 samples
Class 7 (6-15): 1000 samples
Class 8 (1-10): 1000 samples
Class 9 (19-7): 1000 samples
Class 10 (2-5): 1000 samples
Class 11 (7-9): 1000 samples
Class 12 (8-14): 1000 samples
Class 13 (13-20): 1000 samples
Class 14 (4-11): 1000 samples
Class 15 (19-4): 1000 samples
Class 16 (2-20): 1000 samples
Class 17 (18-10): 1000 samples
Class 18 (14-9): 1000 samples
Class 19 (12-19): 1000 samples
Class 20 (8-3): 1000 samples
Class 21 (10-11): 1000 samples
Class 22 (7-14): 1000 samples
Class 23 (20-1): 1000 samples
Class 24 (1-2): 1000 samples
Class 25 (1-19): 1000 samples
Class 26 (15-19): 1000 samples
Class 27 (12-20): 1000 samples
Class 28 (2-12): 1000 samples
Class 29 (16-19): 1000 s

Epoch 1/5:   0%|          | 0/352 [00:04<?, ?it/s]


KeyboardInterrupt: 