In [12]:
import torch 
import visdom
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

# prepare the visdom envirement
viz = visdom.Visdom(env='beamforming', server='10.15.89.41', port=38720, use_incoming_socket=False)

# 1. Prepare Dataset
G = 3
N = 50
K = 5
batchSize = 200


class BeamformingDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter = ',', dtype = np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-2*G*K])
        self.y_data = torch.from_numpy(xy[:, 2*N*G*K:])
        
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
    
    def __len__(self):
        return self.len

## Load the trainingSet and testingSet
trainingSet = BeamformingDataset('trainingData.csv.gz')
testingSet = BeamformingDataset('testingData.csv.gz')
train_loader = DataLoader(dataset = trainingSet, batch_size=batchSize, shuffle=True, num_workers=8)
test_loader = DataLoader(dataset = testingSet, batch_size=batchSize, shuffle=False, num_workers=8)

# 2. Design Model
class BeamformingNet(torch.nn.Module):
    def __init__(self):
        super(BeamformingNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 8, kernel_size = 3, padding = 1)
        # self.conv1 = torch.nn.DataParallel(self.conv1)
        self.conv2 = torch.nn.Conv2d(8, 8, kernel_size = 3, padding = 1)
        # self.conv2 = torch.nn.DataParallel(self.conv2)
        
        self.norm = torch.nn.BatchNorm2d(8, eps = 1e-03, momentum = 0.99)
    
        self.fc = torch.nn.Linear(12000, 2*G*K)
        
    def forward(self, x):
        BATCH_SIZE = x.size(0)
        
        x = x.view(BATCH_SIZE, 1, 1, -1)
        x = F.relu(self.norm(self.conv1(x)))
        x = F.relu(self.norm(self.conv2(x)))
        x = x.view(BATCH_SIZE, -1) #flatten
        x = self.fc(x)
        return x
        x = F.Sigmoid(x)
    
model = BeamformingNet()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# 3. Construct Loss and Optimizer
criterion = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-3)

# 4. Train and Test
def train(epoch):
    running_loss = 0.0
    train_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        inputs, target = inputs.to(device), target.to(device)
        optimizer.zero_grad()
        
        # forward + backward + update
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        train_loss += loss.item()
        if batch_idx % 2 == 1:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 2))
            running_loss = 0.0
    viz.line([train_loss], [epoch], win='train_loss', update='append')
            
def test():
    with torch.no_grad():
        running_loss = 0.0
        for data in test_loader:
            inputs, target = data
            inputs, target = inputs.to(device), target.to(device)
            
            # compute the loss
            outputs = model(inputs)
            loss = criterion(outputs, target)
            running_loss += loss.item()
            
        print('Running loss on test set: %.3f' % running_loss)
        
viz.line([0.], [0], win='train_loss', opts=dict(title='train_loss'))
for epoch in range(20):
    train(epoch)
    test()
            

Setting up a new session...
Without the incoming socket you cannot receive events from the server or register event handlers to your Visdom client.
failed CONNECT via proxy status: 403
failed CONNECT via proxy status: 403
failed CONNECT via proxy status: 403
failed CONNECT via proxy status: 403
failed CONNECT via proxy status: 403
failed CONNECT via proxy status: 403
failed CONNECT via proxy status: 403
failed CONNECT via proxy status: 403


[1,     2] loss: 1508.363
[1,     4] loss: 1471.709
[1,     6] loss: 1494.134
[1,     8] loss: 1489.619
[1,    10] loss: 1512.054
[1,    12] loss: 1491.034
[1,    14] loss: 1480.291
[1,    16] loss: 1484.556
[1,    18] loss: 1488.584
[1,    20] loss: 1490.449
[1,    22] loss: 1482.696
[1,    24] loss: 1459.649
[1,    26] loss: 1442.905
[1,    28] loss: 1460.367
[1,    30] loss: 1444.336
[1,    32] loss: 1416.643
[1,    34] loss: 1421.786
[1,    36] loss: 1411.623
[1,    38] loss: 1388.725
[1,    40] loss: 1420.643
[1,    42] loss: 1384.911
[1,    44] loss: 1359.018
[1,    46] loss: 1366.466
[1,    48] loss: 1373.813
[1,    50] loss: 1358.705
[1,    52] loss: 1343.786
[1,    54] loss: 1328.463
[1,    56] loss: 1340.003
[1,    58] loss: 1322.870
[1,    60] loss: 1356.004
[1,    62] loss: 1332.193
[1,    64] loss: 1313.434
[1,    66] loss: 1299.124
[1,    68] loss: 1305.334
[1,    70] loss: 1292.227
[1,    72] loss: 1269.012
[1,    74] loss: 1280.259
[1,    76] loss: 1266.027
[1,    78] l

