In [6]:
import torch 
import torch.nn as nn

from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from utils import *
from tqdm import tqdm

dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
d = 3
m = 8
t = 200
n = 20000
snr_min = 0
snr_max = 20
lamda = 0.2
radius = 0.1
mc_range = 3
coherent = False

array = UCA(m, lamda)
array.build_array(radius=0.1)
array.build_array_manifold()
array.build_transform_matrices_from_array_manifold()

observations, angles = generate_data(n, t, d, snr_min, snr_max, array, mc_range, coherent)

In [8]:
lr = 1e-2
wd = 1e-8
nbEpoches = 200
batchSize = 256

x_train, x_valid, theta_train, theta_valid = train_test_split(observations, angles, test_size=0.2)
x_train, x_test, theta_train, theta_test = train_test_split(x_train, theta_train, test_size=0.2)

train_set = DATASET(x_train, theta_train)
valid_set = DATASET(x_valid, theta_valid)
test_set = DATASET(x_test, theta_test)

train_loader = DataLoader(train_set, batch_size=batchSize, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batchSize, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batchSize, shuffle=False)

In [9]:
class my_model(nn.Module):

    def __init__(self, m: int, d: int, mc_range: int, array):
        
        super().__init__()
        self.m = m
        self.d = d
        self.array = array

        self.W = array.transform_matrices[..., :m-d-1].to(dev)

        self.bn = nn.BatchNorm1d(2*self.m)
        self.rnn = nn.GRU(input_size=2*self.m, hidden_size=2*self.m, num_layers=1, batch_first=True)
        self.fc = nn.Linear(in_features=2*self.m, out_features=2*self.m*self.m)
        self.mlp = nn.Sequential(nn.Linear(in_features=array.nbSamples_spectrum, out_features=2*self.m), nn.ReLU(),
                                 nn.Linear(in_features=2*self.m, out_features=2*self.m), nn.ReLU(),
                                 nn.Linear(in_features=2*self.m, out_features=2*self.m), nn.ReLU(),
                                 nn.Linear(in_features=2*self.m, out_features=self.d))

    def forward(self, X: torch.Tensor):
        
        X = torch.cat((torch.real(X), torch.imag(X)), dim=-1)
        X = self.bn(X.transpose(1, 2)).transpose(1, 2)
        _, X = self.rnn(X)

        cov = self.fc(X[-1])
        cov = cov.reshape(-1, 2, self.m, self.m)
        cov = cov[:, 0, :, :] + 1j * cov[:, 1, :, :]
        cov = cov @ cov.conj().transpose(1, 2)
        vals, vecs = torch.linalg.eigh(cov)
        idx = torch.sort(torch.abs(vals), dim=1)[1].unsqueeze(dim=1).repeat(repeats=(1, self.m, 1))
        vecs = torch.gather(vecs, dim=2, index=idx)
        En = vecs[:, :, :(self.m - self.d)]
        
        Q = torch.einsum('smr,bmk->bsrk', self.W.conj(), En)
        Q = Q @ Q.conj().transpose(-2, -1)
        # spectrum = 1 / torch.abs(torch.linalg.det(Q))
        spectrum = 1 / torch.linalg.eigvalsh(Q)[..., 0]

        theta = self.mlp(spectrum)
        
        return theta

In [10]:
model = my_model(m, d, mc_range, array).to(dev)
# model.load_state_dict(torch.load('model_1_0dB.pt', weights_only=True))
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.2)
Loss, Loss_theta, Val_theta = [], [], []
bestVal = 1000

for i in tqdm(range(nbEpoches)):
    # Train
    running_loss = 0.0
    for data in train_loader:
        X, theta_true = data[0].to(dev), data[1].to(dev)
        optimizer.zero_grad()
        theta_pred = model(X)
        loss = RMSPE(theta_pred, theta_true) 
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    Loss.append(running_loss/len(train_loader))

    # scheduler.step()

    # Validation 
    with torch.no_grad():
        running_loss = 0.0
        for data in valid_loader:
            X, theta_true = data[0].to(dev), data[1].to(dev)
            theta_pred = model(X)
            loss = RMSPE(theta_pred, theta_true)
            running_loss += loss.item()
        
        Val_theta.append(running_loss/len(test_loader))

        if Val_theta[i] < bestVal:
            bestVal = Val_theta[i]
            torch.save(model.state_dict(), 'model_1.pt')

    print("Iteration {}: Loss theta = {}".format(i, Loss[-1]))

  0%|          | 1/200 [00:03<11:02,  3.33s/it]

