In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import pickle
from numpy.fft import fftshift,ifft
from scipy.signal import stft, windows
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
dataset_path = 'G:/MRM_0.5/'

train_data_path = dataset_path + 'train/'
test_data_path = dataset_path + 'test/'

In [219]:
def load_pkl(file_path):
    with open(file_path, 'rb') as file:
        return pickle.load(file)
    
def rcs(echo):
    return 4*np.pi*np.abs(echo)**2

def awgn(signal, snr):
    # Calculate signal power and convert SNR to linear scale
    signal_power = np.mean(np.abs(signal)**2)
    snr_linear = 10**(snr / 10)
    
    # Calculate noise power and generate complex noise
    noise_power = signal_power / snr_linear
    noise_real = np.random.normal(0, np.sqrt(noise_power / 2), signal.shape)
    noise_imag = np.random.normal(0, np.sqrt(noise_power / 2), signal.shape)
    noise = noise_real + 1j * noise_imag
    
    # Add noise to the signal
    signal_with_noise = signal + noise
    
    return signal_with_noise


def normalize(matrix):
    min_val = np.min(matrix)
    max_val = np.max(matrix)
    normalized_matrix = (matrix - min_val) / (max_val - min_val)
    return normalized_matrix   

def STFT(st,nfft):
    winlen = 64
    _, _, Zxx = stft(st, fs=1024, window = windows.hamming(winlen), nperseg=winlen, nfft=nfft ,noverlap=winlen-1, boundary='zeros', return_onesided=False)
    return np.fft.fftshift(Zxx,axes=0) 

def pad_hrrp(matrix, target_length):
    if target_length is None:
        return matrix
    else:
        rows, cols = matrix.shape 
        padded_matrix = np.zeros((target_length, cols),dtype=np.complex128)
        padded_matrix[:rows, :] = matrix
        return padded_matrix

def image_hrrp(hrrp, pad_size):
    hrrp = pad_hrrp(hrrp,pad_size)
    hrrp = fftshift(ifft(hrrp,axis = 0),axes=0)
    hrrp = np.log10(np.abs(hrrp))  
    hrrp = normalize(hrrp)
    return hrrp

# Function of Load data

In [220]:
# def data processing here.
def preprocessing(data, snr, pad_size):
    polar_type = 'HH'

    # extract rcs
    E = awgn(data['echo'][polar_type],snr=snr)
    rcs_data = normalize(rcs(E))

    # processing echo
    TFR = STFT(E,nfft = 128)
    TFR = normalize(np.abs(TFR)[:,:-1])

    # extract hrrp
    hrrp = awgn(data['hrrp'][polar_type], snr=snr)
    hrrp = image_hrrp(hrrp, pad_size = pad_size)[int(75*pad_size/201):int(125*pad_size/201),:]

    # unify data type
    rcs_data = torch.tensor(rcs_data, dtype = torch.float32)
    TFR = torch.tensor(TFR, dtype = torch.float32)
    hrrp = torch.tensor(hrrp, dtype = torch.float32)
       
    return (rcs_data, TFR, hrrp)
    # return rcs_data

class Dataset(Dataset):
    def __init__(self, dataset_dir, snr, pad_size):
        self.snr = snr
        self.pad_size = pad_size
        self.dataset_dir = dataset_dir
        self.instance_list = self.get_instance()
    
    def get_instance(self):
        instance_list = []
        for label in os.listdir(self.dataset_dir):
            label_dir = os.path.join(self.dataset_dir,label)
            label_list = glob.glob(label_dir+'/*.pkl')
            instance_list += label_list
        return instance_list

    def __len__(self):
        return len(self.instance_list)

    def __getitem__(self, idx):
        data = load_pkl(self.instance_list[idx])
        x = preprocessing(data, snr = self.snr, pad_size= self.pad_size)
        y = data['target_id']
        # return torch.tensor(x, dtype=torch.float32).to(device), torch.tensor(y, dtype=torch.long).to(device)
        return x, torch.tensor(y, dtype=torch.long)


In [221]:
snr = 20
pad_size = 201
train_dataset = Dataset(train_data_path, snr = snr, pad_size = pad_size)
test_dataset = Dataset(test_data_path, snr = snr, pad_size = pad_size)

In [222]:
# label_counts = {0: 0, 1: 0, 2: 0, 3: 0}
# required_samples = 5  # The required number of samples for each label

# data_samples = {0: [], 1: [], 2: [], 3: []}

# # Loop through the dataset
# for i in range(len(train_dataset)):
#     x, y = train_dataset.__getitem__(i)
#     label = int(y.item())  # Convert the tensor to an integer
    
#     if label_counts[label] < required_samples:
#         data_samples[label].append((x, y))
#         label_counts[label] += 1
    
#     # Break the loop if we have enough samples for each label
#     if all(count == required_samples for count in label_counts.values()):
#         break

