In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn 
import torch.nn.functional as f
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

import thop

#### Loading the Data

In [None]:
data = np.load("../Dataset/speed-28-delayangle-32.npy")
print(f"Loaded NumPy Array shape: {data.shape}") # (batch_size, num_ofdm_symbols, num_tx_antennas, fft_size)

data = torch.from_numpy(data) # converting numpy array to a tensor

# Plotting/ verifying data from loaded tensors.
fig = plt.figure(figsize=(18, 25))  # Wider and taller figure for 5x2 layout

for i in range(2):
    # --- Delay-Angle (Left Column) ---
    h_tf_mag = torch.abs(data[10, i*10, :, :]).numpy()
    time_bins, freq_bins = h_tf_mag.shape
    X_tf, Y_tf = np.meshgrid(np.arange(freq_bins), np.arange(time_bins))
                                    
    ax1 = fig.add_subplot(5, 2, 2*i + 1, projection='3d')
    ax1.plot_surface(X_tf, Y_tf, h_tf_mag, cmap='viridis')
    ax1.set_title(f"TimeStep {i*10} - Delay Angle")     
    ax1.set_xlabel("Delay")
    ax1.set_ylabel("Angle")
    ax1.set_zlabel("|h_da|")

    # --- Delay-Angle - Difference (Right Column) ---
    h_dd_mag = torch.abs(data[10, i*31, :, :]).numpy()
    doppler_bins, delay_bins = h_dd_mag.shape
    X_dd, Y_dd = np.meshgrid(np.arange(delay_bins), np.arange(doppler_bins))

    if i == 1:
        h_dd_mag = torch.abs(data[10, 10, :, :]).numpy() - torch.abs(data[10, 0, :, :]).numpy()
    
    ax2 = fig.add_subplot(5, 2, 2*i + 2, projection='3d')
    ax2.plot_surface(X_dd, Y_dd, h_dd_mag, cmap='viridis')
    ax2.set_title(f"Difference TimeStep {i*10} - Delay Angle")
    ax2.set_xlabel("Delay")
    ax2.set_ylabel("Angle")                                    
    ax2.set_zlabel("|h_dd|")

plt.tight_layout()
plt.show()

In [None]:
train_data = data[:700] # first 700 samples; each sample contains 64 correlated frames; each frame is 64 x 64
test_data = data[700:]

print(f"Training data: {train_data.shape}")
print(f"Testing data: {test_data.shape}")

### Defining the Models

In [None]:
class Difference_Net(nn.Module):
    def __init__(self, dim=1024):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(dim*2, dim),
            nn.Tanh(),
            nn.Linear(dim, dim)
        )
    
    def forward(self, cur_input, mem_input):
        x = torch.cat([cur_input, mem_input], dim=0)
        return self.net(x)
    

In [None]:
class Summation_Net(nn.Module):
    def __init__(self, dim=1024):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(dim*2, dim),
            nn.Tanh(),
            nn.Linear(dim, dim)
        )
    
    def forward(self, cur_input, mem_input):
        x = torch.cat([cur_input, mem_input], dim=0)
        return self.net(x)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.input_dim = 1024
        self.encoder = nn.Linear(self.input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, self.input_dim)

    def forward(self, x):
        z = self.encoder(x)
        z = f.tanh(z)
        x_hat = self.decoder(z)

        return x_hat
    

In [None]:
class Differential_AE(nn.Module):
    def __init__(self, autoencoder):
        super().__init__()
        self.AE = autoencoder
        self.encoder_memory = None
        self.decoder_memory = None

        self.diff_func = Difference_Net()
        self.sum_func = Summation_Net()
    
    def reset_memory(self):
        self.encoder_memory = None
        self.decoder_memory = None

    def forward(self, x_seq):
        reconstructed = []
        for t, x in enumerate(x_seq):
            if t == 0:
                self.encoder_memory = x.clone()

                x_recon = self.AE(x)

                self.decoder_memory = x_recon.clone()

            else:
                x_diff = self.diff_func(x, self.encoder_memory) # compute the difference which is the input
                self.encoder_memory = x.clone() # update encoder memory to current input

                x_diff_recon = self.AE(x_diff) # reconstruct the difference

                x_recon = self.sum_func(x_diff_recon, self.decoder_memory) # Approx reconstruction of current input
                self.decoder_memory = x_recon.clone() # update decoder memory to current reconstructed value

            reconstructed.append(x_recon)
            
        return torch.stack(reconstructed, dim=0)

