# Imports

In [None]:
from google.colab import drive
drive_path ='/content/gdrive/Shareddrives/Model_Based_DL/' # TODO - UPDATE ME!
drive.mount('/content/gdrive', force_remount=True)

# System_Model

In [None]:
import numpy as np

class System_model(object):
    def __init__(self, scenario:str , N:int, M:int, freq_values = None):
        self.scenario = scenario                                    # Narrowband or Broadband
        self.N = N                                                  # Number of sensors in element
        self.M = M                                                  # Number of sources
        self.scenario_define(freq_values)                           # Define parameters
        self.create_array()                                         # Define array indicies

    def scenario_define(self, freq_values):
        if self.scenario.startswith("Broadband"):
            ## frequencies initialization ##
            self.min_freq = freq_values[0]   # Define minimal frequency value
            self.max_freq = freq_values[1]   # Define maximal frequency value
            self.f_rng = np.linspace(start=self.min_freq, stop=self.max_freq,
                                     num=self.max_freq - self.min_freq,
                                     endpoint = False)                    # Frequency range of interest
            self.f_sampling = 2 * (self.max_freq)                   # Define sampling rate as twice the maximal frequency
            self.time_axis = np.linspace(0, 1, self.f_sampling, endpoint = False)                      # Define time axis

            ## Array initialization ##
            self.dist = 1 / (2 * self.max_freq)                     # distance between array elements
            # self.dist = 1 / (self.max_freq - self.min_freq)           # distance between array elements

        elif self.scenario.startswith("NarrowBand"):
            ## frequencies initialization ##
            self.min_freq = None
            self.max_freq = None
            self.f_rng = None
            self.fs = None

            ## Array initialization ##
            self.dist = 1 / 2                                       # distance between array elements
        else:
            raise Exception("Scenario: {} is not defined".format(self.scenario))

    def create_array(self):
        self.array = np.linspace(0, self.N, self.N, endpoint = False)   # create array of sensors locations

    def SV_Creation(self, theta, f=1, Array_form= "ULA"):
        if self.scenario == "NarrowBand": f = 1
        if Array_form == "ULA":
            return np.exp(-2 * 1j * np.pi * f * self.dist * self.array * np.sin(theta))

    def __str__(self):
        print("System Model Summery:")
        for key,value in self.__dict__.items():
            print (key, " = " ,value)
        return "End of Model"

Signal_creation

In [None]:
import numpy as np
from matplotlib import pyplot as plt
# from System_Model import *

def create_DOA_with_gap(M, gap):
    while(True):
        DOA = np.round(np.random.rand(M) *  180 ,decimals = 2) - 90.00
        DOA.sort()
        difference_between_angles = np.array([np.abs(DOA[i+1] - DOA[i]) for i in range(M-1)])
        if(np.sum(difference_between_angles > gap) == M - 1 and np.sum(difference_between_angles < (180 - gap)) == M - 1):
            break
    return DOA

def create_closely_spaced_DOA(M, gap):
    if (M == 2):
        first_DOA = np.round(np.random.rand(1) *  180 ,decimals = 2) - 90.00
        second_DOA = ((first_DOA + gap + 90 ) % 180) - 90
        return np.array([first_DOA, second_DOA])
    DOA = [np.round(np.random.rand(1) *  180 ,decimals = 2) - 90.00]
    while(len(DOA) < M):
        candidate_DOA = np.round(np.random.rand(1) *  180 ,decimals = 2) - 90.00
        difference_between_angles = np.array([np.abs(candidate_DOA - DOA[i]) for i in range(len(DOA))])
        if(np.sum(difference_between_angles < gap) == len(DOA) or np.sum((180 - difference_between_angles) < gap) == len(DOA)):
            DOA.append(candidate_DOA)
    return np.array(DOA)

class Samples(System_model):
    def __init__(self, scenario:str , N:int, M:int,
                 DOA:list, observations:int, freq_values:list = None):
        super().__init__(scenario, N, M, freq_values)
        self.T = observations
        if DOA == None:
          self.DOA = (np.pi / 180) * np.array(create_DOA_with_gap(M = self.M, gap = 15)) # (~0.2 rad)
        else:
          self.DOA = (np.pi / 180) * np.array(DOA)                              # define DOA angels

    def samples_creation(self, mode, N_mean= 0, N_Var= 1, S_mean= 0, S_Var= 1, SNR= 10):
        '''
        @mode = represent the specific mode in the specific scenario
                e.g. "Broadband" scenario in "non-coherent" mode
        '''

        if self.scenario.startswith("NarrowBand"):

            if self.M == 0:
              signal = 0

            else:
              signal = self.signal_creation(mode, S_mean, S_Var, SNR)

            noise = self.noise_creation(N_mean, N_Var)
            A = np.array([self.SV_Creation(theta) for theta in self.DOA]).T

            samples = (A @ signal) + noise
            return samples, signal, A, noise

        elif self.scenario.startswith("Broadband"):
            samples = []
            SV = []
            f_axis = []

            signal = self.signal_creation(mode, S_mean, S_Var, SNR)
            noise = self.noise_creation(N_mean, N_Var)

            # TODO: check if the data creation became much slower

            for idx in range(self.f_sampling):

                # mapping from index i to frequency f
                if idx > int(self.f_sampling) // 2:
                    f = - int(self.f_sampling) + idx
                else:
                    f = idx
                A = np.array([self.SV_Creation(theta, f) for theta in self.DOA]).T
                samples.append((A @ signal[:, idx]) + noise[:, idx])
                # samples.append((A @ signal[:, idx % (int(self.f_sampling) // 2)]) + noise[:, idx])
                # samples.append((A @ signal[:, f]) + noise[:, idx])
                # samples.append((A @ signal[:, np.abs(f)]))
                SV.append(A)
                f_axis.append(f)
            samples = np.array(samples)
            SV = np.array(SV)
            samples_time_domain = np.fft.ifft(samples.T, axis=1)[:, :self.T]
            return samples_time_domain, signal, SV, noise

    def noise_creation(self, N_mean, N_Var):
        # for NarrowBand scenario Noise represented in the time domain
        if self.scenario.startswith("NarrowBand"):
            return np.sqrt(N_Var) * (np.sqrt(2) / 2) * (np.random.randn(self.N, self.T) + 1j * np.random.randn(self.N, self.T)) + N_mean

        # for Broadband scenario Noise represented in the frequency domain
        elif self.scenario.startswith("Broadband"):
            noise = np.sqrt(N_Var) * (np.sqrt(2) / 2) * (np.random.randn(self.N, len(self.time_axis)) + 1j * np.random.randn(self.N, len(self.time_axis))) + N_mean
            return np.fft.fft(noise)

    def signal_creation(self, mode:str, S_mean = 0, S_Var = 1, SNR = 10):
        '''
        @mode = represent the specific mode in the specific scenario
                e.g. "Broadband" scenario in "non-coherent" mode
        '''
        amplitude = (10 ** (SNR / 10))
        ## NarrowBand signal creation
        if self.scenario == "NarrowBand":
            if mode == "non-coherent":
                # create M non-coherent signals
                return amplitude * (np.sqrt(2) / 2) * np.sqrt(S_Var) * (np.random.randn(self.M, self.T) + 1j * np.random.randn(self.M, self.T)) + S_mean

            elif mode == "coherent":
                # Coherent signals: same amplitude and phase for all signals
                sig = amplitude * (np.sqrt(2) / 2) * np.sqrt(S_Var) * (np.random.randn(1, self.T) + 1j * np.random.randn(1, self.T)) + S_mean
                return np.repeat(sig, self.M, axis = 0)


        ## Broadband signal creation
        if self.scenario.startswith("Broadband_simple"):
            # generate M random carriers
            carriers = np.random.choice(self.f_rng, self.M).reshape((self.M, 1))

            # create M non-coherent signals
            if mode == "non-coherent":
                carriers_amp = amplitude * (np.sqrt(2) / 2) * (np.random.randn(self.M) + 1j * np.random.randn(self.M))
                carriers_signals = carriers_amp * np.exp(2 * np.pi * 1j * carriers @ self.time_axis.reshape((1, len(self.time_axis)))).T
                return np.fft.fft(carriers_signals.T)

            # Coherent signals: same amplitude and phase for all signals
            if mode == "coherent":
                carriers_amp = amplitude * (np.sqrt(2) / 2) * (np.random.randn(1) + 1j * np.random.randn(1))
                carriers_signals = carriers_amp * np.exp(2 * np.pi * 1j * carriers[0] * self.time_axis)
                return np.tile(np.fft.fft(carriers_signals), (self.M, 1))

        ## Broadband signal creation
        if self.scenario.startswith("Broadband_OFDM"):
            num_sub_carriers = self.max_freq   # number of subcarriers per signal
            # create M non-coherent signals
            signal = np.zeros((self.M, len(self.time_axis))) + 1j * np.zeros((self.M, len(self.time_axis)))
            if mode == "non-coherent":
                for i in range(self.M):
                    for j in range(num_sub_carriers):
                        sig_amp = amplitude * (np.sqrt(2) / 2) * (np.random.randn(1) + 1j * np.random.randn(1))
                        signal[i] += sig_amp * np.exp(1j * 2 * np.pi * j * len(self.f_rng) * self.time_axis / num_sub_carriers)
                    signal[i] *=  (1/num_sub_carriers)
                return np.fft.fft(signal)

            # Coherent signals: same amplitude and phase for all signals
            signal = np.zeros((1, len(self.time_axis))) + 1j * np.zeros((1, len(self.time_axis)))
            if mode == "coherent":
                for j in range(num_sub_carriers):
                    sig_amp = amplitude * (np.sqrt(2) / 2) * (np.random.randn(1) + 1j * np.random.randn(1))
                    signal += sig_amp * np.exp(1j * 2 * np.pi * j * len(self.f_rng) * self.time_axis / num_sub_carriers)
                signal *=  (1/num_sub_carriers)
                return np.tile(np.fft.fft(signal), (self.M, 1))

        else:
            return 0