In [223]:
# for label in data_samples:
#     for idx, (x, y) in enumerate(data_samples[label]):
#         # Plot the first data
#         fig, ax = plt.subplots(figsize=(6, 5))
#         ax.plot((x[0]))
#         plt.xlim([0,512])
#         ax.axis('off')  # Turn off axis
#         ax.text(0.02, 0.98, f'{label}', transform=ax.transAxes, fontsize=30, verticalalignment='top', horizontalalignment='left', fontname='Times New Roman', color='black')
#         plt.savefig(f'./data_show/label_{label}_rcs_{idx}.png', bbox_inches='tight', pad_inches=0)
#         plt.close(fig)

#         # Plot the second data
#         fig, ax = plt.subplots(figsize=(6, 5))
#         ax.pcolormesh(np.linspace(0, 0.5, 512), np.linspace(0, 128, 128), x[1])
#         ax.axis('off')  # Turn off axis
#         ax.text(0.02, 0.98, f'{label}', transform=ax.transAxes, fontsize=30, verticalalignment='top', horizontalalignment='left', fontname='Times New Roman', color='white')
#         plt.savefig(f'./data_show/label_{label}_tfd_{idx}.png', bbox_inches='tight', pad_inches=0)
#         plt.close(fig)

#         # Plot the third data
#         hrrp = x[2]
#         fig, ax = plt.subplots(figsize=(6, 5))
#         ax.pcolormesh(np.linspace(0, 0.5, 512), np.linspace(0, hrrp.shape[0], hrrp.shape[0]), hrrp, cmap='twilight')
#         ax.axis('off')  # Turn off axis
#         ax.text(0.02, 0.98, f'{label}', transform=ax.transAxes, fontsize=30, verticalalignment='top', horizontalalignment='left', fontname='Times New Roman', color='white')
#         plt.savefig(f'./data_show/label_{label}_hrrp_{idx}.png', bbox_inches='tight', pad_inches=0)
#         plt.close(fig)

In [224]:
# from PIL import Image
# image_dir = './data_show'
# # dtype = 'rcs'
# # dtype = 'hrrp'
# dtype = 'tfd'

# def create_grid(image_dir , dtype ):
#     image_filenames = glob.glob(image_dir + '/*.png')

#     filtered_filenames = [filename for filename in image_filenames if dtype in filename]

#     images = [Image.open(filename) for filename in filtered_filenames]
    
#     # Assuming all images are the same size
#     width, height = images[0].size

#     grid_width = 5
#     grid_height = 4

#     width_padding = 25
#     height_padding = 10

#     total_width = width * grid_width + width_padding * (grid_width - 1)
#     total_height = height * grid_height + height_padding * (grid_height - 1)

#     grid_image = Image.new('RGB', (total_width, total_height), color=(255, 255, 255))  # White background
    
#     for index, image in enumerate(images):
#         x = (index % grid_width) * (width + width_padding)
#         y = (index // grid_width) * (height + height_padding)
#         grid_image.paste(image, (x, y))
    
#     grid_image.save(f'./{dtype}.png')
    
#     grid_image.show()

# create_grid(image_dir, dtype)

# Fusion classification

In [225]:
class ClassificationModel(nn.Module):
    def __init__(self):
        super(ClassificationModel, self).__init__()
        
        # 1*512 sequence branch
        self.seq_conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.seq_conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.seq_fc = nn.Linear(32 * 512, 128)

        # 128*512 matrix branch
        self.mat128_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1)
        self.mat128_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.mat128_fc = nn.Linear(32 * 128 * 512, 128)
        
        # 50*512 matrix branch
        self.mat50_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), padding=1)
        self.mat50_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=1)
        self.mat50_fc = nn.Linear(32 * 50 * 512, 128)

        # Fully connected layer
        # self.fc1 = nn.Linear(128 * 3, 64)
        # self.fc1 = nn.Linear(128, 64)
        self.fc1 = nn.Linear(128*2, 64)
        self.fc2 = nn.Linear(64, 4)  # 4 classes

    def forward(self, seq, mat128, mat50):
        # Sequence branch
        # x1 = F.relu(self.seq_conv1(seq))
        # x1 = F.relu(self.seq_conv2(x1))
        # x1 = x1.view(x1.size(0), -1)  # flatten
        # x1 = F.relu(self.seq_fc(x1))
        
        # # 128*512 matrix branch
        x2 = F.relu(self.mat128_conv1(mat128))
        x2 = F.relu(self.mat128_conv2(x2))
        x2 = x2.view(x2.size(0), -1)  # flatten
        x2 = F.relu(self.mat128_fc(x2))
        
        # # # 50*512 matrix branch
        x3 = F.relu(self.mat50_conv1(mat50))
        x3 = F.relu(self.mat50_conv2(x3))
        x3 = x3.view(x3.size(0), -1)  # flatten
        x3 = F.relu(self.mat50_fc(x3))
        
        # # # Concatenate features from all branches
        # x = torch.cat((x1, x2, x3), dim=1)
        x = torch.cat((x2, x3), dim=1)
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

# Example usage
model = ClassificationModel().to(device)

# # 参数设置
num_classes  = 4  # 类别数
num_epochs = 20
learning_rate = 0.001

# # 初始化模型、损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [226]:
def transfer_device(x):
    # Move inputs and labels to the device if they are not already
    if x.device != device:
        x = x.to(device)
    return x

