In [30]:
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = '42'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

# 第二个单元格
import random
import numpy as np
import torch

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)

In [31]:
from data_utilities import load_compact_pkl_dataset

dataset_path = "../../orbit_rf_dataset/data/compact_pkl_datasets/"
dataset_name = "ManyTx"

dataset = load_compact_pkl_dataset(dataset_path, dataset_name)
data = dataset['data']


In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split


In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [34]:
import pywt
from scipy import interpolate

def interpolate_coeff(coeff, target_length=256):
    current_length = len(coeff)
    
    if current_length == target_length:
        return coeff
    elif current_length > target_length:
        return coeff[:target_length]
    else:
        x_original = np.linspace(0, 1, current_length)
        x_target = np.linspace(0, 1, target_length)
        
        interp_func = interpolate.interp1d(x_original, coeff, kind='linear', 
                                         bounds_error=False, fill_value="extrapolate")
        
        return interp_func(x_target)

def data_preprocessing(channel_data):
    i_signal = channel_data[:, 0]  
    q_signal = channel_data[:, 1]  
    
    wavelet = 'db5'
    max_level = pywt.dwt_max_level(len(i_signal), pywt.Wavelet(wavelet).dec_len)
    level = min(4, max_level) 
    
    coeffs_i = pywt.wavedec(i_signal, wavelet, level=level, mode='per')
    coeffs_q = pywt.wavedec(q_signal, wavelet, level=level, mode='per')
    
    i_d1_adj = interpolate_coeff(coeffs_i[1], 128)
    i_d2_adj = interpolate_coeff(coeffs_i[2], 128)  
    i_d3_adj = interpolate_coeff(coeffs_i[3], 128)
    i_d4_adj = interpolate_coeff(coeffs_i[4], 128)
    
    q_d1_adj = interpolate_coeff(coeffs_q[1], 128)
    q_d2_adj = interpolate_coeff(coeffs_q[2], 128)
    q_d3_adj = interpolate_coeff(coeffs_q[3], 128)
    q_d4_adj = interpolate_coeff(coeffs_q[4], 128)
    
    d1_row = np.concatenate([i_d1_adj, q_d1_adj])  
    d2_row = np.concatenate([i_d2_adj, q_d2_adj])  
    d3_row = np.concatenate([i_d3_adj, q_d3_adj]) 
    d4_row = np.concatenate([i_d4_adj, q_d4_adj])  
    
    final_matrix = np.vstack([d1_row, d2_row, d3_row, d4_row])
    return final_matrix

In [35]:
import random as rd

data_in = []
data_out = []

data = [item for i, item in enumerate(data) if i >= 150 or item[0][0][1].shape[0] != 0]

#数据增强
for i in range(0, 60):
    if data[i][0][0][1].shape[0] != 0:
        for j in range(len(data[i][0][0][1])):
            primitive_data = data[i][0][0][1][j]
            data_in.append(data_preprocessing(primitive_data))
            data_out.append(i)
            for _ in range(3):
                primitive_data_temp = primitive_data
                for k in range(256):
                    primitive_data_temp[k][0] =  primitive_data[k][0] + rd.gauss(0, 0.1)
                    primitive_data_temp[k][1] =  primitive_data[k][1] + rd.gauss(0, 0.1)
                data_in.append(data_preprocessing(primitive_data_temp))
                data_out.append(i)

data_out_np = np.array(data_out)
data_in_np = np.array(data_in)

print(data_out_np.shape)
print(data_in_np.shape)

(11776,)
(11776, 4, 256)


