In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import scipy.io as sio
import matplotlib.pyplot as plt
import numpy as np
import os

# Define the loss function
criterion = nn.MSELoss()

# Define the compression ratios
compression_ratios = [4,8]

def calculate_mse(outputs, targets):
    mse = criterion(outputs, targets)
    return mse

def NMSE(outputs, inputs_resized):
    outputs_real = torch.reshape(outputs[:, 0, :, :], (outputs.size(0), -1))
    outputs_imag = torch.reshape(outputs[:, 1, :, :], (outputs.size(0), -1))
    outputs_comp = (outputs_real - 0.5) + 1j * (outputs_imag - 0.5)

    inputs_resized_real = torch.reshape(inputs_resized[:, 0, :, :], (inputs_resized.size(0), -1))
    inputs_resized_imag = torch.reshape(inputs_resized[:, 1, :, :], (inputs_resized.size(0), -1))
    inputs_resized_comp = (inputs_resized_real - 0.5) + 1j * (inputs_resized_imag - 0.5)

    mse = torch.mean(torch.abs(outputs_comp - inputs_resized_comp) ** 2, dim=1)
    power = torch.mean(torch.abs(inputs_resized_comp) ** 2, dim=1)

    nmse = 10 * torch.log10(torch.mean(mse / power))

    return nmse


import torch.nn as nn

