In [2]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import h5py
import json
from matplotlib import pyplot as plt
import torch 
from torch import nn
from torchinfo import summary
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from tqdm import trange, tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
# Update this path if needed
file_path = '/home/lipplopp/Documents/research/notebook/notebook_1/dataset/radioml2018/versions/2/GOLD_XYZ_OSC.0001_1024.hdf5'

In [4]:
n_channels=2
batch_size=32
# Number of frames per snr/modulation combination for train,valid and test data
nf_train = 1024
nf_valid = 512
nf_test = 256

In [None]:
def dataset_split(data,
                   modulations_classes,
                   modulations,snrs,
                   target_modulations,
                   mode,
                   target_snrs,train_proportion=0.7,
                   valid_proportion=0.2,
                   test_proportion=0.1,
                   seed=48):
    np.random.seed(seed)
    train_split_index = int(train_proportion*4096)
    valid_split_index = int((valid_proportion+train_proportion)*4096)
    test_split_index = int((test_proportion+valid_proportion+train_proportion)*4096)
    X_output=[]
    Y_output=[]
    Z_output=[]

    target_modulation_indices = [modulations_classes.index(modu) for modu in target_modulations]
    
    for modu in  target_modulation_indices:       
        for snr in target_snrs:
            snr_modu_indices = np.where((modulations==modu) & (snrs==snr))[0]

            np.random.shuffle(snr_modu_indices)
            train, valid, test, remaining = np.split(snr_modu_indices, [train_split_index,valid_split_index,test_split_index])
            if mode=='train':
                X_output.append(data[np.sort(train)])
                Y_output.append(modulations[np.sort(train)])
                Z_output.append(snrs[np.sort(train)])
            elif mode=='valid':
                X_output.append(data[np.sort(valid)])
                Y_output.append(modulations[np.sort(valid)])
                Z_output.append(snrs[np.sort(valid)])
            elif mode =='test':
                X_output.append(data[np.sort(test)])
                Y_output.append(modulations[np.sort(test)])
                Z_output.append(snrs[np.sort(test)])
            else:
                raise ValueError(f'unknown mode: {mode}. Valid modes are train, valid and test') 
    X_array = np.vstack(X_output)
    Y_array = np.concatenate(Y_output)
    Z_array = np.concatenate(Z_output)
    for index,value in enumerate(np.unique(np.copy(Y_array))):
        Y_array[Y_array==value]=index
    return X_array, Y_array, Z_array


In [None]:
class RadioML18Dataset(Dataset):
    def __init__(self, mode: str,seed=48,):
        super(RadioML18Dataset, self).__init__()
        
        # load data
        hdf5_file = h5py.File("/home/lipplopp/Documents/research/notebook/notebook_1/dataset/radioml2018/versions/2/GOLD_XYZ_OSC.0001_1024.hdf5",  'r')
        self.modulation_classes = json.load(open("/home/lipplopp/Documents/research/notebook/notebook_1/dataset/radioml2018/versions/2/classes-fixed.json", 'r'))
        self.X = hdf5_file['X']
        self.Y = np.argmax(hdf5_file['Y'], axis=1)
        self.Z = hdf5_file['Z'][:, 0]
        
        train_proportion=(24*26*nf_train)/self.X.shape[0]
        valid_proportion=(24*26*nf_valid)/self.X.shape[0]
        test_proportion=(24*26*nf_test)/self.X.shape[0]
        
        """target_modulations =['OOK', '4ASK', 'BPSK', 'QPSK', '8PSK',
        '16QAM', 'AM-SSB-SC', 'AM-DSB-SC', 'FM', 'GMSK','OQPSK']target 
        modulation class and snr"""   

        # in this line i could change it the target modulation 
        self.target_modulations =['OOK', '4ASK', 'BPSK', 'QPSK', '8PSK',
        '16QAM', 'AM-SSB-SC', 'AM-DSB-SC', 'FM', 'GMSK','OQPSK']

        self.target_snrs = np.unique(self.Z)
        
        self.X_data, self.Y_data, self.Z_data = dataset_split(
                                                                  data = self.X,
                                                                  modulations_classes = self.modulation_classes,
                                                                  modulations = self.Y,
                                                                  snrs = self.Z,
                                                                  mode = mode,
                                                                  train_proportion = train_proportion,
                                                                  valid_proportion = valid_proportion,
                                                                  test_proportion = test_proportion,
                                                                  target_modulations = self.target_modulations,
                                                                  target_snrs  = self.target_snrs,
                                                                  seed=48
                                                                 )   

        # store statistic of whole dataset
        self.num_data = self.X_data.shape[0]
        self.num_lbl = len(self.target_modulations)
        self.num_snr = self.target_snrs.shape[0]
        
    def __len__(self):
        return self.X_data.shape[0]

    def __getitem__(self, idx):
        x,y,z = self.X_data[idx], self.Y_data[idx], self.Z_data[idx]
        x,y,z = torch.Tensor(x).transpose(0, 1) , y , z
        return x,y,z

