In [1]:
import scipy.io
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils import prune

import numpy as np
import matplotlib.pyplot as plt
import copy

plt.rcParams['font.family'] = 'Bookman Old Style'


class dataset:
    def __init__(self, x, y, num_training_points=5000, num_memory_levels =3):
        self.num_training_points = num_training_points
        self.num_memory_levels = num_memory_levels # If there are 0 memory taps, this is 1
        self.x = x
        self.y = y

        self.model_training_input, self.model_training_output, self.training_phase, self.model_valid_input, self.model_valid_output, self.valid_phase = self.prepare_data()

    def phase_vector(self, x):    
        """Takes a vector x and returns a vector of phases of each element"""
        Ax = np.abs(x)
        return np.conj(x)/Ax

    def model_expected_output(self, y, phase):
        """Take in data, phase normalised it, and trim. Return as IQ seperately"""
        y_denorm = y*phase
        y_denorm_trim = y_denorm[self.num_memory_levels:]
        return np.array([np.real(y_denorm_trim), np.imag(y_denorm_trim)]).T

    def build_xfc(self, x, num_memory_levels):
        """
        Replicates the MATLAB build_xfc() function.
        """
        num_points = len(x)
        phase = self.phase_vector(x)
        I = np.real(x)
        Q = np.imag(x)

        # Phase-normalized data
        phase_norm_data = np.zeros((num_points, num_memory_levels), dtype=complex)
        for n in range(num_memory_levels, num_points):
            for m in range(num_memory_levels):
                phase_norm_data[n, m] = x[n - m - 1] * phase[n]

        # Ax magnitude feature
        Ax = np.sqrt(I**2 + Q**2)

        # Build A feature matrix (Ax memory taps)
        A_feats = np.zeros((num_points, num_memory_levels))
        for n in range(num_memory_levels, num_points):
            for m in range(num_memory_levels):
                A_feats[n, m] = Ax[n - m]

        # Trim first num_memory_levels samples (as in MATLAB)
        phase_norm_data = phase_norm_data[num_memory_levels:, :]
        A_feats = A_feats[num_memory_levels:, :]
        A3_feats = A_feats ** 3

        # Combine real and imaginary phase-normalized parts with A-features
        xfc = np.hstack([
            np.real(phase_norm_data),
            np.imag(phase_norm_data),
            A_feats,
            A3_feats
        ]).astype(np.float32)

        return xfc

    def build_x_matrix(self, x, num_mem_levels, num_nl_orders):
        """Build Matrix X for find Volterra Model"""
        num_points = len(x)
        X = np.zeros((num_points, num_mem_levels * num_nl_orders), dtype=np.complex128)
        
        for n in range(num_mem_levels - 1, num_points):
            col = 0
            for i in range(num_mem_levels):
                xi = x[n - i]
                for j in range(num_nl_orders):
                    X[n, col] = (abs(xi) ** ((j) * 2)) * xi
                    col += 1

        return X

    def build_y(self, u, A, num_mem_levels, num_nl_orders):
        """Builds y, the output of the volterra Model. Trims Output"""
        num_points = len(u)
        y = np.zeros((num_points, 1), dtype=np.complex128)
        for n in range(num_mem_levels - 1, num_points):
            col = 0 
            for i in range(num_mem_levels):
                ui = u[n-i]
                for j in range(num_nl_orders):
                    y[n]= y[n] + A[col]*(abs(ui)**(j*2)*ui)
                    col += 1
        y = y[self.num_memory_levels:]
        return y
            

    def volterra(self,num_nl_orders):
        """Build component matrix A"""
        X = self.build_x_matrix(self.model_training_input, self.num_memory_levels, num_nl_orders)
        
        X_trim = X[self.num_memory_levels:, :]
        y_trim = self.model_training_output[self.num_memory_levels:]
        return np.linalg.pinv(X_trim.conj().T @ X_trim) @ (X_trim.conj().T @ y_trim);

    def training_data(self):
        """Assign some data for just training"""
        idx_training = range(0, self.num_training_points -1) # training indices

        model_training_input = self.y[idx_training]
        model_training_output = self.x[idx_training]
        return model_training_input, model_training_output
    

    def validation_data(self):
        """Assign some data for just validation"""
        num_validation_points = self.num_training_points 
        validation_end_index = self.num_training_points + num_validation_points
        idx_validation = range(self.num_training_points, validation_end_index -1) # validation indices
        model_valid_input = self.y[idx_validation]
        model_valid_output = self.x[idx_validation]
        return model_valid_input, model_valid_output

    def prepare_data(self):
        """Prepare training and validation data sets"""
        # Training Data
        model_training_input, model_training_output = self.training_data()
        training_phase = self.phase_vector(model_training_input)

        # Validation Data
        model_valid_input, model_valid_output = self.validation_data()
        valid_phase = self.phase_vector(model_valid_input)
        return (model_training_input, model_training_output, training_phase,
                model_valid_input, model_valid_output, valid_phase)

    def get_model_training_xfc(self):
        """Build xfc from model training input"""
        return self.build_xfc(self.model_training_input, self.num_memory_levels)
    
    def get_model_training_expected_output(self):
        """Find what the model should output for training data"""
        return self.model_expected_output(self.model_training_output, self.training_phase)
    
    def get_valid_xfc(self):
        """Build xfc from model validation input"""
        return self.build_xfc(self.model_valid_input, self.num_memory_levels)
    
    def get_model_valid_expected_output(self):
        """Find what the model should output for validation data"""
        return self.model_expected_output(self.model_valid_output, self.valid_phase)
    
    def get_test_data(self):
        """Get test data after validation"""
        num_validation_points = self.num_training_points 
        validation_end_index = self.num_training_points + num_validation_points
        x_data = self.x[validation_end_index:]
        y_data = self.y[validation_end_index:]
        return x_data, y_data