def tensor_process(x, y):
    x = [transfer_device(i).unsqueeze(1) for i in x]
    return x, transfer_device(y) 

for x, y in train_loader:
    x, y = tensor_process(x, y)
    print(x[0].shape, x[1].shape, x[2].shape)
    y = model(x[0],x[1],x[2])
    print(y.shape)
    break

torch.Size([32, 1, 512]) torch.Size([32, 1, 128, 512]) torch.Size([32, 1, 50, 512])
torch.Size([32, 4])


In [227]:
def train(dataloader, loss_fn, optimizer):
    total_acc, total_count, total_loss, = 0, 0, 0
    model.train()
    for x, y in dataloader:
        x, label = tensor_process(x, y)
        predicted_label = model(x[0], x[1], x[2])
        loss = loss_fn(predicted_label, label)
        # Back-propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            total_acc += (predicted_label.argmax(1) == label).sum().item()  # predict true
            total_count += label.size(0)
            total_loss += loss.item()*label.size(0)
    return total_loss/total_count, total_acc/total_count

def test(dataloader, loss_fn):
    model.eval()
    total_acc, total_count, total_loss, = 0, 0, 0

    with torch.no_grad():
        for x, y in dataloader:
            x, label = tensor_process(x, y)
            predicted_label = model(x[0], x[1], x[2])
            loss = loss_fn(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
            total_loss += loss.item()*label.size(0)
    return total_loss/total_count, total_acc/total_count

def fit(epochs, train_dl, test_dl, loss_fn, optimizer):
    train_loss = [] 
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(epochs):
        
        epoch_loss, epoch_acc = train(train_dl, loss_fn, optimizer)
        epoch_test_loss, epoch_test_acc = test(test_dl, loss_fn)

        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc)
        test_loss.append(epoch_test_loss)
        test_acc.append(epoch_test_acc)

        tqdm.write(f'epoch:{epoch}, train_loss:{epoch_loss}, train_acc:{epoch_acc*100}%, test_loss:{epoch_test_loss}, test_acc:{epoch_test_acc*100}%.')
    
    return train_loss, test_loss, train_acc, test_acc

In [228]:
EPOCHS = 20
train_loss, test_loss, train_acc, test_acc = fit(EPOCHS, train_loader, test_loader, criterion, optimizer)

'''20db 18min VH
epoch:0, train_loss:1.1241819155216217, train_acc:66.875%, test_loss:0.4514965268969536, test_acc:82.75%.
epoch:1, train_loss:0.27928072440624235, train_acc:89.575%, test_loss:0.29026298701763154, test_acc:90.0%.
epoch:2, train_loss:0.11156235760450363, train_acc:96.35000000000001%, test_loss:0.207748813778162, test_acc:92.125%.
epoch:3, train_loss:0.03774304964579642, train_acc:98.725%, test_loss:0.18507124507799744, test_acc:94.375%.
epoch:4, train_loss:0.012540637358324603, train_acc:99.675%, test_loss:0.19223011233611031, test_acc:95.8125%.
epoch:5, train_loss:0.0033942202914040535, train_acc:99.95%, test_loss:0.18071796733536757, test_acc:95.9375%.
epoch:6, train_loss:0.0006282970081665553, train_acc:100.0%, test_loss:0.20041460228152574, test_acc:95.5%.
epoch:7, train_loss:0.00038863570083776724, train_acc:100.0%, test_loss:0.19509510923642664, test_acc:96.0%.
epoch:8, train_loss:0.00022862117252952884, train_acc:100.0%, test_loss:0.20998952157795428, test_acc:95.8125%.
'''

'''0db HH
epoch:0, train_loss:1.4693208298683167, train_acc:48.175000000000004%, test_loss:0.7185463911294937, test_acc:65.9375%.
epoch:1, train_loss:0.5613903698921203, train_acc:72.875%, test_loss:0.46897474765777586, test_acc:77.1875%.
epoch:2, train_loss:0.45069819498062136, train_acc:77.55%, test_loss:0.5257919955253602, test_acc:75.0625%.
epoch:3, train_loss:0.4187243970632553, train_acc:80.5%, test_loss:0.49667388617992403, test_acc:76.0%.
epoch:4, train_loss:0.38509046494960786, train_acc:82.5%, test_loss:0.4109884175658226, test_acc:79.5625%.
epoch:5, train_loss:0.3130657175779343, train_acc:86.575%, test_loss:0.43467467069625854, test_acc:79.9375%.
epoch:6, train_loss:0.3210661892294884, train_acc:85.7%, test_loss:0.4288026034832001, test_acc:81.0625%.
epoch:7, train_loss:0.25396066200733186, train_acc:88.725%, test_loss:0.3927362006902695, test_acc:82.125%.
epoch:8, train_loss:0.2111686646938324, train_acc:91.175%, test_loss:0.411899973899126, test_acc:82.1875%.
epoch:9, train_loss:0.19696999317407607, train_acc:92.0%, test_loss:0.4614270284771919, test_acc:79.0%.
epoch:10, train_loss:0.20353862634301187, train_acc:91.925%, test_loss:0.3993707850575447, test_acc:83.0625%.
epoch:11, train_loss:0.15132588325440885, train_acc:93.72500000000001%, test_loss:0.4022131218016148, test_acc:83.9375%.
epoch:12, train_loss:0.13481597982347013, train_acc:94.75%, test_loss:0.4388500380516052, test_acc:83.75%.
epoch:13, train_loss:0.11078749952465296, train_acc:96.15%, test_loss:0.5217597913742066, test_acc:82.5625%.
epoch:14, train_loss:0.12843525234609843, train_acc:94.39999999999999%, test_loss:0.4546648742258549, test_acc:83.8125%.
epoch:15, train_loss:0.11599310493469238, train_acc:95.675%, test_loss:0.4817324535548687, test_acc:84.125%.
epoch:16, train_loss:0.10618336752802134, train_acc:95.8%, test_loss:0.5615658095479011, test_acc:82.1875%.
epoch:17, train_loss:0.09937145296856761, train_acc:96.42500000000001%, test_loss:0.4900254736840725, test_acc:84.5625%.
epoch:18, train_loss:0.090679729051888, train_acc:96.625%, test_loss:0.49746148362755777, test_acc:83.5625%.
epoch:19, train_loss:0.07788088616728783, train_acc:97.075%, test_loss:0.5574340659379959, test_acc:84.5625%.
'''

