In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch 
import torch.nn as nn

from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from utils import *
from models import *
from tqdm import tqdm

dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
dmin = 2
dmax = 5
m = 8
t = 100
n = 50000
snr = 0
lamda = 0.2
distance = 0.1
coherent = False

lr = 1e-3
wd = 1e-9
nbEpoches = 200
batchSize = 32

array = ULA(m, lamda)
array.build_array(distance)
array.build_array_manifold()

In [4]:
observations, angles, labels = generate_data(n, t, dmin, dmax, snr, snr, array, coherent)

x_train, x_valid, theta_train, theta_valid, label_train, label_valid = train_test_split(observations, angles, labels, test_size=0.2)
x_train, x_test, theta_train, theta_test, label_train, label_test = train_test_split(x_train, theta_train, label_train, test_size=0.2)

train_set = DATASET(x_train, theta_train, label_train)
valid_set = DATASET(x_valid, theta_valid, label_valid)
test_set = DATASET(x_test, theta_test, label_test)

train_loader = DataLoader(train_set, batch_size=batchSize, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batchSize, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batchSize, shuffle=False)

In [5]:
model = DA_MUSIC(dmin, dmax, array).to(dev)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

loss_func1 = RMSPE_varied_nbSources(dmin, dmax, dev)
loss_func2 = nn.CrossEntropyLoss()

Loss, Loss_d, Loss_theta, Val_d, Val_theta = [], [], [], [], []

bestVal_d = 0.0
bestVal_theta = 1000.0


for i in tqdm(range(nbEpoches)):

    running_loss = 0.0
    running_loss_d = 0.0
    running_loss_theta = 0.0

    for data in train_loader:
        
        X, theta_true, label_true = data[0].to(dev), data[1].to(dev), data[2].to(dev)
        
        optimizer.zero_grad()
        
        theta_pred, label_pred = model(X)
        
        loss_theta = loss_func1.calculate(theta_pred, theta_true) 
        loss_theta.backward()
        
        loss_d = loss_func2(label_pred, label_true)
        loss_d.backward()
        
        optimizer.step()
        
        running_loss_theta += loss_theta.item()
        running_loss_d += loss_d.item()
        running_loss += loss_theta.item() + loss_d.item()
    
    Loss.append(running_loss/len(train_loader))
    Loss_d.append(running_loss_d/len(train_loader))
    Loss_theta.append(running_loss_theta/len(train_loader))

    with torch.no_grad():

        running_acc_d = 0.0
        running_loss_theta = 0.0

        for data in valid_loader:
            
            X, theta_true, label_true = data[0].to(dev), data[1].to(dev), data[2].to(dev)
            
            theta_pred, label_pred = model(X)
            
            loss_theta = loss_func1.calculate(theta_pred, theta_true) 
            acc_d = (torch.argmax(label_pred, dim=1) == torch.argmax(label_true, dim=1)).float().mean()
            
            running_loss_theta += loss_theta.item()
            running_acc_d += acc_d.item()
        
        Val_d.append(running_acc_d/len(valid_loader))
        Val_theta.append(running_loss_theta/len(valid_loader))

        # if Val_d[-1] > bestVal_d:
            
        #     bestVal_d = Val_d[-1]
        #     torch.save(model.state_dict(), "saved_model.pth")

        if Val_theta[-1] < bestVal_theta:
            bestVal_theta = Val_theta[-1]
            torch.save(model.state_dict(), "damusic_{}dB.pth".format(snr))
            count = 0
        else:
            count += 1

        if count == 20:
            model.load_state_dict(torch.load("damusic_{}dB.pth".format(snr), weights_only=True))


        print("RMSPE = {}, Accuracy = {}".format(Val_theta[-1], Val_d[-1]))

  0%|          | 1/200 [00:07<23:57,  7.23s/it]

RMSPE = 0.37443875476194266, Accuracy = 0.3133985623003195


  1%|          | 2/200 [00:14<23:19,  7.07s/it]

RMSPE = 0.2808114337844971, Accuracy = 0.336361821086262


  2%|▏         | 3/200 [00:21<22:58,  7.00s/it]

RMSPE = 0.2500183285710911, Accuracy = 0.358426517571885


  2%|▏         | 4/200 [00:27<22:37,  6.93s/it]

RMSPE = 0.2259768515158766, Accuracy = 0.4030551118210863


  2%|▎         | 5/200 [00:34<21:52,  6.73s/it]

RMSPE = 0.1912301614547309, Accuracy = 0.44508785942492013


  3%|▎         | 6/200 [00:40<21:18,  6.59s/it]

RMSPE = 0.19128273094233614, Accuracy = 0.4578674121405751


  4%|▎         | 7/200 [00:46<20:37,  6.41s/it]

RMSPE = 0.18156000072011552, Accuracy = 0.5369408945686901


  4%|▍         | 8/200 [00:52<20:04,  6.28s/it]

RMSPE = 0.17968846067262534, Accuracy = 0.5350439297124601


  4%|▍         | 9/200 [00:58<19:40,  6.18s/it]

RMSPE = 0.16837101284497843, Accuracy = 0.5802715654952076


  5%|▌         | 10/200 [01:04<19:32,  6.17s/it]

