In [26]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import numpy as np
import time
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from datetime import datetime
from torchvision.models import resnet18, ResNet18_Weights
from collections import defaultdict
import random
from torch.amp import GradScaler


In [None]:
# 指定 GPU 索引（假设选择 GPU 0）
GPU_INDEX = 0

# 检查 GPU 是否可用并设置设备
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please enable a GPU.")

device = torch.device(f"cuda:{GPU_INDEX}")
torch.cuda.set_device(GPU_INDEX)
print(f"Using device: {device}")

# 数据加载
class ImageDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

def load_image_paths(folder, plot_type='scatter_plots', limit_per_class=1000):
    file_paths = []
    labels = []
    label = 0  # 为每个发射机分配标签
    class_count = {}

    for tx_id in os.listdir(folder):
        tx_folder = os.path.join(folder, tx_id)
        if os.path.isdir(tx_folder):
            plot_folder = os.path.join(tx_folder, plot_type)
            if os.path.exists(plot_folder):
                tx_files = [f for f in os.listdir(plot_folder) if f.endswith('.png')]
                
                # 限制每个发射机类的图片数量
                tx_files = tx_files[:limit_per_class]

                for filename in tx_files:
                    file_paths.append(os.path.join(plot_folder, filename))
                    labels.append(label)

                # 统计每个类的图片数量
                class_count[label] = len(tx_files)

            label += 1

    # 打印每个类的图片数量
    print(f"Total classes: {len(class_count)}")
    for lbl, count in class_count.items():
        print(f"Class {lbl}: {count} images")

    return file_paths, labels

# 模型构建
def build_resnet_model(num_classes):
    # 使用预训练的ResNet模型
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)  # 加载预训练权重
    # 替换最后一层
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    return model

# 训练和评估
def train_and_evaluate_model(model, train_loader, val_loader, test_loader, device, epochs=20, lr=0.001, patience=5):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model = model.to(device)
    best_val_acc = 0
    best_epoch = 0
    train_losses, val_losses = [], []
    early_stop_counter = 0

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct, total = 0, 0

        # Training loop
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = correct / total
        train_losses.append(train_loss / len(train_loader))

        # Validation loop
        model.eval()
        val_loss = 0
        correct, total = 0, 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = correct / total
        val_losses.append(val_loss / len(val_loader))

        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_acc:.4f}")

        # Early stopping logic
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict()
            best_epoch = epoch
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("Early stopping triggered!")
                break

    print(f"Best Validation Accuracy: {best_val_acc:.4f} at epoch {best_epoch+1}")
    model.load_state_dict(best_model)

    # Test loop
    model.eval()
    correct, total = 0, 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_acc = correct / total
    print(f"Test Accuracy: {test_acc:.4f}")

    # Generate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    cm_fig = plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=list(range(len(set(all_labels)))), yticklabels=list(range(len(set(all_labels)))))
    cm_fig.savefig("confusion_matrix.png")
    
    return model, train_losses, val_losses, test_acc, cm