Iteration 0: Loss theta = 0.5100754737854004


  1%|          | 2/200 [00:06<10:58,  3.33s/it]

Iteration 1: Loss theta = 0.48576536774635315


  2%|▏         | 3/200 [00:09<10:55,  3.33s/it]

Iteration 2: Loss theta = 0.4739738941192627


  2%|▏         | 4/200 [00:13<10:58,  3.36s/it]

Iteration 3: Loss theta = 0.4722721254825592


  2%|▎         | 5/200 [00:16<10:53,  3.35s/it]

Iteration 4: Loss theta = 0.46834520280361175


  3%|▎         | 6/200 [00:20<10:53,  3.37s/it]

Iteration 5: Loss theta = 0.4667129862308502


  4%|▎         | 7/200 [00:23<10:47,  3.35s/it]

Iteration 6: Loss theta = 0.4652575218677521


  4%|▍         | 8/200 [00:26<10:42,  3.35s/it]

Iteration 7: Loss theta = 0.4637552285194397


  4%|▍         | 9/200 [00:30<10:38,  3.34s/it]

Iteration 8: Loss theta = 0.461998433470726


  5%|▌         | 10/200 [00:33<10:36,  3.35s/it]

Iteration 9: Loss theta = 0.4603005677461624


  6%|▌         | 11/200 [00:36<10:31,  3.34s/it]

Iteration 10: Loss theta = 0.45663413166999817


  6%|▌         | 12/200 [00:40<10:25,  3.33s/it]

Iteration 11: Loss theta = 0.4481256878376007


  6%|▋         | 13/200 [00:43<10:24,  3.34s/it]

Iteration 12: Loss theta = 0.45206056952476503


  7%|▋         | 14/200 [00:46<10:20,  3.34s/it]

Iteration 13: Loss theta = 0.4155868375301361


  8%|▊         | 15/200 [00:50<10:16,  3.33s/it]

Iteration 14: Loss theta = 0.4003843951225281


  8%|▊         | 16/200 [00:53<10:12,  3.33s/it]

Iteration 15: Loss theta = 0.3731093168258667


  8%|▊         | 17/200 [00:56<10:10,  3.34s/it]

Iteration 16: Loss theta = 0.35252975702285766


  9%|▉         | 18/200 [01:00<10:07,  3.34s/it]

Iteration 17: Loss theta = 0.33885972321033475


 10%|▉         | 19/200 [01:03<10:05,  3.35s/it]

Iteration 18: Loss theta = 0.33509369134902955


 10%|█         | 20/200 [01:06<10:02,  3.35s/it]

Iteration 19: Loss theta = 0.3205335521697998


 10%|█         | 21/200 [01:10<09:57,  3.34s/it]

Iteration 20: Loss theta = 0.2953817456960678


 11%|█         | 22/200 [01:13<09:53,  3.33s/it]

Iteration 21: Loss theta = 0.2906236010789871


 12%|█▏        | 23/200 [01:16<09:54,  3.36s/it]

Iteration 22: Loss theta = 0.29528583407402037


 12%|█▏        | 24/200 [01:20<10:03,  3.43s/it]

Iteration 23: Loss theta = 0.2939057105779648


 12%|█▎        | 25/200 [01:23<09:57,  3.42s/it]

Iteration 24: Loss theta = 0.2751183408498764


 13%|█▎        | 26/200 [01:27<09:53,  3.41s/it]

Iteration 25: Loss theta = 0.26904989510774613


 14%|█▎        | 27/200 [01:30<09:47,  3.40s/it]

Iteration 26: Loss theta = 0.2990246430039406


 14%|█▍        | 28/200 [01:33<09:38,  3.36s/it]

Iteration 27: Loss theta = 0.28937109410762785


 14%|█▍        | 29/200 [01:37<09:43,  3.41s/it]

Iteration 28: Loss theta = 0.28933595180511473


 15%|█▌        | 30/200 [01:40<09:35,  3.38s/it]

Iteration 29: Loss theta = 0.2602818125486374


 16%|█▌        | 31/200 [01:44<09:34,  3.40s/it]