RMSPE = 0.17025629854716431, Accuracy = 0.6055311501597445


  6%|▌         | 11/200 [01:10<19:26,  6.17s/it]

RMSPE = 0.1708361801866906, Accuracy = 0.5354432907348243


  6%|▌         | 12/200 [01:16<19:13,  6.14s/it]

RMSPE = 0.19348557417194684, Accuracy = 0.5371405750798722


  6%|▋         | 13/200 [01:23<19:16,  6.19s/it]

RMSPE = 0.1593449611585742, Accuracy = 0.5521166134185304


  7%|▋         | 14/200 [01:29<19:18,  6.23s/it]

RMSPE = 0.15431836102241145, Accuracy = 0.5584065495207667


  8%|▊         | 15/200 [01:35<19:16,  6.25s/it]

RMSPE = 0.15144796041063607, Accuracy = 0.5594049520766773


  8%|▊         | 16/200 [01:42<19:15,  6.28s/it]

RMSPE = 0.14682404458903656, Accuracy = 0.5970447284345048


  8%|▊         | 17/200 [01:48<19:09,  6.28s/it]

RMSPE = 0.15754097545394502, Accuracy = 0.5915535143769968


  9%|▉         | 18/200 [01:54<18:51,  6.22s/it]

RMSPE = 0.15056805136485601, Accuracy = 0.6276956869009584


 10%|▉         | 19/200 [02:00<18:35,  6.16s/it]

RMSPE = 0.14971077980134434, Accuracy = 0.5971445686900958


 10%|█         | 20/200 [02:07<18:41,  6.23s/it]

RMSPE = 0.1466417808930714, Accuracy = 0.6290934504792333


 10%|█         | 21/200 [02:13<18:40,  6.26s/it]

RMSPE = 0.1371135506005333, Accuracy = 0.643470447284345


 11%|█         | 22/200 [02:19<18:37,  6.28s/it]

RMSPE = 0.13438816673268145, Accuracy = 0.6609424920127795


 12%|█▏        | 23/200 [02:26<18:33,  6.29s/it]

RMSPE = 0.13294952265180338, Accuracy = 0.6386781150159745


 12%|█▏        | 24/200 [02:32<18:26,  6.28s/it]

RMSPE = 0.12975360581669182, Accuracy = 0.6401757188498403


 12%|█▎        | 25/200 [02:38<18:31,  6.35s/it]

RMSPE = 0.12874033256841544, Accuracy = 0.6632388178913738


 13%|█▎        | 26/200 [02:44<18:13,  6.28s/it]

RMSPE = 0.13292128585588436, Accuracy = 0.6541533546325878


 14%|█▎        | 27/200 [02:51<18:06,  6.28s/it]

RMSPE = 0.12761717342054502, Accuracy = 0.6576477635782748


 14%|█▍        | 28/200 [02:57<17:54,  6.25s/it]

RMSPE = 0.12524751785654611, Accuracy = 0.6531549520766773


 14%|█▍        | 29/200 [03:03<17:50,  6.26s/it]

RMSPE = 0.12399115694311862, Accuracy = 0.6089257188498403


 15%|█▌        | 30/200 [03:10<17:55,  6.32s/it]

RMSPE = 0.12559639817228713, Accuracy = 0.6648362619808307


 16%|█▌        | 31/200 [03:16<17:46,  6.31s/it]

RMSPE = 0.12212112743538409, Accuracy = 0.6827076677316294


 16%|█▌        | 32/200 [03:22<17:41,  6.32s/it]

RMSPE = 0.1207063126678284, Accuracy = 0.6925918530351438


 16%|█▋        | 33/200 [03:29<17:33,  6.31s/it]

RMSPE = 0.1210804219348743, Accuracy = 0.6956869009584664


 17%|█▋        | 34/200 [03:35<17:27,  6.31s/it]

RMSPE = 0.12440660212653133, Accuracy = 0.6864017571884984


 18%|█▊        | 35/200 [03:41<17:16,  6.28s/it]

RMSPE = 0.12030442846944919, Accuracy = 0.7010782747603834


 18%|█▊        | 36/200 [03:47<17:11,  6.29s/it]

RMSPE = 0.11702208956495261, Accuracy = 0.7140575079872205


 18%|█▊        | 37/200 [03:54<17:04,  6.29s/it]

RMSPE = 0.11859568898765424, Accuracy = 0.6854033546325878


 19%|█▉        | 38/200 [04:00<17:00,  6.30s/it]

RMSPE = 0.12832819910856863, Accuracy = 0.6995806709265175


 20%|█▉        | 39/200 [04:06<16:48,  6.26s/it]

RMSPE = 0.1150804723794468, Accuracy = 0.7091653354632588


 20%|██        | 40/200 [04:12<16:44,  6.28s/it]

RMSPE = 0.12040634262866487, Accuracy = 0.6904952076677316


 20%|██        | 41/200 [04:19<16:43,  6.31s/it]

RMSPE = 0.11717775897286571, Accuracy = 0.7198482428115016


 21%|██        | 42/200 [04:25<16:30,  6.27s/it]

