In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from utils import *
from tqdm import tqdm

dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
d = 3
m = 8
t = 200
n = 20000
snr_min = 0
snr_max = 0
lamda = 0.2
radius = 0.1
mc_range = 4
coherent = False

array = UCA(m, lamda)
array.build_array(radius=0.1)
array.build_array_manifold()
array.build_transform_matrices_from_array_manifold()

observations, angles = generate_data(n, t, d, snr_min, snr_max, array, mc_range, coherent)

In [3]:
class my_model(nn.Module): 
    
    def __init__(self, m: int, d: int, device=dev):
        
        super().__init__()

        self.m = m
        self.d = d
        self.device = device

        self.bn = nn.BatchNorm1d(2*m, device=device)
        self.rnn = nn.GRU(input_size=2*m, hidden_size=2*m, num_layers=1, batch_first=True, device=device)

        self.mlp = nn.Sequential(nn.Linear(in_features=2*m, out_features=2*m*m, device=device), nn.ReLU(),
                                 nn.Linear(in_features=2*m*m, out_features=2*m*m, device=device), nn.ReLU(),
                                 nn.Linear(in_features=2*m*m, out_features=d, device=device))

    def forward(self, x):
        
        x = torch.cat((torch.real(x), torch.imag(x)), dim=-1)
        x = self.bn(x.transpose(1, 2)).transpose(1, 2)
        _, x = self.rnn(x)
        y = self.mlp(x[-1])

        return y


In [4]:
nbEpoches = 200
lr = 1e-3
wd = 1e-9
batchSize = 256

x_train, x_valid, theta_train, theta_valid = train_test_split(observations, angles, test_size=0.2)
x_train, x_test, theta_train, theta_test = train_test_split(x_train, theta_train, test_size=0.2)

train_set = DATASET(x_train, theta_train)
valid_set = DATASET(x_valid, theta_valid)
test_set = DATASET(x_test, theta_test)

train_loader = DataLoader(train_set, batch_size=batchSize, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batchSize, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batchSize, shuffle=False)

In [5]:
model = my_model(m, d)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
Loss, Loss_theta, Val_theta = [], [], []
bestVal = 1000

for i in tqdm(range(nbEpoches)):
    # Train
    running_loss = 0.0
    for data in train_loader:
        X, theta_true = data[0].to(dev), data[1].to(dev)
        optimizer.zero_grad()
        theta_pred = model(X)
        loss = RMSPE(theta_pred, theta_true) 
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    Loss.append(running_loss/len(train_loader))

    # Validation 
    with torch.no_grad():
        running_loss = 0.0
        for data in valid_loader:
            X, theta_true = data[0].to(dev), data[1].to(dev)
            theta_pred = model(X)
            loss = RMSPE(theta_pred, theta_true)
            running_loss += loss.item()
        
        Val_theta.append(running_loss/len(test_loader))

        if Val_theta[i] < bestVal:
            bestVal = Val_theta[i]
            torch.save(model.state_dict(), 'model_3.pt')

    print("Iteration {}: Loss theta = {}".format(i, Loss[-1]))

  0%|          | 1/200 [00:02<08:27,  2.55s/it]

Iteration 0: Loss theta = 0.5749365657567977


  1%|          | 2/200 [00:04<07:55,  2.40s/it]

Iteration 1: Loss theta = 0.44494314074516295


  2%|▏         | 3/200 [00:07<07:44,  2.36s/it]

Iteration 2: Loss theta = 0.39861658215522766


  2%|▏         | 4/200 [00:09<07:37,  2.33s/it]

Iteration 3: Loss theta = 0.36149456441402433


  2%|▎         | 5/200 [00:11<07:30,  2.31s/it]

Iteration 4: Loss theta = 0.3302151036262512


  3%|▎         | 6/200 [00:14<07:38,  2.37s/it]

Iteration 5: Loss theta = 0.29501938819885254


  4%|▎         | 7/200 [00:16<07:36,  2.36s/it]

