In [1]:
import os
import torch
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
from torchvision.models import resnet18, ResNet18_Weights

In [None]:
# 指定 GPU 索引（假设选择 GPU 0）
GPU_INDEX = 0

# 检查 GPU 是否可用并设置设备
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please enable a GPU.")

device = torch.device(f"cuda:{GPU_INDEX}")
torch.cuda.set_device(GPU_INDEX)
print(f"Using device: {device}")


Using device: cuda:0


In [3]:

# 1. 数据加载
class ImageDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

def load_image_paths(folder, plot_type='scatter_plots'):
    file_paths = []
    labels = []
    label = 0  # 为每个发射机分配标签

    for tx_id in os.listdir(folder):
        tx_folder = os.path.join(folder, tx_id)
        if os.path.isdir(tx_folder):
            plot_folder = os.path.join(tx_folder, plot_type)
            if os.path.exists(plot_folder):
                for filename in os.listdir(plot_folder):
                    if filename.endswith('.png'):
                        file_paths.append(os.path.join(plot_folder, filename))
                        labels.append(label)
            label += 1

    return file_paths, labels


In [4]:

# 2. 模型构建
def build_resnet_model(num_classes):
    model = resnet18(weights=ResNet18_Weights.DEFAULT)
    # 替换最后一层
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    return model


In [5]:
# 3. 训练和评估
def train_and_evaluate_model(model, train_loader, val_loader, test_loader, device, epochs=20, lr=0.001):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model = model.to(device)
    best_val_acc = 0
    train_losses, val_losses = [], []

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct, total = 0, 0

        # Training loop
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = correct / total
        train_losses.append(train_loss / len(train_loader))

        # Validation loop
        model.eval()
        val_loss = 0
        correct, total = 0, 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = correct / total
        val_losses.append(val_loss / len(val_loader))

        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict()

    print(f"Best Validation Accuracy: {best_val_acc:.4f}")
    model.load_state_dict(best_model)

    # Test loop
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_acc = correct / total
    print(f"Test Accuracy: {test_acc:.4f}")
    return model, train_losses, val_losses, test_acc

In [6]:
# 4. 比较训练效果
def compare_scatter_trajectory_training(folder, batch_size=32, epochs=20, lr=0.001):
    print("Loading scatter plot images...")
    scatter_paths, scatter_labels = load_image_paths(folder, plot_type='scatter_plots')
    print("Loading trajectory plot images...")
    trajectory_paths, trajectory_labels = load_image_paths(folder, plot_type='trajectory_plots')

    # 数据集分割
    scatter_train_paths, scatter_test_paths, scatter_train_labels, scatter_test_labels = train_test_split(
        scatter_paths, scatter_labels, test_size=0.2, random_state=42
    )
    trajectory_train_paths, trajectory_test_paths, trajectory_train_labels, trajectory_test_labels = train_test_split(
        trajectory_paths, trajectory_labels, test_size=0.2, random_state=42
    )

    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 构造数据集和数据加载器
    scatter_train_dataset = ImageDataset(scatter_train_paths, scatter_train_labels, transform)
    scatter_test_dataset = ImageDataset(scatter_test_paths, scatter_test_labels, transform)
    trajectory_train_dataset = ImageDataset(trajectory_train_paths, trajectory_train_labels, transform)
    trajectory_test_dataset = ImageDataset(trajectory_test_paths, trajectory_test_labels, transform)

    scatter_train_loader = DataLoader(scatter_train_dataset, batch_size=batch_size, shuffle=True)
    scatter_test_loader = DataLoader(scatter_test_dataset, batch_size=batch_size, shuffle=False)
    trajectory_train_loader = DataLoader(trajectory_train_dataset, batch_size=batch_size, shuffle=True)
    trajectory_test_loader = DataLoader(trajectory_test_dataset, batch_size=batch_size, shuffle=False)

    # 获取类别数量
    num_classes = len(set(scatter_labels))

    # 构建和训练散点图模型
    print("\nTraining scatter plot model...")
    scatter_model = build_resnet_model(num_classes)
    scatter_model, scatter_train_losses, scatter_val_losses, scatter_test_acc = train_and_evaluate_model(
        scatter_model, scatter_train_loader, scatter_test_loader, scatter_test_loader, device, epochs, lr
    )

    # 构建和训练轨迹图模型
    print("\nTraining trajectory plot model...")
    trajectory_model = build_resnet_model(num_classes)
    trajectory_model, trajectory_train_losses, trajectory_val_losses, trajectory_test_acc = train_and_evaluate_model(
        trajectory_model, trajectory_train_loader, trajectory_test_loader, trajectory_test_loader, device, epochs, lr
    )

    # 输出比较结果
    print("\nComparison of Test Accuracies:")
    print(f"Scatter plot model test accuracy: {scatter_test_acc:.4f}")
    print(f"Trajectory plot model test accuracy: {trajectory_test_acc:.4f}")

In [7]:

# 使用示例：比较散点图和轨迹图的训练效果
folder = "../../IQ_signal_plots"  # 图像存储的文件夹
compare_scatter_trajectory_training(folder, batch_size=32, epochs=20, lr=0.001)


Loading scatter plot images...
Loading trajectory plot images...

Training scatter plot model...


Epoch 1/20 - Training:   0%|          | 36/12729 [00:51<5:03:40,  1.44s/it]


KeyboardInterrupt: 