In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import OrderedDict
from torch.optim.lr_scheduler import ReduceLROnPlateau
import thop

In [None]:
torch.cuda.set_device(1)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(torch.cuda.get_device_name(torch.cuda.current_device()))

### Loading the Data

In [None]:
def normalize_complex_data(x):
    norm = torch.linalg.norm(x, dim=-2, keepdim=True) # Normalizing along columns
    return x / (norm + 1e-8), norm

In [None]:
data = torch.from_numpy(np.load("../Dataset/speed-28-delayangle-32.npy"))
data.shape

In [None]:
train_data = data[:700] # first 700 samples; each sample contains 64 correlated frames; each frame is 8 x 76
test_data = data[700:]

print(f"Training data: {train_data.shape}")
print(f"Testing data: {test_data.shape}")

In [None]:
train_data = torch.stack([train_data.real, train_data.imag], dim=2)
test_data = torch.stack([test_data.real, test_data.imag], dim=2)

print("Train shape:", train_data.shape, train_data.dtype)
print("Test shape:", test_data.shape, test_data.dtype)

In [None]:
train_data_norm, train_norm = normalize_complex_data(train_data)
test_data_norm, test_norm = normalize_complex_data(test_data)
train_data_norm.shape, train_norm.shape

### Model Definition

In [None]:
# CRNET Model
class ConvBN(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1):
        if not isinstance(kernel_size, int):
            padding = [(i - 1) // 2 for i in kernel_size]
        else:
            padding = (kernel_size - 1) // 2
        super(ConvBN, self).__init__(OrderedDict([
            ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride,
                               padding=padding, groups=groups, bias=False)),
            ('bn', nn.BatchNorm2d(out_planes))
        ]))


class CRBlock(nn.Module):
    def __init__(self):
        super(CRBlock, self).__init__()
        self.path1 = nn.Sequential(OrderedDict([
            ('conv3x3', ConvBN(2, 7, 3)),
            ('relu1', nn.LeakyReLU(negative_slope=0.3, inplace=True)),
            ('conv1x9', ConvBN(7, 7, [1, 9])),
            ('relu2', nn.LeakyReLU(negative_slope=0.3, inplace=True)),
            ('conv9x1', ConvBN(7, 7, [9, 1])),
        ]))
        self.path2 = nn.Sequential(OrderedDict([
            ('conv1x5', ConvBN(2, 7, [1, 5])),
            ('relu', nn.LeakyReLU(negative_slope=0.3, inplace=True)),
            ('conv5x1', ConvBN(7, 7, [5, 1])),
        ]))
        self.conv1x1 = ConvBN(7 * 2, 2, 1)
        self.identity = nn.Identity()
        self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True)

    def forward(self, x):
        identity = self.identity(x)

        out1 = self.path1(x)
        out2 = self.path2(x)
        out = torch.cat((out1, out2), dim=1)
        out = self.relu(out)
        out = self.conv1x1(out)

        out = self.relu(out + identity)
        return out
class CRNet(nn.Module):
    def __init__(self, reduction=4, input_size=(32,32)):
        super(CRNet, self).__init__()
        in_channel = 2

        self.encoder1 = nn.Sequential(OrderedDict([
            ("conv3x3_bn", ConvBN(in_channel, 2, 3)),
            ("relu1", nn.LeakyReLU(negative_slope=0.3, inplace=True)),
            ("conv1x9_bn", ConvBN(2, 2, [1, 9])),
            ("relu2", nn.LeakyReLU(negative_slope=0.3, inplace=True)),
            ("conv9x1_bn", ConvBN(2, 2, [9, 1])),
        ]))
        self.encoder2 = ConvBN(in_channel, 2, 3)
        self.encoder_conv = nn.Sequential(OrderedDict([
            ("relu1", nn.LeakyReLU(negative_slope=0.3, inplace=True)),
            ("conv1x1_bn", ConvBN(4, 2, 1)),
            ("relu2", nn.LeakyReLU(negative_slope=0.3, inplace=True)),
        ]))

        self.total_size = self._get_flattened_size(in_channel, input_size)

        self.encoder_fc = nn.Linear(self.total_size, self.total_size // reduction)
        self.decoder_fc = nn.Linear(self.total_size // reduction, self.total_size)

        decoder = OrderedDict([
            ("conv5x5_bn", ConvBN(2, 2, 5)),
            ("relu", nn.LeakyReLU(negative_slope=0.3, inplace=True)),
            ("CRBlock1", CRBlock()),
            ("CRBlock2", CRBlock())
        ])
        self.decoder_feature = nn.Sequential(decoder)
        # self.sigmoid = nn.Sigmoid()

        # Initialize weights (your original init code)
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _get_flattened_size(self, in_channel, input_size):
        with torch.no_grad():
            x = torch.zeros(1, in_channel, *input_size)
            encode1 = self.encoder1(x)
            encode2 = self.encoder2(x)
            out = torch.cat((encode1, encode2), dim=1)
            out = self.encoder_conv(out)
            flattened_size = out.view(1, -1).size(1)
            # print(flattened_size)
        return flattened_size

    def forward(self, x):
        n, c, h, w = x.size()
        encode1 = self.encoder1(x)
        encode2 = self.encoder2(x)
        out = torch.cat((encode1, encode2), dim=1)
        out = self.encoder_conv(out)
        out = self.encoder_fc(out.view(n, -1))
        out = self.decoder_fc(out).view(n, c, h, w)
        out = self.decoder_feature(out)
        # out = self.sigmoid(out)
        return out


def crnet(reduction=2):
    r""" Create a proposed CRNet.

    :param reduction: the reciprocal of compression ratio
    :return: an instance of CRNet
    """

    model = CRNet(reduction=reduction)
    return model

In [None]:
model = crnet(100).to(device)
model

In [None]:
# FlOps computation
d_input = torch.randn([1, 2, 32, 32]).to(device)
flops, params = thop.profile(model, inputs=(d_input,), verbose=False)
flops, params = thop.clever_format([flops, params], "%.3f")
flops

In [None]:
criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=5e-3)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3)