Iteration 6: Loss theta = 0.26022532731294634


  4%|▍         | 8/200 [00:18<07:32,  2.36s/it]

Iteration 7: Loss theta = 0.23833224326372146


  4%|▍         | 9/200 [00:21<07:29,  2.35s/it]

Iteration 8: Loss theta = 0.21957106947898863


  5%|▌         | 10/200 [00:23<07:27,  2.35s/it]

Iteration 9: Loss theta = 0.2064143455028534


  6%|▌         | 11/200 [00:26<07:30,  2.38s/it]

Iteration 10: Loss theta = 0.2004539155960083


  6%|▌         | 12/200 [00:28<07:25,  2.37s/it]

Iteration 11: Loss theta = 0.1890275937318802


  6%|▋         | 13/200 [00:30<07:21,  2.36s/it]

Iteration 12: Loss theta = 0.17909319460391998


  7%|▋         | 14/200 [00:33<07:17,  2.35s/it]

Iteration 13: Loss theta = 0.17021960407495498


  8%|▊         | 15/200 [00:35<07:20,  2.38s/it]

Iteration 14: Loss theta = 0.16554515123367308


  8%|▊         | 16/200 [00:37<07:16,  2.37s/it]

Iteration 15: Loss theta = 0.1623976707458496


  8%|▊         | 17/200 [00:40<07:14,  2.37s/it]

Iteration 16: Loss theta = 0.15389451116323472


  9%|▉         | 18/200 [00:42<07:09,  2.36s/it]

Iteration 17: Loss theta = 0.14895059078931808


 10%|▉         | 19/200 [00:44<07:05,  2.35s/it]

Iteration 18: Loss theta = 0.14535963982343675


 10%|█         | 20/200 [00:47<07:03,  2.35s/it]

Iteration 19: Loss theta = 0.1406048335134983


 10%|█         | 21/200 [00:49<07:01,  2.35s/it]

Iteration 20: Loss theta = 0.13947638779878616


 11%|█         | 22/200 [00:51<06:56,  2.34s/it]

Iteration 21: Loss theta = 0.13818602472543717


 12%|█▏        | 23/200 [00:54<06:58,  2.37s/it]

Iteration 22: Loss theta = 0.13602803483605386


 12%|█▏        | 24/200 [00:56<06:59,  2.38s/it]

Iteration 23: Loss theta = 0.13457796409726142


 12%|█▎        | 25/200 [00:59<06:54,  2.37s/it]

Iteration 24: Loss theta = 0.12781343832612038


 13%|█▎        | 26/200 [01:01<06:50,  2.36s/it]

Iteration 25: Loss theta = 0.1283154582977295


 14%|█▎        | 27/200 [01:03<06:48,  2.36s/it]

Iteration 26: Loss theta = 0.12318037196993828


 14%|█▍        | 28/200 [01:06<06:45,  2.36s/it]

Iteration 27: Loss theta = 0.12040709361433982


 14%|█▍        | 29/200 [01:08<06:55,  2.43s/it]

Iteration 28: Loss theta = 0.11763554453849792


 15%|█▌        | 30/200 [01:10<06:42,  2.37s/it]

Iteration 29: Loss theta = 0.11866768881678581


 16%|█▌        | 31/200 [01:13<06:34,  2.34s/it]

Iteration 30: Loss theta = 0.12011538609862328


 16%|█▌        | 32/200 [01:15<06:28,  2.31s/it]

Iteration 31: Loss theta = 0.11642210274934768


 16%|█▋        | 33/200 [01:17<06:21,  2.29s/it]

Iteration 32: Loss theta = 0.11210503563284874


 17%|█▋        | 34/200 [01:20<06:19,  2.29s/it]

Iteration 33: Loss theta = 0.11323873430490494


 18%|█▊        | 35/200 [01:22<06:15,  2.28s/it]