'''20db HH all
epoch:0, train_loss:1.1258682916164398, train_acc:63.2%, test_loss:0.5551396346092224, test_acc:74.9375%.
epoch:1, train_loss:0.4603603246212006, train_acc:80.375%, test_loss:0.4714896693825722, test_acc:80.8125%.
epoch:2, train_loss:0.32917592775821686, train_acc:87.075%, test_loss:0.41207170218229294, test_acc:83.625%.
epoch:3, train_loss:0.23173720529675484, train_acc:90.925%, test_loss:0.45689154595136644, test_acc:83.0625%.
epoch:4, train_loss:0.1456785126030445, train_acc:94.45%, test_loss:0.4064051102101803, test_acc:86.375%.
epoch:5, train_loss:0.12099535960704089, train_acc:95.325%, test_loss:0.38644785068929194, test_acc:87.5625%.
epoch:6, train_loss:0.06384394974075258, train_acc:97.52499999999999%, test_loss:0.376604887843132, test_acc:87.4375%.
epoch:7, train_loss:0.06208055400848388, train_acc:97.725%, test_loss:0.41846803203225136, test_acc:87.625%.
epoch:8, train_loss:0.040952791482210156, train_acc:98.52499999999999%, test_loss:0.4618775695934892, test_acc:87.5625%.
epoch:9, train_loss:0.04372372142318636, train_acc:98.425%, test_loss:0.5051933059096336, test_acc:87.4375%.
epoch:10, train_loss:0.03147885977383703, train_acc:99.0%, test_loss:0.5406148046255111, test_acc:88.6875%.
epoch:11, train_loss:0.009516887980862521, train_acc:99.75%, test_loss:0.5646771858260036, test_acc:88.5%.
epoch:12, train_loss:0.009096605295082555, train_acc:99.675%, test_loss:0.542671508193016, test_acc:89.0%.
epoch:13, train_loss:0.005566739110508934, train_acc:99.85000000000001%, test_loss:0.6695612443797291, test_acc:87.3125%.
epoch:14, train_loss:0.007593071111070458, train_acc:99.725%, test_loss:0.6183080086112023, test_acc:88.3125%.
epoch:15, train_loss:0.015835908899374772, train_acc:99.5%, test_loss:0.8538768084347248, test_acc:85.3125%.
epoch:16, train_loss:0.03461002200655639, train_acc:98.65%, test_loss:0.7680001732707024, test_acc:86.5%.
epoch:17, train_loss:0.04825304436543956, train_acc:98.02499999999999%, test_loss:0.6187161194719374, test_acc:86.75%.
epoch:18, train_loss:0.020370634165243245, train_acc:99.225%, test_loss:0.8981689929962158, test_acc:84.375%.
epoch:19, train_loss:0.01420170836581383, train_acc:99.575%, test_loss:0.6329089766740799, test_acc:88.6875%.
'''