RMSPE = 0.11773712071366965, Accuracy = 0.6523562300319489


 22%|██▏       | 43/200 [04:31<16:24,  6.27s/it]

RMSPE = 0.11677256914468619, Accuracy = 0.7139576677316294


 22%|██▏       | 44/200 [04:37<16:11,  6.23s/it]

RMSPE = 0.11331975812348313, Accuracy = 0.7064696485623003


 22%|██▎       | 45/200 [04:44<16:10,  6.26s/it]

RMSPE = 0.1247704291400818, Accuracy = 0.6787140575079872


 23%|██▎       | 46/200 [04:50<16:00,  6.24s/it]

RMSPE = 0.11712757029091589, Accuracy = 0.669029552715655


 24%|██▎       | 47/200 [04:56<15:55,  6.24s/it]

RMSPE = 0.11413821937462773, Accuracy = 0.6832068690095847


 24%|██▍       | 48/200 [05:03<15:51,  6.26s/it]

RMSPE = 0.11590211310040074, Accuracy = 0.6962859424920128


 24%|██▍       | 49/200 [05:09<15:45,  6.26s/it]

RMSPE = 0.10869356437613027, Accuracy = 0.7173522364217252


 25%|██▌       | 50/200 [05:15<15:39,  6.26s/it]

RMSPE = 0.1106381283257716, Accuracy = 0.7074680511182109


 26%|██▌       | 51/200 [05:21<15:28,  6.23s/it]

RMSPE = 0.1365949102103139, Accuracy = 0.7179512779552716


 26%|██▌       | 52/200 [05:27<15:17,  6.20s/it]

RMSPE = 0.11002386556551479, Accuracy = 0.7161541533546326


 26%|██▋       | 53/200 [05:33<15:09,  6.19s/it]

RMSPE = 0.1163043150791345, Accuracy = 0.7393170926517572


 27%|██▋       | 54/200 [05:39<14:54,  6.12s/it]

RMSPE = 0.10727736996576047, Accuracy = 0.7300319488817891


 28%|██▊       | 55/200 [05:45<14:38,  6.06s/it]

RMSPE = 0.10948087153628992, Accuracy = 0.7177515974440895


 28%|██▊       | 56/200 [05:52<14:40,  6.11s/it]

RMSPE = 0.1122094035005798, Accuracy = 0.713258785942492


 28%|██▊       | 57/200 [05:58<14:30,  6.09s/it]

RMSPE = 0.10830267190266722, Accuracy = 0.7161541533546326


 29%|██▉       | 58/200 [06:05<15:16,  6.46s/it]

RMSPE = 0.11719242414346519, Accuracy = 0.7204472843450479


 30%|██▉       | 59/200 [06:13<15:56,  6.78s/it]

RMSPE = 0.11043280724900219, Accuracy = 0.7064696485623003


 30%|███       | 60/200 [06:21<16:43,  7.17s/it]

RMSPE = 0.11330890569824, Accuracy = 0.7015774760383386


 30%|███       | 61/200 [06:28<16:38,  7.18s/it]

RMSPE = 0.10507638407305787, Accuracy = 0.7375199680511182


 31%|███       | 62/200 [06:35<16:19,  7.10s/it]

RMSPE = 0.10651550003990959, Accuracy = 0.7271365814696485


 32%|███▏      | 63/200 [06:41<15:56,  6.98s/it]

RMSPE = 0.10927931954875922, Accuracy = 0.709464856230032


 32%|███▏      | 64/200 [06:48<15:31,  6.85s/it]

RMSPE = 0.10989612367111273, Accuracy = 0.7230431309904153


 32%|███▎      | 65/200 [06:54<15:04,  6.70s/it]

RMSPE = 0.11074024998246672, Accuracy = 0.713158945686901


 33%|███▎      | 66/200 [07:01<14:42,  6.58s/it]

RMSPE = 0.10899799752730531, Accuracy = 0.7413138977635783


 34%|███▎      | 67/200 [07:07<14:38,  6.61s/it]

RMSPE = 0.10826634603757827, Accuracy = 0.7240415335463258


 34%|███▍      | 68/200 [07:14<14:22,  6.54s/it]

RMSPE = 0.1063266416755728, Accuracy = 0.6852036741214057


 34%|███▍      | 69/200 [07:20<14:09,  6.48s/it]

RMSPE = 0.10586185629565875, Accuracy = 0.7390175718849841


 35%|███▌      | 70/200 [07:27<14:07,  6.52s/it]

RMSPE = 0.10612895853888875, Accuracy = 0.7287340255591054


 36%|███▌      | 71/200 [07:33<14:05,  6.55s/it]

RMSPE = 0.10832286902224295, Accuracy = 0.728035143769968


 36%|███▌      | 72/200 [07:40<14:04,  6.60s/it]

RMSPE = 0.10195397247616857, Accuracy = 0.7434105431309904


 36%|███▋      | 73/200 [07:46<13:47,  6.52s/it]

RMSPE = 0.117343593686343, Accuracy = 0.7222444089456869


 37%|███▋      | 74/200 [07:53<13:38,  6.50s/it]

RMSPE = 0.10323221176957932, Accuracy = 0.7509984025559105


 38%|███▊      | 75/200 [07:59<13:26,  6.45s/it]