DataLoaderCreation - Dor-Aviv

## With DOA

In [None]:
import torch
import numpy as np
# from System_Model import System_model
# from Signal_creation import Samples
from tqdm import tqdm
import random
import h5py

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def CreateDataSetCombined_with_DOA(scenario, mode, N, M, T, Sampels_size, tau, Save=False, DataSet_path= None, True_DOA = None, SNR = 10):
    '''
    @Scenario = "NarrowBand" or "BroadBand"
    @mode = "coherent", "non-coherent"
    '''
    DataSet = []
    DataSetRx = []
    print("Updated")
    for i in tqdm(range(Sampels_size)):
        # # System Model Initialization
        # Sys_Model = System_model(scenario= scenario, N= N, M= M)

        # # Samples Creation - Model Initialization
        # sys_model_samples = Samples(Sys_Model, DOA= True_DOA, observations=T)

        # Samples Creation - Model Initialization

        M = np.random.randint(1, N) # (1,N)
        Sys_Model = Samples(scenario= scenario, N= N, M= M, DOA= True_DOA, observations=T, freq_values=[0, 500])
        X = torch.tensor(Sys_Model.samples_creation(mode = mode, N_mean= 0,
                                                    N_Var= 1, S_mean= 0, S_Var= 1,
                                                    SNR= SNR)[0], dtype=torch.complex64)                   # Samples Creation

        Y = torch.tensor(M, dtype=torch.float64)
        Z = torch.tensor(Sys_Model.DOA, dtype=torch.float64)


        # ------------------ Padding DOA data beacuse samples have different lengths. different M ---> different vetcor size of Z -------------- #

        pad_len = N-1
        Z_padded = torch.zeros(pad_len, dtype=Z.dtype)
        Z_padded[:len(Z)] = Z


        DataSet.append((X,Y,Z_padded))

        New_Rx_tau = Create_Autocorr_tensor_for_data_loader(X, tau).to(torch.float)
        DataSetRx.append((New_Rx_tau,Y))

    if Save:
        filename_x = "DataSet_x_{}_{}_{}_M={}_N={}_T={}_SNR={}.h5".format(scenario, mode, Sampels_size, M, N, T, SNR)
        filename_rx = "DataSet_Rx_{}_{}_{}_M={}_N={}_T={}_SNR={}.h5".format(scenario, mode, Sampels_size, M, N, T, SNR)
        filename_sys = "Sys_Model_{}_{}_{}_M={}_N={}_T={}_SNR={}.h5".format(scenario, mode, Sampels_size, M, N, T, SNR)

        # Optionally ensure the directory exists
        os.makedirs(DataSet_path, exist_ok=True)

        # Save to proper paths
        torch.save(obj=DataSet,   f=os.path.join(DataSet_path, filename_x))
        torch.save(obj=DataSetRx, f=os.path.join(DataSet_path, filename_rx))
        torch.save(obj=Sys_Model, f=os.path.join(DataSet_path, filename_sys))

        # torch.save(obj= DataSet, f=DataSet_path + '/DataSet_x_{}_{}_{}_M={}_N={}_T={}_SNR={}'.format(scenario, mode, Sampels_size, M, N, T, SNR) + '.h5')
        # torch.save(obj= DataSetRx, f=DataSet_path + '/DataSet_Rx_{}_{}_{}_M={}_N={}_T={}_SNR={}'.format(scenario, mode, Sampels_size, M, N, T, SNR) + '.h5')
        # torch.save(obj= Sys_Model, f=DataSet_path + '/Sys_Model_{}_{}_{}_M={}_N={}_T={}_SNR={}'.format(scenario, mode, Sampels_size, M, N, T, SNR) + '.h5')

    return DataSet, DataSetRx ,Sys_Model

## With Steering Matrix

In [None]:
def CreateDataSetCombined_with_Steering_Matrix(args, scenario, mode, N, M, T, Sampels_size, tau, Save=False, DataSet_path= None, True_DOA = None, SNR = 10):
    '''
    @Scenario = "NarrowBand" or "BroadBand"
    @mode = "coherent", "non-coherent"
    '''
    DataSet = []
    DataSetRx = []
    print("Updated")
    for i in tqdm(range(Sampels_size)):
        # # System Model Initialization
        # Sys_Model = System_model(scenario= scenario, N= N, M= M)

        # # Samples Creation - Model Initialization
        # sys_model_samples = Samples(Sys_Model, DOA= True_DOA, observations=T)

        # Samples Creation - Model Initialization

        M = np.random.randint(1, N) # (1,N)

        SNR_list = [0,2.5,5,7.5,10]
        if args.TRAIN_MODE and args.Mixed_SNR_in_train:
          SNR_idx = np.random.randint(len(SNR_list))
          SNR = SNR_list[SNR_idx]



        Sys_Model = Samples(scenario= scenario, N= N, M= M, DOA= True_DOA, observations=T, freq_values=[0, 500])
        X,_,A,_ = Sys_Model.samples_creation(mode = mode, N_mean= 0,
                                                    N_Var= 1, S_mean= 0, S_Var= 1,
                                                    SNR= SNR) # , dtype=torch.complex64)                   # Samples Creation

        X = torch.tensor(X, dtype=torch.complex64)
        A = torch.tensor(A, dtype=torch.complex64)

        Y = torch.tensor(M, dtype=torch.float64)
        # Z = torch.tensor(Sys_Model.DOA, dtype=torch.float64)


        # ------------------ Padding A(theta) data beacuse samples have different lengths. different M ---> different vetcor size of Z -------------- #

        pad_len = N-1
        A_padded = torch.zeros((N,pad_len), dtype=A.dtype)
        A_padded[:,:A.shape[1]] = A

        # print(f'\n\n A --- {A}\n\n')
        # print(f'\n\n A_padded --- {A_padded}\n\n')


        DataSet.append((X,Y,A_padded))

        New_Rx_tau = Create_Autocorr_tensor_for_data_loader(X, tau).to(torch.float)
        DataSetRx.append((New_Rx_tau,Y))

    if Save:
        filename_x = "DataSet_x_{}_{}_{}_M={}_N={}_T={}_SNR={}.h5".format(scenario, mode, Sampels_size, M, N, T, SNR)
        filename_rx = "DataSet_Rx_{}_{}_{}_M={}_N={}_T={}_SNR={}.h5".format(scenario, mode, Sampels_size, M, N, T, SNR)
        filename_sys = "Sys_Model_{}_{}_{}_M={}_N={}_T={}_SNR={}.h5".format(scenario, mode, Sampels_size, M, N, T, SNR)

        # Optionally ensure the directory exists
        os.makedirs(DataSet_path, exist_ok=True)

        # Save to proper paths
        torch.save(obj=DataSet,   f=os.path.join(DataSet_path, filename_x))
        torch.save(obj=DataSetRx, f=os.path.join(DataSet_path, filename_rx))
        torch.save(obj=Sys_Model, f=os.path.join(DataSet_path, filename_sys))

        # torch.save(obj= DataSet, f=DataSet_path + '/DataSet_x_{}_{}_{}_M={}_N={}_T={}_SNR={}'.format(scenario, mode, Sampels_size, M, N, T, SNR) + '.h5')
        # torch.save(obj= DataSetRx, f=DataSet_path + '/DataSet_Rx_{}_{}_{}_M={}_N={}_T={}_SNR={}'.format(scenario, mode, Sampels_size, M, N, T, SNR) + '.h5')
        # torch.save(obj= Sys_Model, f=DataSet_path + '/Sys_Model_{}_{}_{}_M={}_N={}_T={}_SNR={}'.format(scenario, mode, Sampels_size, M, N, T, SNR) + '.h5')

    return DataSet, DataSetRx ,Sys_Model

# utils

In [None]:
import numpy as np
import torch
import random
import torch.nn as nn


def sum_of_diag(Matrix):
    coeff = []
    diag_index = np.linspace(-Matrix.shape[0] + 1, Matrix.shape[0] + 1, 2 * Matrix.shape[0] - 1, endpoint = False, dtype = int)
    for idx in diag_index:
        coeff.append(np.sum(Matrix.diagonal(idx)))
    return coeff

def find_roots(coeff):
    coeff = np.array(coeff)
    A = np.diag(np.ones((len(coeff)-2,), coeff.dtype), -1)
    A[0,:] = -coeff[1:] / coeff[0]
    roots = np.array(np.linalg.eigvals(A))
    return roots

def Set_Overall_Seed(SeedNumber = 42):
    random.seed(SeedNumber)
    np.random.seed(SeedNumber)
    torch.manual_seed(SeedNumber)

class L2NormLayer(nn.Module):
    def __init__(self, dim=(1, 2), eps=1e-6):
        super(L2NormLayer, self).__init__()
        self.dim = dim
        self.eps = eps

    def forward(self, x):
        return torch.nn.functional.normalize(x, p=2, dim=self.dim, eps=self.eps) + self.eps * torch.diag(torch.ones(x.shape[-1], device=x.device))