# 比较训练效果并保存结果
def compare_scatter_trajectory_training(folder, batch_size=32, epochs=20, lr=0.001):
    print("Loading scatter plot images...")
    scatter_paths, scatter_labels = load_image_paths(folder, plot_type='scatter_plots', limit_per_class=1000)
    print("Loading trajectory plot images...")
    trajectory_paths, trajectory_labels = load_image_paths(folder, plot_type='trajectory_plots', limit_per_class=1000)

    # 数据集分割
    scatter_train_paths, scatter_test_paths, scatter_train_labels, scatter_test_labels = train_test_split(
        scatter_paths, scatter_labels, test_size=0.2, random_state=42
    )
    trajectory_train_paths, trajectory_test_paths, trajectory_train_labels, trajectory_test_labels = train_test_split(
        trajectory_paths, trajectory_labels, test_size=0.2, random_state=42
    )

    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 构造数据集和数据加载器
    scatter_train_dataset = ImageDataset(scatter_train_paths, scatter_train_labels, transform)
    scatter_test_dataset = ImageDataset(scatter_test_paths, scatter_test_labels, transform)
    trajectory_train_dataset = ImageDataset(trajectory_train_paths, trajectory_train_labels, transform)
    trajectory_test_dataset = ImageDataset(trajectory_test_paths, trajectory_test_labels, transform)

    scatter_train_loader = DataLoader(scatter_train_dataset, batch_size=batch_size, shuffle=True)
    scatter_test_loader = DataLoader(scatter_test_dataset, batch_size=batch_size, shuffle=False)
    trajectory_train_loader = DataLoader(trajectory_train_dataset, batch_size=batch_size, shuffle=True)
    trajectory_test_loader = DataLoader(trajectory_test_dataset, batch_size=batch_size, shuffle=False)

    # 获取类别数量
    num_classes = len(set(scatter_labels))

    # 当前时间戳
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    folder_name = f"{timestamp}-1000_images_per_class-{num_classes}_classes"
    os.makedirs(folder_name, exist_ok=True)

    # 保存训练的模型和混淆矩阵
    # 构建和训练散点图模型
    print("\nTraining scatter plot model...")
    scatter_model = build_resnet_model(num_classes)
    scatter_model, scatter_train_losses, scatter_val_losses, scatter_test_acc, scatter_cm = train_and_evaluate_model(
        scatter_model, scatter_train_loader, scatter_test_loader, scatter_test_loader, device, epochs, lr
    )
    
    torch.save(scatter_model.state_dict(), os.path.join(folder_name, 'scatter_model.pth'))
    plt.savefig(os.path.join(folder_name, 'scatter_confusion_matrix.png'))

    # 保存训练日志
    with open(os.path.join(folder_name, 'scatter_training_log.txt'), 'w') as f:
        f.write(f"Training Results - Scatter Plot\n")
        f.write(f"Test Accuracy: {scatter_test_acc:.4f}\n")
        f.write(f"Train Losses: {scatter_train_losses}\n")
        f.write(f"Val Losses: {scatter_val_losses}\n")

    # 构建和训练轨迹图模型
    print("\nTraining trajectory plot model...")
    trajectory_model = build_resnet_model(num_classes)
    trajectory_model, trajectory_train_losses, trajectory_val_losses, trajectory_test_acc, trajectory_cm = train_and_evaluate_model(
        trajectory_model, trajectory_train_loader, trajectory_test_loader, trajectory_test_loader, device, epochs, lr
    )
    
    torch.save(trajectory_model.state_dict(), os.path.join(folder_name, 'trajectory_model.pth'))
    plt.savefig(os.path.join(folder_name, 'trajectory_confusion_matrix.png'))

    # 保存训练日志
    with open(os.path.join(folder_name, 'trajectory_training_log.txt'), 'w') as f:
        f.write(f"Training Results - Trajectory Plot\n")
        f.write(f"Test Accuracy: {trajectory_test_acc:.4f}\n")
        f.write(f"Train Losses: {trajectory_train_losses}\n")
        f.write(f"Val Losses: {trajectory_val_losses}\n")

# 使用示例：比较散点图和轨迹图的训练效果
folder = "../../IQ_signal_plots"  # 图像存储的文件夹
compare_scatter_trajectory_training(folder, batch_size=32, epochs=20, lr=0.001)


In [22]:
# 指定 GPU 索引
GPU_INDEX = 0

# 检查 GPU 是否可用并设置设备
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please enable a GPU.")

device = torch.device(f"cuda:{GPU_INDEX}")
torch.cuda.set_device(GPU_INDEX)
print(f"Using device: {device}")

# 数据加载类
class ImageDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