Iteration 34: Loss theta = 0.10936854809522628


 18%|█▊        | 36/200 [01:24<06:10,  2.26s/it]

Iteration 35: Loss theta = 0.10938721358776092


 18%|█▊        | 37/200 [01:26<06:08,  2.26s/it]

Iteration 36: Loss theta = 0.10866316318511964


 19%|█▉        | 38/200 [01:29<06:11,  2.30s/it]

Iteration 37: Loss theta = 0.10394218116998673


 20%|█▉        | 39/200 [01:31<06:10,  2.30s/it]

Iteration 38: Loss theta = 0.10253241524100304


 20%|██        | 40/200 [01:33<06:08,  2.30s/it]

Iteration 39: Loss theta = 0.10622485667467117


 20%|██        | 41/200 [01:36<06:06,  2.30s/it]

Iteration 40: Loss theta = 0.10231001034379006


 21%|██        | 42/200 [01:38<06:13,  2.36s/it]

Iteration 41: Loss theta = 0.10562174022197723


 22%|██▏       | 43/200 [01:40<06:12,  2.37s/it]

Iteration 42: Loss theta = 0.09941996589303016


 22%|██▏       | 44/200 [01:43<06:11,  2.38s/it]

Iteration 43: Loss theta = 0.09965864956378936


 22%|██▎       | 45/200 [01:45<06:05,  2.36s/it]

Iteration 44: Loss theta = 0.10211741834878922


 23%|██▎       | 46/200 [01:47<05:59,  2.34s/it]

Iteration 45: Loss theta = 0.09913719549775124


 24%|██▎       | 47/200 [01:50<05:54,  2.32s/it]

Iteration 46: Loss theta = 0.0955707384645939


 24%|██▍       | 48/200 [01:52<05:50,  2.31s/it]

Iteration 47: Loss theta = 0.09495435386896134


 24%|██▍       | 49/200 [01:54<05:47,  2.30s/it]

Iteration 48: Loss theta = 0.0951773490011692


 25%|██▌       | 50/200 [01:57<05:43,  2.29s/it]

Iteration 49: Loss theta = 0.10140137493610382


 26%|██▌       | 51/200 [01:59<05:40,  2.28s/it]

Iteration 50: Loss theta = 0.10093404427170753


 26%|██▌       | 52/200 [02:01<05:38,  2.29s/it]

Iteration 51: Loss theta = 0.10084485054016114


 26%|██▋       | 53/200 [02:03<05:36,  2.29s/it]

Iteration 52: Loss theta = 0.10220186963677406


 27%|██▋       | 54/200 [02:06<05:33,  2.28s/it]

Iteration 53: Loss theta = 0.1080615346133709


 28%|██▊       | 55/200 [02:08<05:32,  2.29s/it]

Iteration 54: Loss theta = 0.10035047993063927


 28%|██▊       | 56/200 [02:10<05:28,  2.28s/it]

Iteration 55: Loss theta = 0.09212864339351653


 28%|██▊       | 57/200 [02:13<05:28,  2.30s/it]

Iteration 56: Loss theta = 0.0956016905605793


 29%|██▉       | 58/200 [02:15<05:29,  2.32s/it]

Iteration 57: Loss theta = 0.09663288041949272


 30%|██▉       | 59/200 [02:17<05:28,  2.33s/it]

Iteration 58: Loss theta = 0.09696408584713936


 30%|███       | 60/200 [02:20<05:27,  2.34s/it]

Iteration 59: Loss theta = 0.09477906405925751


 30%|███       | 61/200 [02:22<05:23,  2.33s/it]

Iteration 60: Loss theta = 0.09911797419190407


 31%|███       | 62/200 [02:24<05:17,  2.30s/it]

Iteration 61: Loss theta = 0.09380413129925728


 32%|███▏      | 63/200 [02:26<05:12,  2.28s/it]

Iteration 62: Loss theta = 0.09064925402402878


 32%|███▏      | 64/200 [02:29<05:11,  2.29s/it]