In [None]:

epochs=25
best_loss = float('inf')
model_path = "../Models-100/CRNet-99.pth"

for i in range(epochs):
    epoch_loss = 0.0
    model.train()

    for batch in train_data_norm:
        optimizer.zero_grad()
        batch = batch.to(device)

        recons_output = model(batch)

        loss = criterion(recons_output, batch)
        epoch_loss+= loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
    # ---- checkpointing ----
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), model_path)  # save only parameters
        print(f"Epoch {i}, Loss {epoch_loss:.4f} New best model saved")
    else:
        scheduler.step(epoch_loss)
        print(f"Epoch {i}, Loss {epoch_loss:.4f}")

    

### Testing the Model

In [None]:
# Load best model
# model_path = "CRNet-99.pth"
# best_model = crnet(100).to(device)  # rebuild same architecture
# best_model.load_state_dict(torch.load(model_path))

# model = best_model

# Load the polluted state_dict
polluted_state_dict = torch.load(model_path)

# Filter out keys with 'total_ops' or 'total_params'
clean_state_dict = {k: v for k, v in polluted_state_dict.items() if "total_ops" not in k and "total_params" not in k}

# Load only the clean keys
model = crnet(100).to(device)
model.load_state_dict(clean_state_dict)

In [None]:
def correlation(A, A_hat):
    
    batch_size = A.shape[0]
    A = A.cpu().detach().numpy()
    A_hat = A_hat.cpu().detach().numpy()
    
    corr = np.zeros(batch_size)
    for i in range(batch_size):
        In = A[i]
        Out = A_hat[i]
        
        l = []
        for j in range(32):
            n1 = np.sqrt(np.sum(np.conj(In[:, j])*In[:, j]))
            n2 = np.sqrt(np.sum(np.conj(Out[:, j])*Out[:, j]))
    
            num = np.abs(np.sum(np.conj(In[:, j])* Out[:, j]))

            l.append(num/ (n1*n2 + 1e-12))
            
        corr[i] = np.mean(l)

    return corr

In [None]:
def compute_NMSE(gt, pred):
    """
    Compute NMSE in dB between complex-valued ground truth and prediction.

    Args:
        gt: Ground truth tensor of shape [B, 2, H, W], normalized [0, 1]
        pred: Predicted tensor of same shape [B, 2, H, W]

    Returns:
        NMSE in dB (lower is better)
    """
    
    # De-centralize
    gt = gt - 0.5
    pred = pred - 0.5

    # Compute power of complex ground truth
    power_gt = gt[:, 0, :, :]**2 + gt[:, 1, :, :]**2

    # Compute squared error
    diff = gt - pred
    mse = diff[:, 0, :, :]**2 + diff[:, 1, :, :]**2

    # NMSE per sample
    nmse = mse.sum(dim=[1, 2]) / (power_gt.sum(dim=[1, 2]) + 1e-12)

    # Mean NMSE across batch and convert to dB
    nmse_db = 10 * torch.log10(nmse.mean())

    return nmse_db.item()

In [None]:
# Test data
model.eval()   # put model in eval mode
test_loss = 0.0
cosine_sim_col = []
model_output = []
nmse_db_total_test = 0.0

