In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from utils import *
from tqdm import tqdm

dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
d = 3
m = 8
t = 200
n = 30000
snr_min = 0
snr_max = 30
lamda = 0.2
radius = 0.1
mc_range = 3
coherent = False

array = UCA(m, lamda)
array.build_array(radius=0.1)
array.build_array_manifold()
array.build_transform_matrices_from_array_manifold()

observations, angles = generate_data(n, t, d, snr_min, snr_max, array, mc_range, coherent)

In [3]:
class my_model(nn.Module):

    def __init__(self, m: int, d: int, array, device: str = dev):
        
        super().__init__()
        self.m = m
        self.d = d
        self.array = array

        self.bn = nn.BatchNorm1d(2*self.m, device=device)
        self.rnn = nn.GRU(input_size=2*self.m, hidden_size=2*self.m, num_layers=1, batch_first=True, device=device)
        self.fc = nn.Linear(in_features=2*self.m, out_features=2*self.m*self.m, device=device)
        self.mlp = nn.Sequential(nn.Linear(in_features=array.nbSamples_spectrum, out_features=2*self.m, device=device), nn.ReLU(),
                                 nn.Linear(in_features=2*self.m, out_features=2*self.m, device=device), nn.ReLU(),
                                 nn.Linear(in_features=2*self.m, out_features=2*self.m, device=device), nn.ReLU(),
                                 nn.Linear(in_features=2*self.m, out_features=self.d, device=device))

    def forward(self, X: torch.Tensor):
        
        X = torch.cat((torch.real(X), torch.imag(X)), dim=-1)
        X = self.bn(X.transpose(1, 2)).transpose(1, 2)
        _, X = self.rnn(X)

        cov = self.fc(X[-1])
        cov = cov.reshape(-1, 2, self.m, self.m)
        cov = cov[:, 0, :, :] + 1j * cov[:, 1, :, :]
        cov = cov @ cov.conj().transpose(1, 2)
        vals, vecs = torch.linalg.eigh(cov)
        idx = torch.sort(torch.abs(vals), dim=1)[1].unsqueeze(dim=1).repeat(repeats=(1, self.m, 1))
        vecs = torch.gather(vecs, dim=2, index=idx)
        En = vecs[:, :, :(self.m - self.d)]
        
        y = get_spectrum(En, array=self.array)
        theta = self.mlp(y)
        
        return theta

In [4]:
nbEpoches = 200
lr = 1e-2
wd = 1e-8
batchSize = 512

x_train, x_valid, theta_train, theta_valid = train_test_split(observations, angles, test_size=0.2)
x_train, x_test, theta_train, theta_test = train_test_split(x_train, theta_train, test_size=0.2)

train_set = DATASET(x_train, theta_train)
valid_set = DATASET(x_valid, theta_valid)
test_set = DATASET(x_test, theta_test)

train_loader = DataLoader(train_set, batch_size=batchSize, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batchSize, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batchSize, shuffle=False)

In [5]:
model = my_model(m, d, array)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.2)
Loss, Loss_theta, Val_theta = [], [], []
bestVal = 1000

for i in tqdm(range(nbEpoches)):
    # Train
    running_loss = 0.0
    for data in train_loader:
        X, theta_true = data[0].to(dev), data[1].to(dev)
        optimizer.zero_grad()
        theta_pred = model(X)
        loss = RMSPE(theta_pred, theta_true) 
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    Loss.append(running_loss/len(train_loader))

    # scheduler.step()

    # Validation 
    with torch.no_grad():
        running_loss = 0.0
        for data in valid_loader:
            X, theta_true = data[0].to(dev), data[1].to(dev)
            theta_pred = model(X)
            loss = RMSPE(theta_pred, theta_true)
            running_loss += loss.item()
        
        Val_theta.append(running_loss/len(test_loader))

        if Val_theta[i] < bestVal:
            bestVal = Val_theta[i]
            torch.save(model.state_dict(), 'model_2.pt')

    print("Iteration {}: Loss theta = {}".format(i, Loss[-1]))

  0%|          | 1/200 [00:03<12:42,  3.83s/it]

Iteration 0: Loss theta = 0.4850638587223856


  1%|          | 2/200 [00:07<12:09,  3.69s/it]

Iteration 1: Loss theta = 0.43088793597723307


  2%|▏         | 3/200 [00:10<11:56,  3.64s/it]