Iteration 63: Loss theta = 0.0984496508538723


 32%|███▎      | 65/200 [02:31<05:10,  2.30s/it]

Iteration 64: Loss theta = 0.09570657193660737


 33%|███▎      | 66/200 [02:33<05:06,  2.28s/it]

Iteration 65: Loss theta = 0.1079371577501297


 34%|███▎      | 67/200 [02:36<05:00,  2.26s/it]

Iteration 66: Loss theta = 0.09785735934972763


 34%|███▍      | 68/200 [02:38<04:56,  2.25s/it]

Iteration 67: Loss theta = 0.10078788533806801


 34%|███▍      | 69/200 [02:40<04:55,  2.25s/it]

Iteration 68: Loss theta = 0.09427557483315469


 35%|███▌      | 70/200 [02:42<04:52,  2.25s/it]

Iteration 69: Loss theta = 0.09004974260926246


 36%|███▌      | 71/200 [02:45<04:51,  2.26s/it]

Iteration 70: Loss theta = 0.0894015009701252


 36%|███▌      | 72/200 [02:47<04:48,  2.25s/it]

Iteration 71: Loss theta = 0.08715149074792862


 36%|███▋      | 73/200 [02:49<04:47,  2.26s/it]

Iteration 72: Loss theta = 0.0860520675778389


 37%|███▋      | 74/200 [02:51<04:44,  2.26s/it]

Iteration 73: Loss theta = 0.08535891845822334


 38%|███▊      | 75/200 [02:54<04:42,  2.26s/it]

Iteration 74: Loss theta = 0.08944504544138908


 38%|███▊      | 76/200 [02:56<04:39,  2.26s/it]

Iteration 75: Loss theta = 0.08755497276782989


 38%|███▊      | 77/200 [02:58<04:36,  2.25s/it]

Iteration 76: Loss theta = 0.08505418494343758


 39%|███▉      | 78/200 [03:00<04:32,  2.23s/it]

Iteration 77: Loss theta = 0.08167467013001442


 40%|███▉      | 79/200 [03:02<04:28,  2.22s/it]

Iteration 78: Loss theta = 0.08067928865551949


 40%|████      | 80/200 [03:05<04:26,  2.22s/it]

Iteration 79: Loss theta = 0.08443662315607071


 40%|████      | 81/200 [03:07<04:24,  2.22s/it]

Iteration 80: Loss theta = 0.08100022286176682


 41%|████      | 82/200 [03:09<04:21,  2.21s/it]

Iteration 81: Loss theta = 0.08829764753580094


 42%|████▏     | 83/200 [03:11<04:18,  2.21s/it]

Iteration 82: Loss theta = 0.0832426330447197


 42%|████▏     | 84/200 [03:13<04:16,  2.21s/it]

Iteration 83: Loss theta = 0.08953989207744599


 42%|████▎     | 85/200 [03:16<04:14,  2.22s/it]

Iteration 84: Loss theta = 0.0823707291483879


 43%|████▎     | 86/200 [03:18<04:12,  2.22s/it]

Iteration 85: Loss theta = 0.08080380484461784


 44%|████▎     | 87/200 [03:20<04:10,  2.22s/it]

Iteration 86: Loss theta = 0.07769561484456063


 44%|████▍     | 88/200 [03:22<04:09,  2.23s/it]

Iteration 87: Loss theta = 0.07744330242276191


 44%|████▍     | 89/200 [03:25<04:06,  2.23s/it]

Iteration 88: Loss theta = 0.07804017826914787


 45%|████▌     | 90/200 [03:27<04:04,  2.22s/it]

Iteration 89: Loss theta = 0.07725555315613747


 46%|████▌     | 91/200 [03:29<04:02,  2.22s/it]

Iteration 90: Loss theta = 0.07773466750979424


 46%|████▌     | 92/200 [03:31<03:59,  2.22s/it]