### Training the Model

In [None]:
def normalize_complex_data(x):
    norm = torch.linalg.norm(x, dim=-2, keepdim=True) # Normalizing along columns
    return x / (norm + 1e-8), norm

In [None]:
train_data_norm, train_norm = normalize_complex_data(train_data)
test_data_norm, test_norm = normalize_complex_data(test_data)

train_data_norm.shape, train_norm.shape

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device
print(torch.cuda.get_device_name(torch.cuda.current_device()))

In [None]:
AE = AutoEncoder(10) # change the parameter for different compression rates
model = Differential_AE(AE)
model = model.to(device)
model

In [None]:
# FLOps computation
d_input = torch.randn([1, 32, 32]).to(device).flatten().unsqueeze(0)
d_input.shape


In [None]:
flops, params = thop.profile(model, inputs=(d_input,), verbose=False)
flops, params = thop.clever_format([flops, params], "%.3f")
flops

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1.2e-4)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3)

In [None]:
epochs = 25
best_loss = float("inf")   # initialize with infinity
model_path = "../Models-100/dnn-nonlineardiff-model-99.pth"
prev_state_dict = None  # to store last good model state

In [None]:
# best_loss = float('inf')
# prev_model_state = None
# prev_opt_state = None

# for i in range(epochs):
#     epoch_loss = 0.0
#     model.train()
#     nan_detected = False

#     for batch in train_data_norm:
#         optimizer.zero_grad()
#         batch = batch.to(device).view(64, -1)
#         model.reset_memory()

#         recons_seq = model(batch)
#         if torch.isnan(recons_seq).any():
#             print(f"NaN detected in output at epoch {i}")
#             nan_detected = True
#             break

#         loss = (criterion(recons_seq.real, batch.real) +
#                 criterion(recons_seq.imag, batch.imag))
#         if torch.isnan(loss):
#             print(f"NaN detected in loss at epoch {i}")
#             nan_detected = True
#             break

#         epoch_loss += loss.item()
#         loss.backward()

#         for name, param in model.named_parameters():
#             if param.grad is not None and torch.isnan(param.grad).any():
#                 print(f"NaN in gradients for {name} at epoch {i}")
#                 nan_detected = True
#                 break
#         if nan_detected:
#             break

#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
#         optimizer.step()

#     # ---------------- NaN Recovery ----------------
#     if nan_detected:
#         print(f"NaN encountered at epoch {i}. Restoring previous model and optimizer state...")

#         if prev_model_state is not None and prev_opt_state is not None:
#             model.load_state_dict(prev_model_state)
#             optimizer.load_state_dict(prev_opt_state)

#             for g in optimizer.param_groups:
#                 g['lr'] *= 0.5  # shrink learning rate
#             print(f"Learning rate reduced to {optimizer.param_groups[0]['lr']:.2e}. Restarting epoch {i}.")
#         else:
#             print("No previous state found — restarting anyway.")
        
#         continue  # retry this epoch
    
#     # ---------------- Normal Progress ----------------
#     scheduler.step(epoch_loss)

#     if epoch_loss < best_loss:
#         best_loss = epoch_loss
#         prev_model_state = model.state_dict()
#         prev_opt_state = optimizer.state_dict()
#         torch.save(prev_model_state, model_path)
#         print(f"Epoch {i}, Loss {epoch_loss:.4f} — New best model saved")
#     else:
#         print(f"Epoch {i}, Loss {epoch_loss:.4f}")


