In [1]:
import os
import torch
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
from torchvision.models import resnet18, ResNet18_Weights

In [2]:

# 指定 GPU 索引（假设选择 GPU 0）
GPU_INDEX = 0

# 检查 GPU 是否可用并设置设备
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please enable a GPU.")

device = torch.device(f"cuda:{GPU_INDEX}")
torch.cuda.set_device(GPU_INDEX)
print(f"Using device: {device}")


Using device: cuda:0


In [3]:

# 1. 数据加载
class ImageDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

def load_image_paths(folder, plot_type='scatter_plots'):
    file_paths = []
    labels = []
    label = 0  # 为每个发射机分配标签

    for tx_id in os.listdir(folder):
        tx_folder = os.path.join(folder, tx_id)
        if os.path.isdir(tx_folder):
            plot_folder = os.path.join(tx_folder, plot_type)
            if os.path.exists(plot_folder):
                for filename in os.listdir(plot_folder):
                    if filename.endswith('.png'):
                        file_paths.append(os.path.join(plot_folder, filename))
                        labels.append(label)
            label += 1

    return file_paths, labels


In [4]:

# 2. 模型构建
def build_resnet_model(num_classes):
    model = resnet18(weights=ResNet18_Weights.DEFAULT)
    # 替换最后一层
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    return model


In [5]:
# 3. 训练和评估
def train_and_evaluate_model(model, train_loader, val_loader, test_loader, device, epochs=20, lr=0.001):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model = model.to(device)
    best_val_acc = 0
    train_losses, val_losses = [], []

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct, total = 0, 0

        # Training loop
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = correct / total
        train_losses.append(train_loss / len(train_loader))

        # Validation loop
        model.eval()
        val_loss = 0
        correct, total = 0, 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = correct / total
        val_losses.append(val_loss / len(val_loader))

        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict()

    print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    model.load_state_dict(best_model)

    # Test loop
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_acc = correct / total
    print(f"Test Accuracy: {test_acc:.4f}")
    return model, train_losses, val_losses, test_acc

In [6]:
# 4. 比较训练效果
def compare_scatter_trajectory_training(folder, batch_size=32, epochs=20, lr=0.001):
    print("Loading scatter plot images...")
    scatter_paths, scatter_labels = load_image_paths(folder, plot_type='scatter_plots')
    print("Loading trajectory plot images...")
    trajectory_paths, trajectory_labels = load_image_paths(folder, plot_type='trajectory_plots')

    # 数据集分割
    scatter_train_paths, scatter_test_paths, scatter_train_labels, scatter_test_labels = train_test_split(
        scatter_paths, scatter_labels, test_size=0.2, random_state=42
    )
    trajectory_train_paths, trajectory_test_paths, trajectory_train_labels, trajectory_test_labels = train_test_split(
        trajectory_paths, trajectory_labels, test_size=0.2, random_state=42
    )

    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 构造数据集和数据加载器
    scatter_train_dataset = ImageDataset(scatter_train_paths, scatter_train_labels, transform)
    scatter_test_dataset = ImageDataset(scatter_test_paths, scatter_test_labels, transform)
    trajectory_train_dataset = ImageDataset(trajectory_train_paths, trajectory_train_labels, transform)
    trajectory_test_dataset = ImageDataset(trajectory_test_paths, trajectory_test_labels, transform)

    scatter_train_loader = DataLoader(scatter_train_dataset, batch_size=batch_size, shuffle=True)
    scatter_test_loader = DataLoader(scatter_test_dataset, batch_size=batch_size, shuffle=False)
    trajectory_train_loader = DataLoader(trajectory_train_dataset, batch_size=batch_size, shuffle=True)
    trajectory_test_loader = DataLoader(trajectory_test_dataset, batch_size=batch_size, shuffle=False)

    # 获取类别数量
    num_classes = len(set(scatter_labels))

    # 构建和训练散点图模型
    print("\nTraining scatter plot model...")
    scatter_model = build_resnet_model(num_classes)
    scatter_model, scatter_train_losses, scatter_val_losses, scatter_test_acc = train_and_evaluate_model(
        scatter_model, scatter_train_loader, scatter_test_loader, scatter_test_loader, device, epochs, lr
    )

    # 构建和训练轨迹图模型
    print("\nTraining trajectory plot model...")
    trajectory_model = build_resnet_model(num_classes)
    trajectory_model, trajectory_train_losses, trajectory_val_losses, trajectory_test_acc = train_and_evaluate_model(
        trajectory_model, trajectory_train_loader, trajectory_test_loader, trajectory_test_loader, device, epochs, lr
    )

    # 输出比较结果
    print("\nComparison of Test Accuracies:")
    print(f"Scatter plot model test accuracy: {scatter_test_acc:.4f}")
    print(f"Trajectory plot model test accuracy: {trajectory_test_acc:.4f}")