Iteration 91: Loss theta = 0.07724597170948982


 46%|████▋     | 93/200 [03:34<03:58,  2.22s/it]

Iteration 92: Loss theta = 0.0815346297621727


 47%|████▋     | 94/200 [03:36<03:59,  2.25s/it]

Iteration 93: Loss theta = 0.08292807176709176


 48%|████▊     | 95/200 [03:38<03:59,  2.28s/it]

Iteration 94: Loss theta = 0.07914879083633423


 48%|████▊     | 96/200 [03:41<04:01,  2.32s/it]

Iteration 95: Loss theta = 0.07532768309116364


 48%|████▊     | 97/200 [03:43<03:59,  2.33s/it]

Iteration 96: Loss theta = 0.07776695758104324


 49%|████▉     | 98/200 [03:45<03:57,  2.33s/it]

Iteration 97: Loss theta = 0.07559276968240738


 50%|████▉     | 99/200 [03:48<03:54,  2.32s/it]

Iteration 98: Loss theta = 0.08884578064084053


 50%|█████     | 100/200 [03:50<03:52,  2.33s/it]

Iteration 99: Loss theta = 0.08285377368330955


 50%|█████     | 101/200 [03:52<03:49,  2.32s/it]

Iteration 100: Loss theta = 0.07718825355172157


 51%|█████     | 102/200 [03:55<03:47,  2.32s/it]

Iteration 101: Loss theta = 0.07418444395065307


 52%|█████▏    | 103/200 [03:57<03:48,  2.36s/it]

Iteration 102: Loss theta = 0.07413343667984008


 52%|█████▏    | 104/200 [03:59<03:45,  2.35s/it]

Iteration 103: Loss theta = 0.07891013264656067


 52%|█████▎    | 105/200 [04:02<03:42,  2.34s/it]

Iteration 104: Loss theta = 0.07656092181801796


 53%|█████▎    | 106/200 [04:04<03:40,  2.35s/it]

Iteration 105: Loss theta = 0.07216626361012458


 54%|█████▎    | 107/200 [04:06<03:37,  2.34s/it]

Iteration 106: Loss theta = 0.07645930156111717


 54%|█████▍    | 108/200 [04:09<03:36,  2.35s/it]

Iteration 107: Loss theta = 0.079314094632864


 55%|█████▍    | 109/200 [04:11<03:33,  2.35s/it]

Iteration 108: Loss theta = 0.07398157328367233


 55%|█████▌    | 110/200 [04:13<03:33,  2.37s/it]

Iteration 109: Loss theta = 0.07249353900551796


 56%|█████▌    | 111/200 [04:16<03:30,  2.37s/it]

Iteration 110: Loss theta = 0.07237738445401191


 56%|█████▌    | 112/200 [04:18<03:29,  2.39s/it]

Iteration 111: Loss theta = 0.0739887447655201


 56%|█████▋    | 113/200 [04:20<03:23,  2.34s/it]

Iteration 112: Loss theta = 0.07098553285002708


 57%|█████▋    | 114/200 [04:23<03:18,  2.31s/it]

Iteration 113: Loss theta = 0.07085880920290948


 57%|█████▊    | 115/200 [04:25<03:14,  2.29s/it]

Iteration 114: Loss theta = 0.07182144969701768


 58%|█████▊    | 116/200 [04:27<03:13,  2.30s/it]

Iteration 115: Loss theta = 0.08457831174135208


 58%|█████▊    | 117/200 [04:30<03:11,  2.30s/it]

Iteration 116: Loss theta = 0.09861665025353432


 59%|█████▉    | 118/200 [04:32<03:08,  2.30s/it]

Iteration 117: Loss theta = 0.0816791072487831


 60%|█████▉    | 119/200 [04:34<03:08,  2.32s/it]

Iteration 118: Loss theta = 0.07603750973939896


 60%|██████    | 120/200 [04:37<03:06,  2.33s/it]