class EncoderBlock(nn.Module):
    def __init__(self, compressed_dim, height, width):
        super(EncoderBlock, self).__init__()
        self.conv = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1)  # Adjusted number of input channels to 4
        self.bn = nn.BatchNorm2d(2)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.dense = nn.Linear(2 * height * width, compressed_dim)
        self.height = height
        self.width = width

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, compressed_dim, height, width):
        super(DecoderBlock, self).__init__()
        self.compressed_dim = compressed_dim
        self.height = height
        self.width = width
        self.dense = nn.Linear(compressed_dim, 2 * (height) * (width))
        self.conv1 = nn.Conv2d(2, 4, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(4)
        self.relu1 = nn.ReLU()
        self.conv2_upper = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
        self.bn2_upper = nn.BatchNorm2d(4)
        self.conv2_upper2 = nn.Conv2d(4, 8, kernel_size=3, stride=1, padding=1)
        self.bn2_upper2 = nn.BatchNorm2d(8)
        self.conv2_lower = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
        self.bn2_lower = nn.BatchNorm2d(4)
        self.conv2_lower2 = nn.Conv2d(4, 8, kernel_size=3, stride=1, padding=1)
        self.bn2_lower2 = nn.BatchNorm2d(8)
        self.conv3 = nn.Conv2d(16, 4, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(4)
        self.conv4 = nn.Conv2d(4, 2, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(2)
        self.reconstruction = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.dense(x)
        x = x.view(-1, 2, self.height, self.width)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        upper_branch = self.conv2_upper(x)
        upper_branch = self.bn2_upper(upper_branch)

        upper_branch = self.conv2_upper2(upper_branch)
        upper_branch = self.bn2_upper2(upper_branch)

        lower_branch = self.conv2_lower(x)
        lower_branch = self.bn2_lower(lower_branch)

        lower_branch = self.conv2_lower2(lower_branch)
        lower_branch = self.bn2_lower2(lower_branch)

        x = torch.cat([upper_branch, lower_branch], dim=1)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.reconstruction(x)

        x_reshaped = x.view(-1, 2, self.height, self.width)
        x = x + x_reshaped
        return x


class LightweightCNN(nn.Module):
    def __init__(self, compressed_dim, height, width):
        super(LightweightCNN, self).__init__()
        self.encoder = EncoderBlock(compressed_dim, height, width)
        self.decoder = DecoderBlock(compressed_dim, height, width)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


# Load the data
train_data = torch.from_numpy(sio.loadmat(r"C:\Users\mitra\OneDrive\Desktop\Model4 dataset\Model4 train dataset\UMi_LOS_V_freq_subband_train.mat")['x'])
test_data=torch.from_numpy(sio.loadmat(r"C:\Users\mitra\OneDrive\Desktop\Model4 dataset\Model4 test dataset\UMi_LOS_V_freq_subband_test.mat")['x'])
train_data, val_data = train_test_split(train_data, test_size=0.2, random_state=42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a new folder to save model parameters
save_folder = r"C:\Users\mitra\OneDrive\Desktop\Lightweight model parameters for different CRs\UMi_LOS_freq_subband"

for cr in compression_ratios:
    print(f"Training, validating, and testing for compression ratio 1/{cr}")

    def plot_train_loss(train_loss_list):
        train_loss = train_loss_list
        plt.plot(range(1, num_epochs + 1),train_loss)
        plt.xlabel('Epoch')
        plt.ylabel('(Training Loss)')
        plt.title('Training Loss vs. Epoch')
        plt.show()

    train_loss_list = []



    # Normalize the datasets
    max_abs_train = torch.max(torch.abs(train_data))
    max_abs_test = torch.max(torch.abs(test_data))
    max_abs_val = torch.max(torch.abs(val_data))

    train_data_normalized = train_data / max_abs_train * 0.5 + 0.5
    test_data_normalized = test_data / max_abs_test * 0.5 + 0.5
    val_data_normalized = val_data / max_abs_val * 0.5 + 0.5

    # Define the dimensions
    compressed_dim = int(2 * 32 * 32 * (1 / cr))
    height = 32
    width = 32
    model = LightweightCNN(compressed_dim, height, width)
    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Split the data into train and validation sets
    # train_data, val_data = train_test_split(train_data_normalized, test_size=0.2)
    compressed_train_data = train_data_normalized.view(-1, 2, 32, 32).to(torch.float32)
    compressed_val_data = val_data_normalized.view(-1, 2, 32, 32).to(torch.float32)

    # Create DataLoaders for training and validation data
    train_loader = DataLoader(compressed_train_data, batch_size=100, shuffle=True)
    val_loader = DataLoader(compressed_val_data, batch_size=100)

    # Train the model
    model.to(device)
    model.train()

    num_epochs = 250  # Increase the number of epochs for better training

    for epoch in range(num_epochs):
        running_loss = 0.0

        for inputs in train_loader:
            inputs = inputs.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)

            # Resize the input tensor to match the output shape
            inputs_resized = inputs[:, :2, :, :]

            # Compute the MSE loss
            loss = criterion(outputs, inputs_resized)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Update the running loss
            running_loss += loss.item() * inputs.size(0)

        # Calculate the average loss for training set
        train_loss = running_loss / len(train_loader.dataset)
        train_loss_list.append(train_loss)

        # Validate the model
        model.eval()
        running_loss = 0.0

        with torch.no_grad():
            for val_inputs in val_loader:
                val_inputs = val_inputs.to(device)

                # Forward pass
                val_outputs = model(val_inputs)

                # Resize the input tensor to match the output shape
                val_inputs_resized = val_inputs[:, :2, :, :]

                # Compute the MSE loss
                val_loss = criterion(val_outputs, val_inputs_resized)

                # Update the running loss
                running_loss += val_loss.item() * val_inputs.size(0)

        # Calculate the average loss for validation set
        val_loss = running_loss / len(val_loader.dataset)

        # Print the epoch and loss for training and validation sets
        print(f"Epoch {epoch + 1}/{num_epochs} - Train Loss: {train_loss} - Val Loss: {val_loss}")
    plot_train_loss(train_loss_list)
    # Compress the test data
    # compressed_test_data = test_data[:, :, ::cr, ::cr]
    compressed_test_data = test_data_normalized.view(-1, 2, 32, 32).to(torch.float32)

    # Create DataLoader for test data
    test_loader = DataLoader(compressed_test_data, batch_size=100)
    # Convert the model's parameters to the same data type as the input data
    # model.to(inputs.dtype)
    # Test the model
    model.eval()
    running_loss = 0.0
    running_loss_mse = 0.0

    with torch.no_grad():
        for inputs in test_loader:
            inputs = inputs.to(device)

            # Forward pass
            outputs = model(inputs)

            # Resize the input tensor to match the output shape
            inputs_resized = inputs[:, :2, :, :]

            # Compute the MSE loss
            loss = NMSE(outputs, inputs_resized)
            loss_mse = criterion(outputs, inputs_resized)

            # Update the running loss
            running_loss += loss.item() * inputs.size(0)
            running_loss_mse += loss_mse.item() * inputs.size(0)

    # Calculate the average loss for test set
    test_loss = running_loss / len(test_loader.dataset)
    test_loss_mse = running_loss_mse / len(test_loader.dataset)

    # Print the test loss
    print(f"nmse: {test_loss}")
    print(f"Test Loss: {test_loss_mse}")

    # Save the model/weights
    save_path = os.path.join(save_folder, f"CR_{cr}")
    os.makedirs(save_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_path, "model_weights.pth"))