In [None]:
for i in range(epochs):
    epoch_loss = 0
    model.train()
    
    for batch in train_data_norm:
        optimizer.zero_grad()
        batch = batch.to(device)
        batch = batch.view(64, -1)

        model.reset_memory() # reset encoder and decoder memory before start of each batch

        recons_seq = model(batch)
        if torch.isnan(recons_seq).any():
            print(f"NaN detected in output at epoch {i}")
            break

        loss = criterion(recons_seq.real, batch.real) + criterion(recons_seq.imag, batch.imag)
        if torch.isnan(loss):
            print(f"NaN detected in loss at epoch {i}")
            break

        epoch_loss+= loss.item()

        loss.backward()

        # Check gradients for NaNs/Infs
        for name, param in model.named_parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any():
                    print(f"NaN in gradients for {name} at epoch {i}")
                    break

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
    scheduler.step(epoch_loss)
    # ---- checkpointing ----
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), model_path)  # save only parameters
        print(f"Epoch {i}, Loss {epoch_loss:.4f} New best model saved")
    else:
        print(f"Epoch {i}, Loss {epoch_loss:.4f}")

In [None]:
# Load best model
best_model = Differential_AE(AE).to(device)  # rebuild same architecture
best_model.load_state_dict(torch.load(model_path))

model = best_model

In [None]:
def correlation(A, A_hat):
    
    batch_size = A.shape[0]
    A = A.reshape(batch_size, 32, 32).cpu().detach().numpy()
    A_hat = A_hat.reshape(batch_size, 32, 32).cpu().detach().numpy()
    
    corr = np.zeros(batch_size)
    for i in range(batch_size):
        In = A[i]
        Out = A_hat[i]
        
        l = []
        for j in range(32):
            n1 = np.sqrt(np.sum(np.conj(In[:, j])*In[:, j]))
            n2 = np.sqrt(np.sum(np.conj(Out[:, j])*Out[:, j]))
    
            num = np.abs(np.sum(np.conj(In[:, j])* Out[:, j]))

            l.append(num/ (n1*n2 + 1e-12))
            
        corr[i] = np.mean(l)

    return corr

In [None]:
def correlation_renormalized(A, A_hat): # Cosine Similarity along columns for renormalized data
    """
    It is same as the above correlation function expect that it recieves the tensors in proper format.
    Doesn't require reshaping;
    """
    batch_size = A.shape[0]
    # A = A.reshape(batch_size, 8, 76).cpu().detach().numpy()
    # A_hat = A_hat.reshape(batch_size, 8, 76).cpu().detach().numpy()
    A = A.cpu().detach().numpy()
    A_hat = A_hat.cpu().detach().numpy()
    
    corr = np.zeros(batch_size)
    for i in range(batch_size):
        In = A[i]
        Out = A_hat[i]
        
        l = []
        for j in range(32):
            n1 = np.sqrt(np.sum(np.conj(In[:, j])*In[:, j]))
            n2 = np.sqrt(np.sum(np.conj(Out[:, j])*Out[:, j]))
    
            num = np.abs(np.sum(np.conj(In[:, j])* Out[:, j]))

            l.append(num/ (n1*n2 + 1e-12))
            
        corr[i] = np.mean(l)

    return corr

In [None]:
def compute_NMSE(gt, pred):
    """
    Compute NMSE in dB between complex-valued ground truth and prediction.

    Args:
        gt: Ground truth tensor of shape [B, 2, H, W], normalized [0, 1]
        pred: Predicted tensor of same shape [B, 2, H, W]

    Returns:
        NMSE in dB (lower is better)
    """
    
    # De-centralize
    gt = gt - 0.5
    pred = pred - 0.5

    # Compute power of complex ground truth
    power_gt = gt[:, 0, :, :]**2 + gt[:, 1, :, :]**2

    # Compute squared error
    diff = gt - pred
    mse = diff[:, 0, :, :]**2 + diff[:, 1, :, :]**2

    # NMSE per sample
    nmse = mse.sum(dim=[1, 2]) / (power_gt.sum(dim=[1, 2]) + 1e-12)

    # Mean NMSE across batch and convert to dB
    nmse_db = 10 * torch.log10(nmse.mean())

    return nmse_db.item()

In [None]:
model.eval()   # put model in eval mode
train_loss = 0.0
cosine_sim_col = []
model_output = []
nmse_total = 0.0 # for normalized data