Iteration 119: Loss theta = 0.07479473665356635


 60%|██████    | 121/200 [04:39<03:04,  2.33s/it]

Iteration 120: Loss theta = 0.07117345429956913


 61%|██████    | 122/200 [04:42<03:07,  2.41s/it]

Iteration 121: Loss theta = 0.07225839614868164


 62%|██████▏   | 123/200 [04:44<03:03,  2.38s/it]

Iteration 122: Loss theta = 0.07122725889086723


 62%|██████▏   | 124/200 [04:46<02:59,  2.36s/it]

Iteration 123: Loss theta = 0.07322950601577759


 62%|██████▎   | 125/200 [04:48<02:55,  2.34s/it]

Iteration 124: Loss theta = 0.07105210304260254


 63%|██████▎   | 126/200 [04:51<02:53,  2.34s/it]

Iteration 125: Loss theta = 0.07196366488933563


 64%|██████▎   | 127/200 [04:53<02:50,  2.33s/it]

Iteration 126: Loss theta = 0.07221145480871201


 64%|██████▍   | 128/200 [04:55<02:48,  2.33s/it]

Iteration 127: Loss theta = 0.07234331429004669


 64%|██████▍   | 129/200 [04:58<02:46,  2.35s/it]

Iteration 128: Loss theta = 0.07170820876955986


 65%|██████▌   | 130/200 [05:00<02:44,  2.35s/it]

Iteration 129: Loss theta = 0.07494109004735947


 66%|██████▌   | 131/200 [05:03<02:41,  2.34s/it]

Iteration 130: Loss theta = 0.07881010368466378


 66%|██████▌   | 132/200 [05:05<02:38,  2.33s/it]

Iteration 131: Loss theta = 0.07102015063166618


 66%|██████▋   | 133/200 [05:07<02:36,  2.34s/it]

Iteration 132: Loss theta = 0.06756382375955582


 67%|██████▋   | 134/200 [05:09<02:32,  2.31s/it]

Iteration 133: Loss theta = 0.06912612438201904


 68%|██████▊   | 135/200 [05:12<02:29,  2.30s/it]

Iteration 134: Loss theta = 0.06774516686797143


 68%|██████▊   | 136/200 [05:14<02:26,  2.29s/it]

Iteration 135: Loss theta = 0.07124281950294971


 68%|██████▊   | 137/200 [05:16<02:25,  2.32s/it]

Iteration 136: Loss theta = 0.07187111526727677


 69%|██████▉   | 138/200 [05:19<02:23,  2.32s/it]

Iteration 137: Loss theta = 0.06724974431097508


 70%|██████▉   | 139/200 [05:21<02:21,  2.31s/it]

Iteration 138: Loss theta = 0.06596956312656403


 70%|███████   | 140/200 [05:23<02:17,  2.29s/it]

Iteration 139: Loss theta = 0.06817928269505501


 70%|███████   | 141/200 [05:25<02:14,  2.28s/it]

Iteration 140: Loss theta = 0.06538479916751384


 71%|███████   | 142/200 [05:28<02:11,  2.27s/it]

Iteration 141: Loss theta = 0.06641116157174111


 72%|███████▏  | 143/200 [05:30<02:08,  2.25s/it]

Iteration 142: Loss theta = 0.06428554505109788


 72%|███████▏  | 144/200 [05:32<02:05,  2.25s/it]

Iteration 143: Loss theta = 0.06715216964483262


 72%|███████▎  | 145/200 [05:34<02:03,  2.25s/it]

Iteration 144: Loss theta = 0.06460544534027576


 73%|███████▎  | 146/200 [05:37<02:00,  2.24s/it]

Iteration 145: Loss theta = 0.06797112062573434


 74%|███████▎  | 147/200 [05:39<01:58,  2.24s/it]

Iteration 146: Loss theta = 0.08889835149049759


 74%|███████▍  | 148/200 [05:41<01:57,  2.26s/it]