RMSPE = 0.10209047142118692, Accuracy = 0.7240415335463258


 38%|███▊      | 76/200 [08:06<13:28,  6.52s/it]

RMSPE = 0.11016804665422288, Accuracy = 0.7176517571884984


 38%|███▊      | 77/200 [08:12<13:22,  6.53s/it]

RMSPE = 0.10453501134253919, Accuracy = 0.7428115015974441


 39%|███▉      | 78/200 [08:19<13:18,  6.55s/it]

RMSPE = 0.1047481994468945, Accuracy = 0.7180511182108626


 40%|███▉      | 79/200 [08:25<13:04,  6.48s/it]

RMSPE = 0.10776386164819089, Accuracy = 0.7286341853035144


 40%|████      | 80/200 [08:31<12:45,  6.38s/it]

RMSPE = 0.10336903815928358, Accuracy = 0.7259384984025559


 40%|████      | 81/200 [08:38<12:36,  6.36s/it]

RMSPE = 0.10331156404730611, Accuracy = 0.7379193290734825


 41%|████      | 82/200 [08:44<12:30,  6.36s/it]

RMSPE = 0.10292463161693975, Accuracy = 0.7153554313099042


 42%|████▏     | 83/200 [08:50<12:14,  6.28s/it]

RMSPE = 0.10279009976802161, Accuracy = 0.7272364217252396


 42%|████▏     | 84/200 [08:56<12:09,  6.29s/it]

RMSPE = 0.10214754584403084, Accuracy = 0.7217452076677316


 42%|████▎     | 85/200 [09:03<12:02,  6.29s/it]

RMSPE = 0.1039481041387628, Accuracy = 0.7178514376996805


 43%|████▎     | 86/200 [09:09<11:51,  6.24s/it]

RMSPE = 0.11125457408233953, Accuracy = 0.739117412140575


 44%|████▎     | 87/200 [09:15<11:49,  6.28s/it]

RMSPE = 0.1008276918920846, Accuracy = 0.737220447284345


 44%|████▍     | 88/200 [09:22<11:54,  6.38s/it]

RMSPE = 0.10496924053461026, Accuracy = 0.7326277955271565


 44%|████▍     | 89/200 [09:28<11:45,  6.36s/it]

RMSPE = 0.10952715075815829, Accuracy = 0.705870607028754


 45%|████▌     | 90/200 [09:34<11:39,  6.36s/it]

RMSPE = 0.1120836147770714, Accuracy = 0.678214856230032


 46%|████▌     | 91/200 [09:41<11:34,  6.37s/it]

RMSPE = 0.11059830468683578, Accuracy = 0.7165535143769968


 46%|████▌     | 92/200 [09:47<11:27,  6.36s/it]

RMSPE = 0.11090062437251734, Accuracy = 0.6945886581469649


 46%|████▋     | 93/200 [09:54<11:22,  6.38s/it]

RMSPE = 0.10139615443377449, Accuracy = 0.7223442492012779


 47%|████▋     | 94/200 [10:00<11:21,  6.43s/it]

RMSPE = 0.1030387272135899, Accuracy = 0.7314297124600639


 48%|████▊     | 95/200 [10:06<11:00,  6.29s/it]

RMSPE = 0.10442564860224343, Accuracy = 0.7415135782747604


 48%|████▊     | 96/200 [10:12<10:53,  6.29s/it]

RMSPE = 0.1035376335152041, Accuracy = 0.7209464856230032


 48%|████▊     | 97/200 [10:19<10:50,  6.32s/it]

RMSPE = 0.10049726878301785, Accuracy = 0.7252396166134185


 49%|████▉     | 98/200 [10:25<10:49,  6.37s/it]

RMSPE = 0.11156145102395036, Accuracy = 0.7284345047923323


 50%|████▉     | 99/200 [10:32<10:42,  6.36s/it]

RMSPE = 0.10146167676764936, Accuracy = 0.7411142172523961


 50%|█████     | 100/200 [10:38<10:36,  6.37s/it]

RMSPE = 0.1023580210563093, Accuracy = 0.7199480830670927


 50%|█████     | 101/200 [10:44<10:27,  6.34s/it]

RMSPE = 0.10715506221063602, Accuracy = 0.7403154952076677


 51%|█████     | 102/200 [10:51<10:22,  6.35s/it]

RMSPE = 0.1089700663480134, Accuracy = 0.7437100638977636


 52%|█████▏    | 103/200 [10:57<10:23,  6.42s/it]

RMSPE = 0.10295618135041704, Accuracy = 0.7441094249201278


 52%|█████▏    | 104/200 [11:04<10:18,  6.44s/it]

RMSPE = 0.10083269551634408, Accuracy = 0.7476038338658147


 52%|█████▎    | 105/200 [11:10<10:17,  6.50s/it]

RMSPE = 0.10278514189461169, Accuracy = 0.75


 53%|█████▎    | 106/200 [11:17<10:00,  6.39s/it]

RMSPE = 0.10001978520958568, Accuracy = 0.7377196485623003


 54%|█████▎    | 107/200 [11:23<09:46,  6.31s/it]