Iteration 2: Loss theta = 0.4045361233384986


  2%|▏         | 4/200 [00:14<11:45,  3.60s/it]

Iteration 3: Loss theta = 0.37121225893497467


  2%|▎         | 5/200 [00:18<11:40,  3.59s/it]

Iteration 4: Loss theta = 0.35123010607142197


  3%|▎         | 6/200 [00:21<11:39,  3.61s/it]

Iteration 5: Loss theta = 0.3094053135106438


  4%|▎         | 7/200 [00:25<11:32,  3.59s/it]

Iteration 6: Loss theta = 0.2845438883492821


  4%|▍         | 8/200 [00:28<11:26,  3.57s/it]

Iteration 7: Loss theta = 0.29592091315671015


  4%|▍         | 9/200 [00:32<11:21,  3.57s/it]

Iteration 8: Loss theta = 0.24605584850436762


  5%|▌         | 10/200 [00:36<11:22,  3.59s/it]

Iteration 9: Loss theta = 0.23160921037197113


  6%|▌         | 11/200 [00:39<11:18,  3.59s/it]

Iteration 10: Loss theta = 0.2326827519818356


  6%|▌         | 12/200 [00:43<11:14,  3.59s/it]

Iteration 11: Loss theta = 0.22532136542232414


  6%|▋         | 13/200 [00:46<11:06,  3.56s/it]

Iteration 12: Loss theta = 0.22925517276713722


  7%|▋         | 14/200 [00:50<11:01,  3.56s/it]

Iteration 13: Loss theta = 0.20442813007455124


  8%|▊         | 15/200 [00:53<10:55,  3.54s/it]

Iteration 14: Loss theta = 0.17871801829651782


  8%|▊         | 16/200 [00:57<10:54,  3.55s/it]

Iteration 15: Loss theta = 0.16831307740587936


  8%|▊         | 17/200 [01:01<10:55,  3.58s/it]

Iteration 16: Loss theta = 0.15552247982276113


  9%|▉         | 18/200 [01:04<10:52,  3.58s/it]

Iteration 17: Loss theta = 0.1472403524737609


 10%|▉         | 19/200 [01:08<10:44,  3.56s/it]

Iteration 18: Loss theta = 0.13707696215102547


 10%|█         | 20/200 [01:11<10:42,  3.57s/it]

Iteration 19: Loss theta = 0.12821339266864876


 10%|█         | 21/200 [01:15<10:35,  3.55s/it]

Iteration 20: Loss theta = 0.11892823972984364


 11%|█         | 22/200 [01:18<10:30,  3.54s/it]

Iteration 21: Loss theta = 0.115876783273722


 12%|█▏        | 23/200 [01:22<10:24,  3.53s/it]

Iteration 22: Loss theta = 0.11379530967066162


 12%|█▏        | 24/200 [01:25<10:30,  3.58s/it]

Iteration 23: Loss theta = 0.10658065582576551


 12%|█▎        | 25/200 [01:29<10:28,  3.59s/it]

Iteration 24: Loss theta = 0.10146746451133176


 13%|█▎        | 26/200 [01:33<10:24,  3.59s/it]

Iteration 25: Loss theta = 0.10634888571343924


 14%|█▎        | 27/200 [01:36<10:20,  3.59s/it]

Iteration 26: Loss theta = 0.09915829900848239


 14%|█▍        | 28/200 [01:40<10:21,  3.61s/it]

Iteration 27: Loss theta = 0.10074253262657869


 14%|█▍        | 29/200 [01:43<10:11,  3.58s/it]

Iteration 28: Loss theta = 0.10036543913577732


 15%|█▌        | 30/200 [01:47<10:08,  3.58s/it]

Iteration 29: Loss theta = 0.09914334236006987


 16%|█▌        | 31/200 [01:50<10:01,  3.56s/it]

Iteration 30: Loss theta = 0.09270775729888364


 16%|█▌        | 32/200 [01:54<09:58,  3.56s/it]

Iteration 31: Loss theta = 0.09375265456343952


 16%|█▋        | 33/200 [01:58<09:51,  3.54s/it]

Iteration 32: Loss theta = 0.08944862512381453


 17%|█▋        | 34/200 [02:01<09:48,  3.54s/it]

Iteration 33: Loss theta = 0.09178556894001208


 18%|█▊        | 35/200 [02:05<09:44,  3.54s/it]