failed CONNECT via proxy status: 403


Running loss on test set: 29116.391
[2,     2] loss: 1124.289
[2,     4] loss: 1133.138
[2,     6] loss: 1151.255
[2,     8] loss: 1101.430
[2,    10] loss: 1073.820
[2,    12] loss: 1084.060
[2,    14] loss: 1078.298
[2,    16] loss: 1097.016
[2,    18] loss: 1025.069
[2,    20] loss: 1060.799
[2,    22] loss: 1046.692
[2,    24] loss: 1027.150
[2,    26] loss: 1054.553
[2,    28] loss: 999.875
[2,    30] loss: 1011.008
[2,    32] loss: 1027.352
[2,    34] loss: 1013.125
[2,    36] loss: 998.562
[2,    38] loss: 989.248
[2,    40] loss: 956.610
[2,    42] loss: 970.326
[2,    44] loss: 975.104
[2,    46] loss: 936.360
[2,    48] loss: 922.861
[2,    50] loss: 933.911
[2,    52] loss: 900.303
[2,    54] loss: 919.651
[2,    56] loss: 893.046
[2,    58] loss: 895.911
[2,    60] loss: 902.159
[2,    62] loss: 883.581
[2,    64] loss: 872.235
[2,    66] loss: 874.100
[2,    68] loss: 860.467
[2,    70] loss: 849.048
[2,    72] loss: 822.199
[2,    74] loss: 843.775
[2,    76] loss: 843.19

failed CONNECT via proxy status: 403


[3,    60] loss: 489.354
[3,    62] loss: 477.708
[3,    64] loss: 480.073
[3,    66] loss: 463.472
[3,    68] loss: 460.529
[3,    70] loss: 453.396
[3,    72] loss: 443.495
[3,    74] loss: 433.999
[3,    76] loss: 430.842
[3,    78] loss: 424.207
[3,    80] loss: 431.541
[3,    82] loss: 412.322
[3,    84] loss: 414.166
[3,    86] loss: 408.877
[3,    88] loss: 398.844
[3,    90] loss: 398.610
[3,    92] loss: 391.525
[3,    94] loss: 382.322
[3,    96] loss: 381.899
[3,    98] loss: 371.485
[3,   100] loss: 364.629
Running loss on test set: 10170.582
[4,     2] loss: 339.920
[4,     4] loss: 330.552
[4,     6] loss: 328.884
[4,     8] loss: 324.895
[4,    10] loss: 301.802
[4,    12] loss: 305.404
[4,    14] loss: 305.016
[4,    16] loss: 301.259
[4,    18] loss: 299.469
[4,    20] loss: 284.896
[4,    22] loss: 283.172
[4,    24] loss: 291.479
[4,    26] loss: 279.291
[4,    28] loss: 279.866
[4,    30] loss: 270.976
[4,    32] loss: 271.738
[4,    34] loss: 264.248
[4,    36] los

failed CONNECT via proxy status: 403