RMSPE = 0.10154573931195103, Accuracy = 0.7490015974440895


 54%|█████▍    | 108/200 [11:29<09:35,  6.25s/it]

RMSPE = 0.1054102991240474, Accuracy = 0.7486022364217252


 55%|█████▍    | 109/200 [11:35<09:32,  6.30s/it]

RMSPE = 0.1263576428206584, Accuracy = 0.7183506389776357


 55%|█████▌    | 110/200 [11:42<09:28,  6.32s/it]

RMSPE = 0.11052037068544485, Accuracy = 0.726138178913738


 56%|█████▌    | 111/200 [11:48<09:23,  6.33s/it]

RMSPE = 0.10686824592157675, Accuracy = 0.7222444089456869


 56%|█████▌    | 112/200 [11:54<09:19,  6.35s/it]

RMSPE = 0.10172744573781285, Accuracy = 0.7553913738019169


 56%|█████▋    | 113/200 [12:01<09:09,  6.32s/it]

RMSPE = 0.10108613556090254, Accuracy = 0.7469049520766773


 57%|█████▋    | 114/200 [12:07<09:12,  6.42s/it]

RMSPE = 0.0978268322091514, Accuracy = 0.751797124600639


 57%|█████▊    | 115/200 [12:14<09:03,  6.39s/it]

RMSPE = 0.10516549744449866, Accuracy = 0.7452076677316294


 58%|█████▊    | 116/200 [12:20<08:52,  6.34s/it]

RMSPE = 0.10221820638869136, Accuracy = 0.7484025559105432


 58%|█████▊    | 117/200 [12:26<08:46,  6.35s/it]

RMSPE = 0.09491487306813462, Accuracy = 0.7460063897763578


 59%|█████▉    | 118/200 [12:33<08:44,  6.40s/it]

RMSPE = 0.09917107857644748, Accuracy = 0.7418130990415336


 60%|█████▉    | 119/200 [12:39<08:47,  6.51s/it]

RMSPE = 0.09457365623392618, Accuracy = 0.7499001597444089


 60%|██████    | 120/200 [12:46<08:52,  6.65s/it]

RMSPE = 0.09728922216465678, Accuracy = 0.7521964856230032


 60%|██████    | 121/200 [12:53<08:45,  6.65s/it]

RMSPE = 0.09740124521449732, Accuracy = 0.7553913738019169


 61%|██████    | 122/200 [13:00<08:43,  6.71s/it]

RMSPE = 0.09990589413494348, Accuracy = 0.751797124600639


 62%|██████▏   | 123/200 [13:07<08:46,  6.84s/it]

RMSPE = 0.09822522156154767, Accuracy = 0.7663738019169329


 62%|██████▏   | 124/200 [13:13<08:31,  6.73s/it]

RMSPE = 0.10504799060071238, Accuracy = 0.7488019169329073


 62%|██████▎   | 125/200 [13:21<08:33,  6.84s/it]

RMSPE = 0.09665021612145268, Accuracy = 0.7623801916932907


 63%|██████▎   | 126/200 [13:27<08:25,  6.83s/it]

RMSPE = 0.09843480370391292, Accuracy = 0.7351238019169329


 64%|██████▎   | 127/200 [13:35<08:28,  6.97s/it]

RMSPE = 0.09699391640318088, Accuracy = 0.7541932907348243


 64%|██████▍   | 128/200 [13:43<08:54,  7.42s/it]

RMSPE = 0.1011233721821072, Accuracy = 0.7567891373801917


 64%|██████▍   | 129/200 [13:51<08:55,  7.55s/it]

RMSPE = 0.09806299497620366, Accuracy = 0.7559904153354633


 65%|██████▌   | 130/200 [13:57<08:24,  7.21s/it]

RMSPE = 0.10373262361215707, Accuracy = 0.7380191693290735


 66%|██████▌   | 131/200 [14:04<08:00,  6.96s/it]

RMSPE = 0.11472113899434336, Accuracy = 0.7381190095846646


 66%|██████▌   | 132/200 [14:11<07:49,  6.90s/it]

RMSPE = 0.10259266545216496, Accuracy = 0.7585862619808307


 66%|██████▋   | 133/200 [14:17<07:35,  6.80s/it]

RMSPE = 0.11743862198564572, Accuracy = 0.7366214057507987


 67%|██████▋   | 134/200 [14:24<07:21,  6.69s/it]

RMSPE = 0.10510699560467046, Accuracy = 0.735223642172524


 68%|██████▊   | 135/200 [14:30<07:11,  6.63s/it]

RMSPE = 0.1114395447908499, Accuracy = 0.7183506389776357


 68%|██████▊   | 136/200 [14:37<07:01,  6.58s/it]

RMSPE = 0.09951605202671819, Accuracy = 0.7392172523961661


 68%|██████▊   | 137/200 [14:43<06:49,  6.50s/it]

RMSPE = 0.12159690651269005, Accuracy = 0.7141573482428115


 69%|██████▉   | 138/200 [14:49<06:41,  6.48s/it]

RMSPE = 0.1023839311287426, Accuracy = 0.7274361022364217


 70%|██████▉   | 139/200 [14:56<06:37,  6.52s/it]