def load_image_paths_with_mapping(folder, plot_type, tx_label_mapping, 
                                 num_classes_to_select=None, 
                                 limit_per_class=1000, 
                                 use_all_images=False):
    """
    参数说明：
    num_classes_to_select - 要随机选择的类别数量（None表示选择全部）
    limit_per_class - 每个类别最多读取的图像数
    use_all_images - 是否忽略limit_per_class读取所有图像
    """
    file_paths = []
    labels = []
    class_count = defaultdict(int)

    # 随机选择指定数量的类别
    all_tx_ids = list(tx_label_mapping.keys())
    if num_classes_to_select is not None:
        selected_tx_ids = random.sample(all_tx_ids, num_classes_to_select)
        print(f"Randomly selected {num_classes_to_select} classes: {selected_tx_ids}")
    else:
        selected_tx_ids = all_tx_ids

    # 创建新的标签映射（确保标签从0开始连续）
    new_label_mapping = {tx_id: idx for idx, tx_id in enumerate(selected_tx_ids)}

    for tx_id in selected_tx_ids:
        original_label = tx_label_mapping[tx_id]
        tx_folder = os.path.join(folder, tx_id)
        plot_folder = os.path.join(tx_folder, plot_type)
        
        if os.path.exists(plot_folder):
            tx_files = [f for f in os.listdir(plot_folder) if f.endswith('.png')]
            
            # 随机打乱文件顺序
            random.shuffle(tx_files)
            
            # 数量控制逻辑
            if not use_all_images and limit_per_class is not None:
                tx_files = tx_files[:limit_per_class]
            
            # 使用新的连续标签
            new_label = new_label_mapping[tx_id]
            
            for filename in tx_files:
                file_paths.append(os.path.join(plot_folder, filename))
                labels.append(new_label)

            class_count[new_label] += len(tx_files)

    print(f"Loaded {len(class_count)} classes for {plot_type}")
    for lbl, count in class_count.items():
        print(f"Class {lbl}: {count} images")

    return file_paths, labels, new_label_mapping

# 模型构建
def build_resnet_model(num_classes, freeze_layers=True):
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    if freeze_layers:
        for param in model.parameters():
            param.requires_grad = False
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# 训练和评估
def train_and_evaluate_model(model, train_loader, val_loader, test_loader, device, save_dir, epochs=20, lr=0.001, patience=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    model = model.to(device)
    best_val_acc = 0
    best_epoch = 0
    train_losses, val_losses = [], []
    early_stop_counter = 0
    best_model = None

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct, total = 0, 0

        # 训练循环
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = correct / total
        train_losses.append(train_loss / len(train_loader))

        # 验证循环
        model.eval()
        val_loss = 0
        correct, total = 0, 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = correct / total
        val_losses.append(val_loss / len(val_loader))

        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_acc:.4f}")

        # 早停机制
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict()
            best_epoch = epoch
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("Early stopping triggered!")
                break

    print(f"Best Validation Accuracy: {best_val_acc:.4f} at epoch {best_epoch+1}")
    model.load_state_dict(best_model)

    # 测试循环
    model.eval()
    correct, total = 0, 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_acc = correct / total
    print(f"Test Accuracy: {test_acc:.4f}")

    # 生成混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)
    cm_fig = plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", 
                xticklabels=np.unique(all_labels), 
                yticklabels=np.unique(all_labels))
    plt.xlabel('Predicted')
    plt.ylabel('True')
    cm_fig.savefig(os.path.join(save_dir, "confusion_matrix.png"))
    plt.close(cm_fig)
    
    return model, train_losses, val_losses, test_acc, cm