Iteration 34: Loss theta = 0.08959570389829184


 18%|█▊        | 36/200 [02:08<09:40,  3.54s/it]

Iteration 35: Loss theta = 0.0934444018884709


 18%|█▊        | 37/200 [02:12<09:40,  3.56s/it]

Iteration 36: Loss theta = 0.08856908173153274


 19%|█▉        | 38/200 [02:15<09:39,  3.58s/it]

Iteration 37: Loss theta = 0.08735950604865425


 20%|█▉        | 39/200 [02:19<09:36,  3.58s/it]

Iteration 38: Loss theta = 0.08707575715686146


 20%|██        | 40/200 [02:22<09:30,  3.57s/it]

Iteration 39: Loss theta = 0.0859089870201914


 20%|██        | 41/200 [02:26<09:29,  3.58s/it]

Iteration 40: Loss theta = 0.0892042811763914


 21%|██        | 42/200 [02:30<09:27,  3.59s/it]

Iteration 41: Loss theta = 0.08417783324655734


 22%|██▏       | 43/200 [02:33<09:25,  3.60s/it]

Iteration 42: Loss theta = 0.08516464519657586


 22%|██▏       | 44/200 [02:37<09:25,  3.63s/it]

Iteration 43: Loss theta = 0.08506103193289355


 22%|██▎       | 45/200 [02:41<09:19,  3.61s/it]

Iteration 44: Loss theta = 0.08275781120908887


 23%|██▎       | 46/200 [02:44<09:15,  3.61s/it]

Iteration 45: Loss theta = 0.08255263123857348


 24%|██▎       | 47/200 [02:48<09:12,  3.61s/it]

Iteration 46: Loss theta = 0.08711427685461547


 24%|██▍       | 48/200 [02:51<09:10,  3.62s/it]

Iteration 47: Loss theta = 0.08425929554198917


 24%|██▍       | 49/200 [02:55<09:04,  3.61s/it]

Iteration 48: Loss theta = 0.08127156664666377


 25%|██▌       | 50/200 [02:59<09:01,  3.61s/it]

Iteration 49: Loss theta = 0.0879426576981419


 26%|██▌       | 51/200 [03:02<09:02,  3.64s/it]

Iteration 50: Loss theta = 0.08182474872783611


 26%|██▌       | 52/200 [03:06<08:59,  3.64s/it]

Iteration 51: Loss theta = 0.0830942839384079


 26%|██▋       | 53/200 [03:10<08:51,  3.62s/it]

Iteration 52: Loss theta = 0.08272926078030937


 27%|██▋       | 54/200 [03:13<08:43,  3.58s/it]

Iteration 53: Loss theta = 0.08038402878140148


 28%|██▊       | 55/200 [03:17<08:40,  3.59s/it]

Iteration 54: Loss theta = 0.08230364322662354


 28%|██▊       | 56/200 [03:20<08:36,  3.59s/it]

Iteration 55: Loss theta = 0.07713125881395842


 28%|██▊       | 57/200 [03:24<08:41,  3.65s/it]

Iteration 56: Loss theta = 0.08109527885129578


 29%|██▉       | 58/200 [03:28<08:33,  3.62s/it]

Iteration 57: Loss theta = 0.08043012513141883


 30%|██▉       | 59/200 [03:31<08:27,  3.60s/it]

Iteration 58: Loss theta = 0.08029455691576004


 30%|███       | 60/200 [03:35<08:21,  3.58s/it]

Iteration 59: Loss theta = 0.07943474383730638


 30%|███       | 61/200 [03:38<08:15,  3.56s/it]

Iteration 60: Loss theta = 0.08016109231271241


 31%|███       | 62/200 [03:42<08:13,  3.58s/it]

Iteration 61: Loss theta = 0.07647812542946715


 32%|███▏      | 63/200 [03:45<08:12,  3.60s/it]

Iteration 62: Loss theta = 0.07554376007694948


 32%|███▏      | 64/200 [03:49<08:05,  3.57s/it]

Iteration 63: Loss theta = 0.07708124894844859


 32%|███▎      | 65/200 [03:53<08:06,  3.61s/it]

Iteration 64: Loss theta = 0.07826805585309078


 33%|███▎      | 66/200 [03:56<07:58,  3.57s/it]