RMSPE = 0.09652282092898798, Accuracy = 0.7612819488817891


 70%|███████   | 140/200 [15:02<06:32,  6.55s/it]

RMSPE = 0.10517987254233406, Accuracy = 0.7598841853035144


 70%|███████   | 141/200 [15:09<06:22,  6.49s/it]

RMSPE = 0.10282653117903505, Accuracy = 0.7493011182108626


 71%|███████   | 142/200 [15:15<06:13,  6.44s/it]

RMSPE = 0.09759811250070413, Accuracy = 0.7441094249201278


 72%|███████▏  | 143/200 [15:21<06:03,  6.38s/it]

RMSPE = 0.10057357194039007, Accuracy = 0.7404153354632588


 72%|███████▏  | 144/200 [15:29<06:17,  6.74s/it]

RMSPE = 0.10015850883131971, Accuracy = 0.7458067092651757


 72%|███████▎  | 145/200 [15:36<06:07,  6.67s/it]

RMSPE = 0.10105736318011634, Accuracy = 0.7527955271565495


 73%|███████▎  | 146/200 [15:42<06:03,  6.74s/it]

RMSPE = 0.10008149191784782, Accuracy = 0.7700678913738019


 74%|███████▎  | 147/200 [15:49<05:53,  6.66s/it]

RMSPE = 0.09900334917794401, Accuracy = 0.7478035143769968


 74%|███████▍  | 148/200 [15:55<05:39,  6.54s/it]

RMSPE = 0.09754074345857572, Accuracy = 0.7502995207667732


 74%|███████▍  | 149/200 [16:02<05:34,  6.55s/it]

RMSPE = 0.10783863296143163, Accuracy = 0.7241413738019169


 75%|███████▌  | 150/200 [16:09<05:43,  6.86s/it]

RMSPE = 0.1069641968074698, Accuracy = 0.7339257188498403


 76%|███████▌  | 151/200 [16:16<05:35,  6.85s/it]

RMSPE = 0.1014943368756733, Accuracy = 0.753694089456869


 76%|███████▌  | 152/200 [16:23<05:22,  6.73s/it]

RMSPE = 0.09896827348695396, Accuracy = 0.7471046325878594


 76%|███████▋  | 153/200 [16:29<05:09,  6.58s/it]

RMSPE = 0.09895002777679279, Accuracy = 0.7658746006389776


 77%|███████▋  | 154/200 [16:35<05:00,  6.53s/it]

RMSPE = 0.09659940899370577, Accuracy = 0.7551916932907349


 78%|███████▊  | 155/200 [16:41<04:47,  6.40s/it]

RMSPE = 0.09921495266520558, Accuracy = 0.7278354632587859


 78%|███████▊  | 156/200 [16:48<04:42,  6.42s/it]

RMSPE = 0.1038002352268932, Accuracy = 0.7312300319488818


 78%|███████▊  | 157/200 [16:54<04:36,  6.43s/it]

RMSPE = 0.09382729867871958, Accuracy = 0.7663738019169329


 79%|███████▉  | 158/200 [17:01<04:29,  6.42s/it]

RMSPE = 0.10050774734621992, Accuracy = 0.7556908945686901


 80%|███████▉  | 159/200 [17:07<04:26,  6.51s/it]

RMSPE = 0.10240025668384169, Accuracy = 0.7551916932907349


 80%|████████  | 160/200 [17:14<04:18,  6.46s/it]

RMSPE = 0.0989416418507838, Accuracy = 0.7603833865814696


 80%|████████  | 161/200 [17:20<04:10,  6.42s/it]

RMSPE = 0.0955737219831814, Accuracy = 0.7637779552715654


 81%|████████  | 162/200 [17:26<04:00,  6.32s/it]

RMSPE = 0.09514734565545195, Accuracy = 0.7395167731629393


 82%|████████▏ | 163/200 [17:32<03:51,  6.27s/it]

RMSPE = 0.10253266821178003, Accuracy = 0.7412140575079872


 82%|████████▏ | 164/200 [17:38<03:44,  6.23s/it]

RMSPE = 0.09767226808177777, Accuracy = 0.772064696485623


 82%|████████▎ | 165/200 [17:45<03:37,  6.22s/it]

RMSPE = 0.09344222320440097, Accuracy = 0.74810303514377


 83%|████████▎ | 166/200 [17:51<03:34,  6.32s/it]

RMSPE = 0.09672832029600875, Accuracy = 0.7583865814696485


 84%|████████▎ | 167/200 [17:57<03:28,  6.32s/it]

RMSPE = 0.10338296650792844, Accuracy = 0.7496006389776357


 84%|████████▍ | 168/200 [18:04<03:22,  6.33s/it]

RMSPE = 0.09625258142003618, Accuracy = 0.7521964856230032


 84%|████████▍ | 169/200 [18:10<03:15,  6.30s/it]

RMSPE = 0.09576852581561945, Accuracy = 0.762779552715655


 85%|████████▌ | 170/200 [18:17<03:10,  6.37s/it]

RMSPE = 0.10046222682197253, Accuracy = 0.7561900958466453


 86%|████████▌ | 171/200 [18:23<03:02,  6.30s/it]