'''15db HH all
cuda
torch.Size([32, 1, 512]) torch.Size([32, 1, 128, 512]) torch.Size([32, 1, 50, 512])
torch.Size([32, 4])
epoch:0, train_loss:1.3206724812984467, train_acc:61.95%, test_loss:0.5618070149421692, test_acc:73.625%.
epoch:1, train_loss:0.5362536206245422, train_acc:76.25%, test_loss:0.5478888332843781, test_acc:76.0625%.
epoch:2, train_loss:0.4398058173656464, train_acc:81.27499999999999%, test_loss:0.5605051529407501, test_acc:77.0%.
epoch:3, train_loss:0.3839760557413101, train_acc:84.05%, test_loss:0.5285174456238747, test_acc:79.6875%.
epoch:4, train_loss:0.3007578548192978, train_acc:88.4%, test_loss:0.37490991175174715, test_acc:85.5%.
epoch:5, train_loss:0.2433449894785881, train_acc:90.225%, test_loss:0.3933316830545664, test_acc:84.6875%.
epoch:6, train_loss:0.19527279043197632, train_acc:92.15%, test_loss:0.407311934530735, test_acc:84.6875%.
epoch:7, train_loss:0.1667834893167019, train_acc:93.675%, test_loss:0.4132919411361218, test_acc:85.5%.
epoch:8, train_loss:0.12798140347003936, train_acc:94.975%, test_loss:0.39712046533823014, test_acc:86.375%.
epoch:9, train_loss:0.0980242567807436, train_acc:96.42500000000001%, test_loss:0.4124077707529068, test_acc:86.375%.
epoch:10, train_loss:0.08728849812969565, train_acc:96.55%, test_loss:0.4564265362918377, test_acc:86.125%.
epoch:11, train_loss:0.07862618769705296, train_acc:97.225%, test_loss:0.5314003229141235, test_acc:84.6875%.
epoch:12, train_loss:0.05383955233171582, train_acc:98.2%, test_loss:0.559694340005517, test_acc:86.375%.
epoch:13, train_loss:0.05870945263281464, train_acc:97.89999999999999%, test_loss:0.5384301985055209, test_acc:86.4375%.
epoch:14, train_loss:0.046720922825858, train_acc:98.2%, test_loss:0.5689457377791405, test_acc:85.0%.
epoch:15, train_loss:0.045923063683323564, train_acc:98.275%, test_loss:0.5268210481107235, test_acc:86.9375%.
epoch:16, train_loss:0.050265789957717064, train_acc:98.375%, test_loss:0.5201009671390057, test_acc:87.5%.
epoch:17, train_loss:0.026339752141619103, train_acc:99.1%, test_loss:0.6147561480104923, test_acc:86.5625%.
epoch:18, train_loss:0.023187558244448157, train_acc:99.225%, test_loss:0.6425055219978094, test_acc:87.1875%.
epoch:19, train_loss:0.026020501350518317, train_acc:99.05000000000001%, test_loss:0.7581704390048981, test_acc:86.625%.
'''

'''10db all
torch.Size([32, 1, 512]) torch.Size([32, 1, 128, 512]) torch.Size([32, 1, 50, 512])
torch.Size([32, 4])
epoch:0, train_loss:1.4148925111293793, train_acc:51.6%, test_loss:0.8022310763597489, test_acc:63.0625%.
epoch:1, train_loss:0.5780453689098358, train_acc:72.95%, test_loss:0.69844462454319, test_acc:69.625%.
epoch:2, train_loss:0.45838949835300447, train_acc:80.0%, test_loss:0.5431966164708137, test_acc:77.5625%.
epoch:3, train_loss:0.3830210522413254, train_acc:83.7%, test_loss:0.5610465764999389, test_acc:78.75%.
epoch:4, train_loss:0.32131205332279206, train_acc:85.725%, test_loss:0.5502186152338981, test_acc:80.125%.
epoch:5, train_loss:0.2631221823692322, train_acc:88.8%, test_loss:0.5117268845438957, test_acc:80.9375%.
epoch:6, train_loss:0.23543504548072816, train_acc:90.10000000000001%, test_loss:0.4968974223732948, test_acc:83.125%.
epoch:7, train_loss:0.23070990151166915, train_acc:90.8%, test_loss:0.5798195576667786, test_acc:81.9375%.
epoch:8, train_loss:0.17310481360554694, train_acc:93.30000000000001%, test_loss:0.5808150079846383, test_acc:83.0%.
epoch:9, train_loss:0.1688699631989002, train_acc:92.9%, test_loss:0.5787392359972, test_acc:82.0%.
epoch:10, train_loss:0.14581052295863628, train_acc:94.075%, test_loss:0.569445055872202, test_acc:83.5%.
epoch:11, train_loss:0.13251583954691887, train_acc:94.875%, test_loss:0.6179102553427219, test_acc:83.5625%.
epoch:12, train_loss:0.11244852437824011, train_acc:95.375%, test_loss:0.6144866946339608, test_acc:84.5%.
epoch:13, train_loss:0.11477444009482861, train_acc:95.575%, test_loss:0.6188791035115719, test_acc:84.0625%.
epoch:14, train_loss:0.10375585108995437, train_acc:95.875%, test_loss:0.7032929596304893, test_acc:82.75%.
epoch:15, train_loss:0.08183554334565997, train_acc:96.8%, test_loss:0.5967998158931732, test_acc:85.5625%.
epoch:16, train_loss:0.0775267562493682, train_acc:97.225%, test_loss:0.7207096932828426, test_acc:83.8125%.
epoch:17, train_loss:0.0847238945774734, train_acc:96.625%, test_loss:0.7729757909476757, test_acc:83.5%.
epoch:18, train_loss:0.07332225086167454, train_acc:97.1%, test_loss:0.7699450640380383, test_acc:83.6875%.
epoch:19, train_loss:0.06933896141126752, train_acc:97.6%, test_loss:0.7237076684832573, test_acc:85.0625%.
'''