# Model Dor-Aviv

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import warnings
warnings.simplefilter("ignore")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Deep_Model_Order_Selectiton_Net(nn.Module):
    def __init__(self, args):

        super(Deep_Model_Order_Selectiton_Net, self).__init__()

        self.args = args
        self.device = args.device
        self.N = args.N
        self.T = args.T
        self.tau = args.tau
        self.penalty_type = args.penalty_type
        self.batch_size = args.batch_size
        self.residual_coeff = args.residual_coeff_coherent if args.mode == "coherent" else args.residual_coeff_non_coherent

        self.conv1 = nn.Conv2d(self.tau, 16, kernel_size = 2)
        self.conv2 = nn.Conv2d(32, 32, kernel_size = 2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size = 2)
        # self.conv4 = nn.Conv2d(64, 128, kernel_size = 2)
        self.extra_conv4 = nn.Identity() # initialize as identity, set up in the __setup_big_ssn
        self.extra_deconv1 = nn.Identity() # initialize as identity, set up in the __setup_big_ssn

        self.deconv2 = nn.ConvTranspose2d(128, 32, kernel_size= 2)
        self.deconv3 = nn.ConvTranspose2d(64, 16, kernel_size= 2)
        self.deconv4 = nn.ConvTranspose2d(32, 1, kernel_size= 2)
        self.ReLU = nn.ReLU()
        # self.SeLU = nn.SELU()
        self.LeakyReLU = nn.LeakyReLU(args.ActivationVal)
        # self.Tanh = nn.Tanh()
        self.DropOut = nn.Dropout(0.2)
        self.norm_layer = L2NormLayer()
        self.norm_layer_2 = L2NormLayer()
        self.order_classifier = nn.Sequential(

                            nn.Linear(self.N - 1, 64),
                            nn.ReLU(),
                            # nn.Dropout(0.3),
                            nn.Linear(64, 32),
                            nn.ReLU(),
                            # nn.Dropout(0.3),
                            nn.Linear(32, self.N - 1)  # Outputs logits for m ∈ {1,...,N-1}
                        )

        # self.order_classifier = nn.Sequential(

        #                   nn.Linear(self.N - 1, 128),
        #                   nn.LayerNorm(128),         # LayerNorm is great for non-CNN, variable batch size
        #                   nn.ReLU(),
        #                   nn.Dropout(0.3),

        #                   nn.Linear(128, 64),
        #                   nn.LayerNorm(64),
        #                   nn.ReLU(),
        #                   nn.Dropout(0.3),

        #                   nn.Linear(64, 32),
        #                   nn.LayerNorm(32),
        #                   nn.ReLU(),
        #                   nn.Dropout(0.2),

        #                   nn.Linear(32, self.N - 1)  # Final logits
        #               )



    def sum_of_diags(self, Matrix):
        coeff =[]
        diag_index = torch.linspace(-Matrix.shape[0] + 1, Matrix.shape[0] - 1, (2 * Matrix.shape[0]) - 1, dtype = int)
        for idx in diag_index:
            coeff.append(torch.sum(torch.diagonal(Matrix, idx)))
        return torch.stack(coeff, dim = 0)

    def find_roots(self, coeff):
        A_torch = torch.diag(torch.ones(len(coeff)-2,dtype=coeff.dtype), -1)
        A_torch[0,:] = -coeff[1:] / coeff[0]
        roots = torch.linalg.eigvals(A_torch)
        return roots

    def AntiRectifier(self, X):
        return torch.cat((self.ReLU(X), self.ReLU(-X)), 1)


    def hypothesis_testing(self, m, eigenvalues):

        # loss = -T * torch.sum(torch.log(eigenvalues[:, m:]), dim=1) + T * (N - m) * torch.log(torch.mean(eigenvalues[:, m:], dim=1))

        loss = -self.T * torch.sum(torch.log(eigenvalues[:, m:]), dim=1) + self.T * (self.N - m) * torch.log(torch.mean(eigenvalues[:, m:], dim=1))

        Pm = (2 * self.N * m - m ** 2 + 1)
        if self.penalty_type =="mdl":
            penalty = Pm * np.log(self.T) /2

        else:
            penalty = Pm * 2 ############ okay ? check with arad

        res = loss + penalty

        return res


    def Model_Order_Selection(self,Rx,M_true = None):

        """
          Rx: shape (batch_size, N, N)
        """

        batch_size = Rx.shape[0]

        # Eigendecomposition for each matrix in batch

        eigenvalues, _ = torch.linalg.eigh(Rx)
        descending_indices = torch.argsort(eigenvalues, dim=1, descending=True)
        eigenvalues = torch.gather(eigenvalues, dim=1, index=descending_indices)       #size(batch_size,N)

        # print(f'\n\n values before Norm ---- {eigenvalues}\n\n')
        #------------------- Normalize eigenvalues --------------------- #
        if args.normalize_eigenvalues:
          eigenvalues_max = eigenvalues.max(dim=1, keepdim=True).values  # shape (batch_size, 1)

          eigenvalues = eigenvalues / eigenvalues_max

          # print(f'\n\n values After Norm ---- {eigenvalues}\n\n')
        #--------------------------------------------------------------- #


        # print(f'{eigenvalues}\n {M_true}')

        # for eignvalue, m_true in zip(eigenvalues, M_true):
        #   print(f'{eignvalue}\n')
        #   print(f'{m_true}\n')
        #   print(f'--------------------\n')


        # Initialize vector

        test_values = torch.zeros(batch_size, self.N - 1)  # N-1

        for m in range(1, self.N):  # N-1, N
          test_values[:,m - 1] = self.hypothesis_testing(m,eigenvalues)  # m+1

        M_est_batch = torch.argmin(test_values, dim=1) + 1  #size(batch_size,1)

        # print(f'test_values - {test_values}\n\n')

        # if training mode - compute the loss as the hypothesis test for the label (M_true)

        if M_true is not None:
          l_eig = 0
          M_true = M_true.long()  # Convert once
          for idx_in_batch,m_true in enumerate(M_true):
             temp = test_values[idx_in_batch,m_true - 1] # -1
             raw_loss = torch.sum(temp - test_values[idx_in_batch, :])
             l_eig += temp # raw_loss, temp
        else:
          l_eig = None

        return M_est_batch, l_eig, test_values



    def Root_MUSIC(self, Rz, M_True,N):

        # print(f'\n\nM_true shape --- {M_True.size()}\n\n')

        dist = 0.5
        f = 1
        DOA_list = []
        DOA_all_list = []

        Un_list = []

        Bs_Rz = Rz
        for iter in range(len(M_True)):  #(self.batch_size):

            M = int(M_True[iter])
            R = Bs_Rz[iter]
            eigenvalues, eigenvectors = torch.linalg.eig(R)                                         # Find the eigenvalues and eigenvectors using EVD
            Un = eigenvectors[:, torch.argsort(torch.abs(eigenvalues)).flip(0)][:, M:]

            Un_list.append(Un)

            # Un = eigenvectors[:, M:]
            F = torch.matmul(Un, torch.t(torch.conj(Un)))                                           # Set F as the matrix conatains Information,
            coeff = self.sum_of_diags(F)                                                            # Calculate the sum of the diagonals of F
            roots = self.find_roots(coeff)                                                          # Calculate its roots

            roots_angels_all = torch.angle(roots)                                                   # Calculate the phase component of the roots
            DOA_pred_all = torch.arcsin((1/(2 * np.pi * dist * f)) * roots_angels_all)              # Calculate the DOA our of the phase component
            DOA_all_list.append(DOA_pred_all)
            roots_to_return = roots

            roots = roots[sorted(range(roots.shape[0]), key = lambda k : abs(abs(roots[k]) - 1))]   # Take only roots which are outside unit circle
            roots_angels = torch.angle(roots)                                                       # Calculate the phase component of the roots
            DOA_pred_test = torch.arcsin((1/(2 * np.pi * dist * f)) * roots_angels)                 # Calculate the DOA our of the phase component
            mask = (torch.abs(roots) - 1) < 0

            roots = roots[mask][:M]
            roots_angels = torch.angle(roots)                                                       # Calculate the phase component of the roots
            DOA_pred = torch.arcsin((1/(2 * np.pi * dist * f)) * roots_angels)                      # Calculate the DOA our of the phase component




            ########    Padding DOA_prediction   #######

            pad_len = N-1
            DOA_pred_padd = torch.zeros(pad_len, dtype=DOA_pred.dtype)
            DOA_pred_padd[:len(DOA_pred)] = DOA_pred


            ################

            ########    Padding Un_prediction   #######

            # pad_len = N-1








            ###### DOA_list.append(DOA_pred)
            DOA_list.append(DOA_pred_padd)


                                                                       # Convert from radians to Deegres

            eigenvalues = torch.real(eigenvalues) / torch.max(torch.real(eigenvalues))
            # eigenvalues = torch.real(eigenvalues)
            norm_eig = torch.flip(torch.sort(eigenvalues)[0], (0,))
            # eig_diffs.append((norm_eig[0] - norm_eig)[1])
            minimal_signal_eig = norm_eig[M-1] - norm_eig[-1]
            maximal_noise_eig = norm_eig[M] - norm_eig[-1]
            # print(eigenvalues)
            # print(norm_eig[M-1] - norm_eig[-1], norm_eig[M] - norm_eig[-1])


        return torch.stack(DOA_list, dim = 0), Un_list
        #return torch.stack(DOA_list, dim = 0), torch.stack(DOA_all_list, dim = 0), roots_to_return, minimal_signal_eig, maximal_noise_eig





    def Gramian_matrix(self, Kx, eps):
        '''
        multiply a Matrix Kx with its Hermitian Conjecture,
        and adds eps to diagonal Value of the Matrix,
        In order to Ensure Hermit and PSD:
        Kx = (Kx)^H @ (Kx) + eps * I
        @ Kx(input) - Complex matrix with shape [BS, N, N]
        @ eps(input) - Multiplies constant added to each diangonal
        @ Kx_Out - Hermit and PSD matrix with shape [BS, N, N]
        '''
        Kx_list = []
        Bs_kx = Kx
        for iter in range(self.BATCH_SIZE):
            K = Bs_kx[iter]
            Kx_garm = torch.matmul(torch.t(torch.conj(K)), K).to(device)                                       # output size(NxN)
            eps_Unit_Mat = (eps * torch.diag(torch.ones(Kx_garm.shape[0]))).to(device)
            Rz = Kx_garm + eps_Unit_Mat                                                             # output size(NxN)
            Kx_list.append(Rz)
        Kx_Out = torch.stack(Kx_list, dim = 0)
        return Kx_Out

    def gram_diagonal_overload(self, Kx: torch.Tensor, eps: float):

        """Multiply a matrix Kx with its Hermitian conjecture (gram matrix),
            and adds eps to the diagonal values of the matrix,
            ensuring a Hermitian and PSD (Positive Semi-Definite) matrix.

        Args:
        -----
            Kx (torch.Tensor): Complex matrix with shape [BS, N, N],
                where BS is the batch size and N is the matrix size.
            eps (float): Constant added to each diagonal element.

        Returns:
        --------
            torch.Tensor: Hermitian and PSD matrix with shape [BS, N, N].

        """
        # Insuring Tensor input
        if not isinstance(Kx, torch.Tensor):
            Kx = torch.tensor(Kx)
        Kx = Kx.to(device)

        # Kx_garm = torch.matmul(torch.transpose(Kx.conj(), 1, 2).to("cpu"), Kx.to("cpu")).to(device)
        Kx_garm = torch.bmm(Kx.conj().transpose(1, 2), Kx)
        eps_addition = (eps * torch.diag(torch.ones(Kx_garm.shape[-1]))).to(device)
        Kx_Out = Kx_garm + eps_addition

        # check if the matrix is Hermitian - A^H = A
        mask = (torch.abs(Kx_Out - Kx_Out.conj().transpose(1, 2)) > 1e-6)
        if mask.any():
            batch_mask = mask.any(dim=(1,2))
            warnings.warn(f"gram_diagonal_overload: {batch_mask.sum()} matrices in the batch aren't hermitian, taking the average of R and R^H.")
            Kx_Out[batch_mask] = 0.5 * (Kx_Out[batch_mask] + Kx_Out[batch_mask].conj().transpose(1, 2))

        return Kx_Out

    def forward(self, X,M_true = None):  # (self, New_Rx_tau,M_true = None)
        ## Input shape of signal X(t): [Batch size, N, T]

        #----------- Take Original Covariance Matrix ------------- #
        Rx_sample_covariance = X @ X.conj().transpose(-2, -1) / self.T
        if args.norm_Rx:
          Rx_sample_covariance = self.norm_layer_2(Rx_sample_covariance)
        #----------------------------------------------------------#

        x0 = self.pre_processing(X)
        # Rx_tau shape: [Batch size, tau, 2N, N]
        # N = x.shape[-1]
        batch_size, _, _, N = x0.shape
        ############################
        ## Architecture flow ##
        # CNN block #1
        x1 = self.conv1(x0) # Shape: [Batch size, 16, 2N-1, N-1]
        x = self.AntiRectifier(x1) # Shape: [Batch size, 32, 2N-1, N-1]
        # CNN block #2
        x2 = self.conv2(x) # Shape: [Batch size, 32, 2N-2, N-2]
        x = self.AntiRectifier(x2) # Shape: [Batch size, 64, 2N-2, N-2]
        # CNN block #3
        x = self.conv3(x) # Shape: [Batch size, 64, 2N-3, N-3]
        x = self.AntiRectifier(x) # Shape: [Batch size, 128, 2N-3, N-3]

        # Additional CNN block for the big variant
        x = self.extra_conv4(x) # Shape: [Batch size, 128, 2N-4, N-4]
        # Additional Deconv block for the big variant
        x = self.extra_deconv1(x) # Shape: [Batch size, 64, 2N-3, N-3]

        x = self.deconv2(x) # Shape: [Batch size, 32, 2N-2, N-2]
        x = self.AntiRectifier(x) # Shape: [Batch size, 64, 2N-2, N-2]
        # DCNN block #3
        x = self.deconv3(x)     # Shape: [Batch size, 16, 2N-1, N-1]
        x = self.AntiRectifier(x) # Shape: [Batch size, 32, 2N-1, N-1]
        # DCNN block #4
        x = self.DropOut(x)
        Rx = self.deconv4(x)  # Shape: [Batch size, 1, 2N, N]  + x0[:, 0].unsqueeze(1))

        Rx_View = Rx.view(Rx.size(0),Rx.size(2),Rx.size(3))                           # Output shape [Batch size, 2N, N]

        ## Real and Imaginary Reconstruction
        Rx_real = Rx_View[:, :self.N, :]                                               # Output shape [Batch size, N, N])
        Rx_imag = Rx_View[:, self.N:, :]                                              # Output shape [Batch size, N, N])
        Kx_tag = torch.complex(Rx_real, Rx_imag)                                      # Output shape [Batch size, N, N])

        ## Apply Gramian transformation to ensure Hermitian and PSD marix

        # Rz = self.norm_layer(Kx_tag) # Try normalization before garmian matrix
        Rz = self.gram_diagonal_overload(Kx_tag, eps= 0.1)
        # Rz = self.gram_diagonal_overload(Rz, eps= 0.1)

        # Rz = self.Gramian_matrix(Kx_tag, eps= 0.1)
                                                   # Output shape [Batch size, N, N]
        Rz = self.norm_layer(Rz)

        ## Model Order Selection
        # print(Rz)

        # -------------- Try Residual ------------------------ #
        #Rz = (Rx_sample_covariance + Rz) / 2
        Rz = self.residual_coeff * Rx_sample_covariance + (1-self.residual_coeff)*Rz
        #----------------------------------------------------- #

        M_est, l_eig, test_values = self.Model_Order_Selection(Rz,M_true)                      # Output shape [Batch size, 1]

        test_values = test_values.to(self.device)
        logits = self.order_classifier(test_values)  # logits: (B, N-1)

        # print(f'\n\n test values size ---- {test_values.size()}')

        DOA_list, Un_list = self.Root_MUSIC(Rz, M_true,self.N)                      # Output shape [Batch size, M]


        return M_est,l_eig, DOA_list, Un_list, logits  #test_values

    def pre_processing(self, x):
        """
        The input data is a complex signal of size [batch, N, T] and the input to the model supposed to be real tensors
        of size [batch, tau, 2N, N].

        Args:
        -----
            x (torch.Tensor): The complex input tensor of size [batch, N, T].

        Returns:
        --------
            Rx_tau (torch.Tensor): The pre-processed real tensor of size [batch, tau, 2N, N].
        """
        batch_size, N, T = x.shape
        Rx_tau = torch.zeros(batch_size, self.tau, 2 * N, N, device=self.device)
        meu = torch.mean(x, dim=-1, keepdim=True).to(self.device)
        center_x = x - meu
        if center_x.dim() == 2:
            center_x = center_x[None, :, :]

        for i in range(self.tau):
            x1 = center_x[:, :, :center_x.shape[-1] - i].to(torch.complex128)
            x2 = torch.conj(center_x[:, :, i:]).transpose(1, 2).to(torch.complex128)
            Rx_lag = torch.einsum("BNT, BTM -> BNM", x1, x2) / (center_x.shape[-1] - i - 1)
            Rx_lag = torch.cat((torch.real(Rx_lag), torch.imag(Rx_lag)), dim=1)
            Rx_tau[:, i, :, :] = Rx_lag

        return Rx_tau