In [None]:
ds = RadioML18Dataset(mode='test')
data_len = ds.num_data
n_labels=ds.num_lbl
n_snrs = ds.num_snr
frame_size=ds.X.shape[1]

del ds

In [None]:
import time
st = time.time()
train_dl = DataLoader(dataset=RadioML18Dataset(mode='train'),batch_size=64, shuffle=True, drop_last=True)
valid_dl = DataLoader(dataset=RadioML18Dataset(mode='train'),batch_size=128, shuffle=False, drop_last=False)
test_dl = DataLoader(dataset=RadioML18Dataset(mode='test'),batch_size=128, shuffle=False, drop_last=False)
et = time.time()
elapsed_time = et - st
print('Execution time:', elapsed_time, 'seconds')

In [None]:
class CNN_Block(nn.Module):
    def __init__(self,input_shape,output_shape):
        super().__init__()
        self.net=nn.Sequential(
        nn.Conv1d(input_shape, output_shape, kernel_size=3, padding=1),
        nn.BatchNorm1d(output_shape),
        nn.ReLU(),
        nn.MaxPool1d(kernel_size= 2, stride= 2),
        nn.Dropout(0.25),
        )
        
    def forward(self,x):
        return self.net(x)
    
class CNN_NET(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone=nn.Sequential(
        CNN_Block(2,24),
        CNN_Block(24,24),
        CNN_Block(24,48),
        CNN_Block(48,48), 
        )
        self.classifier=nn.Sequential(
            nn.Flatten(),
            nn.Linear(3072,128),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,n_labels)
        )
        
    def forward(self,x):
        return self.classifier(self.backbone(x))
    