Iteration 65: Loss theta = 0.07590231966031225


 34%|███▎      | 67/200 [04:00<07:51,  3.55s/it]

Iteration 66: Loss theta = 0.0755285938319407


 34%|███▍      | 68/200 [04:03<07:48,  3.55s/it]

Iteration 67: Loss theta = 0.07878662332108147


 34%|███▍      | 69/200 [04:07<07:42,  3.53s/it]

Iteration 68: Loss theta = 0.08074500235287767


 35%|███▌      | 70/200 [04:10<07:36,  3.51s/it]

Iteration 69: Loss theta = 0.07513492025042835


 36%|███▌      | 71/200 [04:14<07:33,  3.52s/it]

Iteration 70: Loss theta = 0.07618320674488419


 36%|███▌      | 72/200 [04:17<07:32,  3.53s/it]

Iteration 71: Loss theta = 0.07651241143283091


 36%|███▋      | 73/200 [04:21<07:29,  3.54s/it]

Iteration 72: Loss theta = 0.073864000408273


 37%|███▋      | 74/200 [04:24<07:28,  3.56s/it]

Iteration 73: Loss theta = 0.07690526270552685


 38%|███▊      | 75/200 [04:28<07:24,  3.56s/it]

Iteration 74: Loss theta = 0.07524947645632844


 38%|███▊      | 76/200 [04:32<07:27,  3.61s/it]

Iteration 75: Loss theta = 0.07534583050169442


 38%|███▊      | 77/200 [04:35<07:22,  3.59s/it]

Iteration 76: Loss theta = 0.07198007345983856


 39%|███▉      | 78/200 [04:39<07:20,  3.61s/it]

Iteration 77: Loss theta = 0.0751533298508117


 40%|███▉      | 79/200 [04:43<07:19,  3.63s/it]

Iteration 78: Loss theta = 0.07646961392540681


 40%|████      | 80/200 [04:46<07:16,  3.64s/it]

Iteration 79: Loss theta = 0.07239034987593952


 40%|████      | 81/200 [04:50<07:12,  3.63s/it]

Iteration 80: Loss theta = 0.0759464300384647


 41%|████      | 82/200 [04:53<07:07,  3.62s/it]

Iteration 81: Loss theta = 0.07861913681814545


 42%|████▏     | 83/200 [04:57<07:00,  3.59s/it]

Iteration 82: Loss theta = 0.07429116474170434


 42%|████▏     | 84/200 [05:01<06:57,  3.60s/it]

Iteration 83: Loss theta = 0.07254177331924438


 42%|████▎     | 85/200 [05:04<06:58,  3.64s/it]

Iteration 84: Loss theta = 0.0722206293752319


 43%|████▎     | 86/200 [05:08<06:57,  3.67s/it]

Iteration 85: Loss theta = 0.070026010470955


 44%|████▎     | 87/200 [05:12<06:55,  3.67s/it]

Iteration 86: Loss theta = 0.07421263111265082


 44%|████▍     | 88/200 [05:15<06:51,  3.68s/it]

Iteration 87: Loss theta = 0.0721693478132549


 44%|████▍     | 89/200 [05:19<06:47,  3.67s/it]

Iteration 88: Loss theta = 0.07143578560728776


 45%|████▌     | 90/200 [05:23<06:42,  3.66s/it]

Iteration 89: Loss theta = 0.07234720199515945


 46%|████▌     | 91/200 [05:26<06:37,  3.64s/it]

Iteration 90: Loss theta = 0.07112081995920132


 46%|████▌     | 92/200 [05:30<06:32,  3.63s/it]

Iteration 91: Loss theta = 0.07230612361117413


 46%|████▋     | 93/200 [05:34<06:31,  3.66s/it]

Iteration 92: Loss theta = 0.07087912312463711


 47%|████▋     | 94/200 [05:37<06:26,  3.65s/it]

Iteration 93: Loss theta = 0.07163271621653908


 48%|████▊     | 95/200 [05:41<06:24,  3.66s/it]

Iteration 94: Loss theta = 0.07103021756598824


 48%|████▊     | 96/200 [05:45<06:18,  3.64s/it]

Iteration 95: Loss theta = 0.07076715207413624


 48%|████▊     | 97/200 [05:48<06:14,  3.63s/it]

Iteration 96: Loss theta = 0.0687786202485624


 49%|████▉     | 98/200 [05:52<06:10,  3.63s/it]