with torch.no_grad():   # disable gradient computation
    for batch in test_data_norm:
        batch = batch.to(device)

        recons_seq = model(batch)

        # Accumulate NMSE in dB
        nmse_db = compute_NMSE(batch, recons_seq)
        nmse_db_total_test += nmse_db
        
        # Converting it to complex again
        batch = torch.complex(batch[:, 0, :, :], batch[:, 1, :, :])
        recons_seq = torch.complex(recons_seq[:, 0, :, :], recons_seq[:, 1, :, :])

        model_output.append(recons_seq)

        corr = correlation(batch, recons_seq)

        cosine_sim_col.append(np.mean(corr))

        loss = (criterion(recons_seq.real, batch.real) +
                criterion(recons_seq.imag, batch.imag))
        test_loss += loss.item()

print(f"Test Loss: {test_loss:.4f}")
print(f"Cosine Similarity along columns Test Data (Normalized Data): {np.mean(cosine_sim_col):.4f}")

# Final average NMSE in dB
avg_nmse_db = nmse_db_total_test / len(test_data_norm)
print(f"Average NMSE (dB) (Normalized Data): {avg_nmse_db:.2f} dB")


m_output = torch.stack(model_output)
m_output = torch.stack([m_output.real, m_output.imag], dim=2)

m_output = m_output.cpu() * test_norm # Multiplying the constants back

nmse_db_total_test = 0.0

for i in range(test_data.shape[0]):
    nmse = compute_NMSE(test_data[i], m_output[i])
    nmse_db_total_test += nmse


m_output = torch.complex(m_output[:, :, 0, :, :], m_output[:, :, 1, :, :])
test_data = torch.complex(test_data[:, :, 0, :, :], test_data[:, :, 1, :, :])

cosine_sim_col = []
for i in range(test_data.shape[0]):
    corr = correlation(test_data[i], m_output[i])
    cosine_sim_col.append(np.mean(corr))

print(f"Cosine Similarity along columns Test Data (ReNormalized Data): {np.mean(cosine_sim_col):.4f}")
# Final average NMSE in dB
avg_nmse_db = nmse_db_total_test / len(test_data)
print(f"Average NMSE (dB) (ReNormalized Data): {avg_nmse_db:.2f} dB")


In [None]:
# Train data

model.eval()   # put model in eval mode
train_loss = 0.0
cosine_sim_col = []
model_output = []
nmse_db_total = 0

with torch.no_grad():   # disable gradient computation
    for batch in train_data_norm:
        batch = batch.to(device)

        recons_seq = model(batch)

        # Accumulate NMSE in dB
        nmse_db = compute_NMSE(batch, recons_seq)
        nmse_db_total += nmse_db
        
        # Converting it to complex again
        batch = torch.complex(batch[:, 0, :, :], batch[:, 1, :, :])
        recons_seq = torch.complex(recons_seq[:, 0, :, :], recons_seq[:, 1, :, :])

        model_output.append(recons_seq)

        corr = correlation(batch, recons_seq)

        cosine_sim_col.append(np.mean(corr))

        loss = (criterion(recons_seq.real, batch.real) +
                criterion(recons_seq.imag, batch.imag))
        train_loss += loss.item()

print(f"Train Loss: {train_loss:.4f}")
print(f"Cosine Similarity along columns Train Data: {np.mean(cosine_sim_col):.4f}")

# Final average NMSE in dB
avg_nmse_db = nmse_db_total / len(train_data_norm)

print(f"Average NMSE (dB): {avg_nmse_db:.2f} dB")


m_output = torch.stack(model_output)
m_output = torch.stack([m_output.real, m_output.imag], dim=2)
m_output = m_output.cpu() * train_norm # Multiplying the constants back

nmse_db_total_test = 0.0

for i in range(train_data.shape[0]):
    nmse = compute_NMSE(train_data[i], m_output[i])
    nmse_db_total_test += nmse


m_output = torch.complex(m_output[:, :, 0, :, :], m_output[:, :, 1, :, :])

train_data = torch.complex(train_data[:, :, 0, :, :], train_data[:, :, 1, :, :])

cosine_sim_col = []
for i in range(train_data.shape[0]):
    corr = correlation(train_data[i], m_output[i])
    cosine_sim_col.append(np.mean(corr))

print(f"Cosine Similarity along columns Train Data (ReNormalized Data): {np.mean(cosine_sim_col):.4f}")

# Final average NMSE in dB
avg_nmse_db = nmse_db_total_test / len(train_data)
print(f"Average NMSE (dB) (ReNormalized Data): {avg_nmse_db:.2f} dB")