In [36]:
class WiSigNet(nn.Module):
    def __init__(self, ntx_i, feature_dim=1024):
        super(WiSigNet, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=(3, 3), padding=(1, 1)), 
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),  
            nn.Conv2d(8, 16, kernel_size=(3, 3), padding=(1, 1)),  
            nn.ReLU(),
            nn.MaxPool2d((1, 2)), 
            nn.Conv2d(16, 32, kernel_size=(3, 3), padding=(1, 1)), 
            nn.ReLU(),
            nn.MaxPool2d((1, 2)), 
            nn.ReLU()
        )
        self._calculate_fc_input_dim()
        self.feature_transform = nn.Sequential(
            nn.Linear(self.fc_input_dim, feature_dim),
            nn.ReLU(),
        )
        self.prototype_layer = nn.Linear(feature_dim, ntx_i, bias=False)
        
        self._initialize_weights()
    
    def _calculate_fc_input_dim(self):
        with torch.no_grad():
            dummy_input = torch.zeros(1, 4, 256).unsqueeze(1)  # (1, 1, 4, 256)
            dummy_output = self.conv_layers(dummy_input)
            print("[Debug] Conv output shape:", dummy_output.shape)
            self.fc_input_dim = dummy_output.view(dummy_output.size(0), -1).size(1)
            print("[Debug] Calculated fc_input_dim:", self.fc_input_dim)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = x.unsqueeze(1) 
        x = self.conv_layers(x)  
        x = x.view(x.size(0), -1) 
        x = self.feature_transform(x)  
        x_norm = torch.nn.functional.normalize(x, p=2, dim=1) 
        w_norm = torch.nn.functional.normalize(self.prototype_layer.weight, p=2, dim=1) 
        cosine_similarity = torch.mm(x_norm, w_norm.t())  
        
        return cosine_similarity

    def extract_features(self, x):
        x = x.unsqueeze(1)
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.feature_transform(x)
        x_norm = torch.nn.functional.normalize(x, p=2, dim=1)
        return x_norm

In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LossPrototypes(nn.Module):
    def __init__(self, margin: int = 4, scale: float = 1.5):
        super(LossPrototypes, self).__init__()
        self.margin = margin
        self.scale = scale
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.threshold = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin

    def forward(self, cosine: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        cosine = torch.clamp(cosine, -1 + 1e-7, 1 - 1e-7)
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = torch.where(cosine > self.threshold, phi, cosine - self.mm)
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1.0)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.scale
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, labels)
        return loss

In [38]:
generator = torch.Generator()
generator.manual_seed(seed)

X = torch.from_numpy(data_in_np).float()  
y = torch.from_numpy(data_out_np).long()  
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.2, random_state=42
)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42
)

batch_size = 1024
train_loader = DataLoader(
    TensorDataset(X_train, y_train),
    batch_size=batch_size,
    generator=generator,
    shuffle=True
)
val_loader = DataLoader(
    TensorDataset(X_val, y_val),
    batch_size=batch_size,
    generator=generator
)
test_loader = DataLoader(
    TensorDataset(X_test, y_test),
    batch_size=batch_size,
    generator=generator
)

ntx_i = 60  
model = WiSigNet(ntx_i).to(device)

criterion = LossPrototypes()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)


[Debug] Conv output shape: torch.Size([1, 32, 4, 32])
[Debug] Calculated fc_input_dim: 4096


In [39]:
num_epochs = 100
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
            
    train_loss /= len(train_loader)    
    
    if train_loss < best_val_loss:
        best_val_loss = train_loss
        torch.save(model.state_dict(), "model/tx_test_model_1005/best_model.pth")
        print(f"epoch:{epoch + 1}/{num_epochs} | loss:{train_loss}")


  cosine_similarity = torch.mm(x_norm, w_norm.t())


epoch:1/100 | loss:0.41360048949718475
epoch:2/100 | loss:0.31220594346523284
epoch:3/100 | loss:0.27154742181301117


In [40]:
model.load_state_dict(torch.load("model/tx_test_model_1005/best_model.pth"))
model.eval()
model = model.to(device)
prototype_weights = model.prototype_layer.weight.data
w_norm = F.normalize(prototype_weights, p=2, dim=1)  
print("原型向量已成功提取并保存！")
print(f"原型向量维度: {w_norm.shape}") 


原型向量已成功提取并保存！
原型向量维度: torch.Size([60, 1024])


In [41]:
correct_predictions = 0
total_predictions = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model.extract_features(inputs)
        features = F.normalize(outputs, p=2, dim=1)
        cosine_similarity = torch.mm(features, w_norm.t())
        predicted = torch.argmax(cosine_similarity, dim=1)
        correct_predictions += (predicted == labels).sum().item()
        total_predictions += labels.size(0)
            
print(f"Accuracy: {correct_predictions / total_predictions * 100:.2f}% ")


Accuracy: 95.33% 


  cosine_similarity = torch.mm(features, w_norm.t())