Iteration 97: Loss theta = 0.06943587390215773


 50%|████▉     | 99/200 [05:55<06:06,  3.63s/it]

Iteration 98: Loss theta = 0.06949463388637493


 50%|█████     | 100/200 [05:59<06:02,  3.62s/it]

Iteration 99: Loss theta = 0.06718951306845013


 50%|█████     | 101/200 [06:03<05:56,  3.60s/it]

Iteration 100: Loss theta = 0.06870085824477046


 51%|█████     | 102/200 [06:06<05:54,  3.62s/it]

Iteration 101: Loss theta = 0.06888717785477638


 52%|█████▏    | 103/200 [06:10<05:49,  3.61s/it]

Iteration 102: Loss theta = 0.06985262604920488


 52%|█████▏    | 104/200 [06:13<05:43,  3.58s/it]

Iteration 103: Loss theta = 0.07030271425058968


 52%|█████▎    | 105/200 [06:17<05:39,  3.57s/it]

Iteration 104: Loss theta = 0.07008386991525951


 53%|█████▎    | 106/200 [06:21<05:42,  3.65s/it]

Iteration 105: Loss theta = 0.06929829971570718


 54%|█████▎    | 107/200 [06:24<05:36,  3.62s/it]

Iteration 106: Loss theta = 0.0713289494773275


 54%|█████▍    | 108/200 [06:28<05:31,  3.61s/it]

Iteration 107: Loss theta = 0.0706996290307296


 55%|█████▍    | 109/200 [06:31<05:29,  3.62s/it]

Iteration 108: Loss theta = 0.06985051008431535


 55%|█████▌    | 110/200 [06:35<05:25,  3.61s/it]

Iteration 109: Loss theta = 0.07182095317464125


 56%|█████▌    | 111/200 [06:39<05:19,  3.59s/it]

Iteration 110: Loss theta = 0.07134256002150084


 56%|█████▌    | 112/200 [06:42<05:18,  3.62s/it]

Iteration 111: Loss theta = 0.06980579915015321


 56%|█████▋    | 113/200 [06:46<05:17,  3.65s/it]

Iteration 112: Loss theta = 0.06802513019034737


 57%|█████▋    | 114/200 [06:50<05:11,  3.63s/it]

Iteration 113: Loss theta = 0.07040596420043393


 57%|█████▊    | 115/200 [06:53<05:08,  3.63s/it]

Iteration 114: Loss theta = 0.07260230163994588


 58%|█████▊    | 116/200 [06:57<05:04,  3.62s/it]

Iteration 115: Loss theta = 0.06858126328963983


 58%|█████▊    | 117/200 [07:00<05:01,  3.63s/it]

Iteration 116: Loss theta = 0.06960853522545413


 59%|█████▉    | 118/200 [07:04<04:58,  3.65s/it]

Iteration 117: Loss theta = 0.06857998551506746


 60%|█████▉    | 119/200 [07:08<04:55,  3.65s/it]

Iteration 118: Loss theta = 0.07295907131935421


 60%|██████    | 120/200 [07:11<04:50,  3.63s/it]

Iteration 119: Loss theta = 0.0664103690926966


 60%|██████    | 121/200 [07:15<04:46,  3.63s/it]

Iteration 120: Loss theta = 0.06690796239203528


 61%|██████    | 122/200 [07:19<04:42,  3.62s/it]

Iteration 121: Loss theta = 0.07003432708351236


 62%|██████▏   | 123/200 [07:22<04:37,  3.61s/it]

Iteration 122: Loss theta = 0.06811372406388584


 62%|██████▏   | 124/200 [07:26<04:32,  3.59s/it]

Iteration 123: Loss theta = 0.06548312825984076


 62%|██████▎   | 125/200 [07:29<04:28,  3.58s/it]

Iteration 124: Loss theta = 0.06747404544761307


 63%|██████▎   | 126/200 [07:33<04:26,  3.60s/it]

Iteration 125: Loss theta = 0.0670053006983117


 64%|██████▎   | 127/200 [07:37<04:24,  3.63s/it]

Iteration 126: Loss theta = 0.07198125398472736


 64%|██████▍   | 128/200 [07:40<04:19,  3.60s/it]

Iteration 127: Loss theta = 0.06796745946140666


 64%|██████▍   | 129/200 [07:44<04:14,  3.58s/it]