# 比较训练效果
def compare_scatter_trajectory_training(folder, 
                                      num_classes_to_select=None,
                                      limit_per_class=1000,
                                      use_all_images=False,
                                      batch_size=32, 
                                      epochs=20, 
                                      lr=0.001):
    # 设置随机种子保证可重复性
    random.seed(42)
    
    # 获取所有TX_ID并创建原始标签映射
    all_tx_ids = sorted([tx_id for tx_id in os.listdir(folder) 
                       if os.path.isdir(os.path.join(folder, tx_id))])
    
    # 加载数据集（两个plot_type共享相同的类别选择）
    print("\nLoading scatter plot images...")
    scatter_paths, scatter_labels, label_mapping = load_image_paths_with_mapping(
        folder, 
        'scatter_plots', 
        tx_label_mapping={tx: idx for idx, tx in enumerate(all_tx_ids)},
        num_classes_to_select=num_classes_to_select,
        limit_per_class=limit_per_class,
        use_all_images=use_all_images
    )
    
    print("\nLoading trajectory plot images...")
    trajectory_paths, trajectory_labels, _ = load_image_paths_with_mapping(
        folder, 
        'trajectory_plots', 
        tx_label_mapping={tx: idx for idx, tx in enumerate(all_tx_ids)},
        num_classes_to_select=num_classes_to_select,
        limit_per_class=limit_per_class,
        use_all_images=use_all_images
    )

    # 数据预处理
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 数据集划分和加载
    def prepare_datasets(paths, labels, transform):
        # 划分训练+验证和测试集
        train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
            paths, labels, test_size=0.2, random_state=42, stratify=labels
        )
        # 再划分训练和验证集
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            train_val_paths, train_val_labels, test_size=0.25, random_state=42, stratify=train_val_labels
        )
        # 创建数据集
        train_dataset = ImageDataset(train_paths, train_labels, transform)
        val_dataset = ImageDataset(val_paths, val_labels, test_transform)
        test_dataset = ImageDataset(test_paths, test_labels, test_transform)
        return train_dataset, val_dataset, test_dataset

    # 准备数据集
    print("\nPreparing scatter plot datasets...")
    scatter_train, scatter_val, scatter_test = prepare_datasets(scatter_paths, scatter_labels, train_transform)
    print("Preparing trajectory plot datasets...")
    trajectory_train, trajectory_val, trajectory_test = prepare_datasets(trajectory_paths, trajectory_labels, train_transform)

    # 创建数据加载器
    def create_loaders(train, val, test):
        train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)
        return train_loader, val_loader, test_loader

    scatter_train_loader, scatter_val_loader, scatter_test_loader = create_loaders(scatter_train, scatter_val, scatter_test)
    trajectory_train_loader, trajectory_val_loader, trajectory_test_loader = create_loaders(trajectory_train, trajectory_val, trajectory_test)

    # 创建保存结果的文件夹
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    folder_name = f"{timestamp}-{len(label_mapping)}_classes"
    if num_classes_to_select: folder_name += f"-selected_{num_classes_to_select}"
    os.makedirs(folder_name, exist_ok=True)
    
    scatter_save_dir = os.path.join(folder_name, 'scatter')
    os.makedirs(scatter_save_dir, exist_ok=True)
    traj_save_dir = os.path.join(folder_name, 'trajectory')
    os.makedirs(traj_save_dir, exist_ok=True)

    # 训练散点图模型
    print("\nTraining scatter plot model...")
    scatter_model = build_resnet_model(num_classes=len(label_mapping), freeze_layers=False)
    scatter_model, _, _, scatter_test_acc, _ = train_and_evaluate_model(
        scatter_model, scatter_train_loader, scatter_val_loader, scatter_test_loader, 
        device, scatter_save_dir, epochs, lr
    )
    torch.save(scatter_model.state_dict(), os.path.join(scatter_save_dir, 'model.pth'))

    # 训练轨迹图模型
    print("\nTraining trajectory plot model...")
    trajectory_model = build_resnet_model(num_classes=len(label_mapping), freeze_layers=False)
    trajectory_model, _, _, traj_test_acc, _ = train_and_evaluate_model(
        trajectory_model, trajectory_train_loader, trajectory_val_loader, trajectory_test_loader, 
        device, traj_save_dir, epochs, lr
    )
    torch.save(trajectory_model.state_dict(), os.path.join(traj_save_dir, 'model.pth'))

    # 保存训练日志
    with open(os.path.join(folder_name, 'results.txt'), 'w') as f:
        f.write(f"Scatter Model Test Accuracy: {scatter_test_acc:.4f}\n")
        f.write(f"Trajectory Model Test Accuracy: {traj_test_acc:.4f}\n")
        f.write(f"Parameters:\n")
        f.write(f"num_classes_to_select: {num_classes_to_select}\n")
        f.write(f"limit_per_class: {limit_per_class}\n")
        f.write(f"use_all_images: {use_all_images}\n")

    print("\nTraining completed. Results saved in", folder_name)