Iteration 147: Loss theta = 0.07480907395482063


 74%|███████▍  | 149/200 [05:44<01:59,  2.34s/it]

Iteration 148: Loss theta = 0.08380771666765213


 75%|███████▌  | 150/200 [05:46<01:57,  2.35s/it]

Iteration 149: Loss theta = 0.13262636899948121


 76%|███████▌  | 151/200 [05:48<01:55,  2.35s/it]

Iteration 150: Loss theta = 0.11592854529619218


 76%|███████▌  | 152/200 [05:51<01:53,  2.35s/it]

Iteration 151: Loss theta = 0.10260084673762321


 76%|███████▋  | 153/200 [05:53<01:50,  2.36s/it]

Iteration 152: Loss theta = 0.09300675541162491


 77%|███████▋  | 154/200 [05:56<01:48,  2.36s/it]

Iteration 153: Loss theta = 0.08942927420139313


 78%|███████▊  | 155/200 [05:58<01:46,  2.37s/it]

Iteration 154: Loss theta = 0.08749856472015381


 78%|███████▊  | 156/200 [06:00<01:44,  2.36s/it]

Iteration 155: Loss theta = 0.08155422300100326


 78%|███████▊  | 157/200 [06:03<01:41,  2.37s/it]

Iteration 156: Loss theta = 0.08083931043744087


 79%|███████▉  | 158/200 [06:05<01:38,  2.35s/it]

Iteration 157: Loss theta = 0.0786937914788723


 80%|███████▉  | 159/200 [06:07<01:36,  2.35s/it]

Iteration 158: Loss theta = 0.07613184779882431


 80%|████████  | 160/200 [06:10<01:34,  2.35s/it]

Iteration 159: Loss theta = 0.07591041892766953


 80%|████████  | 161/200 [06:12<01:30,  2.32s/it]

Iteration 160: Loss theta = 0.07329309739172458


 81%|████████  | 162/200 [06:14<01:27,  2.29s/it]

Iteration 161: Loss theta = 0.07507217392325401


 82%|████████▏ | 163/200 [06:16<01:24,  2.27s/it]

Iteration 162: Loss theta = 0.0728565126657486


 82%|████████▏ | 164/200 [06:19<01:21,  2.26s/it]

Iteration 163: Loss theta = 0.07010099798440933


 82%|████████▎ | 165/200 [06:21<01:19,  2.27s/it]

Iteration 164: Loss theta = 0.0672495099902153


 83%|████████▎ | 166/200 [06:23<01:16,  2.26s/it]

Iteration 165: Loss theta = 0.06739268690347672


 84%|████████▎ | 167/200 [06:25<01:14,  2.25s/it]

Iteration 166: Loss theta = 0.06513342097401618


 84%|████████▍ | 168/200 [06:28<01:12,  2.25s/it]

Iteration 167: Loss theta = 0.06699165970087051


 84%|████████▍ | 169/200 [06:30<01:09,  2.24s/it]

Iteration 168: Loss theta = 0.06534166477620601


 85%|████████▌ | 170/200 [06:32<01:07,  2.25s/it]

Iteration 169: Loss theta = 0.07132870942354202


 86%|████████▌ | 171/200 [06:34<01:04,  2.24s/it]

Iteration 170: Loss theta = 0.06719784915447236


 86%|████████▌ | 172/200 [06:37<01:02,  2.24s/it]

Iteration 171: Loss theta = 0.06477063685655594


 86%|████████▋ | 173/200 [06:39<01:00,  2.23s/it]

Iteration 172: Loss theta = 0.06577871903777123


 87%|████████▋ | 174/200 [06:41<00:58,  2.26s/it]

Iteration 173: Loss theta = 0.0636383255571127


 88%|████████▊ | 175/200 [06:43<00:56,  2.25s/it]

Iteration 174: Loss theta = 0.0624735414981842


 88%|████████▊ | 176/200 [06:46<00:53,  2.24s/it]