Iteration 128: Loss theta = 0.07122094742953777


 65%|██████▌   | 130/200 [07:47<04:11,  3.60s/it]

Iteration 129: Loss theta = 0.06530787178168171


 66%|██████▌   | 131/200 [07:51<04:09,  3.62s/it]

Iteration 130: Loss theta = 0.06658315040955418


 66%|██████▌   | 132/200 [07:55<04:07,  3.63s/it]

Iteration 131: Loss theta = 0.06509138565314443


 66%|██████▋   | 133/200 [07:58<04:05,  3.66s/it]

Iteration 132: Loss theta = 0.06395376218776953


 67%|██████▋   | 134/200 [08:02<04:03,  3.69s/it]

Iteration 133: Loss theta = 0.06489944183512737


 68%|██████▊   | 135/200 [08:06<03:59,  3.69s/it]

Iteration 134: Loss theta = 0.0710874076344465


 68%|██████▊   | 136/200 [08:09<03:54,  3.66s/it]

Iteration 135: Loss theta = 0.06983891551039721


 68%|██████▊   | 137/200 [08:13<03:48,  3.63s/it]

Iteration 136: Loss theta = 0.0661154903079334


 69%|██████▉   | 138/200 [08:17<03:46,  3.65s/it]

Iteration 137: Loss theta = 0.06725640742010192


 70%|██████▉   | 139/200 [08:20<03:45,  3.69s/it]

Iteration 138: Loss theta = 0.06631297255425077


 70%|███████   | 140/200 [08:24<03:41,  3.69s/it]

Iteration 139: Loss theta = 0.06722790512599443


 70%|███████   | 141/200 [08:28<03:37,  3.69s/it]

Iteration 140: Loss theta = 0.06445610738898579


 71%|███████   | 142/200 [08:32<03:33,  3.68s/it]

Iteration 141: Loss theta = 0.06540798837024915


 72%|███████▏  | 143/200 [08:35<03:29,  3.68s/it]

Iteration 142: Loss theta = 0.0641907749599532


 72%|███████▏  | 144/200 [08:39<03:25,  3.67s/it]

Iteration 143: Loss theta = 0.0664623391471411


 72%|███████▎  | 145/200 [08:42<03:21,  3.66s/it]

Iteration 144: Loss theta = 0.06483866382194192


 73%|███████▎  | 146/200 [08:46<03:17,  3.65s/it]

Iteration 145: Loss theta = 0.06764707459430945


 74%|███████▎  | 147/200 [08:50<03:12,  3.64s/it]

Iteration 146: Loss theta = 0.06501970283294979


 74%|███████▍  | 148/200 [08:53<03:10,  3.66s/it]

Iteration 147: Loss theta = 0.06606440128464448


 74%|███████▍  | 149/200 [08:57<03:05,  3.64s/it]

Iteration 148: Loss theta = 0.0644140233726878


 75%|███████▌  | 150/200 [09:01<03:01,  3.64s/it]

Iteration 149: Loss theta = 0.06490140427884303


 76%|███████▌  | 151/200 [09:04<02:58,  3.63s/it]

Iteration 150: Loss theta = 0.06436505403957869


 76%|███████▌  | 152/200 [09:08<02:54,  3.64s/it]

Iteration 151: Loss theta = 0.06450199571095015


 76%|███████▋  | 153/200 [09:12<02:50,  3.63s/it]

Iteration 152: Loss theta = 0.06686809149227645


 77%|███████▋  | 154/200 [09:15<02:47,  3.64s/it]

Iteration 153: Loss theta = 0.06630946087994073


 78%|███████▊  | 155/200 [09:19<02:43,  3.64s/it]

Iteration 154: Loss theta = 0.06455530520332486


 78%|███████▊  | 156/200 [09:22<02:40,  3.64s/it]

Iteration 155: Loss theta = 0.06548188656176392


 78%|███████▊  | 157/200 [09:26<02:37,  3.65s/it]

Iteration 156: Loss theta = 0.0623440089586534


 79%|███████▉  | 158/200 [09:30<02:32,  3.62s/it]

Iteration 157: Loss theta = 0.06907189451158047


 80%|███████▉  | 159/200 [09:33<02:27,  3.59s/it]

Iteration 158: Loss theta = 0.06601812788530399


 80%|████████  | 160/200 [09:37<02:22,  3.56s/it]