Iteration 30: Loss theta = 0.2562464174628258


 16%|█▌        | 32/200 [01:47<09:31,  3.40s/it]

Iteration 31: Loss theta = 0.2538759559392929


 16%|█▋        | 33/200 [01:50<09:23,  3.37s/it]

Iteration 32: Loss theta = 0.24228381991386413


 17%|█▋        | 34/200 [01:54<09:21,  3.38s/it]

Iteration 33: Loss theta = 0.2388618278503418


 18%|█▊        | 35/200 [01:57<09:14,  3.36s/it]

Iteration 34: Loss theta = 0.257469439804554


 18%|█▊        | 36/200 [02:00<09:06,  3.33s/it]

Iteration 35: Loss theta = 0.25315178573131564


 18%|█▊        | 37/200 [02:04<09:02,  3.33s/it]

Iteration 36: Loss theta = 0.25003273636102674


 19%|█▉        | 38/200 [02:07<09:06,  3.37s/it]

Iteration 37: Loss theta = 0.23356174290180207


 20%|█▉        | 39/200 [02:11<09:00,  3.35s/it]

Iteration 38: Loss theta = 0.22553827613592148


 20%|██        | 40/200 [02:14<08:56,  3.35s/it]

Iteration 39: Loss theta = 0.21922504246234895


 20%|██        | 41/200 [02:17<08:51,  3.34s/it]

Iteration 40: Loss theta = 0.2311170694231987


 21%|██        | 42/200 [02:21<08:55,  3.39s/it]

Iteration 41: Loss theta = 0.21844761729240417


 22%|██▏       | 43/200 [02:24<08:48,  3.37s/it]

Iteration 42: Loss theta = 0.20690975725650787


 22%|██▏       | 44/200 [02:27<08:39,  3.33s/it]

Iteration 43: Loss theta = 0.2026399153470993


 22%|██▎       | 45/200 [02:31<08:35,  3.33s/it]

Iteration 44: Loss theta = 0.19941852182149888


 23%|██▎       | 46/200 [02:34<08:29,  3.31s/it]

Iteration 45: Loss theta = 0.2003303188085556


 24%|██▎       | 47/200 [02:37<08:33,  3.35s/it]

Iteration 46: Loss theta = 0.20682376742362976


 24%|██▍       | 48/200 [02:41<08:28,  3.34s/it]

Iteration 47: Loss theta = 0.20513490885496138


 24%|██▍       | 49/200 [02:44<08:20,  3.32s/it]

Iteration 48: Loss theta = 0.2046734619140625


 25%|██▌       | 50/200 [02:47<08:15,  3.30s/it]

Iteration 49: Loss theta = 0.21228000313043593


 26%|██▌       | 51/200 [02:50<08:10,  3.29s/it]

Iteration 50: Loss theta = 0.19906631231307984


 26%|██▌       | 52/200 [02:54<08:08,  3.30s/it]

Iteration 51: Loss theta = 0.22842120230197907


 26%|██▋       | 53/200 [02:57<08:04,  3.30s/it]

Iteration 52: Loss theta = 0.23037175387144088


 27%|██▋       | 54/200 [03:00<08:00,  3.29s/it]

Iteration 53: Loss theta = 0.21299811601638793


 28%|██▊       | 55/200 [03:04<07:55,  3.28s/it]

Iteration 54: Loss theta = 0.2022278141975403


 28%|██▊       | 56/200 [03:07<07:50,  3.27s/it]

Iteration 55: Loss theta = 0.194459790289402


 28%|██▊       | 57/200 [03:10<07:46,  3.26s/it]

Iteration 56: Loss theta = 0.19783901751041413


 29%|██▉       | 58/200 [03:13<07:41,  3.25s/it]

Iteration 57: Loss theta = 0.21781790584325791


 30%|██▉       | 59/200 [03:17<07:49,  3.33s/it]

Iteration 58: Loss theta = 0.21203755110502243


 30%|███       | 60/200 [03:20<07:46,  3.34s/it]

Iteration 59: Loss theta = 0.21025512009859085


 30%|███       | 61/200 [03:23<07:38,  3.30s/it]

Iteration 60: Loss theta = 0.19976290494203566


 31%|███       | 62/200 [03:27<07:34,  3.29s/it]

