Train a CNN model for beamforming using hybrid supervised and unsupervised training

In [57]:
import os
import sys
import json
import torch
import importlib
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from typing import List, Dict
import importlib
# Get the current working directory
scripts_dir = os.getcwd()
# Go up one level
project_root = os.path.abspath(os.path.join(scripts_dir, '..'))
sys.path.append(project_root)

import src.CNN
importlib.reload(src.CNN)
from src.CNN import ChannelCNN, Trainer

In [58]:
import os
import sys
import json
import torch
import importlib
import matplotlib.pyplot as plt
import pandas as pd
# Get the current working directory
scripts_dir = os.getcwd()
# Go up one level
project_root = os.path.abspath(os.path.join(scripts_dir, '..'))
sys.path.append(project_root)

import src.utils
importlib.reload(src.utils)
from src.utils import calculate_sum_rate_sc

import src.sc_wmmse
importlib.reload(src.sc_wmmse)
from src.sc_wmmse import WMMSE_alg_sc

In [3]:
class setup():
    def __init__(self, n_tx, n_rx:int, num_streams:int, num_users, PT):
        self.n_tx = n_tx
        self.n_rx = n_rx
        self.d = num_streams
        self.K = num_users
        self.PT = PT

In [59]:
# Define the setup
num_users = 10
n_tx = 4
n_rx = 2
num_streams = 2
PT = 100
set_up = setup(n_tx, n_rx, num_streams, num_users, PT)

# Defien the CNN model and the trainer
cn = ChannelCNN(set_up)
tr = Trainer(set_up, cn)

In [32]:
def proj_power(V, PT_sc):
    num_users_sc = len(V)
    # Projects V according to the constraint
    alph = torch.sqrt(torch.tensor(PT_sc)) / torch.sqrt(torch.tensor(sum([torch.trace(V[str(k)] @ V[str(k)].conj().T) for k in range(num_users_sc)])))
    V_proj = {str(k): alph * V[str(k)] for k in range(num_users_sc)}
    return V_proj

def init_V(H):
    # Initializes V according to Hu's code
    V = {}
    for k in range(len(H_dict)):
        V[str(k)] = (torch.linalg.pinv(H[str(k)] @ H[str(k)].conj().T) @ H[str(k)]).conj().T
    return V

# The setup
num_users = 10
n_tx = 4
n_rx = [2] * num_users
d = [2] * num_users
PT = 100
sig = [1] * num_users
alpha = [1] * num_users
max_iter_alg = 100
tol_alg = 1e-3

data = []
for _ in range(100):  # 5 rows
    row = {f'user_{i}': torch.randn(n_rx[i], n_tx, dtype=torch.cdouble) for i in range(num_users)}
    data.append(row)

H = pd.DataFrame(data)

V_col = []
V_init_col = []

for idx, row in H.iterrows():
    H_dict = {str(i): row[i] for i in range(len(row))}
    wmm = WMMSE_alg_sc(K=num_users, n_tx=n_tx, n_rx=n_rx, H=H_dict, PT=PT, sig_k=sig, d=d, alpha=alpha, max_iter_alg=max_iter_alg, tol_alg=tol_alg)
    V_init = init_V(H_dict)
    V_init_proj = proj_power(V_init, PT)
    V_l, U_l, W_l = wmm.algorithm(V_init_proj)
    V_init_col.append(V_init_proj)
    V_col.append(V_l[-1])

V_df = pd.DataFrame(V_col)
V_init_df = pd.DataFrame(V_init_col)

# dset = pd.concat([H, V_init_df, V_df], axis=1)
dset = pd.concat([H, V_df], axis=1)

  H_dict = {str(i): row[i] for i in range(len(row))}
  alph = torch.sqrt(torch.tensor(PT_sc)) / torch.sqrt(torch.tensor(sum([torch.trace(V[str(k)] @ V[str(k)].conj().T) for k in range(num_users_sc)])))
  alph = torch.sqrt(torch.tensor(self.PT)) / torch.sqrt(torch.tensor(sum([torch.trace(V[str(k)] @ V[str(k)].conj().T) for k in range(self.K)])))


In [34]:
dset