Iteration 159: Loss theta = 0.06444337346444004


 80%|████████  | 161/200 [09:40<02:19,  3.58s/it]

Iteration 160: Loss theta = 0.06456284166166656


 81%|████████  | 162/200 [09:44<02:15,  3.58s/it]

Iteration 161: Loss theta = 0.06352292265939086


 82%|████████▏ | 163/200 [09:48<02:12,  3.59s/it]

Iteration 162: Loss theta = 0.06308499292323463


 82%|████████▏ | 164/200 [09:51<02:09,  3.60s/it]

Iteration 163: Loss theta = 0.06336467369998756


 82%|████████▎ | 165/200 [09:55<02:06,  3.61s/it]

Iteration 164: Loss theta = 0.0633876912884022


 83%|████████▎ | 166/200 [09:58<02:02,  3.59s/it]

Iteration 165: Loss theta = 0.06358124423576028


 84%|████████▎ | 167/200 [10:02<01:59,  3.61s/it]

Iteration 166: Loss theta = 0.06581483761730947


 84%|████████▍ | 168/200 [10:06<01:55,  3.61s/it]

Iteration 167: Loss theta = 0.06230560454883074


 84%|████████▍ | 169/200 [10:09<01:51,  3.59s/it]

Iteration 168: Loss theta = 0.06320093542729553


 85%|████████▌ | 170/200 [10:13<01:47,  3.59s/it]

Iteration 169: Loss theta = 0.06474670375648298


 86%|████████▌ | 171/200 [10:16<01:43,  3.58s/it]

Iteration 170: Loss theta = 0.06285292763066919


 86%|████████▌ | 172/200 [10:20<01:40,  3.59s/it]

Iteration 171: Loss theta = 0.06420555769612915


 86%|████████▋ | 173/200 [10:24<01:37,  3.61s/it]

Iteration 172: Loss theta = 0.06630047146034868


 87%|████████▋ | 174/200 [10:27<01:34,  3.63s/it]

Iteration 173: Loss theta = 0.06411422681259482


 88%|████████▊ | 175/200 [10:31<01:31,  3.64s/it]

Iteration 174: Loss theta = 0.0633105684659983


 88%|████████▊ | 176/200 [10:35<01:27,  3.64s/it]

Iteration 175: Loss theta = 0.06388767759658788


 88%|████████▊ | 177/200 [10:38<01:23,  3.62s/it]

Iteration 176: Loss theta = 0.0639875610604098


 89%|████████▉ | 178/200 [10:42<01:19,  3.62s/it]

Iteration 177: Loss theta = 0.06547250059482299


 90%|████████▉ | 179/200 [10:45<01:16,  3.62s/it]

Iteration 178: Loss theta = 0.06253729535168723


 90%|█████████ | 180/200 [10:49<01:11,  3.60s/it]

Iteration 179: Loss theta = 0.061599942904553916


 90%|█████████ | 181/200 [10:53<01:08,  3.60s/it]

Iteration 180: Loss theta = 0.061645834363604844


 91%|█████████ | 182/200 [10:56<01:04,  3.58s/it]

Iteration 181: Loss theta = 0.0658042674982234


 92%|█████████▏| 183/200 [11:00<01:01,  3.61s/it]

Iteration 182: Loss theta = 0.0648980044612759


 92%|█████████▏| 184/200 [11:03<00:57,  3.62s/it]

Iteration 183: Loss theta = 0.06304151162897285


 92%|█████████▎| 185/200 [11:07<00:54,  3.64s/it]

Iteration 184: Loss theta = 0.06266605756000469


 93%|█████████▎| 186/200 [11:11<00:51,  3.64s/it]

Iteration 185: Loss theta = 0.06267200028033633


 94%|█████████▎| 187/200 [11:14<00:47,  3.65s/it]

Iteration 186: Loss theta = 0.0631769128928059


 94%|█████████▍| 188/200 [11:18<00:43,  3.66s/it]

Iteration 187: Loss theta = 0.06259621228826673


 94%|█████████▍| 189/200 [11:22<00:40,  3.66s/it]

Iteration 188: Loss theta = 0.06165786989425358


 95%|█████████▌| 190/200 [11:25<00:36,  3.65s/it]

Iteration 189: Loss theta = 0.06392618690274264


 96%|█████████▌| 191/200 [11:29<00:32,  3.63s/it]