# Methods Dor-Aviv

In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


def create_Rx_batch(X_batch,args):

  Rx_batch = X_batch @ X_batch.conj().transpose(-2, -1) / args.T  # shape: (batch_size, N, N)

  return Rx_batch


def hypothesis_testing(m, eigenvalues, args):

  # loss = -T * torch.sum(torch.log(eigenvalues[:, m:]), dim=1) + T * (N - m) * torch.log(torch.mean(eigenvalues[:, m:], dim=1))

  loss = -args.T * torch.sum(torch.log(eigenvalues[:, m:]), dim=1) + args.T * (args.N - m) * torch.log(torch.mean(eigenvalues[:, m:], dim=1))


  Pm = (2 * args.N * m - m ** 2 + 1)
  if args.penalty_type =="mdl":
      penalty = Pm * np.log(args.T) /2

  else:
      penalty = Pm * 2 ############ okay ? check with arad

  res = loss + penalty

  return res




def Model_Order_Selection(Rx, args):

  """
    Rx: shape (batch_size, N, N)
  """
  batch_size = Rx.shape[0]

  # Eigendecomposition for each matrix in batch

  eigenvalues, _ = torch.linalg.eigh(Rx)
  descending_indices = torch.argsort(eigenvalues, dim=1, descending=True)
  eigenvalues = torch.gather(eigenvalues, dim=1, index=descending_indices)       #size(batch_size,N)
  # print(f'eigenvalues - {eigenvalues}\n\n')

  # heuristic try Dor-Aviv

  # print(f'Rz_eigenvalues------ ----{eigenvalues}\n\n')


  # eigenvalues = 1000 * eigenvalues**3

  # print(f'heuristic_eigenvalues------ ----{eigenvalues}\n\n')

  # Initialize vector

  test_values = torch.zeros(batch_size, args.N-1)

  for m in range(1, args.N):
    test_values[:,m - 1] = hypothesis_testing(m, eigenvalues, args)

  M_est_batch = torch.argmin(test_values, dim=1) + 1  #size(batch_size,1)



  return M_est_batch