# Load data
data = scipy.io.loadmat("PA_IO.mat")
x = data["x"].squeeze()
y = data["y"].squeeze()

# Create dataset object
data_obj = dataset(x, y)

# Access training data
model_xfc = data_obj.get_model_training_xfc()
model_training_expected_output = data_obj.get_model_training_expected_output()

# Access validation data
valid_xfc = data_obj.get_valid_xfc()
model_valid_expected_output = data_obj.get_model_valid_expected_output()

# Access test data
x_data, y_data = data_obj.get_test_data()

In [2]:
# Get Volterra Model of PA
num_memory_levels = 3
num_nl_orders = 5
A = data_obj.volterra(num_nl_orders)
print(A)

[ 0.63146409-0.11924996j  0.1694142 -0.01547961j -0.29918448+0.23918013j
  1.27680917-1.02310172j -0.87359925+0.71303688j  0.12086561+0.23694819j
  0.01362539+0.09275574j -0.09982292-0.22003663j  0.45094657+0.5524524j
 -0.35963882-0.32360673j -0.06705462-0.13809442j -0.02614158-0.01855645j
  0.16654034-0.04988737j -0.48981135-0.01449885j  0.34387347+0.01718324j]


In [6]:
# Train NN on backprop PA for inv model

class PNTDNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(PNTDNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class NN:
    def __init__(self, pntdnn):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pntdnn = pntdnn
        
    def build_dataloaders(self, x , y):
        X = torch.tensor(x, dtype=torch.float32)
        Y = torch.tensor(y, dtype=torch.float32)
        dataset = TensorDataset(X, Y)
        loader = DataLoader(dataset, batch_size=256, shuffle=True)
        return loader

    
    def get_best_model(self, train_loader, valid_loader, num_epochs=400, learning_rate=1e-3):
        """Train model and return the best model based on validation loss"""
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.pntdnn.parameters(), lr=learning_rate)
        
        train_losses = []
        valid_losses = []
        best_valid_loss = float('inf')
        best_model_state = None
        best_epoch = 0
        
        for epoch in range(num_epochs):
            self.pntdnn.train()
            running_train_loss = 0
            running_valid_loss = 0
            
            for xb, yb in train_loader:
                optimizer.zero_grad()
                preds = self.pntdnn(xb)
                loss = criterion(preds, yb)
                loss.backward()
                optimizer.step()
                running_train_loss += loss.item() * xb.size(0)
                
            train_loss = running_train_loss
            
            self.pntdnn.eval()
            with torch.no_grad():
                for xb, yb in valid_loader:
                    preds = self.pntdnn(xb)
                    loss = criterion(preds, yb)
                    running_valid_loss += loss.item() * xb.size(0)
                
            valid_loss = running_valid_loss
            
            train_losses.append(train_loss)
            valid_losses.append(valid_loss)
            
            # Save best model
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                best_model_state = copy.deepcopy(self.pntdnn.state_dict())
                best_epoch = epoch + 1
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch + 1:3d}/{num_epochs}  Loss={train_loss:.4e}  Valid Loss={valid_loss:.4e}")
        
        # Load best model
        self.pntdnn.load_state_dict(best_model_state)
        print(f"\nBest model from epoch {best_epoch} with validation loss: {best_valid_loss:.4e}")
        
        return train_losses, valid_losses, best_epoch

    def prune_model(self, prune_amount=0.2):
        pruned_model = copy.deepcopy(self.pntdnn)
        parameters_to_prune = (
            (pruned_model.fc1, 'weight'),
            (pruned_model.fc2, 'weight'),
        )
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=prune_amount,
        )
        return pruned_model
    
    def calculate_nmse(self, x, y):
        """Return NMSE in dB"""
        self.pntdnn.eval()
        with torch.no_grad():
            inputs = torch.tensor(x, dtype=torch.float32)
            targets = torch.tensor(y, dtype=torch.float32)
            outputs = self.pntdnn(inputs)
            mse_loss = nn.MSELoss()(outputs, targets).item()
            signal_power = torch.mean(targets ** 2).item()
            nmse = mse_loss / signal_power
            nmse = 10 * np.log10(nmse)
        return nmse
    