# 使用示例
if __name__ == "__main__":
    folder = "../../IQ_signal_plots"

    # 基本使用（选择10个类别，每个类100张图）
compare_scatter_trajectory_training(
    folder,
    num_classes_to_select=10,
    limit_per_class=100,
    use_all_images=False
)

# 使用全部类别和全部图像
compare_scatter_trajectory_training(
    folder,
    num_classes_to_select=None,
    use_all_images=True
)

# 使用全部类别但限制数量
compare_scatter_trajectory_training(
    folder,
    num_classes_to_select=None,
    limit_per_class=500
)

Using device: cuda:0

Loading scatter plot images...
Randomly selected 10 classes: ['13-14', '1-16', '19-10', '18-2', '18-13', '14-11', '12-20', '8-13', '11-4', '20-4']
Loaded 10 classes for scatter_plots
Class 0: 100 images
Class 1: 100 images
Class 2: 100 images
Class 3: 100 images
Class 4: 100 images
Class 5: 100 images
Class 6: 100 images
Class 7: 100 images
Class 8: 100 images
Class 9: 100 images

Loading trajectory plot images...
Randomly selected 10 classes: ['1-1', '1-11', '2-15', '8-8', '18-14', '14-11', '2-17', '6-1', '13-7', '8-18']
Loaded 10 classes for trajectory_plots
Class 0: 100 images
Class 1: 100 images
Class 2: 100 images
Class 3: 100 images
Class 4: 100 images
Class 5: 100 images
Class 6: 100 images
Class 7: 100 images
Class 8: 100 images
Class 9: 100 images

Preparing scatter plot datasets...
Preparing trajectory plot datasets...

Training scatter plot model...


Epoch 1/20 - Training: 100%|██████████| 19/19 [00:03<00:00,  5.50it/s]


Epoch [1/20], Train Loss: 2.2661, Train Acc: 0.1600, Val Loss: 4.4703, Val Acc: 0.1200


Epoch 2/20 - Training: 100%|██████████| 19/19 [00:03<00:00,  5.21it/s]


Epoch [2/20], Train Loss: 1.9596, Train Acc: 0.2267, Val Loss: 3.1605, Val Acc: 0.1150


Epoch 3/20 - Training: 100%|██████████| 19/19 [00:03<00:00,  4.88it/s]


Epoch [3/20], Train Loss: 1.8397, Train Acc: 0.3100, Val Loss: 2.4457, Val Acc: 0.1550


Epoch 4/20 - Training: 100%|██████████| 19/19 [00:03<00:00,  4.85it/s]


Epoch [4/20], Train Loss: 1.7768, Train Acc: 0.3367, Val Loss: 4.4462, Val Acc: 0.1300


Epoch 5/20 - Training: 100%|██████████| 19/19 [00:03<00:00,  4.91it/s]


Epoch [5/20], Train Loss: 1.5764, Train Acc: 0.4133, Val Loss: 2.4376, Val Acc: 0.2300


Epoch 6/20 - Training: 100%|██████████| 19/19 [00:03<00:00,  4.90it/s]


Epoch [6/20], Train Loss: 1.5230, Train Acc: 0.4600, Val Loss: 7.5720, Val Acc: 0.1000


Epoch 7/20 - Training: 100%|██████████| 19/19 [00:03<00:00,  5.06it/s]