Iteration 61: Loss theta = 0.21817836970090865


 32%|███▏      | 63/200 [03:30<07:29,  3.28s/it]

Iteration 62: Loss theta = 0.2076773962378502


 32%|███▏      | 64/200 [03:33<07:24,  3.27s/it]

Iteration 63: Loss theta = 0.19087860733270645


 32%|███▎      | 65/200 [03:36<07:21,  3.27s/it]

Iteration 64: Loss theta = 0.215469414293766


 33%|███▎      | 66/200 [03:40<07:16,  3.26s/it]

Iteration 65: Loss theta = 0.23174187272787095


 34%|███▎      | 67/200 [03:43<07:12,  3.25s/it]

Iteration 66: Loss theta = 0.20464612424373627


 34%|███▍      | 68/200 [03:46<07:07,  3.24s/it]

Iteration 67: Loss theta = 0.20305191218852997


 34%|███▍      | 69/200 [03:50<07:13,  3.31s/it]

Iteration 68: Loss theta = 0.20862243324518204


 35%|███▌      | 70/200 [03:53<07:13,  3.33s/it]

Iteration 69: Loss theta = 0.21796975076198577


 36%|███▌      | 71/200 [03:56<07:12,  3.35s/it]

Iteration 70: Loss theta = 0.23041904896497725


 36%|███▌      | 72/200 [04:00<07:10,  3.37s/it]

Iteration 71: Loss theta = 0.24856689751148223


 36%|███▋      | 73/200 [04:03<07:09,  3.38s/it]

Iteration 72: Loss theta = 0.4018893361091614


 37%|███▋      | 74/200 [04:06<07:02,  3.35s/it]

Iteration 73: Loss theta = 0.3858343869447708


 38%|███▊      | 75/200 [04:10<06:55,  3.32s/it]

Iteration 74: Loss theta = 0.35878256022930144


 38%|███▊      | 76/200 [04:13<06:49,  3.30s/it]

Iteration 75: Loss theta = 0.33820410311222077


 38%|███▊      | 77/200 [04:16<06:53,  3.36s/it]

Iteration 76: Loss theta = 0.32171649813652037


 39%|███▉      | 78/200 [04:20<06:49,  3.36s/it]

Iteration 77: Loss theta = 0.307651504278183


 40%|███▉      | 79/200 [04:23<06:47,  3.37s/it]

Iteration 78: Loss theta = 0.37055274784564973


 40%|████      | 80/200 [04:27<06:44,  3.37s/it]

Iteration 79: Loss theta = 0.3432330548763275


 40%|████      | 81/200 [04:30<06:43,  3.39s/it]

Iteration 80: Loss theta = 0.3169824683666229


 41%|████      | 82/200 [04:33<06:41,  3.40s/it]

Iteration 81: Loss theta = 0.2997568386793137


 42%|████▏     | 83/200 [04:37<06:36,  3.39s/it]

Iteration 82: Loss theta = 0.2888945162296295


 42%|████▏     | 84/200 [04:40<06:36,  3.42s/it]

Iteration 83: Loss theta = 0.28010410487651827


 42%|████▎     | 85/200 [04:44<06:32,  3.41s/it]

Iteration 84: Loss theta = 0.27277836441993714


 43%|████▎     | 86/200 [04:47<06:29,  3.42s/it]

Iteration 85: Loss theta = 0.2734332764148712


 44%|████▎     | 87/200 [04:50<06:20,  3.37s/it]

Iteration 86: Loss theta = 0.2698019081354141


 44%|████▍     | 88/200 [04:54<06:16,  3.36s/it]

Iteration 87: Loss theta = 0.2622044751048088


 44%|████▍     | 89/200 [04:57<06:13,  3.36s/it]

Iteration 88: Loss theta = 0.25759112030267717


 45%|████▌     | 90/200 [05:00<06:09,  3.36s/it]

Iteration 89: Loss theta = 0.25804042875766753


 46%|████▌     | 91/200 [05:04<06:01,  3.32s/it]

Iteration 90: Loss theta = 0.254401895403862


 46%|████▌     | 92/200 [05:07<06:00,  3.33s/it]

Iteration 91: Loss theta = 0.25510322153568266


 46%|████▋     | 93/200 [05:10<05:57,  3.34s/it]

Iteration 92: Loss theta = 0.24637793481349946


 47%|████▋     | 94/200 [05:14<05:54,  3.34s/it]