In [7]:

# 使用示例：比较散点图和轨迹图的训练效果
folder = "../../IQ_signal_plots"  # 图像存储的文件夹
compare_scatter_trajectory_training(folder, batch_size=32, epochs=20, lr=0.001)


Loading scatter plot images...
Loading trajectory plot images...

Training scatter plot model...


Epoch 1/20 - Training: 100%|██████████| 12729/12729 [41:00<00:00,  5.17it/s]


Epoch [1/20], Train Loss: 4.1533, Train Acc: 0.0921, Val Loss: 3.6039, Val Acc: 0.1985


Epoch 2/20 - Training: 100%|██████████| 12729/12729 [40:38<00:00,  5.22it/s]


Epoch [2/20], Train Loss: 2.9626, Train Acc: 0.3401, Val Loss: 2.6540, Val Acc: 0.4116


Epoch 3/20 - Training: 100%|██████████| 12729/12729 [41:29<00:00,  5.11it/s]


Epoch [3/20], Train Loss: 2.3691, Train Acc: 0.4723, Val Loss: 2.4333, Val Acc: 0.4674


Epoch 4/20 - Training: 100%|██████████| 12729/12729 [41:32<00:00,  5.11it/s]


Epoch [4/20], Train Loss: 2.0904, Train Acc: 0.5323, Val Loss: 2.3210, Val Acc: 0.4973


Epoch 5/20 - Training: 100%|██████████| 12729/12729 [41:32<00:00,  5.11it/s]


Epoch [5/20], Train Loss: 1.9030, Train Acc: 0.5713, Val Loss: 2.2943, Val Acc: 0.5076


Epoch 6/20 - Training: 100%|██████████| 12729/12729 [41:41<00:00,  5.09it/s]


Epoch [6/20], Train Loss: 1.7548, Train Acc: 0.6020, Val Loss: 2.2691, Val Acc: 0.5156


Epoch 7/20 - Training: 100%|██████████| 12729/12729 [39:35<00:00,  5.36it/s]


Epoch [7/20], Train Loss: 1.6200, Train Acc: 0.6287, Val Loss: 2.4471, Val Acc: 0.5059


Epoch 8/20 - Training: 100%|██████████| 12729/12729 [39:20<00:00,  5.39it/s]


Epoch [8/20], Train Loss: 1.4836, Train Acc: 0.6551, Val Loss: 2.4820, Val Acc: 0.5139


Epoch 9/20 - Training: 100%|██████████| 12729/12729 [39:25<00:00,  5.38it/s]


Epoch [9/20], Train Loss: 1.3384, Train Acc: 0.6833, Val Loss: 2.5220, Val Acc: 0.5161


Epoch 10/20 - Training: 100%|██████████| 12729/12729 [39:26<00:00,  5.38it/s]


Epoch [10/20], Train Loss: 1.1748, Train Acc: 0.7153, Val Loss: 2.6974, Val Acc: 0.5157


Epoch 11/20 - Training: 100%|██████████| 12729/12729 [39:37<00:00,  5.35it/s]