'''0db rcs
epoch:0, train_loss:1.3787201023101807, train_acc:29.099999999999998%, test_loss:1.3593519592285157, test_acc:39.3125%.
epoch:1, train_loss:1.3305347728729249, train_acc:34.150000000000006%, test_loss:1.2982155251502991, test_acc:39.3125%.
epoch:2, train_loss:1.3061100645065307, train_acc:36.475%, test_loss:1.4111748695373536, test_acc:31.4375%.
epoch:3, train_loss:1.2932102060317994, train_acc:37.0%, test_loss:1.3230850720405578, test_acc:36.875%.
epoch:4, train_loss:1.270222324371338, train_acc:39.175%, test_loss:1.3090008640289306, test_acc:34.8125%.
epoch:5, train_loss:1.2549326467514037, train_acc:41.65%, test_loss:1.277556655406952, test_acc:38.625%.
epoch:6, train_loss:1.2126625690460204, train_acc:45.225%, test_loss:1.3562083625793457, test_acc:30.375000000000004%.
epoch:7, train_loss:1.184486002922058, train_acc:46.85%, test_loss:1.2285668992996215, test_acc:43.375%.
epoch:8, train_loss:1.1309524159431457, train_acc:49.975%, test_loss:1.261878900527954, test_acc:39.0%.
epoch:9, train_loss:1.093246265411377, train_acc:52.1%, test_loss:1.2244709944725036, test_acc:43.25%.
epoch:10, train_loss:1.0641372008323668, train_acc:53.900000000000006%, test_loss:1.3006906080245972, test_acc:41.4375%.
epoch:11, train_loss:1.0406457018852233, train_acc:55.2%, test_loss:1.235673166513443, test_acc:43.125%.
epoch:12, train_loss:1.0150558581352234, train_acc:56.3%, test_loss:1.2720732545852662, test_acc:40.9375%.
epoch:13, train_loss:0.9747609429359436, train_acc:58.8%, test_loss:1.2848359417915345, test_acc:41.0%.
epoch:14, train_loss:0.9556393847465515, train_acc:59.52499999999999%, test_loss:1.2380124855041503, test_acc:44.6875%.
epoch:15, train_loss:0.9457359724044799, train_acc:60.099999999999994%, test_loss:1.2413744628429413, test_acc:44.875%.
epoch:16, train_loss:0.9122983732223511, train_acc:61.975%, test_loss:1.2864688432216644, test_acc:43.125%.
epoch:17, train_loss:0.9044516272544861, train_acc:62.0%, test_loss:1.3764006543159484, test_acc:41.6875%.
epoch:18, train_loss:0.8791815657615661, train_acc:63.449999999999996%, test_loss:1.2690665447711944, test_acc:46.4375%.
epoch:19, train_loss:0.8737780175209046, train_acc:63.6%, test_loss:1.3262947702407837, test_acc:43.6875%.
'''

'''0db tfd
epoch:0, train_loss:1.7372279663085937, train_acc:47.825%, test_loss:0.8199108672142029, test_acc:62.0625%.
epoch:1, train_loss:0.7152552335262299, train_acc:66.3%, test_loss:0.802225536108017, test_acc:63.625%.
epoch:2, train_loss:0.6179312219619751, train_acc:72.125%, test_loss:0.7702690470218658, test_acc:66.875%.
epoch:3, train_loss:0.5333746482133865, train_acc:76.125%, test_loss:0.7625479227304459, test_acc:68.625%.
epoch:4, train_loss:0.45745155882835387, train_acc:80.72500000000001%, test_loss:0.6510397821664811, test_acc:73.375%.
epoch:5, train_loss:0.38958149874210357, train_acc:83.525%, test_loss:0.7018607223033905, test_acc:73.9375%.
epoch:6, train_loss:0.33758846414089205, train_acc:85.8%, test_loss:0.7632811850309372, test_acc:73.5%.
epoch:7, train_loss:0.30875521671772005, train_acc:86.85000000000001%, test_loss:0.6848002249002456, test_acc:75.25%.
epoch:8, train_loss:0.27230318105220797, train_acc:88.9%, test_loss:0.7871409010887146, test_acc:72.75%.
epoch:9, train_loss:0.25484457659721377, train_acc:89.55%, test_loss:0.8207544338703155, test_acc:74.0625%.
epoch:10, train_loss:0.24425067386031152, train_acc:89.5%, test_loss:0.7225407785177231, test_acc:74.9375%.
epoch:11, train_loss:0.2187012955546379, train_acc:90.925%, test_loss:0.9277432534098625, test_acc:74.25%.
epoch:12, train_loss:0.20499988871812821, train_acc:91.025%, test_loss:0.8208165013790131, test_acc:76.1875%.
epoch:13, train_loss:0.20599607345461846, train_acc:91.625%, test_loss:0.957995417714119, test_acc:74.0625%.
epoch:14, train_loss:0.18670487159490584, train_acc:92.30000000000001%, test_loss:1.0334151840209962, test_acc:74.0%.
epoch:15, train_loss:0.17647653076052666, train_acc:92.425%, test_loss:0.9535931795835495, test_acc:76.4375%.
epoch:16, train_loss:0.1616349500864744, train_acc:92.675%, test_loss:1.038391211181879, test_acc:75.9375%.
epoch:17, train_loss:0.16125226566195489, train_acc:93.35%, test_loss:1.027125672698021, test_acc:75.0625%.
epoch:18, train_loss:0.16203805954754352, train_acc:93.15%, test_loss:0.9525357609987259, test_acc:76.4375%.
epoch:19, train_loss:0.1571488407701254, train_acc:93.72500000000001%, test_loss:0.983634644150734, test_acc:76.25%.
'''