Iteration 93: Loss theta = 0.24792832285165786


 48%|████▊     | 95/200 [05:17<05:52,  3.36s/it]

Iteration 94: Loss theta = 0.2542822748422623


 48%|████▊     | 96/200 [05:20<05:45,  3.32s/it]

Iteration 95: Loss theta = 0.24842346906661988


 48%|████▊     | 97/200 [05:24<05:40,  3.31s/it]

Iteration 96: Loss theta = 0.2374110358953476


 49%|████▉     | 98/200 [05:27<05:35,  3.28s/it]

Iteration 97: Loss theta = 0.23473829090595244


 50%|████▉     | 99/200 [05:30<05:31,  3.28s/it]

Iteration 98: Loss theta = 0.24338258653879166


 50%|█████     | 100/200 [05:33<05:29,  3.29s/it]

Iteration 99: Loss theta = 0.24478418946266176


 50%|█████     | 101/200 [05:37<05:25,  3.29s/it]

Iteration 100: Loss theta = 0.23796590209007262


 51%|█████     | 102/200 [05:40<05:24,  3.32s/it]

Iteration 101: Loss theta = 0.23196464866399766


 52%|█████▏    | 103/200 [05:43<05:21,  3.32s/it]

Iteration 102: Loss theta = 0.23416199624538422


 52%|█████▏    | 104/200 [05:47<05:21,  3.35s/it]

Iteration 103: Loss theta = 0.2377351215481758


 52%|█████▎    | 105/200 [05:50<05:18,  3.35s/it]

Iteration 104: Loss theta = 0.23235002845525743


 53%|█████▎    | 106/200 [05:53<05:12,  3.32s/it]

Iteration 105: Loss theta = 0.2378487828373909


 54%|█████▎    | 107/200 [05:57<05:07,  3.31s/it]

Iteration 106: Loss theta = 0.24119318187236785


 54%|█████▍    | 108/200 [06:00<05:04,  3.31s/it]

Iteration 107: Loss theta = 0.2534210321307182


 55%|█████▍    | 109/200 [06:04<05:07,  3.37s/it]

Iteration 108: Loss theta = 0.2296054795384407


 55%|█████▌    | 110/200 [06:07<05:00,  3.34s/it]

Iteration 109: Loss theta = 0.225431107878685


 56%|█████▌    | 111/200 [06:10<04:56,  3.33s/it]

Iteration 110: Loss theta = 0.22553344190120697


 56%|█████▌    | 112/200 [06:13<04:52,  3.33s/it]

Iteration 111: Loss theta = 0.22657135456800462


 56%|█████▋    | 113/200 [06:17<04:49,  3.33s/it]

Iteration 112: Loss theta = 0.21582126319408418


 57%|█████▋    | 114/200 [06:20<04:47,  3.34s/it]

Iteration 113: Loss theta = 0.21958546578884125


 57%|█████▊    | 115/200 [06:23<04:43,  3.33s/it]

Iteration 114: Loss theta = 0.22308641105890273


 58%|█████▊    | 116/200 [06:27<04:40,  3.34s/it]

Iteration 115: Loss theta = 0.2421724584698677


 58%|█████▊    | 117/200 [06:30<04:37,  3.34s/it]

Iteration 116: Loss theta = 0.226211079955101


 59%|█████▉    | 118/200 [06:33<04:30,  3.29s/it]

Iteration 117: Loss theta = 0.23544320285320283


 60%|█████▉    | 119/200 [06:36<04:18,  3.19s/it]

Iteration 118: Loss theta = 0.23896012604236602


 60%|██████    | 120/200 [06:39<04:09,  3.12s/it]

Iteration 119: Loss theta = 0.2385724028944969


 60%|██████    | 121/200 [06:42<04:04,  3.09s/it]

Iteration 120: Loss theta = 0.25972881734371184


 61%|██████    | 122/200 [06:45<03:58,  3.06s/it]

Iteration 121: Loss theta = 0.23949738085269928


 62%|██████▏   | 123/200 [06:48<03:51,  3.00s/it]

Iteration 122: Loss theta = 0.2258579370379448


 62%|██████▏   | 124/200 [06:51<03:46,  2.98s/it]

Iteration 123: Loss theta = 0.21689656406641006


 62%|██████▎   | 125/200 [06:54<03:42,  2.96s/it]