Epoch [11/20], Train Loss: 1.0009, Train Acc: 0.7498, Val Loss: 2.8867, Val Acc: 0.5115


Epoch 12/20 - Training: 100%|██████████| 12729/12729 [39:36<00:00,  5.36it/s]


Epoch [12/20], Train Loss: 0.8243, Train Acc: 0.7874, Val Loss: 3.0417, Val Acc: 0.5114


Epoch 13/20 - Training: 100%|██████████| 12729/12729 [42:38<00:00,  4.98it/s]


Epoch [13/20], Train Loss: 0.6648, Train Acc: 0.8227, Val Loss: 3.3463, Val Acc: 0.5096


Epoch 14/20 - Training: 100%|██████████| 12729/12729 [48:05<00:00,  4.41it/s]


Epoch [14/20], Train Loss: 0.5327, Train Acc: 0.8537, Val Loss: 3.5667, Val Acc: 0.5060


Epoch 15/20 - Training: 100%|██████████| 12729/12729 [47:20<00:00,  4.48it/s]


Epoch [15/20], Train Loss: 0.4305, Train Acc: 0.8787, Val Loss: 3.8214, Val Acc: 0.5045


Epoch 16/20 - Training: 100%|██████████| 12729/12729 [46:51<00:00,  4.53it/s]


Epoch [16/20], Train Loss: 0.3548, Train Acc: 0.8980, Val Loss: 4.0173, Val Acc: 0.5062


Epoch 17/20 - Training: 100%|██████████| 12729/12729 [48:23<00:00,  4.38it/s]


Epoch [17/20], Train Loss: 0.3061, Train Acc: 0.9102, Val Loss: 4.1218, Val Acc: 0.5024


Epoch 18/20 - Training: 100%|██████████| 12729/12729 [48:33<00:00,  4.37it/s]


Epoch [18/20], Train Loss: 0.2674, Train Acc: 0.9208, Val Loss: 4.4063, Val Acc: 0.5029


Epoch 19/20 - Training: 100%|██████████| 12729/12729 [48:28<00:00,  4.38it/s]


Epoch [19/20], Train Loss: 0.2402, Train Acc: 0.9284, Val Loss: 4.4097, Val Acc: 0.5067


Epoch 20/20 - Training: 100%|██████████| 12729/12729 [47:35<00:00,  4.46it/s]


Epoch [20/20], Train Loss: 0.2187, Train Acc: 0.9346, Val Loss: 4.5219, Val Acc: 0.5021
Best Validation Accuracy: 0.5161
Test Accuracy: 0.5021

Training trajectory plot model...


Epoch 1/20 - Training: 100%|██████████| 12729/12729 [55:56<00:00,  3.79it/s]


Epoch [1/20], Train Loss: 4.2741, Train Acc: 0.0712, Val Loss: 3.7382, Val Acc: 0.1623


Epoch 2/20 - Training: 100%|██████████| 12729/12729 [53:06<00:00,  4.00it/s]


Epoch [2/20], Train Loss: 3.1171, Train Acc: 0.3033, Val Loss: 2.6942, Val Acc: 0.3988


Epoch 3/20 - Training: 100%|██████████| 12729/12729 [52:29<00:00,  4.04it/s]


Epoch [3/20], Train Loss: 2.4332, Train Acc: 0.4561, Val Loss: 2.4383, Val Acc: 0.4625


Epoch 4/20 - Training: 100%|██████████| 12729/12729 [1:08:44<00:00,  3.09it/s]


Epoch [4/20], Train Loss: 2.1347, Train Acc: 0.5208, Val Loss: 2.3048, Val Acc: 0.4919


Epoch 5/20 - Training: 100%|██████████| 12729/12729 [55:10<00:00,  3.84it/s]


Epoch [5/20], Train Loss: 1.9410, Train Acc: 0.5622, Val Loss: 2.2509, Val Acc: 0.5139


Epoch 6/20 - Training: 100%|██████████| 12729/12729 [53:20<00:00,  3.98it/s]