RMSPE = 0.09467577831908917, Accuracy = 0.740714856230032


 86%|████████▌ | 172/200 [18:29<02:57,  6.34s/it]

RMSPE = 0.10354204992421519, Accuracy = 0.7388178913738019


 86%|████████▋ | 173/200 [18:36<02:51,  6.37s/it]

RMSPE = 0.10035589820565508, Accuracy = 0.7486022364217252


 87%|████████▋ | 174/200 [18:42<02:44,  6.34s/it]

RMSPE = 0.10495467193591328, Accuracy = 0.7330271565495208


 88%|████████▊ | 175/200 [18:48<02:37,  6.29s/it]

RMSPE = 0.10170407298083503, Accuracy = 0.7582867412140575


 88%|████████▊ | 176/200 [18:54<02:29,  6.24s/it]

RMSPE = 0.09939898862339817, Accuracy = 0.7447084664536742


 88%|████████▊ | 177/200 [19:01<02:24,  6.29s/it]

RMSPE = 0.09621097126041357, Accuracy = 0.7553913738019169


 89%|████████▉ | 178/200 [19:07<02:19,  6.33s/it]

RMSPE = 0.09620514400184345, Accuracy = 0.7455071884984026


 90%|████████▉ | 179/200 [19:13<02:13,  6.35s/it]

RMSPE = 0.0945758674139032, Accuracy = 0.7552915335463258


 90%|█████████ | 180/200 [19:20<02:08,  6.44s/it]

RMSPE = 0.09559200170893258, Accuracy = 0.7532947284345048


 90%|█████████ | 181/200 [19:26<02:02,  6.43s/it]

RMSPE = 0.10130707260233145, Accuracy = 0.7406150159744409


 91%|█████████ | 182/200 [19:33<01:54,  6.38s/it]

RMSPE = 0.09728920178862806, Accuracy = 0.7437100638977636


 92%|█████████▏| 183/200 [19:39<01:49,  6.45s/it]

RMSPE = 0.09531679711402796, Accuracy = 0.7593849840255591


 92%|█████████▏| 184/200 [19:46<01:42,  6.42s/it]

RMSPE = 0.10531654258886465, Accuracy = 0.7350239616613419


 92%|█████████▎| 185/200 [19:52<01:37,  6.48s/it]

RMSPE = 0.09634851728574917, Accuracy = 0.7368210862619808


 93%|█████████▎| 186/200 [20:00<01:35,  6.84s/it]

RMSPE = 0.09754726464470355, Accuracy = 0.7539936102236422


 94%|█████████▎| 187/200 [20:07<01:28,  6.84s/it]

RMSPE = 0.09887067526102827, Accuracy = 0.7662739616613419


 94%|█████████▍| 188/200 [20:13<01:19,  6.66s/it]

RMSPE = 0.09922674138801166, Accuracy = 0.7766573482428115


 94%|█████████▍| 189/200 [20:19<01:12,  6.58s/it]

RMSPE = 0.10097348842377099, Accuracy = 0.7568889776357828


 95%|█████████▌| 190/200 [20:26<01:05,  6.56s/it]

RMSPE = 0.09534612919290225, Accuracy = 0.7714656549520766


 96%|█████████▌| 191/200 [20:33<00:59,  6.57s/it]

RMSPE = 0.10676322772670478, Accuracy = 0.7512979233226837


 96%|█████████▌| 192/200 [20:39<00:53,  6.64s/it]

RMSPE = 0.09533502851812223, Accuracy = 0.766473642172524


 96%|█████████▋| 193/200 [20:46<00:46,  6.66s/it]

RMSPE = 0.09655813444346285, Accuracy = 0.7400159744408946


 97%|█████████▋| 194/200 [20:53<00:40,  6.67s/it]

RMSPE = 0.09651156060230999, Accuracy = 0.7488019169329073


 98%|█████████▊| 195/200 [21:00<00:33,  6.71s/it]

RMSPE = 0.10181725396515844, Accuracy = 0.7424121405750799


 98%|█████████▊| 196/200 [21:06<00:27,  6.77s/it]

RMSPE = 0.0984536282980023, Accuracy = 0.7494009584664537


 98%|█████████▊| 197/200 [21:14<00:20,  6.99s/it]

RMSPE = 0.10211070477010344, Accuracy = 0.7663738019169329


 99%|█████████▉| 198/200 [21:20<00:13,  6.76s/it]

RMSPE = 0.09651826731503581, Accuracy = 0.7564896166134185


100%|█████████▉| 199/200 [21:27<00:06,  6.64s/it]

RMSPE = 0.09444575037914343, Accuracy = 0.7605830670926518


100%|██████████| 200/200 [21:33<00:00,  6.47s/it]

RMSPE = 0.09967936149325234, Accuracy = 0.761082268370607





In [6]:
model = DA_MUSIC(dmin, dmax, array).to(dev)
model.load_state_dict(torch.load("damusic_{}dB.pth".format(snr), weights_only=True))
loss_func_test_global = RMSPE_varied_nbSources(dmin, dmax, dev)

running_acc_d = 0.0
running_loss_theta = 0.0