Iteration 124: Loss theta = 0.21959061324596404


 63%|██████▎   | 126/200 [06:57<03:37,  2.94s/it]

Iteration 125: Loss theta = 0.20662089318037033


 64%|██████▎   | 127/200 [07:00<03:33,  2.92s/it]

Iteration 126: Loss theta = 0.20475228279829025


 64%|██████▍   | 128/200 [07:03<03:30,  2.92s/it]

Iteration 127: Loss theta = 0.20220688104629517


 64%|██████▍   | 129/200 [07:06<03:27,  2.93s/it]

Iteration 128: Loss theta = 0.19863269835710526


 65%|██████▌   | 130/200 [07:09<03:25,  2.93s/it]

Iteration 129: Loss theta = 0.19763720095157622


 66%|██████▌   | 131/200 [07:11<03:20,  2.91s/it]

Iteration 130: Loss theta = 0.1931196591258049


 66%|██████▌   | 132/200 [07:14<03:16,  2.89s/it]

Iteration 131: Loss theta = 0.19369332134723663


 66%|██████▋   | 133/200 [07:17<03:14,  2.90s/it]

Iteration 132: Loss theta = 0.19027141809463502


 67%|██████▋   | 134/200 [07:20<03:11,  2.90s/it]

Iteration 133: Loss theta = 0.1872064056992531


 68%|██████▊   | 135/200 [07:23<03:08,  2.90s/it]

Iteration 134: Loss theta = 0.1904434809088707


 68%|██████▊   | 136/200 [07:26<03:04,  2.88s/it]

Iteration 135: Loss theta = 0.1843957069516182


 68%|██████▊   | 137/200 [07:29<03:01,  2.87s/it]

Iteration 136: Loss theta = 0.18394894003868104


 69%|██████▉   | 138/200 [07:32<02:58,  2.88s/it]

Iteration 137: Loss theta = 0.1864926451444626


 70%|██████▉   | 139/200 [07:34<02:55,  2.88s/it]

Iteration 138: Loss theta = 0.17894918352365494


 70%|███████   | 140/200 [07:37<02:52,  2.87s/it]

Iteration 139: Loss theta = 0.17837995141744614


 70%|███████   | 141/200 [07:40<02:48,  2.86s/it]

Iteration 140: Loss theta = 0.17935931503772737


 71%|███████   | 142/200 [07:43<02:45,  2.85s/it]

Iteration 141: Loss theta = 0.17696825951337813


 72%|███████▏  | 143/200 [07:46<02:41,  2.84s/it]

Iteration 142: Loss theta = 0.18031349897384644


 72%|███████▏  | 144/200 [07:49<02:38,  2.84s/it]

Iteration 143: Loss theta = 0.17421179205179216


 72%|███████▎  | 145/200 [07:51<02:36,  2.85s/it]

Iteration 144: Loss theta = 0.17686033606529236


 73%|███████▎  | 146/200 [07:54<02:34,  2.87s/it]

Iteration 145: Loss theta = 0.17366640090942384


 74%|███████▎  | 147/200 [07:57<02:31,  2.87s/it]

Iteration 146: Loss theta = 0.17482514411211014


 74%|███████▍  | 148/200 [08:00<02:30,  2.90s/it]

Iteration 147: Loss theta = 0.1719653668999672


 74%|███████▍  | 149/200 [08:03<02:30,  2.96s/it]

Iteration 148: Loss theta = 0.17048534989356995


 75%|███████▌  | 150/200 [08:06<02:27,  2.95s/it]

Iteration 149: Loss theta = 0.17111847668886185


 76%|███████▌  | 151/200 [08:09<02:24,  2.94s/it]

Iteration 150: Loss theta = 0.1677008807659149


 76%|███████▌  | 152/200 [08:12<02:22,  2.96s/it]

Iteration 151: Loss theta = 0.16920490562915802


 76%|███████▋  | 153/200 [08:15<02:18,  2.94s/it]

Iteration 152: Loss theta = 0.16988923817873


 77%|███████▋  | 154/200 [08:18<02:15,  2.95s/it]

Iteration 153: Loss theta = 0.16816768795251846


 78%|███████▊  | 155/200 [08:21<02:11,  2.93s/it]