Iteration 175: Loss theta = 0.07520969040691852


 88%|████████▊ | 177/200 [06:48<00:51,  2.23s/it]

Iteration 176: Loss theta = 0.08266011998057365


 89%|████████▉ | 178/200 [06:50<00:48,  2.21s/it]

Iteration 177: Loss theta = 0.07568452879786491


 90%|████████▉ | 179/200 [06:52<00:46,  2.21s/it]

Iteration 178: Loss theta = 0.07295859575271607


 90%|█████████ | 180/200 [06:55<00:46,  2.31s/it]

Iteration 179: Loss theta = 0.07017305411398411


 90%|█████████ | 181/200 [06:57<00:45,  2.38s/it]

Iteration 180: Loss theta = 0.06907871410250664


 91%|█████████ | 182/200 [07:00<00:43,  2.39s/it]

Iteration 181: Loss theta = 0.06943197436630726


 92%|█████████▏| 183/200 [07:02<00:40,  2.38s/it]

Iteration 182: Loss theta = 0.0658634652197361


 92%|█████████▏| 184/200 [07:04<00:37,  2.36s/it]

Iteration 183: Loss theta = 0.06565196022391319


 92%|█████████▎| 185/200 [07:07<00:35,  2.39s/it]

Iteration 184: Loss theta = 0.070890893638134


 93%|█████████▎| 186/200 [07:09<00:33,  2.41s/it]

Iteration 185: Loss theta = 0.06766194082796574


 94%|█████████▎| 187/200 [07:12<00:31,  2.41s/it]

Iteration 186: Loss theta = 0.07232140228152276


 94%|█████████▍| 188/200 [07:14<00:28,  2.40s/it]

Iteration 187: Loss theta = 0.06902833230793476


 94%|█████████▍| 189/200 [07:16<00:26,  2.39s/it]

Iteration 188: Loss theta = 0.06553699918091298


 95%|█████████▌| 190/200 [07:19<00:23,  2.38s/it]

Iteration 189: Loss theta = 0.0651830068230629


 96%|█████████▌| 191/200 [07:21<00:21,  2.38s/it]

Iteration 190: Loss theta = 0.06330369770526886


 96%|█████████▌| 192/200 [07:23<00:18,  2.37s/it]

Iteration 191: Loss theta = 0.06211358346045017


 96%|█████████▋| 193/200 [07:26<00:16,  2.40s/it]

Iteration 192: Loss theta = 0.06193716399371624


 97%|█████████▋| 194/200 [07:28<00:14,  2.41s/it]

Iteration 193: Loss theta = 0.06311044432222843


 98%|█████████▊| 195/200 [07:31<00:12,  2.42s/it]

Iteration 194: Loss theta = 0.061585668027400974


 98%|█████████▊| 196/200 [07:33<00:09,  2.40s/it]

Iteration 195: Loss theta = 0.06034975633025169


 98%|█████████▊| 197/200 [07:36<00:07,  2.39s/it]

Iteration 196: Loss theta = 0.0612058312445879


 99%|█████████▉| 198/200 [07:38<00:04,  2.39s/it]

Iteration 197: Loss theta = 0.061845633313059804


100%|█████████▉| 199/200 [07:40<00:02,  2.37s/it]

Iteration 198: Loss theta = 0.06250588111579418


100%|██████████| 200/200 [07:43<00:00,  2.32s/it]

Iteration 199: Loss theta = 0.06117001324892044





In [6]:
model_test = my_model(m, d)
model_test.load_state_dict(torch.load('model_3.pt', weights_only=True))
running_loss = 0.0 

with torch.no_grad():
    for data in test_loader:
        X, theta_true= data[0].to(dev), data[1].to(dev)
        theta_pred = model_test(X)
        loss = RMSPE(theta_pred, theta_true)
        running_loss += loss.item()

Acc_theta = running_loss/len(test_loader)

print("RMSE DoA =", Acc_theta)

RMSE DoA = 0.06278936851483125