Unnamed: 0,user_0,user_1,user_2,user_3,user_4,user_5,user_6,user_7,user_8,user_9,0,1,2,3,4,5,6,7,8,9
0,"[[tensor(-0.3174+0.2691j, dtype=torch.complex1...","[[tensor(-0.7710-0.6892j, dtype=torch.complex1...","[[tensor(0.4248-0.7853j, dtype=torch.complex12...","[[tensor(-0.6270-0.4134j, dtype=torch.complex1...","[[tensor(0.8697+0.2104j, dtype=torch.complex12...","[[tensor(-1.3601-0.5822j, dtype=torch.complex1...","[[tensor(-1.4476-0.1881j, dtype=torch.complex1...","[[tensor(0.6659-0.6918j, dtype=torch.complex12...","[[tensor(0.6049+0.2167j, dtype=torch.complex12...","[[tensor(0.1799-1.4796j, dtype=torch.complex12...","[[tensor(1.3000-1.8167j, dtype=torch.complex12...","[[tensor(-4.2119e-116-3.0417e-116j, dtype=torc...","[[tensor(0.2109-1.6266j, dtype=torch.complex12...","[[tensor(-0.9600+0.5363j, dtype=torch.complex1...","[[tensor(-1.2858e-114-8.3430e-115j, dtype=torc...","[[tensor(3.5927e-119+5.1662e-119j, dtype=torch...","[[tensor(1.4389e-116-1.2445e-115j, dtype=torch...","[[tensor(-5.7934e-120+3.3107e-121j, dtype=torc...","[[tensor(-1.5601+0.9222j, dtype=torch.complex1...","[[tensor(5.6921e-115+9.3346e-116j, dtype=torch..."
1,"[[tensor(0.6888-0.2492j, dtype=torch.complex12...","[[tensor(-1.1235+0.3720j, dtype=torch.complex1...","[[tensor(-1.6394-0.5866j, dtype=torch.complex1...","[[tensor(-0.9216-1.2958j, dtype=torch.complex1...","[[tensor(-0.6637+0.1286j, dtype=torch.complex1...","[[tensor(0.5788-1.3278j, dtype=torch.complex12...","[[tensor(-0.6855+0.8910j, dtype=torch.complex1...","[[tensor(-0.3278+0.3371j, dtype=torch.complex1...","[[tensor(0.4953-1.2797j, dtype=torch.complex12...","[[tensor(0.3691+0.3681j, dtype=torch.complex12...","[[tensor(1.5345-0.7531j, dtype=torch.complex12...","[[tensor(8.7742e-135-6.2374e-135j, dtype=torch...","[[tensor(-1.8854e-133+9.1612e-133j, dtype=torc...","[[tensor(-0.0496-0.0210j, dtype=torch.complex1...","[[tensor(-1.0812-1.3502j, dtype=torch.complex1...","[[tensor(-7.6978e-137+1.0926e-136j, dtype=torc...","[[tensor(2.1594e-136-1.2636e-136j, dtype=torch...","[[tensor(5.8188e-133+5.4773e-133j, dtype=torch...","[[tensor(2.4930e-129+2.9628e-129j, dtype=torch...","[[tensor(3.4713-1.2082j, dtype=torch.complex12..."
2,"[[tensor(1.3437-0.7390j, dtype=torch.complex12...","[[tensor(-0.4524-0.0027j, dtype=torch.complex1...","[[tensor(-0.0045+0.9367j, dtype=torch.complex1...","[[tensor(0.0475-1.3670j, dtype=torch.complex12...","[[tensor(1.3733-0.4788j, dtype=torch.complex12...","[[tensor(-0.5951-0.4008j, dtype=torch.complex1...","[[tensor(0.0137+1.3554j, dtype=torch.complex12...","[[tensor(0.5322-0.2477j, dtype=torch.complex12...","[[tensor(-1.2356+0.6202j, dtype=torch.complex1...","[[tensor(0.5391-1.4007j, dtype=torch.complex12...","[[tensor(-1.8504e-131-5.2960e-133j, dtype=torc...","[[tensor(-3.0267e-126+2.2739e-126j, dtype=torc...","[[tensor(2.5817-0.8452j, dtype=torch.complex12...","[[tensor(-0.5726+0.0891j, dtype=torch.complex1...","[[tensor(1.6840e-138+4.7683e-138j, dtype=torch...","[[tensor(3.2245e-128+1.3623e-128j, dtype=torch...","[[tensor(-2.0110e-138-2.7966e-137j, dtype=torc...","[[tensor(-1.7747e-125+1.2775e-125j, dtype=torc...","[[tensor(-0.1810+0.3051j, dtype=torch.complex1...","[[tensor(0.3164-0.7610j, dtype=torch.complex12..."
3,"[[tensor(-0.6601+1.2663j, dtype=torch.complex1...","[[tensor(0.4808-0.1574j, dtype=torch.complex12...","[[tensor(0.1407+0.1795j, dtype=torch.complex12...","[[tensor(0.1139+1.3220j, dtype=torch.complex12...","[[tensor(1.0005+0.9430j, dtype=torch.complex12...","[[tensor(-0.0556-1.3588j, dtype=torch.complex1...","[[tensor(0.9017-0.5273j, dtype=torch.complex12...","[[tensor(0.0731-0.5748j, dtype=torch.complex12...","[[tensor(-0.3227-1.1620j, dtype=torch.complex1...","[[tensor(0.2688+0.5712j, dtype=torch.complex12...","[[tensor(1.5771e-143+3.2965e-144j, dtype=torch...","[[tensor(2.7675+1.9298j, dtype=torch.complex12...","[[tensor(0.3285-1.0953j, dtype=torch.complex12...","[[tensor(-3.4300e-146+5.6060e-146j, dtype=torc...","[[tensor(-5.6638e-149+5.1512e-149j, dtype=torc...","[[tensor(-0.2813+2.3720j, dtype=torch.complex1...","[[tensor(2.2152e-150+8.4328e-151j, dtype=torch...","[[tensor(-1.1325e-149-3.1718e-149j, dtype=torc...","[[tensor(-4.3557e-145+8.6984e-145j, dtype=torc...","[[tensor(0.0219+0.9383j, dtype=torch.complex12..."
4,"[[tensor(-1.1457-0.3892j, dtype=torch.complex1...","[[tensor(-0.5187-0.3638j, dtype=torch.complex1...","[[tensor(0.7049-1.3787j, dtype=torch.complex12...","[[tensor(0.2324+0.3872j, dtype=torch.complex12...","[[tensor(0.1782-1.3190j, dtype=torch.complex12...","[[tensor(-0.5742-0.5926j, dtype=torch.complex1...","[[tensor(0.2066+0.6997j, dtype=torch.complex12...","[[tensor(1.2183+0.3556j, dtype=torch.complex12...","[[tensor(0.3444+0.2601j, dtype=torch.complex12...","[[tensor(-0.2001-0.9365j, dtype=torch.complex1...","[[tensor(-3.1969e-119+3.4601e-119j, dtype=torc...","[[tensor(8.2181e-120+5.6860e-120j, dtype=torch...","[[tensor(3.5252e-117-4.1327e-118j, dtype=torch...","[[tensor(-3.0041+0.6439j, dtype=torch.complex1...","[[tensor(1.2168+0.3968j, dtype=torch.complex12...","[[tensor(7.2271e-127+1.0092e-125j, dtype=torch...","[[tensor(0.8348-0.2216j, dtype=torch.complex12...","[[tensor(3.6027e-113+1.5094e-113j, dtype=torch...","[[tensor(0.1647+1.1165j, dtype=torch.complex12...","[[tensor(-1.3514e-114+3.8123e-113j, dtype=torc..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,"[[tensor(-0.6088-0.0546j, dtype=torch.complex1...","[[tensor(0.8194+0.4362j, dtype=torch.complex12...","[[tensor(0.7557-0.3543j, dtype=torch.complex12...","[[tensor(-0.7104+0.7286j, dtype=torch.complex1...","[[tensor(-0.5020-0.0752j, dtype=torch.complex1...","[[tensor(-0.3098-1.1018j, dtype=torch.complex1...","[[tensor(-0.0142+1.4732j, dtype=torch.complex1...","[[tensor(0.3815+0.2055j, dtype=torch.complex12...","[[tensor(0.0537+1.0362j, dtype=torch.complex12...","[[tensor(-0.3763+0.5018j, dtype=torch.complex1...","[[tensor(1.2816e-131-1.0220e-131j, dtype=torch...","[[tensor(4.9990e-136-7.6478e-136j, dtype=torch...","[[tensor(-0.1010+2.5646j, dtype=torch.complex1...","[[tensor(1.2801-1.7102j, dtype=torch.complex12...","[[tensor(-1.6707e-141+2.2302e-141j, dtype=torc...","[[tensor(-0.1803+1.7907j, dtype=torch.complex1...","[[tensor(-1.9609e-135-5.5042e-135j, dtype=torc...","[[tensor(3.0529e-135-9.8636e-135j, dtype=torch...","[[tensor(0.8597-1.0150j, dtype=torch.complex12...","[[tensor(4.1144e-135-2.6097e-135j, dtype=torch..."
96,"[[tensor(-0.1900+1.2547j, dtype=torch.complex1...","[[tensor(0.7454-0.6015j, dtype=torch.complex12...","[[tensor(1.1742+0.6464j, dtype=torch.complex12...","[[tensor(-0.5767-0.5780j, dtype=torch.complex1...","[[tensor(-0.1007+0.3206j, dtype=torch.complex1...","[[tensor(-0.6714-1.4303j, dtype=torch.complex1...","[[tensor(-1.0949+1.1514j, dtype=torch.complex1...","[[tensor(-1.8528-0.7419j, dtype=torch.complex1...","[[tensor(0.3167-0.3425j, dtype=torch.complex12...","[[tensor(0.6074+0.6445j, dtype=torch.complex12...","[[tensor(-1.0298e-141+3.6699e-142j, dtype=torc...","[[tensor(-0.9617-0.9644j, dtype=torch.complex1...","[[tensor(1.8353e-153-1.1318e-153j, dtype=torch...","[[tensor(4.7301e-135+1.1359e-135j, dtype=torch...","[[tensor(1.3989e-132+5.1728e-133j, dtype=torch...","[[tensor(-0.1421+0.5318j, dtype=torch.complex1...","[[tensor(-1.8375-1.8440j, dtype=torch.complex1...","[[tensor(-6.2447e-134-1.2458e-133j, dtype=torc...","[[tensor(4.7317e-144+8.3788e-145j, dtype=torch...","[[tensor(0.1092-1.5870j, dtype=torch.complex12..."
97,"[[tensor(0.2988-0.5356j, dtype=torch.complex12...","[[tensor(0.2767-0.3987j, dtype=torch.complex12...","[[tensor(-1.1543-0.3206j, dtype=torch.complex1...","[[tensor(-0.1231-0.3255j, dtype=torch.complex1...","[[tensor(0.2651-0.4744j, dtype=torch.complex12...","[[tensor(-0.5378+1.0899j, dtype=torch.complex1...","[[tensor(1.3002+0.2406j, dtype=torch.complex12...","[[tensor(0.1588+0.8089j, dtype=torch.complex12...","[[tensor(0.7797-0.3146j, dtype=torch.complex12...","[[tensor(0.2387-0.0212j, dtype=torch.complex12...","[[tensor(-8.2888e-139+2.4874e-137j, dtype=torc...","[[tensor(-7.3419e-141+3.9877e-141j, dtype=torc...","[[tensor(4.2468e-138+3.1632e-136j, dtype=torch...","[[tensor(-1.0083-0.2633j, dtype=torch.complex1...","[[tensor(1.6228+2.3862j, dtype=torch.complex12...","[[tensor(-8.3534e-136-9.1173e-135j, dtype=torc...","[[tensor(0.3454-0.1749j, dtype=torch.complex12...","[[tensor(9.4920e-136+3.3639e-135j, dtype=torch...","[[tensor(1.5973+0.5338j, dtype=torch.complex12...","[[tensor(-5.6862e-137+9.2612e-137j, dtype=torc..."
98,"[[tensor(0.4059+0.6422j, dtype=torch.complex12...","[[tensor(0.8847+0.8704j, dtype=torch.complex12...","[[tensor(-1.0500+0.1156j, dtype=torch.complex1...","[[tensor(-0.1382-0.4882j, dtype=torch.complex1...","[[tensor(-0.4448-0.7660j, dtype=torch.complex1...","[[tensor(-0.1140-0.1818j, dtype=torch.complex1...","[[tensor(0.4456+0.1822j, dtype=torch.complex12...","[[tensor(-1.1319+0.4771j, dtype=torch.complex1...","[[tensor(-0.3598+0.4189j, dtype=torch.complex1...","[[tensor(-1.3604+0.9766j, dtype=torch.complex1...","[[tensor(-4.4157e-140+4.0215e-140j, dtype=torc...","[[tensor(0.0489-0.0732j, dtype=torch.complex12...","[[tensor(-1.6792+0.1214j, dtype=torch.complex1...","[[tensor(-1.2866e-140-2.8609e-140j, dtype=torc...","[[tensor(4.1239e-138+4.0187e-138j, dtype=torch...","[[tensor(-2.9156e-139-1.0618e-137j, dtype=torc...","[[tensor(-0.1478-0.0582j, dtype=torch.complex1...","[[tensor(-1.1418e-137+2.3040e-138j, dtype=torc...","[[tensor(9.7275e-139+3.1383e-139j, dtype=torch...","[[tensor(0.4272-1.0651j, dtype=torch.complex12..."


In [7]:
# Define the dataset dataframe
def complex_tensor():
    real = torch.randn(2, 8)
    imag = torch.randn(2, 8)
    return real + 1j * imag

# Create the DataFrame
H_df = pd.DataFrame([[complex_tensor() for _ in range(4)] for _ in range(10)])

def complex_tensor():
    real = torch.randn(8, 2)
    imag = torch.randn(8, 2)
    return real + 1j * imag

# Create the DataFrame
V_df = pd.DataFrame([[complex_tensor() for _ in range(4)] for _ in range(10)])

dataset = pd.concat([H_df, V_df], axis=1)

In [28]:
dataset

Unnamed: 0,0,1,2,3,0.1,1.1,2.1,3.1
0,"[[tensor(-0.0671-0.7584j), tensor(0.7654-1.669...","[[tensor(-0.0850+2.0890j), tensor(0.6860-1.006...","[[tensor(-1.4190-0.3194j), tensor(-1.3981-0.36...","[[tensor(-2.0406+1.2129j), tensor(-0.7994+0.01...","[[tensor(-2.2761-1.4484j), tensor(1.2999+2.947...","[[tensor(-1.2215+0.9061j), tensor(-0.2374-0.64...","[[tensor(0.3742+0.6502j), tensor(0.4598-0.5954...","[[tensor(1.7330-1.3346j), tensor(0.8281+1.9333..."
1,"[[tensor(-2.3240+1.3299j), tensor(-1.3996+0.09...","[[tensor(0.9684-1.7266j), tensor(-0.6797-1.840...","[[tensor(1.7691-0.2782j), tensor(-0.9194-0.153...","[[tensor(-1.2872+0.0022j), tensor(-1.2818+1.72...","[[tensor(-0.1527+0.1605j), tensor(0.6668+1.747...","[[tensor(0.8145+2.4532j), tensor(0.5506-0.3737...","[[tensor(0.8166-0.2685j), tensor(3.1040-0.6226...","[[tensor(-0.4254+2.0138j), tensor(0.5117+1.658..."
2,"[[tensor(0.8616-1.6796j), tensor(0.0656-0.5628...","[[tensor(-0.4451+0.3902j), tensor(-1.4060-0.46...","[[tensor(-0.3006-2.9138j), tensor(0.8641+0.980...","[[tensor(-0.9107+2.1415j), tensor(1.9554+0.650...","[[tensor(0.5721+1.4493j), tensor(-0.4224+0.954...","[[tensor(-2.3652+1.0367j), tensor(-0.8970-0.12...","[[tensor(0.9812+1.7056j), tensor(0.7562+0.0377...","[[tensor(-0.7285-1.4143j), tensor(0.0571-0.711..."
3,"[[tensor(0.5128+2.2246j), tensor(-0.7010-1.415...","[[tensor(-1.1856-0.4819j), tensor(0.6214-1.416...","[[tensor(-0.1885-1.3662j), tensor(-0.7833+0.60...","[[tensor(-0.6910+0.2716j), tensor(-0.6368-0.91...","[[tensor(0.5218-0.5860j), tensor(0.1351-0.2889...","[[tensor(1.2143+0.7065j), tensor(-0.1696+0.120...","[[tensor(-0.1198-1.0393j), tensor(-0.8785+2.16...","[[tensor(-0.8914+1.4690j), tensor(0.8811+0.652..."
4,"[[tensor(-1.1951-0.9662j), tensor(1.7508+1.108...","[[tensor(0.1701+0.3513j), tensor(-0.5437-0.729...","[[tensor(0.7469-0.8845j), tensor(0.9412+0.3307...","[[tensor(-0.0068+0.3495j), tensor(0.3833+0.256...","[[tensor(0.1751-0.2768j), tensor(0.2558-1.0891...","[[tensor(-0.7947-1.2705j), tensor(-0.2561-0.26...","[[tensor(-0.6154-0.4045j), tensor(0.8162+0.045...","[[tensor(-0.6924-0.2615j), tensor(1.9418+0.535..."
5,"[[tensor(-1.2995-2.5753j), tensor(0.2858-2.326...","[[tensor(0.6334-0.2164j), tensor(-2.9397-0.008...","[[tensor(0.4459+1.3150j), tensor(0.4346+1.4968...","[[tensor(-0.3480+0.3138j), tensor(-0.5462+0.45...","[[tensor(-0.1475-1.3746j), tensor(0.2122+0.293...","[[tensor(1.0587-0.1317j), tensor(0.9242+0.4298...","[[tensor(-0.1584+1.2031j), tensor(0.7313+0.937...","[[tensor(0.1070-0.8492j), tensor(1.1959+0.3698..."
6,"[[tensor(-0.2393-0.1396j), tensor(-0.0497+2.16...","[[tensor(-0.3758+1.2095j), tensor(0.3108+0.605...","[[tensor(-0.5330-0.2716j), tensor(-0.1520+0.48...","[[tensor(-0.9080+0.7858j), tensor(0.3499-0.972...","[[tensor(0.5895+2.2349j), tensor(0.2653-3.3129...","[[tensor(1.5641-1.3597j), tensor(0.5240+0.8918...","[[tensor(-2.5834-0.5254j), tensor(1.2642+0.369...","[[tensor(-0.2199-0.6399j), tensor(-0.8450+1.34..."
7,"[[tensor(0.2745+0.4996j), tensor(0.4409-0.3289...","[[tensor(1.4717-1.4084j), tensor(0.8559+0.8322...","[[tensor(-0.7999-1.4514j), tensor(1.3643-0.084...","[[tensor(0.6672-0.8128j), tensor(-1.0016+0.993...","[[tensor(1.8183+0.4838j), tensor(-0.3537-0.831...","[[tensor(0.8323+1.6676j), tensor(0.4486+1.1067...","[[tensor(-0.4147-0.1938j), tensor(-0.0886-0.66...","[[tensor(-1.6282-0.8072j), tensor(-0.2779+1.31..."
8,"[[tensor(1.1549+1.7577j), tensor(-0.0731+0.357...","[[tensor(0.4737-0.2838j), tensor(-0.3717-0.338...","[[tensor(-0.4688+0.0398j), tensor(0.7929+0.196...","[[tensor(1.8209-0.4706j), tensor(-0.0160+0.742...","[[tensor(2.3053+0.5111j), tensor(-0.5735+2.072...","[[tensor(-0.9143+0.5262j), tensor(-0.8846+0.15...","[[tensor(0.8569-1.2288j), tensor(-1.1770-1.180...","[[tensor(0.5714+0.8075j), tensor(-0.0329-0.234..."
9,"[[tensor(-1.2501+1.1638j), tensor(0.2655-1.256...","[[tensor(0.1878-1.6681j), tensor(1.4101+0.5654...","[[tensor(-0.1287+0.3333j), tensor(0.6849+0.567...","[[tensor(0.0955-1.0764j), tensor(0.8175-0.4965...","[[tensor(-0.4266+0.3778j), tensor(0.8342+0.051...","[[tensor(-0.0540-0.5140j), tensor(0.0464-1.357...","[[tensor(1.1570-1.0600j), tensor(-3.0951+0.478...","[[tensor(1.6532-1.2396j), tensor(2.5600-0.1069..."


In [60]:
# Supervised training
tr.train_supervised(dataset=dset, num_epochs=1000, batch_size=2, lr=0.0001)

Epoch 1/1000, Loss: 0.6345232129096985
Epoch 2/1000, Loss: 0.6296367049217224
Epoch 3/1000, Loss: 0.6251901388168335
Epoch 4/1000, Loss: 0.6206528544425964
Epoch 5/1000, Loss: 0.6151571869850159
Epoch 6/1000, Loss: 0.6094835996627808
Epoch 7/1000, Loss: 0.6036584973335266
Epoch 8/1000, Loss: 0.5979498624801636
Epoch 9/1000, Loss: 0.5921768546104431
Epoch 10/1000, Loss: 0.5857403874397278
Epoch 11/1000, Loss: 0.5785773992538452
Epoch 12/1000, Loss: 0.5707677006721497
Epoch 13/1000, Loss: 0.5608784556388855
Epoch 14/1000, Loss: 0.551145076751709
Epoch 15/1000, Loss: 0.540895938873291
Epoch 16/1000, Loss: 0.5297767519950867
Epoch 17/1000, Loss: 0.5182663798332214


KeyboardInterrupt: 

In [61]:
# Unsupervised training
tr.train_unsupervised(dataset=dset, num_epochs=100, batch_size=2, lr=1e-3)

tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBackward0>)
tensor(0., dtype=torch.float64, grad_fn=<MulBack

KeyboardInterrupt: 