Iteration 154: Loss theta = 0.16721186488866807


 78%|███████▊  | 156/200 [08:24<02:09,  2.93s/it]

Iteration 155: Loss theta = 0.16451163619756698


 78%|███████▊  | 157/200 [08:27<02:07,  2.96s/it]

Iteration 156: Loss theta = 0.16827428698539734


 79%|███████▉  | 158/200 [08:30<02:04,  2.96s/it]

Iteration 157: Loss theta = 0.16311472833156584


 80%|███████▉  | 159/200 [08:33<02:01,  2.96s/it]

Iteration 158: Loss theta = 0.16784572154283522


 80%|████████  | 160/200 [08:36<01:58,  2.96s/it]

Iteration 159: Loss theta = 0.16487766206264495


 80%|████████  | 161/200 [08:39<01:55,  2.97s/it]

Iteration 160: Loss theta = 0.16256535917520523


 81%|████████  | 162/200 [08:42<01:52,  2.96s/it]

Iteration 161: Loss theta = 0.16250210583209992


 82%|████████▏ | 163/200 [08:45<01:49,  2.95s/it]

Iteration 162: Loss theta = 0.16329282730817796


 82%|████████▏ | 164/200 [08:48<01:46,  2.95s/it]

Iteration 163: Loss theta = 0.16028216749429702


 82%|████████▎ | 165/200 [08:50<01:43,  2.94s/it]

Iteration 164: Loss theta = 0.16302824079990386


 83%|████████▎ | 166/200 [08:53<01:40,  2.95s/it]

Iteration 165: Loss theta = 0.16079473048448562


 84%|████████▎ | 167/200 [08:56<01:37,  2.94s/it]

Iteration 166: Loss theta = 0.1577913409471512


 84%|████████▍ | 168/200 [08:59<01:33,  2.94s/it]

Iteration 167: Loss theta = 0.15975818783044815


 84%|████████▍ | 169/200 [09:02<01:30,  2.93s/it]

Iteration 168: Loss theta = 0.15845843970775605


 85%|████████▌ | 170/200 [09:05<01:27,  2.93s/it]

Iteration 169: Loss theta = 0.15835634291172027


 86%|████████▌ | 171/200 [09:08<01:24,  2.93s/it]

Iteration 170: Loss theta = 0.16003736436367036


 86%|████████▌ | 172/200 [09:11<01:21,  2.92s/it]

Iteration 171: Loss theta = 0.15913216799497604


 86%|████████▋ | 173/200 [09:14<01:18,  2.91s/it]

Iteration 172: Loss theta = 0.16077278286218644


 87%|████████▋ | 174/200 [09:17<01:15,  2.92s/it]

Iteration 173: Loss theta = 0.15877119839191436


 88%|████████▊ | 175/200 [09:20<01:12,  2.92s/it]

Iteration 174: Loss theta = 0.15823621660470963


 88%|████████▊ | 176/200 [09:23<01:09,  2.91s/it]

Iteration 175: Loss theta = 0.15836333751678466


 88%|████████▊ | 177/200 [09:26<01:06,  2.91s/it]

Iteration 176: Loss theta = 0.16149768561124803


 89%|████████▉ | 178/200 [09:28<01:03,  2.90s/it]

Iteration 177: Loss theta = 0.15728510469198226


 90%|████████▉ | 179/200 [09:31<01:00,  2.90s/it]

Iteration 178: Loss theta = 0.15878450006246567


 90%|█████████ | 180/200 [09:34<00:58,  2.91s/it]

Iteration 179: Loss theta = 0.15472286760807039


 90%|█████████ | 181/200 [09:37<00:55,  2.91s/it]

Iteration 180: Loss theta = 0.1569376364350319


 91%|█████████ | 182/200 [09:40<00:52,  2.91s/it]

Iteration 181: Loss theta = 0.15734679251909256


 92%|█████████▏| 183/200 [09:43<00:49,  2.91s/it]

Iteration 182: Loss theta = 0.1548624101281166


 92%|█████████▏| 184/200 [09:46<00:46,  2.92s/it]

Iteration 183: Loss theta = 0.15297348737716676


 92%|█████████▎| 185/200 [09:49<00:43,  2.91s/it]

Iteration 184: Loss theta = 0.15239366233348847


 93%|█████████▎| 186/200 [09:52<00:40,  2.93s/it]