def evaluate_hypothesis_testing_model(Test_data, args, plot=True):
    all_preds = []
    all_labels = []

    for i, data in enumerate(Test_data):
        X, M_true, _ = data  # DOA_list_true unused
        X = X.to(args.device)
        M_true = M_true.to(args.device)

        Rx = create_Rx_batch(X, args)  # shape: (1, N, N)
        M_est = Model_Order_Selection(Rx, args)  # shape: (1,)

        all_preds.append(M_est.cpu())
        all_labels.append(M_true.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Accuracy
    accuracy = (all_preds == all_labels).sum().item() / all_labels.size(0)

    # Confusion matrix
    cm = confusion_matrix(all_labels.numpy(), all_preds.numpy())

    return accuracy, cm





# EvaluationMesures

In [None]:
import numpy as np
import torch.nn as nn
import torch
from itertools import permutations
from torch.autograd import Variable
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu");



def Spectrum_loss(A_true, Un_list, M_true):

    loss = 0;

    for iter in range(A_true.shape[0]):

        A_true_iter = A_true[iter]
        Un_iter = Un_list[iter]
        m_true_iter = int(M_true[iter])

        # chopping A

        A_true_iter = A_true_iter[:,:m_true_iter]


        # Matrix multuplication

        P = Un_iter.conj().T @ A_true_iter

        # norm2 to each columns and summation

        col_norms_squared = torch.sum(P.abs() ** 2, dim=0)
        loss += torch.sum(col_norms_squared)


    return loss




def permute_prediction(prediction):
    torch_perm_list = []
    for p in list(permutations(range(prediction.shape[0]),prediction.shape[0])):
        torch_perm_list.append(prediction.index_select( 0, torch.tensor(list(p),dtype = torch.int64).to(device)))
    predictions = torch.stack(torch_perm_list, dim = 0)
    return predictions


class PRMSELoss(nn.Module):
    def __init__(self):
        super(PRMSELoss, self).__init__()
    def forward(self, preds, DOA):
      prmse = []
      for iter in range(preds.shape[0]):
          prmse_list = []
          Batch_preds = preds[iter].to(device)
          targets = DOA[iter].to(device)
          prediction_perm = permute_prediction(Batch_preds).to(device)
          for prediction in prediction_perm:
              ## Old evaluation measure [-pi/2, pi/2]
              error = (((prediction - targets) + (np.pi / 2)) % np.pi) - np.pi / 2                        # Calculate error with modulo pi
              prmse_val = (1 / np.sqrt(len(targets))) * torch.linalg.norm(error)                          # Calculate MSE
              prmse_list.append(prmse_val)
          prmse_tensor = torch.stack(prmse_list, dim = 0)
          prmse_min = torch.min(prmse_tensor)
          prmse.append(prmse_min)
      result = torch.sum(torch.stack(prmse, dim = 0))
      return result



class PRMSELoss_Dor_Aviv(nn.Module):
    def __init__(self):
        super(PRMSELoss_Dor_Aviv, self).__init__()
    def forward(self, preds, DOA, M_True):
      prmse = []
      for iter in range(preds.shape[0]):
          m_true = int(M_True[iter])


          # print(f'\n\n m_true ---- {m_true} \n\n Targets ----  {DOA[iter].to(device)}\n\n Prediction ---- {preds[iter].to(device)}')


          prmse_list = []
          Batch_preds = preds[iter][:m_true].to(device) # "chopping" padded vector to not have zeros
          targets = DOA[iter][:m_true].to(device)

          # print(f'\n\n m_true ---- {m_true} \n\n unpadded Targets ---- {targets}\n\n unpadded Prediction ---- {Batch_preds}')

          prediction_perm = permute_prediction(Batch_preds).to(device)
          for prediction in prediction_perm:
              ## Old evaluation measure [-pi/2, pi/2]
              error = (((prediction - targets) + (np.pi / 2)) % np.pi) - np.pi / 2                        # Calculate error with modulo pi
              prmse_val = (1 / np.sqrt(len(targets))) * torch.linalg.norm(error)                          # Calculate MSE
              prmse_list.append(prmse_val)
          prmse_tensor = torch.stack(prmse_list, dim = 0)
          prmse_min = torch.min(prmse_tensor)
          prmse.append(prmse_min)
      result = torch.sum(torch.stack(prmse, dim = 0))
      return result

class PMSELoss(nn.Module):
    def __init__(self):
        super(PMSELoss, self).__init__()
    def forward(self, preds, DOA):
      prmse = []
      for iter in range(preds.shape[0]):
          prmse_list = []
          Batch_preds = preds[iter].to(device)
          targets = DOA[iter].to(device)
          prediction_perm = permute_prediction(Batch_preds).to(device)
          for prediction in prediction_perm:
              ## Old evaluation measure [-pi/2, pi/2]
              error = (((prediction - targets) + (np.pi / 2)) % np.pi) - np.pi / 2                        # Calculate error with modulo pi
              prmse_val = (1 / len(targets)) * (torch.linalg.norm(error) ** 2)                           # Calculate MSE
              prmse_list.append(prmse_val)
          prmse_tensor = torch.stack(prmse_list, dim = 0)
          prmse_min = torch.min(prmse_tensor)
          prmse.append(prmse_min)
      result = torch.sum(torch.stack(prmse, dim = 0))
      return result

# Schedulares

In [None]:
from torch.optim.lr_scheduler import LambdaLR

def lr_lambda_increase(epoch, warmup_steps, start_lr, end_lr):

    start_lr_factor = start_lr / end_lr  # = 0.01 if you're warming from 1e-7 → 1e-5
    if epoch < warmup_steps:
        return start_lr_factor + (1.0 - start_lr_factor) * (epoch / warmup_steps)
    else:
        return 1.0  # Stay at 1e-5 after warmup


def lr_lambda_decrease(epoch, decay_start, start_lr, end_lr, total_epochs):
    if epoch < decay_start:
        return 1.0  # constant phase (start_lr)
    else:
        decay_epochs = total_epochs - decay_start
        decay_progress = (epoch - decay_start) / decay_epochs
        decay_progress = min(decay_progress, 1.0)
        return 1.0 - decay_progress * (1.0 - end_lr / start_lr)



# Run_Simulation Dor-Aviv

In [None]:
!pip install mplcursors

In [None]:
import torch
import numpy as np
import scipy as sc
import matplotlib.pyplot as plt
import random
import torch.nn as nn
import warnings
import time
import mplcursors
import copy
import torch.optim as optim
from datetime import datetime
from itertools import permutations
from torch.autograd import Variable
from tqdm import tqdm
from torch.optim import lr_scheduler
from sklearn.model_selection import train_test_split
import wandb
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import confusion_matrix
import seaborn as sns
from functools import partial


# from DataLoaderCreation import *
# from Signal_creation import *
# from methods import *
# from models import *
# from EvaluationMesures import *

warnings.simplefilter("ignore")
plt.close('all')

def Set_Overall_Seed(SeedNumber = 42):
  random.seed(SeedNumber)
  np.random.seed(SeedNumber)
  torch.manual_seed(SeedNumber)

Set_Overall_Seed()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# saving_path = r"C:\Users\dorsh\OneDrive\שולחן העבודה\My Drive\Thesis\DeepRootMUSIC\Code\Weights\Models"

def Run_Simulation(args, Model_Train_DataSet,
                    Model_Test_DataSet,

                    Sys_Model
                    ):

    ## Set the seed for all available random operations
    Set_Overall_Seed()
    print("\n----------------------\n")
    now = datetime.now()
    dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
    dt_string_for_save = now.strftime("%d_%m_%Y_%H_%M")
    print("date and time =", dt_string)

    # --------------------- #
    #  Model initialization #
    # --------------------- #

    model = Deep_Model_Order_Selectiton_Net(args)
    model = model.to(device)

    ## Loading available model
    if args.load_flag == True:
      model.load_state_dict(torch.load(args.pre_trained_model_path, map_location=device))
      print("Loaded Succesfully")

    ## Create an optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr,weight_decay=args.weight_decay)
    ####### criterion = nn.CrossEntropyLoss()
    if args.use_scheduler:
      if args.scheduler_type == "warm_up_increase":
        wrapped_lambda = partial(lr_lambda_increase, warmup_steps=args.warmup_steps, start_lr=args.start_lr, end_lr=args.end_lr)
        scheduler = LambdaLR(optimizer, lr_lambda=wrapped_lambda)
      elif args.scheduler_type == "warm_up_decrease":
        lr_schedule = partial(lr_lambda_decrease,decay_start=args.epoch_decay_start,start_lr=args.start_lr,end_lr=args.end_lr,total_epochs=args.epochs)
        scheduler = LambdaLR(optimizer, lr_lambda=lr_schedule)
      elif args.scheduler_type== "decrease":
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=args.scheduler_factor, patience=args.scheduler_patience, verbose=True)

    else:
      scheduler = None
    # ------------------ #
    # Data Organization  #
    # ------------------ #

    Train_DataSet, Valid_DataSet = train_test_split(Model_Train_DataSet, test_size=args.validation_size_ratio, shuffle=True)



    Train_data = torch.utils.data.DataLoader(Train_DataSet,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    drop_last=False) # False
    Valid_data = torch.utils.data.DataLoader(Valid_DataSet,
                                    batch_size=1,
                                    shuffle=False,
                                    drop_last=False) # False

    Test_data = torch.utils.data.DataLoader(Model_Test_DataSet,
                                    batch_size=1,
                                    shuffle=False,
                                    drop_last=False) # False

    print("Training DataSet size", len(Train_DataSet))
    print("Validation DataSet size", len(Valid_DataSet))
    print("Test_DataSet", len(Model_Test_DataSet))

    # ------------#
    # Train Model #
    # ----------- #


    saving_path = os.path.join(args.saving_path, f'{args.mode}',f'loss_type_{args.loss_type}',f'N_{args.N}',f'T_{args.T}', f"SNR_{'mix' if args.Mixed_SNR_in_train else args.SNR}"
)
    os.makedirs(saving_path, exist_ok=True)  # ✅ create the directory if it doesn't exist

    ## Train using the "train_model" function
    model = train_model(model, Train_data, Valid_data,
                 optimizer, epochs= args.epochs,scheduler=scheduler,data_type=args.data_type, loss_regularization=args.loss_regularization, loss_type=args.loss_type,
                    WANDB_TRACKING=args.WANDB_TRACKING, saving_path=saving_path,args = args)

    # ---------------#
    # Evaluate Model #
    # -------------- #

    Test_accuracy, _ = evaluate_model_MOS(model, Test_data)
    print("Test_accuracy = {}".format(Test_accuracy))