[5,     2] loss: 135.420
[5,     4] loss: 135.328
[5,     6] loss: 128.075
[5,     8] loss: 125.630
[5,    10] loss: 126.081
[5,    12] loss: 121.466
[5,    14] loss: 120.747
[5,    16] loss: 120.775
[5,    18] loss: 114.546
[5,    20] loss: 112.136
[5,    22] loss: 111.183
[5,    24] loss: 109.779
[5,    26] loss: 108.639
[5,    28] loss: 106.832
[5,    30] loss: 102.256
[5,    32] loss: 100.446
[5,    34] loss: 98.188
[5,    36] loss: 96.943
[5,    38] loss: 94.276
[5,    40] loss: 93.367
[5,    42] loss: 92.114
[5,    44] loss: 90.920
[5,    46] loss: 93.071
[5,    48] loss: 89.832
[5,    50] loss: 86.016
[5,    52] loss: 85.527
[5,    54] loss: 85.926
[5,    56] loss: 83.202
[5,    58] loss: 81.052
[5,    60] loss: 79.279
[5,    62] loss: 77.589
[5,    64] loss: 74.751
[5,    66] loss: 73.340
[5,    68] loss: 73.079
[5,    70] loss: 71.263
[5,    72] loss: 69.872
[5,    74] loss: 69.202
[5,    76] loss: 67.900
[5,    78] loss: 67.060
[5,    80] loss: 63.907
[5,    82] loss: 64.510


failed CONNECT via proxy status: 403