Iteration 190: Loss theta = 0.06524629028219926


 96%|█████████▌| 192/200 [11:32<00:28,  3.61s/it]

Iteration 191: Loss theta = 0.06276093756681994


 96%|█████████▋| 193/200 [11:36<00:25,  3.59s/it]

Iteration 192: Loss theta = 0.06471813558355759


 97%|█████████▋| 194/200 [11:40<00:21,  3.58s/it]

Iteration 193: Loss theta = 0.06355231245489497


 98%|█████████▊| 195/200 [11:43<00:18,  3.62s/it]

Iteration 194: Loss theta = 0.06548945125388472


 98%|█████████▊| 196/200 [11:47<00:14,  3.59s/it]

Iteration 195: Loss theta = 0.06181965552662548


 98%|█████████▊| 197/200 [11:50<00:10,  3.58s/it]

Iteration 196: Loss theta = 0.062397086306622156


 99%|█████████▉| 198/200 [11:54<00:07,  3.61s/it]

Iteration 197: Loss theta = 0.061920609129102605


100%|█████████▉| 199/200 [11:58<00:03,  3.60s/it]

Iteration 198: Loss theta = 0.06082133193941493


100%|██████████| 200/200 [12:01<00:00,  3.61s/it]

Iteration 199: Loss theta = 0.06467727777597151





In [6]:
model_test = my_model(m, d, array)
model_test.load_state_dict(torch.load('model_2.pt', weights_only=True))
running_loss = 0.0 

with torch.no_grad():
    for data in test_loader:
        X, theta_true= data[0].to(dev), data[1].to(dev)
        theta_pred = model_test(X)
        loss = RMSPE(theta_pred, theta_true)
        running_loss += loss.item()

Acc_theta = running_loss/len(test_loader)

print("RMSE DoA =", Acc_theta)

RMSE DoA = 0.06337683498859406


In [7]:
array.x = array.x.cpu()
array.y = array.y.cpu()

observations, angles = generate_data(100, t, d, 0, 0, array, mc_range, coherent)
Acc_theta_0 = RMSPE(model_test(observations.to(dev)), angles)

print("0dB RMSE DoA =", Acc_theta_0)


angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

0dB RMSE DoA = tensor(0.1843, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.3154, device='cuda:0')


In [8]:
array.x = array.x.cpu()
array.y = array.y.cpu()

observations, angles = generate_data(100, t, d, 5, 5, array, mc_range, coherent)
Acc_theta_5 = RMSPE(model_test(observations.to(dev)), angles)

print("5dB RMSE DoA =", Acc_theta_5)

angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

5dB RMSE DoA = tensor(0.0798, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.1821, device='cuda:0')


In [9]:
array.x = array.x.cpu()
array.y = array.y.cpu()

observations, angles = generate_data(100, t, d, 10, 10, array, mc_range, coherent)
Acc_theta_10 = RMSPE(model_test(observations.to(dev)), angles)

print("10dB RMSE DoA =", Acc_theta_10)

angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

10dB RMSE DoA = tensor(0.0584, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.1343, device='cuda:0')


In [10]:
array.x = array.x.cpu()
array.y = array.y.cpu()

observations, angles = generate_data(100, t, d, 15, 15, array, mc_range, coherent)
Acc_theta_15 = RMSPE(model_test(observations.to(dev)), angles)

print("15dB RMSE DoA =", Acc_theta_15)

angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

15dB RMSE DoA = tensor(0.0603, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.0820, device='cuda:0')


In [11]:
array.x = array.x.cpu()
array.y = array.y.cpu()

observations, angles = generate_data(100, t, d, 20, 20, array, mc_range, coherent)
Acc_theta_15 = RMSPE(model_test(observations.to(dev)), angles)

print("15dB RMSE DoA =", Acc_theta_15)

angles_rare = []
for i in range(observations.shape[0]):
    angles_rare.append(RARE(observations[i].T, d, 3, array)[0])
angles_rare = torch.stack(angles_rare, dim=0)

print("RMSPE RARE =", RMSPE(angles_rare.cuda(), angles.cuda()))

15dB RMSE DoA = tensor(0.0580, device='cuda:0', grad_fn=<MeanBackward0>)
RMSPE RARE = tensor(0.0613, device='cuda:0')