def train_model(model, Train_data, Valid_data,
                 optimizer, epochs,saving_path,data_type, loss_regularization, loss_type,args,
                 scheduler=None, WANDB_TRACKING=True):


    # ------------ WANDB Tracking----------------------#


    if args.loss_type == "cross_entropy":
      # --------- Weight for CE --------- #
      # Start with equal weights
      weights = torch.ones(args.N - 1)
      if args.mode == 'coherent':
          weights[1:args.N-1] = 5.0   # weights on middle values for coherent case


      criterion = nn.CrossEntropyLoss(weight=weights.to(args.device))  #, label_smoothing=0.1)
    elif args.loss_type == "rpmse":
      criterion = PRMSELoss_Dor_Aviv()

    if WANDB_TRACKING:
        wandb.finish()
        wandb.login(key="3ec39d34b7882297a057fdc2126cd037352175a4")

        # Generate a unique timestamp
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

        # Initialize the WandB run
        wandb.init(
            project="MBDL",
            name=f"{args.mode}_{args.loss_type}_T_{args.T}_SNR_{'mix' if args.Mixed_SNR_in_train else args.SNR}_Data_Size{args.nNumberOfSampels}_lr_{args.lr}_batch_{args.batch_size}_resid_coeff_{model.residual_coeff}_normRx_{args.norm_Rx}_{timestamp}",  # Add timestamp to the run name, f'best_model_datasize{args.nNumberOfSampels}_lr{args.lr}_batch_{args.batch_size}_weightdecay_{args.weight_decay}.pk')
            config={"epochs": epochs}
        )

    since = time.time()
    min_valid_loss = np.inf

    print("\n---Start Training Stage ---\n")

    for epoch in tqdm(range(epochs)):

            # ----------- #
            #  Training
            # ----------- #

            model.train()
            Overall_train_loss = 0.0
            model = model.to(device)

            for i, data in enumerate(Train_data):

                if data_type == "DOA":

                  X, M_true, DOA_list_true = data
                  X = X.to(device)                             # Rx = Variable(Rx, requires_grad=True).to(device)
                  M_true = M_true.to(device)                   # M_true = Variable(M_true, requires_grad=True).to(device)

                  DOA_list_true = DOA_list_true.to(device)                ############### truncate because the data was zero-padded


                  ## Compute model DOA predictions
                  M_estimation,l_eig, DOA_list_estimation, Un_list = model(X,M_true)

                  ## Compute training loss
                  # train_loss = criterion(M_estimation.to(M_true.dtype), M_true)
                  DOA_loss = criterion(DOA_list_estimation, DOA_list_true, M_true)
                  train_loss = DOA_loss + loss_regularization*l_eig

                elif data_type == "Steering":
                  X, M_true, A_true = data
                  X = X.to(device)                             # Rx = Variable(Rx, requires_grad=True).to(device)
                  M_true = M_true.to(device)                   # M_true = Variable(M_true, requires_grad=True).to(device)
                  A_true = A_true.to(device)                ############### truncate because the data was zero-padded


                  ## Compute model DOA predictions
                  M_estimation,l_eig, DOA_list_estimation, Un_list,test_values = model(X,M_true)
                  test_values_minus = - test_values
                  test_values_minus = test_values_minus.to(device)


                  ## Compute training loss
                  # train_loss = criterion(M_estimation.to(M_true.dtype), M_true)
                  spectrum_loss = Spectrum_loss(A_true, Un_list, M_true)
                  train_loss = loss_regularization*spectrum_loss + l_eig
                  if loss_type == "cross_entropy":
                    soft_decision_loss = criterion(test_values_minus,M_true.long()-1)       ########## (-1), long
                    train_loss = loss_regularization*spectrum_loss + soft_decision_loss



                # print(f'M_estimation --- {M_estimation.to(M_true.dtype)}\n\n M_true --- {M_true}')

                ## Update weights
                try:
                  train_loss.backward()
                except RuntimeError:
                  print("linalg error")
                  pass

                optimizer.step()
                model.zero_grad()

                # update total loss

                Overall_train_loss += train_loss.item()

            Overall_train_loss = Overall_train_loss / len(Train_data)              # compute the epoch training loss


            # ----------- #
            #  Validation
            # ----------- #


            Overall_valid_loss = 0.0
            model.eval()

            with torch.no_grad():
                for i, data in enumerate(Valid_data):

                    if data_type == "DOA":  #PEMSE Loss

                      X, M_true, DOA_list_true = data
                      X = X.to(device)
                      M_true = M_true.to(device)
                      DOA_list_true = DOA_list_true.to(device)             ############### truncate because the data was zero-padded

                      M_estimation,l_eig, DOA_list_estimation = model(X,M_true)

                      # eval_loss = criterion(M_estimation.float(), M_true.float())
                      DOA_loss = criterion(DOA_list_estimation,DOA_list_true, M_true)
                      eval_loss = loss_regularization*DOA_loss + l_eig



                    elif data_type == "Steering":

                      X, M_true, A_true = data
                      X = X.to(device)                             # Rx = Variable(Rx, requires_grad=True).to(device)
                      M_true = M_true.to(device)                   # M_true = Variable(M_true, requires_grad=True).to(device)
                      A_true = A_true.to(device)                ############### truncate because the data was zero-padded


                      ## Compute model DOA predictions
                      M_estimation,l_eig, DOA_list_estimation, Un_list,test_values = model(X,M_true)
                      test_values_minus = - test_values
                      test_values_minus = test_values_minus.to(device)

                      ## Compute training loss
                      # train_loss = criterion(M_estimation.to(M_true.dtype), M_true)
                      spectrum_loss = Spectrum_loss(A_true, Un_list, M_true)
                      eval_loss = loss_regularization*spectrum_loss + l_eig
                      if args.loss_type == "cross_entropy":
                        soft_decision_loss = criterion(test_values_minus, M_true.long()-1)  ########## (-1), long
                        eval_loss = loss_regularization*spectrum_loss + soft_decision_loss



                    Overall_valid_loss += eval_loss.item()

                Overall_valid_loss = Overall_valid_loss / len(Valid_data)

            if args.use_scheduler:
                if args.scheduler_type == "warm_up_decrease" or args.scheduler_type == "warm_up_increase":
                  scheduler.step()
                else:
                  scheduler.step(Overall_valid_loss)  # Step the scheduler with the validation loss


            # --------------------------------- #
            #  Tracking Learning Curves (WANDB)
            # --------------------------------- #

            if WANDB_TRACKING:
              wandb.log({
                        "train_loss": Overall_train_loss,
                        "val_loss": Overall_valid_loss,
                        "lr": optimizer.param_groups[0]['lr'],

                    }, step=epoch)


            # --------------------------------- #
            #  Save Best Model
            # --------------------------------- #

            ## save model weights for better validation performences

            if Overall_valid_loss < min_valid_loss:
                print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{Overall_valid_loss:.6f}) \t Saving The Model')
                min_valid_loss = Overall_valid_loss
                best_epoch = epoch
                ## Saving State Dict
                best_model_weights = copy.deepcopy(model.state_dict())
                checkpoint_path = os.path.join(saving_path, f'best_model_datasize{args.nNumberOfSampels}_lr{args.lr}_resid_coeff{model.residual_coeff}_normRx_{args.norm_Rx}.pk')
                torch.save(model.state_dict(), checkpoint_path)

            # if epoch % 100 == 0:
            #    checkpoint_path = os.path.join(saving_path, f'{model_name}_model_epoch_{epoch}.pk')
            #    torch.save(model.state_dict(), checkpoint_path)

    # Save last model
    checkpoint_path = os.path.join(saving_path, f'Last_model_datasize{args.nNumberOfSampels}_lr{args.lr}_resid_coeff{model.residual_coeff}_normRx_{args.norm_Rx}.pk')
    torch.save(model.state_dict(), checkpoint_path)


    if WANDB_TRACKING:
      wandb.finish()

    time_elapsed = time.time() - since
    print("\n--- Training summary ---")
    print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60))
    print('Minimal Validation loss: {:4f} at epoch {}'.format(min_valid_loss, best_epoch))

    # return the best model

    model.load_state_dict(best_model_weights)

    return model


def evaluate_model_MOS(model, Data, plot=False):

    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():                                                        # Gradients Calculation isnt required for evaluation
        for i, data in enumerate(Data):
            X, M_true, DOA_list_true = data
            X = X.to(device)
            M_true = M_true.to(device)
            DOA_list_true = DOA_list_true.to(device)

            ## Compute model DOA predictions
            M_estimation,_,_,_,logits = model(X, M_true=M_true)                                          # Compute prediction of DOA's

            # Match to predict clasiffier head
            # M_estimation = torch.argmax(logits, dim=1) + 1  # Convert to 1-based
            M_estimation = torch.argmin(logits, dim=1) + 1  # Convert to 1-based

            all_preds.append(M_estimation.cpu())
            all_labels.append(M_true.cpu())


    # Concatenate all batches
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Calculate accuracy
    correct = (all_preds == all_labels).sum().item()
    total = all_labels.size(0)
    accuracy = correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels.numpy(), all_preds.numpy())

    return accuracy, cm