Iteration 185: Loss theta = 0.1520374086499214


 94%|█████████▎| 187/200 [09:55<00:38,  2.93s/it]

Iteration 186: Loss theta = 0.152179094851017


 94%|█████████▍| 188/200 [09:58<00:35,  2.92s/it]

Iteration 187: Loss theta = 0.1548099061846733


 94%|█████████▍| 189/200 [10:01<00:32,  2.92s/it]

Iteration 188: Loss theta = 0.1525434023141861


 95%|█████████▌| 190/200 [10:03<00:29,  2.92s/it]

Iteration 189: Loss theta = 0.15031650185585022


 96%|█████████▌| 191/200 [10:06<00:26,  2.92s/it]

Iteration 190: Loss theta = 0.1529323935508728


 96%|█████████▌| 192/200 [10:09<00:23,  2.92s/it]

Iteration 191: Loss theta = 0.15011477172374726


 96%|█████████▋| 193/200 [10:12<00:20,  2.92s/it]

Iteration 192: Loss theta = 0.15314338088035584


 97%|█████████▋| 194/200 [10:15<00:17,  2.92s/it]

Iteration 193: Loss theta = 0.15195127546787263


 98%|█████████▊| 195/200 [10:18<00:14,  2.92s/it]

Iteration 194: Loss theta = 0.15048650473356248


 98%|█████████▊| 196/200 [10:21<00:11,  2.92s/it]

Iteration 195: Loss theta = 0.15235178232192992


 98%|█████████▊| 197/200 [10:24<00:08,  2.91s/it]

Iteration 196: Loss theta = 0.14977429270744325


 99%|█████████▉| 198/200 [10:27<00:05,  2.91s/it]

Iteration 197: Loss theta = 0.1497730404138565


100%|█████████▉| 199/200 [10:30<00:02,  2.91s/it]

Iteration 198: Loss theta = 0.15123320788145064


100%|██████████| 200/200 [10:33<00:00,  3.17s/it]

Iteration 199: Loss theta = 0.14964014589786528





In [11]:
model_test = my_model(m, d, mc_range, array).to(dev)
model_test.load_state_dict(torch.load('model_1.pt', weights_only=True))
running_loss = 0.0 

with torch.no_grad():
    for data in test_loader:
        X, theta_true= data[0].to(dev), data[1].to(dev)
        theta_pred = model_test(X)
        loss = RMSPE(theta_pred, theta_true)
        running_loss += loss.item()

Acc_theta = running_loss/len(test_loader)

print("Mixed RMSE DoA =", Acc_theta)

Mixed RMSE DoA = 0.14915905892848969


In [12]:
observations, angles = generate_data(100, t, d, 0, 0, array, mc_range, coherent)
Acc_theta_0 = RMSPE(model_test(observations.to(dev)), angles)

print("0dB RMSE DoA =", Acc_theta_0)


angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

0dB RMSE DoA = tensor(0.1903, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.2966, device='cuda:0')


In [13]:
observations, angles = generate_data(100, t, d, 5, 5, array, mc_range, coherent)
Acc_theta_5 = RMSPE(model_test(observations.to(dev)), angles)

print("5dB RMSE DoA =", Acc_theta_5)

angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

5dB RMSE DoA = tensor(0.1567, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.2387, device='cuda:0')


In [14]:
observations, angles = generate_data(100, t, d, 10, 10, array, mc_range, coherent)
Acc_theta_10 = RMSPE(model_test(observations.to(dev)), angles)

print("10dB RMSE DoA =", Acc_theta_10)

angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

10dB RMSE DoA = tensor(0.1546, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.1963, device='cuda:0')


In [15]:
observations, angles = generate_data(100, t, d, 15, 15, array, mc_range, coherent)
Acc_theta_15 = RMSPE(model_test(observations.to(dev)), angles)

print("15dB RMSE DoA =", Acc_theta_15)

angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

15dB RMSE DoA = tensor(0.1365, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.1125, device='cuda:0')


In [16]:
observations, angles = generate_data(100, t, d, 20, 20, array, mc_range, coherent)
Acc_theta_20 = RMSPE(model_test(observations.to(dev)), angles)

print("20dB RMSE DoA =", Acc_theta_20)

angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

20dB RMSE DoA = tensor(0.1310, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.0299, device='cuda:0')