for data in test_loader:

    X, theta_true, label_true = data[0].to(dev), data[1].to(dev), data[2].to(dev)
    theta_pred, label_pred = model(X)
    loss_theta = loss_func_test_global.calculate(theta_pred, theta_true)
    acc_d = (torch.argmax(label_pred, dim=1) == torch.argmax(label_true, dim=1)).float().mean()
    running_loss_theta += loss_theta.item()
    running_acc_d += acc_d.item()

print("Accuracy of DA-MUSIC is {}%".format(running_acc_d/len(test_loader)*100))
print("RMSPE = {}".format(running_loss_theta/len(test_loader)))

Accuracy of DA-MUSIC is 75.3%
RMSPE = 0.09275875145196914


In [7]:
model = DA_MUSIC_v2(dmin, dmax, array).to(dev)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

loss_func1 = RMSPE_varied_nbSources(dmin, dmax, dev)
loss_func2 = nn.CrossEntropyLoss()

Loss, Loss_d, Loss_theta, Val_d, Val_theta = [], [], [], [], []

bestVal_d = 0.0
bestVal_theta = 1000.0


for i in tqdm(range(nbEpoches)):

    running_loss = 0.0
    running_loss_d = 0.0
    running_loss_theta = 0.0

    for data in train_loader:
        
        X, theta_true, label_true = data[0].to(dev), data[1].to(dev), data[2].to(dev)
        
        optimizer.zero_grad()
        
        theta_pred, label_pred = model(X)
        
        loss_theta = loss_func1.calculate(theta_pred, theta_true) 
        loss_theta.backward()
        
        loss_d = loss_func2(label_pred, label_true)
        loss_d.backward()
        
        optimizer.step()
        
        running_loss_theta += loss_theta.item()
        running_loss_d += loss_d.item()
        running_loss += loss_theta.item() + loss_d.item()
    
    Loss.append(running_loss/len(train_loader))
    Loss_d.append(running_loss_d/len(train_loader))
    Loss_theta.append(running_loss_theta/len(train_loader))

    with torch.no_grad():

        running_acc_d = 0.0
        running_loss_theta = 0.0

        for data in valid_loader:
            
            X, theta_true, label_true = data[0].to(dev), data[1].to(dev), data[2].to(dev)
            
            theta_pred, label_pred = model(X)
            
            loss_theta = loss_func1.calculate(theta_pred, theta_true) 
            acc_d = (torch.argmax(label_pred, dim=1) == torch.argmax(label_true, dim=1)).float().mean()
            
            running_loss_theta += loss_theta.item()
            running_acc_d += acc_d.item()
        
        Val_d.append(running_acc_d/len(valid_loader))
        Val_theta.append(running_loss_theta/len(valid_loader))

        # if Val_d[-1] > bestVal_d:
            
        #     bestVal_d = Val_d[-1]
        #     torch.save(model.state_dict(), "saved_model.pth")

        if Val_theta[-1] < bestVal_theta:
            bestVal_theta = Val_theta[-1]
            torch.save(model.state_dict(), "damusic_v2_{}dB.pth".format(snr))
            count = 0
        else:
            count += 1

        if count == 20:
            model.load_state_dict(torch.load("damusic_v2_{}dB.pth".format(snr), weights_only=True))


        print("RMSPE = {}, Accuracy = {}".format(Val_theta[-1], Val_d[-1]))

  0%|          | 0/200 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 must have the same dtype, but got ComplexFloat and Float

In [None]:
model = DA_MUSIC_v2(dmin, dmax, array).to(dev)
model.load_state_dict(torch.load("damusic_v2_{}dB.pth".format(snr), weights_only=True))
loss_func_test_global = RMSPE_varied_nbSources(dmin, dmax, dev)

running_acc_d = 0.0
running_loss_theta = 0.0

for data in test_loader:

    X, theta_true, label_true = data[0].to(dev), data[1].to(dev), data[2].to(dev)
    theta_pred, label_pred = model(X)
    loss_theta = loss_func_test_global.calculate(theta_pred, theta_true)
    acc_d = (torch.argmax(label_pred, dim=1) == torch.argmax(label_true, dim=1)).float().mean()
    running_loss_theta += loss_theta.item()
    running_acc_d += acc_d.item()

print("Accuracy of DA-MUSIC is {}%".format(running_acc_d/len(test_loader)*100))
print("RMSPE = {}".format(running_loss_theta/len(test_loader)))

In [None]:
nbSources_test = torch.argmax(theta_test, dim=1)

aic_test = AIC(x_test)
aic_acc = (aic_test == nbSources_test).float().mean()
print("Accuracy of AIC estimation is {}%".format(aic_acc * 100))

In [None]:
mdl_test = MDL(x_test)
mdl_acc = (mdl_test == nbSources_test).float().mean()
print("Accuracy of MDL estimation is {}%".format(mdl_acc * 100))

In [None]:
loss_func_test = RMSPE_varied_nbSources_test(dmin, dmax, dev)

for data in test_loader:

    X, theta_true = data[0].to(dev), data[1].to(dev)
    theta_pred, label_pred = model(X)
    loss_func_test.calculate(theta_pred, label_pred, theta_true)

loss_func_test.resume()