# Plots

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_accuracy_comparison_vs_snr(args, SNR_values, model_results, baseline_results,use_different_model ):


    # File paths for saving data
    saving_path = os.path.join(args.plots_path, f'{args.mode}',f'loss_type_{args.loss_type}',f'N_{args.N}', f'T_{args.T}')
    os.makedirs(saving_path, exist_ok=True)  # ✅ create the directory if it doesn't exist

    fig, ax = plt.subplots(figsize=(8, 5))

    ax.plot(SNR_values, model_results, label="MBDL", marker='*',
            color='darkviolet', linewidth=1.5, linestyle='--')

    ax.plot(SNR_values, baseline_results, label="Classic Hypothesys Testing", marker='o',
            color='b', linewidth=1.5, linestyle='--')

    ax.set_xlabel("SNR (dB)")
    ax.set_ylabel("Accuracy")
    ax.set_title("MBDL vs Hypothesis Testing Accuracy")
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend()
    # ax.set_ylim(0, 1.05)  # Clamp accuracy to [0,1]

    plt.tight_layout()


    # --------------------- Saving Image And Results ------------------ #
    plt.savefig(os.path.join(saving_path, f'SNR_VS_ACCURACY_RESULTS_Mixed_SNR_in_train_{args.Mixed_SNR_in_train}_different_model_{use_different_model}.png'), dpi=300)
    np.save(os.path.join(saving_path, "Our_Results.npy"), np.array(model_results))
    np.save(os.path.join(saving_path, "Classic_Algo_RESULTS.npy"), np.array(baseline_results))
    #------------------------------------------------------------------ #

    plt.show()

def plot_accuracy_comparison_vs_T(args, T_values, model_results, baseline_results,use_different_model ):


    # File paths for saving data
    saving_path = os.path.join(args.plots_path, f'{args.mode}',f'loss_type_{args.loss_type}',f'N_{args.N}', f'T_{args.T}')
    os.makedirs(saving_path, exist_ok=True)  # ✅ create the directory if it doesn't exist

    fig, ax = plt.subplots(figsize=(8, 5))

    ax.plot(T_values, model_results, label="MBDL", marker='*',
            color='g', linewidth=1.5, linestyle='--')

    ax.plot(T_values, baseline_results, label="Classic Hypothesys Testing", marker='o',
            color='b', linewidth=1.5, linestyle='--')

    ax.set_xlabel("T")
    ax.set_ylabel("Accuracy")
    ax.set_title("MBDL vs Hypothesis Testing Accuracy")
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend()
    # ax.set_ylim(0, 1.05)  # Clamp accuracy to [0,1]

    plt.tight_layout()


    # --------------------- Saving Image And Results ------------------ #
    plt.savefig(os.path.join(saving_path, f'T_VS_ACCURACY_RESULTS_SNR_{args.SNR}_different_model_{use_different_model}.png'), dpi=300)
    np.save(os.path.join(saving_path, "Our_Results.npy"), np.array(model_results))
    np.save(os.path.join(saving_path, "Classic_Algo_RESULTS.npy"), np.array(baseline_results))
    #------------------------------------------------------------------ #

    plt.show()

def plot_cm(cm, use_different_model, type="model"):

        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Confusion Matrix (Hypothesis Testing)')

        # --------- Save CM ---------- #
        # File paths for saving data
        saving_path = os.path.join(args.plots_path, f'{args.mode}',f'loss_type_{args.loss_type}',f'N_{args.N}', f'T_{args.T}', f'SNR_{args.SNR}')
        os.makedirs(saving_path, exist_ok=True)  # ✅ create the directory if it doesn't exist
        plt.savefig(os.path.join(saving_path, f'Confusion_Matrix_type_{type}_different_model_{use_different_model}.png'), dpi=300)

        plt.show()



# Config

In [None]:
import argparse
import torch
import os


def get_options(args=None):
    parser = argparse.ArgumentParser(
        description="Arguments and hyperparameters for main.py")

    # device
    parser.add_argument('--device', default=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))


    # Paths
    parser.add_argument('--Main_path', default="/content/gdrive/Shareddrives/Model_Based_DL/")
    parser.add_argument('--Main_Data_path', default="/content/gdrive/Shareddrives/Model_Based_DL/DataSet")
    parser.add_argument('--saving_path', default="/content/gdrive/Shareddrives/Model_Based_DL/Weights")
    parser.add_argument('--plots_path', default="/content/gdrive/Shareddrives/Model_Based_DL/Plots")
    parser.add_argument('--Simulations_path', default="/content/gdrive/Shareddrives/Model_Based_DL/Simulations")
    parser.add_argument('--Data_Scenario_path', default="LowSNR")

    parser.add_argument('--pre_trained_model_path', default="/content/gdrive/Shareddrives/Model_Based_DL/Weights/coherent/loss_type_cross_entropy/N_5/T_100/SNR_0/best_model_datasize3000_lr0.0001_resid_coeff0.2_normRx_True.pk")




    # Main Commands
    parser.add_argument('--SAVE_TO_FILE', default=False)
    parser.add_argument('--CREATE_DATA', default=True)
    parser.add_argument('--LOAD_DATA', default=False)
    parser.add_argument('--TRAIN_MODE', default=False)
    parser.add_argument('--SAVE_MODEL', default=True)
    parser.add_argument('--EVALUATE_MODE', default=False)

    parser.add_argument('--Mixed_SNR_in_train', default=False)


    # Data Parameters
    parser.add_argument('--data_type', default="Steering") # "DOA",
    parser.add_argument('--Create_Training_Data', default=True) # "DOA",


    parser.add_argument('--tau', default=8)
    parser.add_argument('--N', default=5) #8
    parser.add_argument('--M', default=2)
    parser.add_argument('--T', default=20)
    parser.add_argument('--SNR', default=0)
    parser.add_argument('--nNumberOfSampels', default=3000)  #100, 10,000, 3000
    parser.add_argument('--Train_Test_Ratio', default=0.2) # 1
    parser.add_argument('--scenario', default="NarrowBand")  # "Broadband_OFDM", "Broadband_simple"
    parser.add_argument('--mode', default="coherent")  #coherent, non-coherent


    # Training parameters

    parser.add_argument('--optimal_gamma_val', default=1)
    parser.add_argument('--loss_type', default="cross_entropy")  # "l_eig"  , "cross_entropy"

    parser.add_argument('--lr', default=1e-4)  # 1e-4
    parser.add_argument('--optimal_step', default=1)
    parser.add_argument('--epochs', default=100) #, 100
    parser.add_argument('--optimizer_name', default="Adam")
    parser.add_argument('--Schedular', default=True)
    parser.add_argument('--weight_decay', default=1e-4)  # 1e-3
    parser.add_argument('--loss_regularization', default=0)  #0, 1e-3, 1
    parser.add_argument('--load_flag', default=False)
    parser.add_argument('--loading_path', default=r'/content/gdrive/Shareddrives/Model_Based_DL/Weights/model_tau_2_M_2_100Samples_SNR_10_T_2_just_a_test_best_model.pk')

    parser.add_argument('--Plot', default=False)
    parser.add_argument('--validation_size_ratio', default=0.3)
    parser.add_argument('--batch_size', default=64)

    parser.add_argument('--simulation_saving_path', default="/content/gdrive/Shareddrives/Model_Based_DL/Weights/Models")
    parser.add_argument('--WANDB_TRACKING', type=bool, default=True, help="weather to log training to WANDB")


    # Scheduler Parameters

    parser.add_argument('--use_scheduler', type=bool, default=False, help="weather to use scheduler")

    parser.add_argument('--scheduler_type', default="warm_up_decrease")  # "warm_up_increase", "warm_up_decrease", "decrease"
    parser.add_argument('--warmup_steps', default=20)       # for the case of "warm_up_increase"
    parser.add_argument('--epoch_decay_start', default=15)  # for the case of "warm_up_decrease"
    parser.add_argument('--start_lr', default=1e-4)
    parser.add_argument('--end_lr', default=1e-6)

    parser.add_argument('--scheduler_factor', type=float, default=0.5, help="in how much to divide the lr") # for the case o "decrease" (Platu)
    parser.add_argument('--scheduler_patience', type=int, default=5,
                        help="amount of episodes with no improvement to wait for lr reduction")





    # Model Parameters

    parser.add_argument('--penalty_type', default="mdl")
    parser.add_argument('--ActivationVal', default=0.5)
    parser.add_argument('--normalize_eigenvalues', default=True)

    parser.add_argument('--residual_coeff_coherent', default=0.2)
    parser.add_argument('--residual_coeff_non_coherent', default=0.2) #0.5

    parser.add_argument('--norm_Rx', default=True)









    opts = parser.parse_args(args)

    return opts

# Main Dor-Aviv

In [None]:
import sys
import torch
import os
import matplotlib.pyplot as plt
import warnings

# from System_Model import *
# from Signal_creation import *
# from DataLoaderCreation import *
# from EvaluationMesures import *
# from methods import *
# from models import *
# from Run_Simulation import *
# from utils import *

warnings.simplefilter("ignore")
plt.close('all')
os.system('cls||clear')