Epoch [6/20], Train Loss: 1.7951, Train Acc: 0.5918, Val Loss: 2.2151, Val Acc: 0.5270


Epoch 7/20 - Training: 100%|██████████| 12729/12729 [52:52<00:00,  4.01it/s]


Epoch [7/20], Train Loss: 1.6714, Train Acc: 0.6172, Val Loss: 2.1894, Val Acc: 0.5340


Epoch 8/20 - Training: 100%|██████████| 12729/12729 [50:35<00:00,  4.19it/s]


Epoch [8/20], Train Loss: 1.5504, Train Acc: 0.6408, Val Loss: 2.2690, Val Acc: 0.5344


Epoch 9/20 - Training: 100%|██████████| 12729/12729 [46:17<00:00,  4.58it/s]


Epoch [9/20], Train Loss: 1.4338, Train Acc: 0.6637, Val Loss: 2.2849, Val Acc: 0.5318


Epoch 10/20 - Training: 100%|██████████| 12729/12729 [46:17<00:00,  4.58it/s]


Epoch [10/20], Train Loss: 1.3080, Train Acc: 0.6885, Val Loss: 2.3562, Val Acc: 0.5386


Epoch 11/20 - Training: 100%|██████████| 12729/12729 [46:31<00:00,  4.56it/s]


Epoch [11/20], Train Loss: 1.1779, Train Acc: 0.7139, Val Loss: 2.5382, Val Acc: 0.5348


Epoch 12/20 - Training: 100%|██████████| 12729/12729 [46:30<00:00,  4.56it/s]


Epoch [12/20], Train Loss: 1.0415, Train Acc: 0.7428, Val Loss: 2.5659, Val Acc: 0.5374


Epoch 13/20 - Training: 100%|██████████| 12729/12729 [46:16<00:00,  4.58it/s]


Epoch [13/20], Train Loss: 0.9087, Train Acc: 0.7705, Val Loss: 2.7273, Val Acc: 0.5349


Epoch 14/20 - Training: 100%|██████████| 12729/12729 [46:16<00:00,  4.58it/s]


Epoch [14/20], Train Loss: 0.7771, Train Acc: 0.7992, Val Loss: 2.8701, Val Acc: 0.5327


Epoch 15/20 - Training: 100%|██████████| 12729/12729 [47:54<00:00,  4.43it/s]


Epoch [15/20], Train Loss: 0.6630, Train Acc: 0.8251, Val Loss: 3.0542, Val Acc: 0.5317


Epoch 16/20 - Training: 100%|██████████| 12729/12729 [53:45<00:00,  3.95it/s]


Epoch [16/20], Train Loss: 0.5614, Train Acc: 0.8487, Val Loss: 3.2146, Val Acc: 0.5277


Epoch 17/20 - Training: 100%|██████████| 12729/12729 [53:46<00:00,  3.94it/s]


Epoch [17/20], Train Loss: 0.4793, Train Acc: 0.8685, Val Loss: 3.5661, Val Acc: 0.5279


Epoch 18/20 - Training: 100%|██████████| 12729/12729 [53:38<00:00,  3.95it/s]


Epoch [18/20], Train Loss: 0.4109, Train Acc: 0.8853, Val Loss: 3.4984, Val Acc: 0.5312


Epoch 19/20 - Training: 100%|██████████| 12729/12729 [53:36<00:00,  3.96it/s]


Epoch [19/20], Train Loss: 0.3565, Train Acc: 0.8988, Val Loss: 3.7834, Val Acc: 0.5283


Epoch 20/20 - Training: 100%|██████████| 12729/12729 [53:34<00:00,  3.96it/s]


Epoch [20/20], Train Loss: 0.3170, Train Acc: 0.9088, Val Loss: 3.8658, Val Acc: 0.5289
Best Validation Accuracy: 0.5386
Test Accuracy: 0.5289

Comparison of Test Accuracies:
Scatter plot model test accuracy: 0.5021
Trajectory plot model test accuracy: 0.5289