Running loss on test set: 863.108
[7,     2] loss: 19.050
[7,     4] loss: 18.539
[7,     6] loss: 18.244
[7,     8] loss: 18.644
[7,    10] loss: 17.594
[7,    12] loss: 17.541
[7,    14] loss: 17.030
[7,    16] loss: 16.672
[7,    18] loss: 16.677
[7,    20] loss: 16.543
[7,    22] loss: 16.501
[7,    24] loss: 16.443
[7,    26] loss: 15.908
[7,    28] loss: 15.723
[7,    30] loss: 15.953
[7,    32] loss: 15.663
[7,    34] loss: 15.370
[7,    36] loss: 15.122
[7,    38] loss: 14.513
[7,    40] loss: 15.235
[7,    42] loss: 14.948
[7,    44] loss: 14.866
[7,    46] loss: 14.174
[7,    48] loss: 13.913
[7,    50] loss: 13.901
[7,    52] loss: 13.507
[7,    54] loss: 13.860
[7,    56] loss: 13.173
[7,    58] loss: 13.409
[7,    60] loss: 13.398
[7,    62] loss: 12.802
[7,    64] loss: 13.128
[7,    66] loss: 12.650
[7,    68] loss: 12.736
[7,    70] loss: 12.500
[7,    72] loss: 12.020
[7,    74] loss: 11.941
[7,    76] loss: 12.094
[7,    78] loss: 11.807
[7,    80] loss: 11.453
[7,   

failed CONNECT via proxy status: 403


[8,    60] loss: 7.225
[8,    62] loss: 7.196
[8,    64] loss: 7.028
[8,    66] loss: 6.998
[8,    68] loss: 6.881
[8,    70] loss: 6.868
[8,    72] loss: 6.868
[8,    74] loss: 6.917
[8,    76] loss: 6.614
[8,    78] loss: 6.875
[8,    80] loss: 6.528
[8,    82] loss: 6.530
[8,    84] loss: 6.552
[8,    86] loss: 6.580
[8,    88] loss: 6.290
[8,    90] loss: 6.185
[8,    92] loss: 6.225
[8,    94] loss: 6.380
[8,    96] loss: 6.212
[8,    98] loss: 6.304
[8,   100] loss: 6.084
Running loss on test set: 260.667
[9,     2] loss: 5.431
[9,     4] loss: 5.249
[9,     6] loss: 5.541
[9,     8] loss: 5.151
[9,    10] loss: 5.158
[9,    12] loss: 5.173
[9,    14] loss: 5.153
[9,    16] loss: 5.068
[9,    18] loss: 5.066
[9,    20] loss: 4.933
[9,    22] loss: 5.021
[9,    24] loss: 4.963
[9,    26] loss: 4.942
[9,    28] loss: 4.828
[9,    30] loss: 4.968
[9,    32] loss: 4.944
[9,    34] loss: 4.743
[9,    36] loss: 4.813
[9,    38] loss: 4.752
[9,    40] loss: 4.725
[9,    42] loss: 4.735


failed CONNECT via proxy status: 403


[10,     2] loss: 3.602
[10,     4] loss: 3.503
[10,     6] loss: 3.531
[10,     8] loss: 3.555
[10,    10] loss: 3.539
[10,    12] loss: 3.524
[10,    14] loss: 3.480
[10,    16] loss: 3.448
[10,    18] loss: 3.417
[10,    20] loss: 3.351
[10,    22] loss: 3.456
[10,    24] loss: 3.476
[10,    26] loss: 3.351
[10,    28] loss: 3.552
[10,    30] loss: 3.367
[10,    32] loss: 3.356
[10,    34] loss: 3.341
[10,    36] loss: 3.350
[10,    38] loss: 3.418
[10,    40] loss: 3.446
[10,    42] loss: 3.280
[10,    44] loss: 3.414
[10,    46] loss: 3.258
[10,    48] loss: 3.306
[10,    50] loss: 3.256
[10,    52] loss: 3.187
[10,    54] loss: 3.248
[10,    56] loss: 3.188
[10,    58] loss: 3.194
[10,    60] loss: 3.209
[10,    62] loss: 3.245
[10,    64] loss: 3.265
[10,    66] loss: 3.266
[10,    68] loss: 3.282
[10,    70] loss: 3.200
[10,    72] loss: 3.160
[10,    74] loss: 3.253
[10,    76] loss: 3.201
[10,    78] loss: 3.101
[10,    80] loss: 3.149
[10,    82] loss: 3.030
[10,    84] loss

failed CONNECT via proxy status: 403


Running loss on test set: 96.858
[12,     2] loss: 1.998
[12,     4] loss: 1.964
[12,     6] loss: 2.036
[12,     8] loss: 2.021
[12,    10] loss: 2.012
[12,    12] loss: 2.017
[12,    14] loss: 1.970
[12,    16] loss: 1.948
[12,    18] loss: 1.945
[12,    20] loss: 1.966
[12,    22] loss: 1.990
[12,    24] loss: 2.010
[12,    26] loss: 1.916
[12,    28] loss: 1.903
[12,    30] loss: 1.883
[12,    32] loss: 1.923
[12,    34] loss: 1.905
[12,    36] loss: 1.853
[12,    38] loss: 1.930
[12,    40] loss: 1.957
[12,    42] loss: 1.955
[12,    44] loss: 1.918
[12,    46] loss: 1.977
[12,    48] loss: 1.955
[12,    50] loss: 1.903
[12,    52] loss: 1.957
[12,    54] loss: 1.935
[12,    56] loss: 1.962
[12,    58] loss: 1.970
[12,    60] loss: 1.869
[12,    62] loss: 1.884
[12,    64] loss: 1.942
[12,    66] loss: 1.859
[12,    68] loss: 1.881
[12,    70] loss: 1.912
[12,    72] loss: 1.877
[12,    74] loss: 1.912
[12,    76] loss: 1.906
[12,    78] loss: 1.785
[12,    80] loss: 1.887
[12,   

failed CONNECT via proxy status: 403


[13,     2] loss: 1.591
[13,     4] loss: 1.595
[13,     6] loss: 1.591
[13,     8] loss: 1.569
[13,    10] loss: 1.524
[13,    12] loss: 1.546
[13,    14] loss: 1.527
[13,    16] loss: 1.551
[13,    18] loss: 1.513
[13,    20] loss: 1.534
[13,    22] loss: 1.515
[13,    24] loss: 1.520
[13,    26] loss: 1.507
[13,    28] loss: 1.509
[13,    30] loss: 1.540
[13,    32] loss: 1.515
[13,    34] loss: 1.529
[13,    36] loss: 1.543
[13,    38] loss: 1.585
[13,    40] loss: 1.566
[13,    42] loss: 1.723
[13,    44] loss: 1.597
[13,    46] loss: 1.568
[13,    48] loss: 1.587
[13,    50] loss: 1.582
[13,    52] loss: 1.581
[13,    54] loss: 1.623
[13,    56] loss: 1.536
[13,    58] loss: 1.555
[13,    60] loss: 1.550
[13,    62] loss: 1.533
[13,    64] loss: 1.621
[13,    66] loss: 1.538
[13,    68] loss: 1.594
[13,    70] loss: 1.544
[13,    72] loss: 1.538
[13,    74] loss: 1.556
[13,    76] loss: 1.585
[13,    78] loss: 1.552
[13,    80] loss: 1.517
[13,    82] loss: 1.535
[13,    84] loss

failed CONNECT via proxy status: 403


Running loss on test set: 53.264
[15,     2] loss: 1.079
[15,     4] loss: 1.103
[15,     6] loss: 1.076
[15,     8] loss: 1.045
[15,    10] loss: 1.105
[15,    12] loss: 1.111
[15,    14] loss: 1.062
[15,    16] loss: 1.111
[15,    18] loss: 1.084
[15,    20] loss: 1.088
[15,    22] loss: 1.117
[15,    24] loss: 1.053
[15,    26] loss: 1.131
[15,    28] loss: 1.168
[15,    30] loss: 1.098
[15,    32] loss: 1.068
[15,    34] loss: 1.024
[15,    36] loss: 1.069
[15,    38] loss: 1.061
[15,    40] loss: 1.090
[15,    42] loss: 1.029
[15,    44] loss: 1.036
[15,    46] loss: 1.091
[15,    48] loss: 1.147
[15,    50] loss: 1.129
[15,    52] loss: 1.035
[15,    54] loss: 1.057
[15,    56] loss: 1.112
[15,    58] loss: 1.146
[15,    60] loss: 1.091
[15,    62] loss: 1.123
[15,    64] loss: 1.169
[15,    66] loss: 1.177
[15,    68] loss: 1.132
[15,    70] loss: 1.081
[15,    72] loss: 1.121
[15,    74] loss: 1.117
[15,    76] loss: 1.063
[15,    78] loss: 1.059
[15,    80] loss: 1.037
[15,   

failed CONNECT via proxy status: 403


[16,    88] loss: 0.964
[16,    90] loss: 0.872
[16,    92] loss: 0.891
[16,    94] loss: 0.961
[16,    96] loss: 0.899
[16,    98] loss: 1.034
[16,   100] loss: 1.185
Running loss on test set: 42.657
[17,     2] loss: 0.964
[17,     4] loss: 0.856
[17,     6] loss: 0.746
[17,     8] loss: 0.814
[17,    10] loss: 0.803
[17,    12] loss: 0.790
[17,    14] loss: 0.747
[17,    16] loss: 0.724
[17,    18] loss: 0.743
[17,    20] loss: 0.751
[17,    22] loss: 0.756
[17,    24] loss: 0.826
[17,    26] loss: 0.801
[17,    28] loss: 0.768
[17,    30] loss: 0.764
[17,    32] loss: 0.863
[17,    34] loss: 0.779
[17,    36] loss: 0.804
[17,    38] loss: 0.815
[17,    40] loss: 0.762
[17,    42] loss: 0.782
[17,    44] loss: 0.859
[17,    46] loss: 0.908
[17,    48] loss: 0.815
[17,    50] loss: 0.847
[17,    52] loss: 0.906
[17,    54] loss: 0.794
[17,    56] loss: 0.832
[17,    58] loss: 0.868
[17,    60] loss: 0.848
[17,    62] loss: 0.763
[17,    64] loss: 0.856
[17,    66] loss: 0.911
[17,   

failed CONNECT via proxy status: 403


[18,     2] loss: 0.881
[18,     4] loss: 0.739
[18,     6] loss: 0.692
[18,     8] loss: 0.816
[18,    10] loss: 0.800
[18,    12] loss: 0.680
[18,    14] loss: 0.648
[18,    16] loss: 0.685
[18,    18] loss: 0.706
[18,    20] loss: 0.669
[18,    22] loss: 0.645
[18,    24] loss: 0.703
[18,    26] loss: 0.764
[18,    28] loss: 0.757
[18,    30] loss: 0.743
[18,    32] loss: 0.728
[18,    34] loss: 0.801
[18,    36] loss: 0.912
[18,    38] loss: 0.782
[18,    40] loss: 0.766
[18,    42] loss: 0.727
[18,    44] loss: 0.841
[18,    46] loss: 0.842
[18,    48] loss: 0.733
[18,    50] loss: 0.696
[18,    52] loss: 0.695
[18,    54] loss: 0.672
[18,    56] loss: 0.731
[18,    58] loss: 0.766
[18,    60] loss: 0.912
[18,    62] loss: 0.903
[18,    64] loss: 0.952
[18,    66] loss: 0.875
[18,    68] loss: 0.899
[18,    70] loss: 0.797
[18,    72] loss: 0.691
[18,    74] loss: 0.803
[18,    76] loss: 0.781
[18,    78] loss: 0.689
[18,    80] loss: 0.695
[18,    82] loss: 0.716
[18,    84] loss

failed CONNECT via proxy status: 403


Running loss on test set: 25.071
[20,     2] loss: 0.564
[20,     4] loss: 0.621
[20,     6] loss: 0.803
[20,     8] loss: 0.775
[20,    10] loss: 0.683
[20,    12] loss: 0.568
[20,    14] loss: 0.762
[20,    16] loss: 0.800
[20,    18] loss: 0.737
[20,    20] loss: 0.711
[20,    22] loss: 0.686
[20,    24] loss: 0.621
[20,    26] loss: 0.558
[20,    28] loss: 0.547
[20,    30] loss: 0.542
[20,    32] loss: 0.586
[20,    34] loss: 0.635
[20,    36] loss: 0.594
[20,    38] loss: 0.586
[20,    40] loss: 0.642
[20,    42] loss: 0.572
[20,    44] loss: 0.552
[20,    46] loss: 0.538
[20,    48] loss: 0.573
[20,    50] loss: 0.594
[20,    52] loss: 0.612
[20,    54] loss: 0.579
[20,    56] loss: 0.528
[20,    58] loss: 0.533
[20,    60] loss: 0.536
[20,    62] loss: 0.569
[20,    64] loss: 0.585
[20,    66] loss: 0.574
[20,    68] loss: 0.548
[20,    70] loss: 0.522
[20,    72] loss: 0.533
[20,    74] loss: 0.564
[20,    76] loss: 0.572
[20,    78] loss: 0.618
[20,    80] loss: 0.625
[20,   

In [13]:
# Generate the dataset
import numpy as np

np.random.seed(0)

G = 3
N = 50
K = 5
trainingSample = 20000
testingSample = 5000

trainingInputs = np.random.randn(trainingSample, 2*N*G*K)
testingInputs = np.random.randn(testingSample, 2*N*G*K)

# outputs size: 2*G*k
a = np.random.randn(2*N*G*K, 2*G*K)
trainingOutputs = np.dot(trainingInputs, a)
testingOutputs = np.dot(testingInputs, a)

trainingData = np.hstack((trainingInputs, trainingOutputs))
testingData = np.hstack((testingInputs, testingOutputs))

np.savetxt('trainingData.csv.gz', trainingData, fmt='%5f', delimiter=',')
np.savetxt('testingData.csv.gz', testingData, fmt='%5f', delimiter=',')

In [8]:
import torch 
import visdom
import numpy as np

vis = visdom.Visdom(env='test', server='10.15.89.41', port=38720, use_incoming_socket=False)

##
trace = dict(x=[1, 2, 3], y=[4, 5, 6], mode="markers+lines", type='custom',
             marker={'color': 'red', 'symbol': 104, 'size': "10"},
             text=["one", "two", "three"], name='1st Trace')
layout = dict(title="First Plot", xaxis={'title': 'x'}, yaxis={'title': 'y'})

vis._send({'data': [trace], 'layout': layout, 'win': 'mywin'})

Setting up a new session...
Without the incoming socket you cannot receive events from the server or register event handlers to your Visdom client.


'mywin'