'''0db hrrp
epoch:0, train_loss:1.5446708908081055, train_acc:25.650000000000002%, test_loss:1.3866466784477234, test_acc:25.0%.
epoch:1, train_loss:1.3768860511779786, train_acc:27.425%, test_loss:1.3103150272369384, test_acc:37.1875%.
epoch:2, train_loss:1.2623022966384887, train_acc:35.075%, test_loss:1.1668832516670227, test_acc:42.4375%.
epoch:3, train_loss:1.0735173449516295, train_acc:46.9%, test_loss:0.8588788545131684, test_acc:56.06250000000001%.
epoch:4, train_loss:0.7377133669853211, train_acc:62.824999999999996%, test_loss:0.6864599823951721, test_acc:60.3125%.
epoch:5, train_loss:0.5942367622852326, train_acc:70.275%, test_loss:0.5693736511468888, test_acc:71.0625%.
epoch:6, train_loss:0.5225624542236328, train_acc:73.475%, test_loss:0.5195651578903199, test_acc:72.5625%.
epoch:7, train_loss:0.46629568076133726, train_acc:77.0%, test_loss:0.4733832567930222, test_acc:73.6875%.
epoch:8, train_loss:0.4482596480846405, train_acc:78.10000000000001%, test_loss:0.4632673770189285, test_acc:76.0625%.
epoch:9, train_loss:0.4007415096759796, train_acc:80.425%, test_loss:0.46763298511505125, test_acc:75.25%.
epoch:10, train_loss:0.3718769130706787, train_acc:82.15%, test_loss:0.4571414965391159, test_acc:76.5625%.
epoch:11, train_loss:0.3398978446722031, train_acc:83.39999999999999%, test_loss:0.41893003046512606, test_acc:76.6875%.
epoch:12, train_loss:0.3424643580913544, train_acc:82.8%, test_loss:0.39015767753124236, test_acc:79.125%.
epoch:13, train_loss:0.2967569440603256, train_acc:85.55%, test_loss:0.38903634786605834, test_acc:78.6875%.
epoch:14, train_loss:0.28106075710058215, train_acc:85.925%, test_loss:0.41802974313497543, test_acc:76.75%.
epoch:15, train_loss:0.28169140839576723, train_acc:86.725%, test_loss:0.39392183035612105, test_acc:79.125%.
epoch:16, train_loss:0.2590124389529228, train_acc:87.47500000000001%, test_loss:0.4415725016593933, test_acc:76.4375%.
epoch:17, train_loss:0.24902698147296906, train_acc:88.02499999999999%, test_loss:0.4105693963170052, test_acc:77.1875%.
epoch:18, train_loss:0.2806972295641899, train_acc:86.375%, test_loss:0.3994844141602516, test_acc:77.0625%.
epoch:19, train_loss:0.2607276937365532, train_acc:87.75%, test_loss:0.4389170655608177, test_acc:76.625%.
'''


'''5db all 
cuda
torch.Size([32, 1, 512]) torch.Size([32, 1, 128, 512]) torch.Size([32, 1, 50, 512])
torch.Size([32, 4])
epoch:0, train_loss:1.1066276550292968, train_acc:57.05%, test_loss:0.6918746972084046, test_acc:68.375%.
epoch:1, train_loss:0.5834570685625077, train_acc:73.175%, test_loss:0.6275622481107712, test_acc:71.5%.
epoch:2, train_loss:0.44816092586517337, train_acc:80.575%, test_loss:0.6425983744859696, test_acc:72.625%.
epoch:3, train_loss:0.3599015097618103, train_acc:84.975%, test_loss:0.5497231543064117, test_acc:76.875%.
epoch:4, train_loss:0.26929483646154406, train_acc:88.3%, test_loss:0.5525931656360626, test_acc:78.4375%.
epoch:5, train_loss:0.24232800471782684, train_acc:89.275%, test_loss:0.6000458815693855, test_acc:79.6875%.
epoch:6, train_loss:0.1820175214111805, train_acc:92.175%, test_loss:0.6462807458639145, test_acc:80.5625%.
epoch:7, train_loss:0.16539679327607154, train_acc:92.75%, test_loss:0.7511023604869842, test_acc:79.6875%.
epoch:8, train_loss:0.1435285525470972, train_acc:93.925%, test_loss:0.7103536489605904, test_acc:81.0%.
epoch:9, train_loss:0.1445681993961334, train_acc:94.39999999999999%, test_loss:0.7002745613455772, test_acc:80.9375%.
epoch:10, train_loss:0.12170579592883587, train_acc:94.89999999999999%, test_loss:0.7566036841273308, test_acc:80.25%.
epoch:11, train_loss:0.13242352199554444, train_acc:95.075%, test_loss:0.8948468044400215, test_acc:79.25%.
epoch:12, train_loss:0.12649951136112214, train_acc:95.125%, test_loss:0.7817598308622837, test_acc:80.125%.
epoch:13, train_loss:0.10589461302012206, train_acc:95.39999999999999%, test_loss:0.8214061388373375, test_acc:80.8125%.
epoch:14, train_loss:0.10617950323969126, train_acc:95.825%, test_loss:0.8325457383692264, test_acc:81.4375%.
epoch:15, train_loss:0.11185296278074383, train_acc:95.95%, test_loss:0.8843921723961831, test_acc:79.5625%.
epoch:16, train_loss:0.09260145926102996, train_acc:96.65%, test_loss:0.9140777304768563, test_acc:80.9375%.
epoch:17, train_loss:0.10112197603471577, train_acc:96.2%, test_loss:1.0915761172771454, test_acc:79.1875%.
epoch:18, train_loss:0.08169100842997432, train_acc:96.675%, test_loss:1.050314729809761, test_acc:79.8125%.
epoch:19, train_loss:0.08243767752125859, train_acc:96.65%, test_loss:1.0692332315444946, test_acc:79.5%.

'''