# Instantiate and train the model
input_size = model_xfc.shape[1]
hidden_size = 12
output_size = 2
pntdnn = PNTDNN(input_size, hidden_size, output_size)
nn_model = NN(pntdnn)
train_loader = nn_model.build_dataloaders(model_xfc, model_training_expected_output)
valid_loader = nn_model.build_dataloaders(valid_xfc, model_valid_expected_output)
train_losses, valid_losses, best_epoch = nn_model.get_best_model(train_loader, valid_loader, num_epochs=2000)
print(nn_model.calculate_nmse(model_xfc, model_training_expected_output))



Epoch  10/2000  Loss=1.1938e+01  Valid Loss=1.2860e+01
Epoch  20/2000  Loss=9.2629e-01  Valid Loss=9.7084e-01
Epoch  20/2000  Loss=9.2629e-01  Valid Loss=9.7084e-01
Epoch  30/2000  Loss=4.5922e-01  Valid Loss=5.3578e-01
Epoch  30/2000  Loss=4.5922e-01  Valid Loss=5.3578e-01
Epoch  40/2000  Loss=3.0070e-01  Valid Loss=3.5869e-01
Epoch  40/2000  Loss=3.0070e-01  Valid Loss=3.5869e-01
Epoch  50/2000  Loss=2.1274e-01  Valid Loss=2.5750e-01
Epoch  50/2000  Loss=2.1274e-01  Valid Loss=2.5750e-01
Epoch  60/2000  Loss=1.6398e-01  Valid Loss=1.9687e-01
Epoch  60/2000  Loss=1.6398e-01  Valid Loss=1.9687e-01
Epoch  70/2000  Loss=1.3523e-01  Valid Loss=1.6322e-01
Epoch  70/2000  Loss=1.3523e-01  Valid Loss=1.6322e-01
Epoch  80/2000  Loss=1.2165e-01  Valid Loss=1.4571e-01
Epoch  80/2000  Loss=1.2165e-01  Valid Loss=1.4571e-01
Epoch  90/2000  Loss=1.1374e-01  Valid Loss=1.3590e-01
Epoch  90/2000  Loss=1.1374e-01  Valid Loss=1.3590e-01
Epoch 100/2000  Loss=1.0685e-01  Valid Loss=1.2580e-01
Epoch 100/

In [None]:
for sparisty in [0.2, 0.4, 0.6, 0.8]:
    pruned_model = nn_model.prune_model(prune_amount=sparisty)
    nn_model_pruned = NN(pruned_model)
    train_loader_pruned = nn_model_pruned.build_dataloaders(model_xfc, model_training_expected_output)
    valid_loader_pruned = nn_model_pruned.build_dataloaders(valid_xfc, model_valid_expected_output)
    train_losses_pruned, valid_losses_pruned, best_epoch_pruned = nn_model_pruned.get_best_model(train_loader_pruned, valid_loader_pruned, num_epochs=200)
    print(f"Pruned Model NMSE at sparsity {sparisty}: {nn_model_pruned.calculate_nmse(model_xfc, model_training_expected_output)} dB")

Epoch  10/200  Loss=4.8473e-02  Valid Loss=5.7157e-02
Epoch  20/200  Loss=4.6755e-02  Valid Loss=5.4827e-02
Epoch  20/200  Loss=4.6755e-02  Valid Loss=5.4827e-02
Epoch  30/200  Loss=4.6196e-02  Valid Loss=5.3945e-02
Epoch  30/200  Loss=4.6196e-02  Valid Loss=5.3945e-02
Epoch  40/200  Loss=4.5411e-02  Valid Loss=5.6006e-02
Epoch  40/200  Loss=4.5411e-02  Valid Loss=5.6006e-02
Epoch  50/200  Loss=4.8200e-02  Valid Loss=5.4007e-02
Epoch  50/200  Loss=4.8200e-02  Valid Loss=5.4007e-02
Epoch  60/200  Loss=4.6152e-02  Valid Loss=5.6797e-02
Epoch  60/200  Loss=4.6152e-02  Valid Loss=5.6797e-02
Epoch  70/200  Loss=4.6317e-02  Valid Loss=5.6623e-02
Epoch  70/200  Loss=4.6317e-02  Valid Loss=5.6623e-02
Epoch  80/200  Loss=4.4862e-02  Valid Loss=5.6525e-02
Epoch  80/200  Loss=4.4862e-02  Valid Loss=5.6525e-02
Epoch  90/200  Loss=4.5702e-02  Valid Loss=5.9120e-02
Epoch  90/200  Loss=4.5702e-02  Valid Loss=5.9120e-02
Epoch 100/200  Loss=5.0929e-02  Valid Loss=5.6336e-02
Epoch 100/200  Loss=5.0929e-