def main(args):


    Set_Overall_Seed()
    now = datetime.now()
    dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
    dt_string_for_save = now.strftime("%d_%m_%Y_%H_%M")

    ############################
    ##        Commands        ##
    ############################
    if(args.SAVE_TO_FILE):
        file_path = os.path.join(args.Simulations_path, "Results", "Scores", dt_string_for_save + ".txt")
        sys.stdout = open(file_path, "w")

    # print("------------------------------------")
    # print("---------- New Simulation ----------")
    # print("------------------------------------")
    # print("date and time =", dt_string)


    ############################
    ###   Create Data Sets   ###
    ############################

    if args.CREATE_DATA:
        Set_Overall_Seed()
        Create_Training_Data = args.Create_Training_Data  # False, i assume it was changed manually and final version included only test
        Create_Testing_Data = True  # True
        # print("Creating Data...\n")
        if Create_Training_Data:
        ## Training Datasets
            DataSet_x_train, DataSet_Rx_train, _ = CreateDataSetCombined_with_Steering_Matrix(
                                    args = args,
                                    scenario= args.scenario,
                                    mode= args.mode,
                                    N= args.N, M= args.M , T= args.T,
                                    Sampels_size = args.nNumberOfSampels,
                                    tau = args.tau,
                                    Save = True,
                                    DataSet_path = os.path.join(args.Main_Data_path, args.Data_Scenario_path, "TrainingData"),
                                    True_DOA = None,
                                    SNR = args.SNR)
        if Create_Testing_Data:
        ## Test Datasets
            DataSet_x_test, DataSet_Rx_test, Sys_Model = CreateDataSetCombined_with_Steering_Matrix(
                                    args = args,
                                    scenario = args.scenario,
                                    mode = args.mode,
                                    N= args.N, M= args.M , T= args.T,
                                    Sampels_size = int(args.Train_Test_Ratio * args.nNumberOfSampels),
                                    tau = args.tau,
                                    Save = True,
                                    DataSet_path = os.path.join(args.Main_Data_path, args.Data_Scenario_path, "TestData"),
                                    True_DOA = None,
                                    SNR = args.SNR)

        print("Finished Creating Data...\n")
    ############################
    ###    Load Data Sets    ###
    ############################

    if args.LOAD_DATA:
        train_details_line = '_{}_{}_{}_M={}_N={}_T={}_SNR={}.h5'.format(args.scenario, args.mode, args.nNumberOfSampels, args.M, args.N, args.T, args.SNR)
        test_details_line = '_{}_{}_{}_M={}_N={}_T={}_SNR={}.h5'.format(args.scenario, args.mode, int(args.Train_Test_Ratio * args.nNumberOfSampels), args.M, args.N, args.T, args.SNR)
        # train_details_line = '_{}_{}_{}_M={}_N={}_T={}.h5'.format(scenario, mode, nNumberOfSampels, M, N, T)
        # test_details_line = '_{}_{}_{}_M={}_N={}_T={}.h5'.format(scenario, mode, int(Train_Test_Ratio * nNumberOfSampels), M, N, T)

        # Base path combining the main and scenario paths
        base_path = os.path.join(args.Main_Data_path, args.Data_Scenario_path)

        # Reading the datasets using clean, platform-independent paths
        DataSet_Rx_train = Read_Data(os.path.join(base_path, "TrainingData", "DataSet_Rx" + train_details_line))
        DataSet_Rx_test  = Read_Data(os.path.join(base_path, "TestData", "DataSet_Rx" + test_details_line))
        DataSet_x_test   = Read_Data(os.path.join(base_path, "TestData", "DataSet_x" + test_details_line))
        Sys_Model        = Read_Data(os.path.join(base_path, "TestData", "Sys_Model" + test_details_line))


    ############################
    ###    Training stage    ###
    ############################

    if args.TRAIN_MODE:



        ############################
        ###    Run Simulation    ###
        ############################


        Run_Simulation(
                        args, Model_Train_DataSet = DataSet_x_train,
                        Model_Test_DataSet = DataSet_x_test,

                        Sys_Model = Sys_Model
                        )




    ############################
    ###   Evaluation stage   ###
    ############################

    if args.EVALUATE_MODE:


        ############################
        ###    Load Data Set     ###
        ############################
        # base_path = os.path.join(args.Main_Data_path, args.Data_Scenario_path)
        # DataSet_Rx_test = Read_Data(os.path.join(base_path, "TestData", "DataSet_Rx" + test_details_line))
        # DataSet_x_test  = Read_Data(os.path.join(base_path, "TestData", "DataSet_x"  + test_details_line))
        # Sys_Model       = Read_Data(os.path.join(base_path, "TestData", "Sys_Model"  + test_details_line))

        # print("SNR = {}".format(args.SNR))
        # print("scenario = {}".format(args.scenario))
        # print("mode = {}".format(args.mode))
        # print("Observations = {}\n".format(args.T))


        ############################
        ###    Load Model     ###
        ############################



        model = Deep_Model_Order_Selectiton_Net(args)

        # Load it to the specified device, either gpu or cpu
        model = model.to(args.device)
        model.load_state_dict(torch.load(args.pre_trained_model_path, map_location=args.device))



        DataSet_Rx_test = torch.utils.data.DataLoader(DataSet_Rx_test,  #
                                    batch_size=1,
                                    shuffle=False,
                                    drop_last=False)

        Test_data = torch.utils.data.DataLoader(DataSet_x_test, # DataSet_x_test
                                batch_size=1,
                                shuffle=False,
                                drop_last=False)


        # Test_data_train = torch.utils.data.DataLoader(DataSet_x_train, # DataSet_x_test
        #                 batch_size=1,
        #                 shuffle=False,
        #                 drop_last=False)

        Test_acc, cm_model = evaluate_model_MOS(model, Test_data) # Test_data
        print("\n\nDeep_Model_Order_Selectiton_Net Test loss on test data = {}".format(Test_acc))

        # Test_acc_train, _ = evaluate_model_MOS(model, Test_data_train) # Test_data
        # print("\n\nDeep_Model_Order_Selectiton_Net Test loss on train data = {}".format(Test_acc_train))


        # ---------------- Evaluate original Hypothesys testing -------------------- #
        ht_acc, cm_ht = evaluate_hypothesis_testing_model(Test_data=Test_data, args=args)

        print("\n\n Hypothesys testing Accuracy on test data = {}".format(ht_acc))

        # print("end")

        return Test_acc, ht_acc, cm_model, cm_ht

# Run Train/Evaluation

In [None]:
SNR_values = [0]
T_values = [10, 20,50, 70, 100]

models_paths = ["/content/gdrive/Shareddrives/Model_Based_DL/Weights/non-coherent/loss_type_cross_entropy/N_5/T_10/SNR_0/best_model_datasize3000_lr0.0001_resid_coeff0.2_normRx_True.pk",
                "/content/gdrive/Shareddrives/Model_Based_DL/Weights/non-coherent/loss_type_cross_entropy/N_5/T_20/SNR_0/best_model_datasize3000_lr0.0001_resid_coeff0.2_normRx_True.pk",
                "/content/gdrive/Shareddrives/Model_Based_DL/Weights/non-coherent/loss_type_cross_entropy/N_5/T_50/SNR_0/best_model_datasize3000_lr0.0001_resid_coeff0.2_normRx_True.pk",
                "/content/gdrive/Shareddrives/Model_Based_DL/Weights/non-coherent/loss_type_cross_entropy/N_5/T_70/SNR_0/best_model_datasize3000_lr0.0001_resid_coeff0.2_normRx_True.pk",
                "/content/gdrive/Shareddrives/Model_Based_DL/Weights/non-coherent/loss_type_cross_entropy/N_5/T_100/SNR_0/best_model_datasize3000_lr0.0001_resid_coeff0.2_normRx_True.pk",
                ]


# ----------- Run Training for different SNR ----------- #

args = get_options(args=[])

Training = False
Evaluation = True
Run_through_SNR = False
Run_through_T = True

use_different_model = False

model_results = []
ht_results = []

if Training:
  args.TRAIN_MODE = True
if Evaluation:
  args.EVALUATE_MODE = True


if Evaluation:

  args.Create_Training_Data = False


  # ----------------- When T is set and snr Changes -------------------------- #
  if Run_through_SNR:

    for id, snr in enumerate(SNR_values):
      args.SNR = snr

      print('-------------------------------------------------------------------------')
      print(f'\n\n--------------- Starting Evaluation For SNR={snr} ----------------\n\n')
      print('-------------------------------------------------------------------------')

      if use_different_model:
          args.pre_trained_model_path = models_paths[id]

      # -------------- Run Evaluation ------------ #
      Test_acc, ht_acc, cm_model, cm_ht = main(args=args)
      model_results.append(Test_acc)
      ht_results.append(ht_acc)

      plot_cm(cm=cm_model, use_different_model=use_different_model, type="model")
      plot_cm(cm=cm_ht, use_different_model=use_different_model, type="HT")
      # ------------------------------------------ #

    # ---------- Plot Results ---------------
    plot_accuracy_comparison_vs_snr(args=args, SNR_values=SNR_values, model_results=model_results, baseline_results=ht_results, use_different_model=use_different_model)

  # -------------------------------------------------------------------------- #
  # ----------------- When SNR is set and T Changes -------------------------- #

  if Run_through_T:

    for id, t in enumerate(T_values):
      args.T = t

      print('-------------------------------------------------------------------------')
      print(f'\n\n--------------- Starting Evaluation For T={t} ----------------\n\n')
      print('-------------------------------------------------------------------------')

      if use_different_model:
          args.pre_trained_model_path = models_paths[id]

      # -------------- Run Evaluation ------------ #
      Test_acc, ht_acc, cm_model, cm_ht = main(args=args)
      model_results.append(Test_acc)
      ht_results.append(ht_acc)

      plot_cm(cm=cm_model, use_different_model=use_different_model, type="model")
      plot_cm(cm=cm_ht, use_different_model=use_different_model, type="HT")
      # ------------------------------------------ #

    # ---------- Plot Results ---------------
    plot_accuracy_comparison_vs_T(args=args, T_values=T_values, model_results=model_results, baseline_results=ht_results, use_different_model=use_different_model)

if Training:

  if args.Mixed_SNR_in_train:
    main(args=args) # args.SNR value is not important

  else:
    for snr in SNR_values:
      args.SNR = snr

      for t in T_values:
        args.T = t

        print('-------------------------------------------------------------------------')
        print(f'\n\n--------------- Starting Training For SNR={snr} T={t} ----------------\n\n')
        print('-------------------------------------------------------------------------')
        # ---- Run Training ----- #
        main(args=args)
        # ----------------------- #