epoch:0, train_loss:1.2830977203845977, train_acc:61.85000000000001%, test_loss:0.5660697293281555, test_acc:73.5625%.
epoch:1, train_loss:0.5060488374233246, train_acc:77.4%, test_loss:0.5216201359033584, test_acc:78.375%.
epoch:2, train_loss:0.37869855481386183, train_acc:84.425%, test_loss:0.4199406266212463, test_acc:84.375%.
epoch:3, train_loss:0.31354732033610344, train_acc:86.675%, test_loss:0.3621220889687538, test_acc:84.5%.
epoch:4, train_loss:0.22454623532295226, train_acc:90.9%, test_loss:0.27758949115872383, test_acc:88.9375%.
epoch:5, train_loss:0.17842225548624993, train_acc:92.975%, test_loss:0.2912437452375889, test_acc:88.75%.
epoch:6, train_loss:0.14480595825612544, train_acc:94.69999999999999%, test_loss:0.3013531240820885, test_acc:88.9375%.
epoch:7, train_loss:0.09629098856449127, train_acc:96.175%, test_loss:0.2842272785305977, test_acc:90.1875%.
epoch:8, train_loss:0.07206326244026423, train_acc:97.32499999999999%, test_loss:0.27254777505993844, test_acc:89.5%.


'5db all \ncuda\ntorch.Size([32, 1, 512]) torch.Size([32, 1, 128, 512]) torch.Size([32, 1, 50, 512])\ntorch.Size([32, 4])\nepoch:0, train_loss:1.1066276550292968, train_acc:57.05%, test_loss:0.6918746972084046, test_acc:68.375%.\nepoch:1, train_loss:0.5834570685625077, train_acc:73.175%, test_loss:0.6275622481107712, test_acc:71.5%.\nepoch:2, train_loss:0.44816092586517337, train_acc:80.575%, test_loss:0.6425983744859696, test_acc:72.625%.\nepoch:3, train_loss:0.3599015097618103, train_acc:84.975%, test_loss:0.5497231543064117, test_acc:76.875%.\nepoch:4, train_loss:0.26929483646154406, train_acc:88.3%, test_loss:0.5525931656360626, test_acc:78.4375%.\nepoch:5, train_loss:0.24232800471782684, train_acc:89.275%, test_loss:0.6000458815693855, test_acc:79.6875%.\nepoch:6, train_loss:0.1820175214111805, train_acc:92.175%, test_loss:0.6462807458639145, test_acc:80.5625%.\nepoch:7, train_loss:0.16539679327607154, train_acc:92.75%, test_loss:0.7511023604869842, test_acc:79.6875%.\nepoch:8, tr

In [229]:
# torch.save(model, f'model(rcstfrhrrpHH{snr}db).pth')


In [230]:
# load model
# snr = 15
# model = torch.load(f'model(rcstfrhrrpHH{snr}db).pth')
# model.eval()  # 设置模型为评估模式

In [231]:
# test_dataset = Dataset(test_data_path, snr = snr, pad_size = pad_size)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [232]:
# all_preds = []
# all_labels = []
# with torch.no_grad():
#     for x, labels in tqdm(test_loader, desc='predicting...'):
#         x, labels = tensor_process(x, labels)
#         outputs = model(x[0], x[1], x[2])
#         _, preds = torch.max(outputs, 1)
#         all_preds.extend(preds.cpu().numpy())
#         all_labels.extend(labels.cpu().numpy())

In [233]:
# correct_predictions = sum([1 for label, pred in zip(all_labels, all_preds) if label == pred])
# total_predictions = len(all_labels)

# # 计算精度
# accuracy = correct_predictions / total_predictions

# print(f"模型的精度是: {accuracy:.2%}")

In [234]:
# cm = confusion_matrix(all_labels, all_preds)

# disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=range(num_classes))
# disp.plot(cmap=plt.cm.Blues)
# plt.show()