with torch.no_grad():   # disable gradient computation
    for batch in train_data_norm:
        batch = batch.to(device)
        batch = batch.view(64, -1)

        model.reset_memory()  # reset memory for each sequence

        recons_seq = model(batch)
        model_output.append(recons_seq)

        corr = correlation(batch, recons_seq)
        
        cosine_sim_col.append(np.mean(corr))

        loss = (criterion(recons_seq.real, batch.real) +
                criterion(recons_seq.imag, batch.imag))
        train_loss += loss.item()

        batch = batch.reshape(64, 32, 32)
        recons_seq = recons_seq.reshape(64, 32, 32)

        B = torch.stack([batch.real, batch.imag], dim=1)
        R = torch.stack([recons_seq.real, recons_seq.imag], dim=1)
        nmse = compute_NMSE(B, R)
        nmse_total+=nmse
        

print(f"Train Loss: {train_loss:.4f}")
print(f"Cosine Similarity along columns Train Data: {np.mean(cosine_sim_col):.4f}")

# Final average NMSE in dB
avg_nmse_db = nmse_total / len(train_data_norm)
print(f"Average NMSE (dB): {avg_nmse_db:.2f} dB")

nmse_total = 0.0

m_output = torch.stack(model_output)
m_output = m_output.reshape(m_output.shape[0], 64, 32, 32)
m_output = m_output.cpu() * train_norm # Multiplying the constants back

cosine_sim_col = []
for i in range(train_data.shape[0]):
    corr = correlation_renormalized(train_data[i], m_output[i])

    B = torch.stack([train_data[i].real, train_data[i].imag], dim=1)
    R = torch.stack([m_output[i].real, m_output[i].imag], dim=1)

    nmse = compute_NMSE(B, R)
    nmse_total+=nmse
    
    cosine_sim_col.append(np.mean(corr))

print(f"Cosine Similarity along columns Train Data (ReNormalized Data): {np.mean(cosine_sim_col):.4f}")
# Final average NMSE in dB
avg_nmse_db = nmse_total / len(train_data_norm)
print(f"Average NMSE (dB) (ReNormalized Data): {avg_nmse_db:.2f} dB")


In [None]:
model.eval()   # put model in eval mode
test_loss = 0.0
cosine_sim_col = []
model_output = []
nmse_total = 0.0 # for normalized data

with torch.no_grad():   # disable gradient computation
    for batch in test_data_norm:
        batch = batch.to(device)
        batch = batch.view(64, -1)

        model.reset_memory()  # reset memory for each sequence

        recons_seq = model(batch)
        model_output.append(recons_seq)

        corr = correlation(batch, recons_seq)
        cosine_sim_col.append(np.mean(corr))

        loss = (criterion(recons_seq.real, batch.real) +
                criterion(recons_seq.imag, batch.imag))
        test_loss += loss.item()

        batch = batch.reshape(64, 32, 32)
        recons_seq = recons_seq.reshape(64, 32, 32)

        B = torch.stack([batch.real, batch.imag], dim=1)
        R = torch.stack([recons_seq.real, recons_seq.imag], dim=1)
        nmse = compute_NMSE(B, R)
        nmse_total+=nmse

print(f"Test Loss: {test_loss:.4f}")
print(f"Cosine Similarity along columns Test Data (Normalized Data): {np.mean(cosine_sim_col):.4f}")

# Final average NMSE in dB
avg_nmse_db = nmse_total / len(test_data_norm)
print(f"Average NMSE (dB): {avg_nmse_db:.2f} dB")

nmse_total = 0.0

m_output = torch.stack(model_output)
m_output = m_output.reshape(m_output.shape[0], 64, 32, 32)
m_output = m_output.cpu() * test_norm # Multiplying the constants back

cosine_sim_col = []
for i in range(test_data.shape[0]):
    corr = correlation_renormalized(test_data[i], m_output[i])
    cosine_sim_col.append(np.mean(corr))

    B = torch.stack([test_data[i].real, test_data[i].imag], dim=1)
    R = torch.stack([m_output[i].real, m_output[i].imag], dim=1)

    nmse = compute_NMSE(B, R)
    nmse_total+=nmse

print(f"Cosine Similarity along columns Test Data (ReNormalized Data): {np.mean(cosine_sim_col):.4f}")

# Final average NMSE in dB
avg_nmse_db = nmse_total / len(test_data_norm)
print(f"Average NMSE (dB) (ReNormalized Data): {avg_nmse_db:.2f} dB")