Epoch [7/20], Train Loss: 1.3296, Train Acc: 0.5150, Val Loss: 2.7661, Val Acc: 0.1350


Epoch 8/20 - Training: 100%|██████████| 19/19 [00:03<00:00,  4.92it/s]


Epoch [8/20], Train Loss: 1.1523, Train Acc: 0.5833, Val Loss: 3.3871, Val Acc: 0.2100


Epoch 9/20 - Training:  74%|███████▎  | 14/19 [00:03<00:01,  4.52it/s]


KeyboardInterrupt: 

In [None]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import numpy as np
import time
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from datetime import datetime
from torchvision.models import resnet18, ResNet18_Weights
from collections import defaultdict
import random
from torch.cuda.amp import autocast, GradScaler

# 指定 GPU 索引
GPU_INDEX = 0

# 检查 GPU 是否可用并设置设备
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please enable a GPU.")

device = torch.device(f"cuda:{GPU_INDEX}")
torch.cuda.set_device(GPU_INDEX)
torch.backends.cudnn.benchmark = True  # 新增优化点
print(f"Using device: {device}")

# 数据加载类（保持不变）
class ImageDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

# 加载图像路径函数（保持不变）
def load_image_paths_with_mapping(folder, plot_type, tx_label_mapping, 
                                 num_classes_to_select=None, 
                                 limit_per_class=1000, 
                                 use_all_images=False):
    file_paths = []
    labels = []
    class_count = defaultdict(int)

    all_tx_ids = list(tx_label_mapping.keys())
    if num_classes_to_select is not None:
        selected_tx_ids = random.sample(all_tx_ids, num_classes_to_select)
        print(f"Randomly selected {num_classes_to_select} classes: {selected_tx_ids}")
    else:
        selected_tx_ids = all_tx_ids

    new_label_mapping = {tx_id: idx for idx, tx_id in enumerate(selected_tx_ids)}

    for tx_id in selected_tx_ids:
        original_label = tx_label_mapping[tx_id]
        tx_folder = os.path.join(folder, tx_id)
        plot_folder = os.path.join(tx_folder, plot_type)
        
        if os.path.exists(plot_folder):
            tx_files = [f for f in os.listdir(plot_folder) if f.endswith('.png')]
            random.shuffle(tx_files)
            
            if not use_all_images and limit_per_class is not None:
                tx_files = tx_files[:limit_per_class]
            
            new_label = new_label_mapping[tx_id]
            
            for filename in tx_files:
                file_paths.append(os.path.join(plot_folder, filename))
                labels.append(new_label)

            class_count[new_label] += len(tx_files)

    print(f"Loaded {len(class_count)} classes for {plot_type}")
    for lbl, count in class_count.items():
        print(f"Class {lbl}: {count} images")

    return file_paths, labels, new_label_mapping

# 模型构建（保持不变）
def build_resnet_model(num_classes, freeze_layers=True):
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    if freeze_layers:
        for param in model.parameters():
            param.requires_grad = False
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# 训练和评估（包含优化修改）
def train_and_evaluate_model(model, train_loader, val_loader, test_loader, device, save_dir, epochs=20, lr=0.001, patience=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scaler = torch.amp.GradScaler()
    
    model = model.to(device)
    best_val_acc = 0
    best_epoch = 0
    train_losses, val_losses = [], []
    early_stop_counter = 0
    best_model = None

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct, total = 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            # 优化数据传输
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)  # 优化梯度清零
            
            # 混合精度训练
            with torch.amp.autocast(device_type='cuda'):
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = correct / total
        train_losses.append(train_loss / len(train_loader))

        # 验证循环
        model.eval()
        val_loss = 0
        correct, total = 0, 0

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = correct / total
        val_losses.append(val_loss / len(val_loader))

        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_acc:.4f}")

        # 早停机制
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict()
            best_epoch = epoch
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("Early stopping triggered!")
                break

    print(f"Best Validation Accuracy: {best_val_acc:.4f} at epoch {best_epoch+1}")
    model.load_state_dict(best_model)

    # 测试循环
    model.eval()
    correct, total = 0, 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_acc = correct / total
    print(f"Test Accuracy: {test_acc:.4f}")

    # 生成混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)
    cm_fig = plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", 
                xticklabels=np.unique(all_labels), 
                yticklabels=np.unique(all_labels))
    plt.xlabel('Predicted')
    plt.ylabel('True')
    cm_fig.savefig(os.path.join(save_dir, "confusion_matrix.png"))
    plt.close(cm_fig)
    
    return model, train_losses, val_losses, test_acc, cm