In [None]:
def train_model(model,verbose=True,device='cuda',num_epoch=30):
    
    model.to(device)
    
    train_loss = torch.zeros(num_epoch)
    train_acc = torch.zeros(num_epoch)
    val_loss = torch.zeros(num_epoch)
    val_acc = torch.zeros(num_epoch)
    
    lr = 1e-4
    optimizer = optim.Adam(list(model.parameters()), lr=lr)
    lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer , 0.9 )
    criterion = nn.CrossEntropyLoss()
    
    
    for epoch in trange(num_epoch, unit='epochs'):    
        #Trainning phase
        model.train()
        for x,y,z in train_dl:
            # TODO
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss[epoch] += loss.detach().cpu()
            train_acc[epoch] += torch.mean((torch.argmax(logits.to('cpu'), dim=-1) == y.to('cpu')).float())
        #Evaluation phase
        model.eval()
        with torch.no_grad():
            for x,y,z in valid_dl:
                x = x.to(device)
                y = y.to(device)
                logits = model(x)
                loss = criterion(logits,y)
                val_loss[epoch] += loss.detach().cpu()
                val_acc[epoch] += torch.mean((torch.argmax(logits.to('cpu'), dim=-1) == y.to('cpu')).float())
        lr_scheduler.step()
        train_loss[epoch] /= (len(train_dl.dataset) // train_dl.batch_size)
        train_acc[epoch] /= (len(train_dl.dataset) // train_dl.batch_size)
        val_loss[epoch] /= (len(valid_dl.dataset)//valid_dl.batch_size)
        val_acc[epoch] /= (len(valid_dl.dataset)//valid_dl.batch_size)
        if verbose:
            tqdm.write('Epoch {} (train) -- loss: {:.4f} accuracy: {:.4f}'.format(epoch, train_loss[epoch], train_acc[epoch]))
            tqdm.write('Epoch {} (valid) -- loss: {:.4f} accuracy: {:.4f}'.format(epoch, val_loss[epoch], val_acc[epoch]))
    return model,[train_loss,train_acc,val_loss,val_acc]

In [None]:
def test_model(model, device='cuda'):
    model.eval()
    Y_pred_ = []  # Predictions
    Y_true_ = []  # Ground truth
    Z_snr_ = []   # SNR values
    
    target_classes = test_dl.dataset.target_modulations
    target_snrs = test_dl.dataset.target_snrs
    modulation_classes = test_dl.dataset.modulation_classes
    target_modulations_indices = [modulation_classes.index(mod) for mod in target_classes]
    
    # Initialize accuracy stats DataFrame
    accuracy_stats = pd.DataFrame(
        0.0,
        index=target_classes,
        columns=target_snrs.astype('str'))
    
    # Get predictions
    with torch.no_grad():
        for x, y, z in test_dl:
            # Move tensors to specified device
            x = x.to(device)
            y = y.to(device)
            z = z.to(device)
            
            # Get model predictions on device
            logits = model(x)
            y_pred = torch.argmax(logits, dim=-1)
            
            # Store results
            Y_pred_.append(y_pred.cpu())  # Move back to CPU for storage
            Y_true_.append(y.cpu())
            Z_snr_.append(z.cpu())
    
    # Convert to numpy for easier processing
    Y_pred = torch.cat(Y_pred_).numpy()
    Y_true = torch.cat(Y_true_).numpy()
    Z_snr = torch.cat(Z_snr_).numpy()
    
    # Calculate overall accuracy
    correct_preds = (Y_pred == Y_true).sum()
    total_samples = len(Y_true)
    total_accuracy = round(correct_preds * 100 / total_samples, 2)
    print(f'Accuracy on test dataset: {total_accuracy}%')
    
    # Map indices back to original modulation classes if needed
    for index, value in enumerate(target_modulations_indices):
        Y_pred[Y_pred == index] = value
        Y_true[Y_true == index] = value
    
    # Calculate accuracy per modulation and SNR
    for modu in target_modulations_indices:
        mod_class = modulation_classes[modu]
        for snr in target_snrs:
            snr_str = str(snr)
            
            # Find samples for this modulation and SNR
            mask = (Y_true == modu) & (Z_snr == snr)
            total_samples = mask.sum()
            
            if total_samples > 0:
                # Count correct predictions
                correct_samples = ((Y_pred == Y_true) & mask).sum()
                
                # Calculate and store accuracy percentage
                accuracy = (correct_samples * 100 / total_samples)
                accuracy_stats.loc[mod_class, snr_str] = round(accuracy, 2)
    
    return accuracy_stats

def plot_training_history(model_name:str, history:list):
    plt.figure(figsize=(10, 6))
    plt.title(f'Training of {model_name} model on radioml2018')
    plt.xlabel('Epochs')   
    plt.plot(history[0], label='train_loss')
    plt.plot(history[1], label='train_accuracy')
    plt.plot(history[2], label='valid_loss')
    plt.plot(history[3], label='valid_accuracy')
    plt.legend(loc="upper left")
    plt.show()
    
def plot_test_accuracy(model, device='cuda'):
    accuracy_df = test_model(model, device)
    
    fig, axes = plt.subplots(len(test_dl.dataset.target_modulations), 1, figsize=(12, 8), sharex=True, sharey=True)
    fig.subplots_adjust(hspace=0.4)
    fig.supylabel('Accuracy (%)')
    fig.supxlabel('Signal to noise ratios (dB)')
    
    # Handle the case where there's only one modulation
    if len(test_dl.dataset.target_modulations) == 1:
        axes = [axes]
    
    for index, ax in enumerate(axes):
        ax.set_title(accuracy_df.index[index])
        ax.bar(accuracy_df.iloc[index].index, accuracy_df.iloc[index].values)
        ax.set_ylim(0, 100)  # Set y-axis from 0 to 100%
        
    plt.tight_layout()
    plt.show()
    
    return accuracy_df  # Return the DataFrame for potential further analysis

def train_test_plots(model, model_name, verbose=False, device='cuda', num_epoch=30):
    model, train_history = train_model(model, verbose=verbose, device=device, num_epoch=num_epoch)
    torch.save(model, f'{model_name}.pth')
    plot_training_history(model_name, train_history)
    accuracy_results = plot_test_accuracy(model, device)
    del model
    return train_history, accuracy_results

In [None]:
def train_test_plots(model,model_name,verbose=False,device='cuda',num_epoch=30):
    model, train_history = train_model(model,verbose=verbose,device=device,num_epoch=num_epoch)
    torch.save(model,f'{model_name}.pth')
    plot_training_history(model_name,train_history)
    plot_test_accuracy(model)
    del model

In [None]:
train_test_plots(CNN_NET(),'CNN_NET',verbose=True)

In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns

def test_model_with_improved_plots(model, device='cuda'):
    """
    Enhanced version of test_model that properly tracks all modulations
    and produces better visualizations
    """
    model.eval()
    Y_pred_ = []  # Predictions
    Y_true_ = []  # Ground truth
    Z_snr_ = []   # SNR values
    
    test_dl.dataset.target_modulations
    target_classes = test_dl.dataset.target_modulations
    target_snrs = test_dl.dataset.target_snrs
    modulation_classes = test_dl.dataset.modulation_classes
    
    # Print debug info about target modulations
    print(f"Target modulations: {target_classes}")
    print(f"Target SNRs: {target_snrs}")
    
    # Initialize accuracy stats DataFrame
    accuracy_stats = pd.DataFrame(
        0.0,
        index=target_classes,
        columns=target_snrs.astype('str'))
    
    # Get predictions
    with torch.no_grad():
        for x, y, z in test_dl:
            # Move tensors to specified device
            x = x.to(device)
            y = y.to(device)
            z = z.to(device)
            
            # Get model predictions on device
            logits = model(x)
            y_pred = torch.argmax(logits, dim=-1)
            
            # Store results
            Y_pred_.append(y_pred.cpu())  # Move back to CPU for storage
            Y_true_.append(y.cpu())
            Z_snr_.append(z.cpu())
    
    # Convert to numpy for easier processing
    Y_pred = torch.cat(Y_pred_).numpy()
    Y_true = torch.cat(Y_true_).numpy()
    Z_snr = torch.cat(Z_snr_).numpy()
    
    # Calculate overall accuracy
    correct_preds = (Y_pred == Y_true).sum()
    total_samples = len(Y_true)
    total_accuracy = round(correct_preds * 100 / total_samples, 2)
    print(f'Overall accuracy on test dataset: {total_accuracy}%')
    
    # Count samples for each modulation type
    mod_counts = {}
    for mod_idx, mod_name in enumerate(target_classes):
        count = np.sum(Y_true == mod_idx)
        mod_counts[mod_name] = count
        print(f"Modulation {mod_name}: {count} test samples")
    
    # Calculate accuracy per modulation and SNR
    for mod_idx, mod_name in enumerate(target_classes):
        for snr_idx, snr in enumerate(target_snrs):
            snr_str = str(snr)
            
            # Find samples for this modulation and SNR
            mask = (Y_true == mod_idx) & (Z_snr == snr)
            total_samples = mask.sum()
            
            if total_samples > 0:
                # Count correct predictions
                correct_samples = ((Y_pred == Y_true) & mask).sum()
                
                # Calculate and store accuracy percentage
                accuracy = (correct_samples * 100 / total_samples)
                accuracy_stats.loc[mod_name, snr_str] = round(accuracy, 2)
            else:
                # Mark as NaN if there are no samples
                accuracy_stats.loc[mod_name, snr_str] = np.nan
                print(f"Warning: No samples for {mod_name} at SNR={snr}")
    
    return accuracy_stats

def plot_improved_test_accuracy(model, device='cuda'):
    """
    Enhanced plotting function that shows all modulations properly
    """
    accuracy_df = test_model_with_improved_plots(model, device)
    
    # 1. Single plot with all modulation types
    plt.figure(figsize=(14, 8))
    
    # Convert to long format for seaborn
    accuracy_long = accuracy_df.reset_index().melt(
        id_vars=['index'], 
        var_name='SNR', 
        value_name='Accuracy'
    )
    accuracy_long.columns = ['Modulation', 'SNR', 'Accuracy']
    
    # Convert SNR to numeric for proper ordering
    accuracy_long['SNR'] = accuracy_long['SNR'].astype(float)
    
    # Plot all modulations
    sns.lineplot(
        data=accuracy_long, 
        x='SNR', 
        y='Accuracy', 
        hue='Modulation',
        marker='o',
        markersize=8,
        linewidth=2
    )
    
    # Specifically highlight PSK modulations
    psk_mods = [mod for mod in accuracy_df.index if 'PSK' in mod]
    if psk_mods:
        print(f"Highlighting PSK modulations: {psk_mods}")
        psk_df = accuracy_long[accuracy_long['Modulation'].isin(psk_mods)]
        
        # Use a separate plot command with larger linewidth to highlight PSK
        for mod in psk_mods:
            mod_data = psk_df[psk_df['Modulation'] == mod]
            plt.plot(mod_data['SNR'], mod_data['Accuracy'], 
                     linewidth=3.5, 
                     linestyle='--',
                     marker='*', 
                     markersize=12)
    
    plt.title('Classification Accuracy vs SNR for Different Modulation Types', fontsize=16)
    plt.xlabel('Signal-to-Noise Ratio (dB)', fontsize=14)
    plt.ylabel('Accuracy (%)', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig('all_modulations_accuracy.png', dpi=300)
    plt.show()
    
    # 2. Individual subplots for each modulation
    # Find how many rows we need (assuming 3 plots per row)
    n_rows = (len(accuracy_df.index) + 2) // 3  # Ceiling division
    
    fig, axes = plt.subplots(n_rows, 3, figsize=(18, n_rows*4), sharey=True)
    axes = axes.flatten()  # Flatten to make indexing easier
    
    # Hide unused subplots
    for i in range(len(accuracy_df.index), len(axes)):
        axes[i].set_visible(False)
    
    # Plot each modulation in its own subplot
    for i, mod in enumerate(accuracy_df.index):
        ax = axes[i]
        
        # Get data for this modulation
        mod_data = accuracy_df.loc[mod].astype(float)
        
        # Plot bar chart
        ax.bar(mod_data.index, mod_data.values, color='skyblue' if 'PSK' not in mod else 'red')
        
        # Add line for trend
        ax.plot(mod_data.index, mod_data.values, 'k--', linewidth=2)
        
        # Styling
        ax.set_title(f'{mod}', fontsize=14)
        ax.set_xlabel('SNR (dB)' if i >= len(accuracy_df.index) - 3 else '')
        ax.set_ylabel('Accuracy (%)' if i % 3 == 0 else '')
        ax.set_ylim(0, 105)  # Set y-axis from 0 to 100% with a bit of margin
        ax.grid(True, alpha=0.3)
        
        # Rotate x-axis labels for better readability
        plt.setp(ax.get_xticklabels(), rotation=45)
    
    plt.suptitle('Classification Accuracy by Modulation Type', fontsize=18)
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)  # Make room for suptitle
    plt.savefig('modulation_accuracy_subplots.png', dpi=300)
    plt.show()
    
    # 3. Also create a heatmap visualization
    plt.figure(figsize=(14, 8))
    sns.heatmap(accuracy_df.astype(float), annot=True, cmap='viridis', fmt='.1f', 
                cbar_kws={'label': 'Accuracy (%)'})
    plt.title('Classification Accuracy Heatmap by Modulation and SNR', fontsize=16)
    plt.xlabel('Signal-to-Noise Ratio (dB)', fontsize=14)
    plt.ylabel('Modulation Type', fontsize=14)
    plt.tight_layout()
    plt.savefig('modulation_accuracy_heatmap.png', dpi=300)
    plt.show()
    
    return accuracy_df


def check_dataset_distribution(test_dl):
    """
    Analyze the distribution of modulations in the dataset
    """
    # Get dataset
    dataset = test_dl.dataset
    
    # Analyze distribution
    mod_counts = {}
    snr_mod_counts = {}
    
    # Initialize counts for all modulations
    for mod in dataset.target_modulations:
        mod_counts[mod] = 0
    
    # Count occurrences of each modulation
    for i in range(len(dataset)):
        _, mod_idx, snr = dataset[i]
        mod = dataset.target_modulations[mod_idx]
        
        # Count by modulation
        mod_counts[mod] += 1
        
        # Count by modulation and SNR
        if snr not in snr_mod_counts:
            snr_mod_counts[snr] = {}
        if mod not in snr_mod_counts[snr]:
            snr_mod_counts[snr][mod] = 0
        snr_mod_counts[snr][mod] += 1
    
    print("Modulation distribution in dataset:")
    for mod, count in mod_counts.items():
        print(f"  {mod}: {count} samples")
    
    # Special check for PSK modulations
    psk_mods = [mod for mod in dataset.target_modulations if 'PSK' in mod]
    print("\nPSK modulation distribution:")
    for mod in psk_mods:
        print(f"\n{mod} distribution across SNRs:")
        for snr in sorted(snr_mod_counts.keys()):
            count = snr_mod_counts[snr].get(mod, 0)
            print(f"  SNR {snr}dB: {count} samples")
    
    return mod_counts, snr_mod_counts


# Enhanced version of train_test_plots that includes improved plotting
def improved_train_test_plots(model, model_name, verbose=False, device='cuda', num_epoch=30):
    # First check the dataset distribution
    print("Analyzing test dataset distribution...")
    mod_counts, snr_mod_counts = check_dataset_distribution(test_dl)
    
    # Train the model
    print(f"\nTraining {model_name}...")
    model, train_history = train_model(model, verbose=verbose, device=device, num_epoch=num_epoch)
    torch.save(model, f'{model_name}.pth')
    
    # Plot training history
    plot_training_history(model_name, train_history)
    
    # Plot test accuracy with improved visualization
    print("\nEvaluating and plotting test accuracy...")
    accuracy_results = plot_improved_test_accuracy(model, device)
    
    return model, train_history, accuracy_results