# 比较训练效果（包含参数调整）
def compare_scatter_trajectory_training(folder, 
                                      num_classes_to_select=None,
                                      limit_per_class=1000,
                                      use_all_images=False,
                                      batch_size=256,  # 修改后的batch_size
                                      epochs=20, 
                                      lr=0.001):
    random.seed(42)
    
    all_tx_ids = sorted([tx_id for tx_id in os.listdir(folder) 
                       if os.path.isdir(os.path.join(folder, tx_id))])
    
    print("\nLoading scatter plot images...")
    scatter_paths, scatter_labels, label_mapping = load_image_paths_with_mapping(
        folder, 
        'scatter_plots', 
        tx_label_mapping={tx: idx for idx, tx in enumerate(all_tx_ids)},
        num_classes_to_select=num_classes_to_select,
        limit_per_class=limit_per_class,
        use_all_images=use_all_images
    )
    
    print("\nLoading trajectory plot images...")
    trajectory_paths, trajectory_labels, _ = load_image_paths_with_mapping(
        folder, 
        'trajectory_plots', 
        tx_label_mapping={tx: idx for idx, tx in enumerate(all_tx_ids)},
        num_classes_to_select=num_classes_to_select,
        limit_per_class=limit_per_class,
        use_all_images=use_all_images
    )

    # 数据预处理（保持不变）
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 数据集划分和加载（保持不变）
    def prepare_datasets(paths, labels, transform):
        train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
            paths, labels, test_size=0.2, random_state=42, stratify=labels
        )
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            train_val_paths, train_val_labels, test_size=0.25, random_state=42, stratify=train_val_labels
        )
        return (
            ImageDataset(train_paths, train_labels, transform),
            ImageDataset(val_paths, val_labels, test_transform),
            ImageDataset(test_paths, test_labels, test_transform)
        )

    print("\nPreparing scatter plot datasets...")
    scatter_train, scatter_val, scatter_test = prepare_datasets(scatter_paths, scatter_labels, train_transform)
    print("Preparing trajectory plot datasets...")
    trajectory_train, trajectory_val, trajectory_test = prepare_datasets(trajectory_paths, trajectory_labels, train_transform)

    # 创建数据加载器（添加pin_memory）
    def create_loaders(train, val, test):
        train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True)
        val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, pin_memory=True)
        test_loader = DataLoader(test, batch_size=batch_size, shuffle=False, pin_memory=True)
        return train_loader, val_loader, test_loader

    scatter_train_loader, scatter_val_loader, scatter_test_loader = create_loaders(scatter_train, scatter_val, scatter_test)
    trajectory_train_loader, trajectory_val_loader, trajectory_test_loader = create_loaders(trajectory_train, trajectory_val, trajectory_test)

    # 创建保存目录（保持不变）
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    folder_name = f"{timestamp}-{len(label_mapping)}_classes"
    if num_classes_to_select: folder_name += f"-selected_{num_classes_to_select}"
    os.makedirs(folder_name, exist_ok=True)
    
    scatter_save_dir = os.path.join(folder_name, 'scatter')
    os.makedirs(scatter_save_dir, exist_ok=True)
    traj_save_dir = os.path.join(folder_name, 'trajectory')
    os.makedirs(traj_save_dir, exist_ok=True)

    # 训练模型（保持不变）
    print("\nTraining scatter plot model...")
    scatter_model = build_resnet_model(num_classes=len(label_mapping), freeze_layers=False)
    scatter_model, _, _, scatter_test_acc, _ = train_and_evaluate_model(
        scatter_model, scatter_train_loader, scatter_val_loader, scatter_test_loader, 
        device, scatter_save_dir, epochs, lr
    )
    torch.save(scatter_model.state_dict(), os.path.join(scatter_save_dir, 'model.pth'))

    print("\nTraining trajectory plot model...")
    trajectory_model = build_resnet_model(num_classes=len(label_mapping), freeze_layers=False)
    trajectory_model, _, _, traj_test_acc, _ = train_and_evaluate_model(
        trajectory_model, trajectory_train_loader, trajectory_val_loader, trajectory_test_loader, 
        device, traj_save_dir, epochs, lr
    )
    torch.save(trajectory_model.state_dict(), os.path.join(traj_save_dir, 'model.pth'))

    # 保存结果（保持不变）
    with open(os.path.join(folder_name, 'results.txt'), 'w') as f:
        f.write(f"Scatter Model Test Accuracy: {scatter_test_acc:.4f}\n")
        f.write(f"Trajectory Model Test Accuracy: {traj_test_acc:.4f}\n")
        f.write(f"Parameters:\n")
        f.write(f"num_classes_to_select: {num_classes_to_select}\n")
        f.write(f"limit_per_class: {limit_per_class}\n")
        f.write(f"use_all_images: {use_all_images}\n")

    print("\nTraining completed. Results saved in", folder_name)

# 使用示例（包含参数调整）
if __name__ == "__main__":
    folder = "../../IQ_signal_plots"

    # 基本使用（调整batch_size）
    compare_scatter_trajectory_training(
        folder,
        num_classes_to_select=150,
        limit_per_class=1000,
        use_all_images=False,
        batch_size=256
    )
    
    # 全量数据示例
    compare_scatter_trajectory_training(
        folder,
        num_classes_to_select=None,
        use_all_images=True,
        batch_size=256
    )

    # 限制数量示例
    compare_scatter_trajectory_training(
        folder,
        num_classes_to_select=None,
        limit_per_class=500,
        batch_size=256
    )

Using device: cuda:0

Loading scatter plot images...
Randomly selected 150 classes: ['13-14', '1-16', '19-10', '18-2', '18-13', '14-11', '12-20', '8-13', '11-4', '20-4', '1-19', '1-18', '11-7', '18-11', '18-15', '7-10', '9-20', '16-5', '20-3', '18-12', '3-18', '19-11', '1-10', '2-7', '20-16', '11-19', '2-16', '18-10', '15-1', '8-7', '11-17', '12-7', '6-1', '4-11', '10-11', '10-1', '16-19', '10-10', '15-6', '8-14', '15-19', '19-20', '13-7', '5-5', '1-15', '2-3', '18-14', '18-9', '10-4', '3-13', '1-8', '9-14', '14-13', '19-6', '19-4', '16-1', '19-13', '12-1', '2-17', '8-20', '20-19', '2-1', '13-18', '3-20', '20-1', '2-14', '3-1', '20-12', '4-10', '20-20', '2-4', '5-20', '16-16', '20-8', '8-3', '14-10', '1-2', '11-20', '20-15', '13-20', '19-3', '7-7', '19-8', '19-14', '9-7', '14-8', '8-1', '10-17', '17-11', '1-11', '17-10', '19-1', '12-19', '11-1', '1-14', '3-19', '14-12', '19-2', '2-6', '5-1', '18-4', '2-8', '8-8', '16-20', '2-15', '20-18', '10-7', '2-19', '20-14', '19-7', '18-7', '3-2',

Epoch 1/20 - Training:  61%|██████    | 214/352 [04:33<02:56,  1.28s/it]


KeyboardInterrupt: 