In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from comet_ml import Experiment
from comet_ml.integration.pytorch import log_model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import pytz
import numpy as np

from models.gan_models import *
import yfinance as yf
import MetaTrader5 as mt5

import os

In [4]:
#api_key = os.environ['COMET_API_KEY']
api_key = 'vpNJF6XOWcHS6HqH9ZFEjwRcD'

In [21]:
experiment = Experiment(
  api_key=api_key,
  project_name="tail-price",
  workspace="artaasd95"
)

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/artaasd95/tail-price/6f4c8d0eeff640c9b0a362f65b9c829b



In [21]:
mt5.initialize()

print('loading current tf data')
utc_from = datetime(2021, 1, 1, tzinfo=pytz.timezone("Asia/Nicosia"))
utc_to = datetime.now(pytz.timezone("Asia/Nicosia"))

data = mt5.copy_rates_range('XAUUSD', mt5.TIMEFRAME_H4, utc_from, utc_to)
data = pd.DataFrame(data)
time_data = data.time
data.drop(columns=['tick_volume', 'spread', 'real_volume'], inplace=True)

loading current tf data


In [23]:
data.to_csv('xau_2021_H4.csv')

In [6]:
data = pd.read_csv('xau_2021_H4.csv')

In [7]:
data

Unnamed: 0.1,Unnamed: 0,time,open,high,low,close
0,0,1609718400,1904.48,1918.60,1900.62,1916.57
1,1,1609732800,1916.65,1925.15,1915.44,1921.56
2,2,1609747200,1921.52,1935.09,1921.34,1932.03
3,3,1609761600,1931.95,1942.06,1927.79,1939.75
4,4,1609776000,1939.78,1944.33,1929.23,1936.92
...,...,...,...,...,...,...
5582,5582,1723780800,2458.38,2459.97,2450.72,2452.68
5583,5583,1723795200,2452.68,2464.50,2451.11,2462.32
5584,5584,1723809600,2462.32,2492.34,2461.17,2491.49
5585,5585,1723824000,2491.46,2500.08,2477.46,2495.82


In [22]:
input_size = 1
hidden_size = 1024
seq_length = 50
num_layers = 12
batch_size = 32
num_epochs = 100
learning_rate = 0.3

# Initialize models
main_gen = MainGenerator(input_size, hidden_size, num_layers, 0.5, 1)
noise_gen = NoiseGenerator(input_size, 2048, 1, 0.4)
discriminator = Discriminator(input_size, hidden_size, 15, 0.4)

In [23]:
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

In [24]:
main_gen.to(device)
noise_gen.to(device)
discriminator.to(device)

Discriminator(
  (lstm): LSTM(1, 1024, num_layers=15, batch_first=True, dropout=0.4)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=1, bias=True)
  (leaky_relu): LeakyReLU(negative_slope=0.2)
  (sigmoid): Sigmoid()
)

In [25]:
optimizer_G = torch.optim.AdamW(list(main_gen.parameters()) + list(noise_gen.parameters()), lr=learning_rate)
optimizer_D = torch.optim.AdamW(discriminator.parameters(), lr=learning_rate)
optimizer_noise = torch.optim.AdamW(noise_gen.parameters(), lr=learning_rate)

optim_g_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer_G, 0.1)
optim_d_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer_D, 0.1)
optim_noise_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer_noise, 0.1)

In [26]:
cauchy_dist = torch.distributions.cauchy.Cauchy(loc=0, scale=0.5)

In [27]:
#adversarial_loss = F.kl_div
adversarial_loss = F.mse_loss

In [28]:
train_data = data[:int(len(data)*0.7)]
train_data = train_data.reset_index(drop=True)
test_data = data[int(len(data)*0.7):]
test_data = test_data.reset_index(drop=True)

In [15]:
# Training loop
for epoch in tqdm(range(num_epochs)):
    batch_start = 0
    total_loss_list = []
    g_loss_list = []
    noise_loss_list = []
    d_loss_list = []
    for idx in range(int(len(train_data)/batch_size)-1):
        price = torch.tensor(train_data[batch_start:batch_start+batch_size].close.values, dtype=torch.float).reshape(batch_size,1)
        price = price.to(device)
        batch_start = batch_start + batch_size
        # Generate random noise inputs
        z1 = cauchy_dist.sample([batch_size, 1])
        z1 = z1.to(device)
        #z2 = cauchy_dist.sample([1])
        
        # Generate fake data
        fake_main = main_gen(price)
        fake_noise = noise_gen(torch.rand([batch_size, 1], device=device))
        fake_data = fake_main + fake_noise
        
        # Train discriminator
        optimizer_D.zero_grad()
        #d_loss = adversarial_loss(F.log_softmax(discriminator(fake_data.detach()), dim=0), F.softmax(price, dim=0), reduction='batchmean')
        d_loss = adversarial_loss(discriminator(fake_data.detach()), fake_data)
        d_loss.backward()
        optimizer_D.step()
        
        # Train generators
        optimizer_G.zero_grad()
        fake_main = main_gen(price)
        fake_noise = noise_gen(z1)
        fake_data = fake_main + fake_noise
        #d_loss = adversarial_loss(F.log_softmax(discriminator(fake_data), dim=0), F.softmax(price, dim=0), reduction='batchmean')
        #g_loss = adversarial_loss(F.log_softmax(fake_data, dim=0), F.softmax(price, dim=0), reduction='batchmean')
        #noise_loss = adversarial_loss(F.log_softmax(fake_noise, dim=0), F.softmax(z1, dim=0), reduction='batchmean')
        d_loss = adversarial_loss(discriminator(fake_data), fake_data)
        g_loss = adversarial_loss(price, fake_data)
        noise_loss = adversarial_loss(fake_noise, z1)
        total_loss = 0.5 * ((0.8 * g_loss + 0.2 * noise_loss) + d_loss)
        total_loss.backward()
        optimizer_G.step()
        optimizer_noise.step()

        total_loss_list.append(total_loss.item())
        g_loss_list.append(g_loss.item())
        noise_loss_list.append(noise_loss.item())
        d_loss_list.append(d_loss.item())

    optim_g_sched.step()
    optim_d_sched.step()
    optim_noise_sched.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}]  Discriminator Loss: {d_loss.item():.4f}  Generator Loss: {g_loss.item():.4f}")
    experiment.log_metric('Main Generator\Train', np.average(g_loss_list), step=epoch)
    experiment.log_metric('Noise Generator\Train', np.average(noise_loss_list), step=epoch)
    experiment.log_metric('Discriminator\Train', np.average(d_loss_list), step=epoch)
    experiment.log_metric('Total\Train', np.average(total_loss_list), step=epoch)


  1%|          | 1/100 [00:46<1:15:59, 46.06s/it]

Epoch [1/100]  Discriminator Loss: 69.8172  Generator Loss: 13681.5391


  2%|▏         | 2/100 [01:31<1:14:26, 45.58s/it]

Epoch [2/100]  Discriminator Loss: 81.4149  Generator Loss: 16572.7852


  3%|▎         | 3/100 [02:16<1:13:26, 45.43s/it]

Epoch [3/100]  Discriminator Loss: 16.2975  Generator Loss: 16825.8047


  4%|▍         | 4/100 [03:01<1:12:35, 45.37s/it]

Epoch [4/100]  Discriminator Loss: 8.6734  Generator Loss: 16825.0859


  5%|▌         | 5/100 [03:47<1:11:45, 45.33s/it]

Epoch [5/100]  Discriminator Loss: 8.2481  Generator Loss: 16821.0332


  6%|▌         | 6/100 [04:32<1:11:04, 45.37s/it]

Epoch [6/100]  Discriminator Loss: 8.3827  Generator Loss: 16822.1230


  7%|▋         | 7/100 [05:17<1:10:19, 45.37s/it]

Epoch [7/100]  Discriminator Loss: 8.3902  Generator Loss: 16821.8613


  8%|▊         | 8/100 [06:03<1:09:31, 45.34s/it]

Epoch [8/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


  9%|▉         | 9/100 [06:48<1:08:50, 45.39s/it]

Epoch [9/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 10%|█         | 10/100 [07:33<1:07:59, 45.32s/it]

Epoch [10/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 11%|█         | 11/100 [08:19<1:07:11, 45.30s/it]

Epoch [11/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 12%|█▏        | 12/100 [09:04<1:06:27, 45.31s/it]

Epoch [12/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 13%|█▎        | 13/100 [09:49<1:05:42, 45.32s/it]

Epoch [13/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 14%|█▍        | 14/100 [10:35<1:04:55, 45.30s/it]

Epoch [14/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 15%|█▌        | 15/100 [11:20<1:04:08, 45.27s/it]

Epoch [15/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 16%|█▌        | 16/100 [12:05<1:03:18, 45.22s/it]

Epoch [16/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 17%|█▋        | 17/100 [12:50<1:02:31, 45.20s/it]

Epoch [17/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 18%|█▊        | 18/100 [13:35<1:01:46, 45.20s/it]

Epoch [18/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 19%|█▉        | 19/100 [14:20<1:01:02, 45.22s/it]

Epoch [19/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 20%|██        | 20/100 [15:06<1:00:22, 45.28s/it]

Epoch [20/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 21%|██        | 21/100 [15:51<59:40, 45.32s/it]  

Epoch [21/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 22%|██▏       | 22/100 [16:37<58:56, 45.33s/it]

Epoch [22/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 23%|██▎       | 23/100 [17:22<58:07, 45.30s/it]

Epoch [23/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 24%|██▍       | 24/100 [18:07<57:18, 45.24s/it]

Epoch [24/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 25%|██▌       | 25/100 [18:52<56:39, 45.32s/it]

Epoch [25/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 26%|██▌       | 26/100 [19:38<55:54, 45.32s/it]

Epoch [26/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 27%|██▋       | 27/100 [20:23<55:10, 45.35s/it]

Epoch [27/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 28%|██▊       | 28/100 [21:09<54:30, 45.42s/it]

Epoch [28/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 29%|██▉       | 29/100 [21:54<53:44, 45.41s/it]

Epoch [29/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 30%|███       | 30/100 [22:39<52:56, 45.38s/it]

Epoch [30/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 31%|███       | 31/100 [23:25<52:08, 45.33s/it]

Epoch [31/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 32%|███▏      | 32/100 [24:10<51:19, 45.29s/it]

Epoch [32/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 33%|███▎      | 33/100 [24:55<50:29, 45.22s/it]

Epoch [33/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 34%|███▍      | 34/100 [25:40<49:38, 45.12s/it]

Epoch [34/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 35%|███▌      | 35/100 [26:25<48:50, 45.09s/it]

Epoch [35/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 36%|███▌      | 36/100 [27:10<48:03, 45.06s/it]

Epoch [36/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 37%|███▋      | 37/100 [27:55<47:19, 45.07s/it]

Epoch [37/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 38%|███▊      | 38/100 [28:40<46:34, 45.07s/it]

Epoch [38/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 39%|███▉      | 39/100 [29:25<45:44, 44.99s/it]

Epoch [39/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 40%|████      | 40/100 [30:10<44:55, 44.93s/it]

Epoch [40/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 41%|████      | 41/100 [30:55<44:10, 44.92s/it]

Epoch [41/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 42%|████▏     | 42/100 [31:40<43:29, 45.00s/it]

Epoch [42/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 43%|████▎     | 43/100 [32:25<42:47, 45.04s/it]

Epoch [43/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 44%|████▍     | 44/100 [33:10<42:06, 45.11s/it]

Epoch [44/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 45%|████▌     | 45/100 [33:55<41:20, 45.11s/it]

Epoch [45/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 46%|████▌     | 46/100 [34:40<40:34, 45.08s/it]

Epoch [46/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 47%|████▋     | 47/100 [35:25<39:52, 45.13s/it]

Epoch [47/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 48%|████▊     | 48/100 [36:10<39:04, 45.08s/it]

Epoch [48/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 49%|████▉     | 49/100 [36:55<38:15, 45.02s/it]

Epoch [49/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 50%|█████     | 50/100 [37:40<37:24, 44.89s/it]

Epoch [50/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 51%|█████     | 51/100 [38:25<36:37, 44.85s/it]

Epoch [51/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 52%|█████▏    | 52/100 [39:10<35:53, 44.86s/it]

Epoch [52/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 53%|█████▎    | 53/100 [39:54<35:08, 44.87s/it]

Epoch [53/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 54%|█████▍    | 54/100 [40:39<34:24, 44.87s/it]

Epoch [54/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 55%|█████▌    | 55/100 [41:24<33:38, 44.86s/it]

Epoch [55/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 56%|█████▌    | 56/100 [42:09<32:56, 44.92s/it]

Epoch [56/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 57%|█████▋    | 57/100 [42:54<32:13, 44.96s/it]

Epoch [57/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 58%|█████▊    | 58/100 [43:39<31:29, 45.00s/it]

Epoch [58/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 59%|█████▉    | 59/100 [44:25<30:49, 45.12s/it]

Epoch [59/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 60%|██████    | 60/100 [45:10<30:08, 45.22s/it]

Epoch [60/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 61%|██████    | 61/100 [45:56<29:26, 45.30s/it]

Epoch [61/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 62%|██████▏   | 62/100 [46:41<28:43, 45.34s/it]

Epoch [62/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 63%|██████▎   | 63/100 [47:27<27:57, 45.35s/it]

Epoch [63/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 64%|██████▍   | 64/100 [48:12<27:16, 45.45s/it]

Epoch [64/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 65%|██████▌   | 65/100 [48:58<26:34, 45.55s/it]

Epoch [65/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 66%|██████▌   | 66/100 [49:44<25:51, 45.63s/it]

Epoch [66/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 67%|██████▋   | 67/100 [50:29<25:04, 45.59s/it]

Epoch [67/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 68%|██████▊   | 68/100 [51:15<24:18, 45.59s/it]

Epoch [68/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 69%|██████▉   | 69/100 [52:00<23:31, 45.54s/it]

Epoch [69/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 70%|███████   | 70/100 [52:46<22:43, 45.46s/it]

Epoch [70/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 71%|███████   | 71/100 [53:31<21:58, 45.48s/it]

Epoch [71/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 72%|███████▏  | 72/100 [54:17<21:14, 45.51s/it]

Epoch [72/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 73%|███████▎  | 73/100 [55:02<20:29, 45.53s/it]

Epoch [73/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 74%|███████▍  | 74/100 [55:48<19:42, 45.46s/it]

Epoch [74/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 75%|███████▌  | 75/100 [56:33<18:56, 45.46s/it]

Epoch [75/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 76%|███████▌  | 76/100 [57:18<18:10, 45.43s/it]

Epoch [76/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 77%|███████▋  | 77/100 [58:04<17:25, 45.46s/it]

Epoch [77/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 78%|███████▊  | 78/100 [58:50<16:42, 45.55s/it]

Epoch [78/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 79%|███████▉  | 79/100 [59:35<15:56, 45.53s/it]

Epoch [79/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 80%|████████  | 80/100 [1:00:21<15:10, 45.51s/it]

Epoch [80/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 81%|████████  | 81/100 [1:01:06<14:25, 45.54s/it]

Epoch [81/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 82%|████████▏ | 82/100 [1:01:52<13:40, 45.56s/it]

Epoch [82/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 83%|████████▎ | 83/100 [1:02:37<12:54, 45.58s/it]

Epoch [83/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 84%|████████▍ | 84/100 [1:03:23<12:08, 45.54s/it]

Epoch [84/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 85%|████████▌ | 85/100 [1:04:08<11:22, 45.53s/it]

Epoch [85/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 86%|████████▌ | 86/100 [1:04:54<10:38, 45.59s/it]

Epoch [86/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 87%|████████▋ | 87/100 [1:05:40<09:53, 45.63s/it]

Epoch [87/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 88%|████████▊ | 88/100 [1:06:25<09:07, 45.62s/it]

Epoch [88/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 89%|████████▉ | 89/100 [1:07:11<08:21, 45.56s/it]

Epoch [89/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 90%|█████████ | 90/100 [1:07:57<07:35, 45.60s/it]

Epoch [90/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 91%|█████████ | 91/100 [1:08:42<06:50, 45.57s/it]

Epoch [91/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 92%|█████████▏| 92/100 [1:09:28<06:05, 45.66s/it]

Epoch [92/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 93%|█████████▎| 93/100 [1:10:13<05:19, 45.60s/it]

Epoch [93/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 94%|█████████▍| 94/100 [1:10:59<04:33, 45.61s/it]

Epoch [94/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 95%|█████████▌| 95/100 [1:11:45<03:48, 45.70s/it]

Epoch [95/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 96%|█████████▌| 96/100 [1:12:30<03:02, 45.60s/it]

Epoch [96/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 97%|█████████▋| 97/100 [1:13:16<02:16, 45.60s/it]

Epoch [97/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 98%|█████████▊| 98/100 [1:14:01<01:30, 45.49s/it]

Epoch [98/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


 99%|█████████▉| 99/100 [1:14:46<00:45, 45.43s/it]

Epoch [99/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516


100%|██████████| 100/100 [1:15:32<00:00, 45.32s/it]

Epoch [100/100]  Discriminator Loss: 8.3903  Generator Loss: 16821.8516





In [18]:
# Test loop
batch_start = 0
for idx in range(int(len(test_data)/batch_size)-1):
    price = torch.tensor(test_data[batch_start:batch_start+batch_size].close.values, dtype=torch.float).reshape(batch_size,1)
    price = price.to(device)
    batch_start = batch_start + batch_size
    # Generate random noise inputs
    z1 = cauchy_dist.sample([batch_size, 1])
    z1 = z1.to(device)
    #z2 = cauchy_dist.sample([1])
    main_gen.eval()
    noise_gen.eval()
    discriminator.eval()
    # Generate fake data
    fake_main = main_gen(price)
    fake_noise = noise_gen(torch.rand([batch_size, 1], device=device))
    fake_data = fake_main + fake_noise


    #d_loss = adversarial_loss(F.log_softmax(discriminator(fake_data), dim=0), F.softmax(price, dim=0), reduction='batchmean')
    #g_loss = adversarial_loss(F.log_softmax(fake_data, dim=0), F.softmax(price, dim=0), reduction='batchmean')
    #noise_loss = adversarial_loss(F.log_softmax(fake_noise, dim=0), F.softmax(z1, dim=0), reduction='batchmean')
    d_loss = adversarial_loss(discriminator(fake_data), price)
    g_loss = adversarial_loss(price, fake_data)
    noise_loss = adversarial_loss(fake_noise, z1)
    total_loss = 0.5 * ((0.8 * g_loss + 0.2 * noise_loss) + d_loss)
    optimizer_G.step()
    optimizer_noise.step()

optim_g_sched.step()
optim_d_sched.step()
optim_noise_sched.step()

print(f"Epoch [{epoch+1}/{num_epochs}]  Discriminator Loss: {d_loss.item():.4f}  Generator Loss: {g_loss.item():.4f}")
experiment.log_metric('Main Generator\Test', g_loss, step=epoch)
experiment.log_metric('Noise Generator\Test', noise_loss, step=epoch)
experiment.log_metric('Discriminator\Test', d_loss, step=epoch)
experiment.log_metric('Total\Test', total_loss, step=epoch)
# log_model(experiment, model=main_gen, model_name="Main Generator")
# log_model(experiment, model=noise_gen, model_name="Noise Generator")
# log_model(experiment, model=discriminator, model_name="Discriminator")


Epoch [100/100]  Discriminator Loss: 385699.9375  Generator Loss: 382206.8125


In [20]:
experiment.end()

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/artaasd95/tail-price/29faec1725b4493eae8c3f8cb9469e86
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     Discriminator\Test          : 385699.9375
[1;38;5;39mCOMET INFO:[0m     Discriminator\Train [100]   : (8.331120317632502, 308944.53143544035)
[1;38;5;39mCOMET INFO:[0m     Main Generator\Test         : 382206.8125
[1;38;5;39mCOMET INFO:[0m     Main Generator\Train [100]  : (13899.041346526343, 328724.9913884943)
[1;38;5;39mCOMET INFO:[0m     Noise Generator\Test        : 18.31653022766

In [30]:
# Training loop
for epoch in tqdm(range(num_epochs)):
    batch_start = 0
    total_loss_list = []
    g_loss_list = []
    noise_loss_list = []
    d_loss_list = []
    for idx in range(int(len(data)/batch_size)-1):
        price = torch.tensor(data[batch_start:batch_start+batch_size].close.values, dtype=torch.float).reshape(batch_size,1)
        price = price.to(device)
        batch_start = batch_start + batch_size
        # Generate random noise inputs
        z1 = cauchy_dist.sample([batch_size, 1])
        z1 = z1.to(device)
        #z2 = cauchy_dist.sample([1])
        
        # Generate fake data
        fake_main = main_gen(price)
        fake_noise = noise_gen(torch.rand([batch_size, 1], device=device))
        fake_data = fake_main + fake_noise
        
        # Train discriminator
        optimizer_D.zero_grad()
        #d_loss = adversarial_loss(F.log_softmax(discriminator(fake_data.detach()), dim=0), F.softmax(price, dim=0), reduction='batchmean')
        d_loss = adversarial_loss(discriminator(fake_data.detach()), fake_data)
        d_loss.backward()
        optimizer_D.step()
        
        # Train generators
        optimizer_G.zero_grad()
        fake_main = main_gen(price)
        fake_noise = noise_gen(z1)
        fake_data = fake_main + fake_noise
        #d_loss = adversarial_loss(F.log_softmax(discriminator(fake_data), dim=0), F.softmax(price, dim=0), reduction='batchmean')
        #g_loss = adversarial_loss(F.log_softmax(fake_data, dim=0), F.softmax(price, dim=0), reduction='batchmean')
        #noise_loss = adversarial_loss(F.log_softmax(fake_noise, dim=0), F.softmax(z1, dim=0), reduction='batchmean')
        d_loss = adversarial_loss(discriminator(fake_data), fake_data)
        g_loss = adversarial_loss(price, fake_data)
        noise_loss = adversarial_loss(fake_noise, z1)
        total_loss = 0.5 * ((0.8 * g_loss + 0.2 * noise_loss) + d_loss)
        total_loss.backward()
        optimizer_G.step()
        optimizer_noise.step()

        total_loss_list.append(total_loss.item())
        g_loss_list.append(g_loss.item())
        noise_loss_list.append(noise_loss.item())
        d_loss_list.append(d_loss.item())

    optim_g_sched.step()
    optim_d_sched.step()
    optim_noise_sched.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}]  Discriminator Loss: {d_loss.item():.4f}  Generator Loss: {g_loss.item():.4f}")
    experiment.log_metric('Main Generator\Whole Train', np.average(g_loss_list), step=epoch)
    experiment.log_metric('Noise Generator\Whole Train', np.average(noise_loss_list), step=epoch)
    experiment.log_metric('Discriminator\Whole Train', np.average(d_loss_list), step=epoch)
    experiment.log_metric('Total\Whole Train', np.average(total_loss_list), step=epoch)


  1%|          | 1/100 [01:05<1:47:52, 65.38s/it]

Epoch [1/100]  Discriminator Loss: 452.1952  Generator Loss: 9682.1543


  2%|▏         | 2/100 [02:09<1:45:57, 64.88s/it]

Epoch [2/100]  Discriminator Loss: 62.6394  Generator Loss: 111821.8438


  3%|▎         | 3/100 [03:14<1:44:48, 64.83s/it]

Epoch [3/100]  Discriminator Loss: 7.2478  Generator Loss: 181671.6562


  4%|▍         | 4/100 [04:19<1:43:40, 64.80s/it]

Epoch [4/100]  Discriminator Loss: 3.0179  Generator Loss: 187927.0469


  5%|▌         | 5/100 [05:24<1:42:47, 64.92s/it]

Epoch [5/100]  Discriminator Loss: 0.3189  Generator Loss: 188601.2188


  6%|▌         | 6/100 [06:29<1:41:49, 64.99s/it]

Epoch [6/100]  Discriminator Loss: 0.3002  Generator Loss: 188680.7031


  7%|▋         | 7/100 [07:35<1:40:53, 65.10s/it]

Epoch [7/100]  Discriminator Loss: 0.2969  Generator Loss: 188687.4375


  8%|▊         | 8/100 [08:39<1:39:24, 64.83s/it]

Epoch [8/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


  9%|▉         | 9/100 [09:43<1:38:13, 64.77s/it]

Epoch [9/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 10%|█         | 10/100 [10:48<1:36:55, 64.62s/it]

Epoch [10/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


[1;38;5;196mCOMET ERROR:[0m Heartbeat processing error
 11%|█         | 11/100 [11:52<1:35:31, 64.40s/it]

Epoch [11/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 12%|█▏        | 12/100 [12:56<1:34:16, 64.28s/it]

Epoch [12/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 13%|█▎        | 13/100 [13:59<1:32:44, 63.96s/it]

Epoch [13/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 14%|█▍        | 14/100 [15:03<1:31:37, 63.92s/it]

Epoch [14/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 15%|█▌        | 15/100 [16:07<1:30:39, 64.00s/it]

Epoch [15/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 16%|█▌        | 16/100 [17:08<1:28:18, 63.08s/it]

Epoch [16/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 17%|█▋        | 17/100 [18:08<1:26:00, 62.17s/it]

Epoch [17/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 18%|█▊        | 18/100 [19:08<1:24:02, 61.50s/it]

Epoch [18/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 19%|█▉        | 19/100 [20:08<1:22:23, 61.04s/it]

Epoch [19/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 20%|██        | 20/100 [21:08<1:20:59, 60.75s/it]

Epoch [20/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 21%|██        | 21/100 [22:08<1:19:50, 60.65s/it]

Epoch [21/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 22%|██▏       | 22/100 [23:08<1:18:42, 60.54s/it]

Epoch [22/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 23%|██▎       | 23/100 [24:09<1:17:40, 60.52s/it]

Epoch [23/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 24%|██▍       | 24/100 [25:09<1:16:33, 60.44s/it]

Epoch [24/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 25%|██▌       | 25/100 [26:10<1:15:29, 60.40s/it]

Epoch [25/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 26%|██▌       | 26/100 [27:09<1:14:18, 60.25s/it]

Epoch [26/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 27%|██▋       | 27/100 [28:10<1:13:24, 60.33s/it]

Epoch [27/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 28%|██▊       | 28/100 [29:10<1:12:22, 60.31s/it]

Epoch [28/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 29%|██▉       | 29/100 [30:10<1:11:13, 60.19s/it]

Epoch [29/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 30%|███       | 30/100 [31:10<1:10:08, 60.12s/it]

Epoch [30/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 31%|███       | 31/100 [32:10<1:09:05, 60.08s/it]

Epoch [31/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 32%|███▏      | 32/100 [33:10<1:08:03, 60.06s/it]

Epoch [32/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 33%|███▎      | 33/100 [34:10<1:07:02, 60.04s/it]

Epoch [33/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 34%|███▍      | 34/100 [35:10<1:06:02, 60.04s/it]

Epoch [34/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 35%|███▌      | 35/100 [36:10<1:05:04, 60.07s/it]

Epoch [35/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 36%|███▌      | 36/100 [37:10<1:04:05, 60.09s/it]

Epoch [36/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 37%|███▋      | 37/100 [38:11<1:03:12, 60.21s/it]

Epoch [37/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 38%|███▊      | 38/100 [39:11<1:02:06, 60.11s/it]

Epoch [38/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 39%|███▉      | 39/100 [40:11<1:01:06, 60.11s/it]

Epoch [39/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 40%|████      | 40/100 [41:11<1:00:07, 60.13s/it]

Epoch [40/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 41%|████      | 41/100 [42:11<59:06, 60.12s/it]  

Epoch [41/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 42%|████▏     | 42/100 [43:11<58:04, 60.08s/it]

Epoch [42/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 43%|████▎     | 43/100 [44:11<56:59, 59.99s/it]

Epoch [43/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 44%|████▍     | 44/100 [45:11<55:57, 59.96s/it]

Epoch [44/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 45%|████▌     | 45/100 [46:11<55:04, 60.08s/it]

Epoch [45/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 46%|████▌     | 46/100 [47:11<54:01, 60.03s/it]

Epoch [46/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 47%|████▋     | 47/100 [48:11<53:00, 60.02s/it]

Epoch [47/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 48%|████▊     | 48/100 [49:11<51:54, 59.90s/it]

Epoch [48/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 49%|████▉     | 49/100 [50:10<50:50, 59.82s/it]

Epoch [49/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 50%|█████     | 50/100 [51:10<49:50, 59.81s/it]

Epoch [50/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 51%|█████     | 51/100 [52:10<48:53, 59.88s/it]

Epoch [51/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 52%|█████▏    | 52/100 [53:11<48:03, 60.08s/it]

Epoch [52/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 53%|█████▎    | 53/100 [54:11<47:09, 60.20s/it]

Epoch [53/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 54%|█████▍    | 54/100 [55:11<46:06, 60.14s/it]

Epoch [54/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 55%|█████▌    | 55/100 [56:11<45:05, 60.11s/it]

Epoch [55/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 56%|█████▌    | 56/100 [57:11<44:03, 60.09s/it]

Epoch [56/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 57%|█████▋    | 57/100 [58:11<43:02, 60.06s/it]

Epoch [57/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 58%|█████▊    | 58/100 [59:11<42:03, 60.07s/it]

Epoch [58/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 59%|█████▉    | 59/100 [1:00:11<41:00, 60.01s/it]

Epoch [59/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 60%|██████    | 60/100 [1:01:11<40:03, 60.08s/it]

Epoch [60/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 61%|██████    | 61/100 [1:02:12<39:06, 60.17s/it]

Epoch [61/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 62%|██████▏   | 62/100 [1:03:11<38:00, 60.01s/it]

Epoch [62/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 63%|██████▎   | 63/100 [1:04:12<37:02, 60.07s/it]

Epoch [63/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 64%|██████▍   | 64/100 [1:05:12<36:04, 60.12s/it]

Epoch [64/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 65%|██████▌   | 65/100 [1:06:12<35:07, 60.22s/it]

Epoch [65/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 66%|██████▌   | 66/100 [1:07:12<34:04, 60.14s/it]

Epoch [66/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 67%|██████▋   | 67/100 [1:08:12<33:02, 60.08s/it]

Epoch [67/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 68%|██████▊   | 68/100 [1:09:12<32:01, 60.03s/it]

Epoch [68/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 69%|██████▉   | 69/100 [1:10:12<31:02, 60.07s/it]

Epoch [69/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 70%|███████   | 70/100 [1:11:12<30:02, 60.10s/it]

Epoch [70/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 71%|███████   | 71/100 [1:12:13<29:03, 60.12s/it]

Epoch [71/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 72%|███████▏  | 72/100 [1:13:13<28:04, 60.16s/it]

Epoch [72/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 73%|███████▎  | 73/100 [1:14:13<27:01, 60.07s/it]

Epoch [73/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 74%|███████▍  | 74/100 [1:15:13<26:00, 60.03s/it]

Epoch [74/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 75%|███████▌  | 75/100 [1:16:13<24:59, 59.96s/it]

Epoch [75/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 76%|███████▌  | 76/100 [1:17:12<23:58, 59.94s/it]

Epoch [76/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 77%|███████▋  | 77/100 [1:18:13<23:02, 60.10s/it]

Epoch [77/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 78%|███████▊  | 78/100 [1:19:13<22:04, 60.20s/it]

Epoch [78/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 79%|███████▉  | 79/100 [1:20:14<21:05, 60.25s/it]

Epoch [79/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 80%|████████  | 80/100 [1:21:14<20:02, 60.13s/it]

Epoch [80/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 81%|████████  | 81/100 [1:22:13<19:00, 60.05s/it]

Epoch [81/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 82%|████████▏ | 82/100 [1:23:13<17:59, 59.96s/it]

Epoch [82/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 83%|████████▎ | 83/100 [1:24:13<16:58, 59.93s/it]

Epoch [83/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 84%|████████▍ | 84/100 [1:25:13<15:59, 59.95s/it]

Epoch [84/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 85%|████████▌ | 85/100 [1:26:13<14:58, 59.91s/it]

Epoch [85/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 86%|████████▌ | 86/100 [1:27:13<13:57, 59.85s/it]

Epoch [86/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 87%|████████▋ | 87/100 [1:28:13<12:59, 59.95s/it]

Epoch [87/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 88%|████████▊ | 88/100 [1:29:13<12:00, 60.02s/it]

Epoch [88/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 89%|████████▉ | 89/100 [1:30:14<11:02, 60.22s/it]

Epoch [89/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 90%|█████████ | 90/100 [1:31:14<10:01, 60.20s/it]

Epoch [90/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 91%|█████████ | 91/100 [1:32:14<09:01, 60.17s/it]

Epoch [91/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 92%|█████████▏| 92/100 [1:33:14<08:01, 60.14s/it]

Epoch [92/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 93%|█████████▎| 93/100 [1:34:14<07:00, 60.12s/it]

Epoch [93/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 94%|█████████▍| 94/100 [1:35:14<06:00, 60.04s/it]

Epoch [94/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 95%|█████████▌| 95/100 [1:36:14<05:00, 60.01s/it]

Epoch [95/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 96%|█████████▌| 96/100 [1:37:14<04:00, 60.10s/it]

Epoch [96/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 97%|█████████▋| 97/100 [1:38:14<03:00, 60.02s/it]

Epoch [97/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 98%|█████████▊| 98/100 [1:39:14<01:59, 59.97s/it]

Epoch [98/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


 99%|█████████▉| 99/100 [1:40:14<00:59, 59.94s/it]

Epoch [99/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406


100%|██████████| 100/100 [1:41:13<00:00, 60.74s/it]

Epoch [100/100]  Discriminator Loss: 0.2971  Generator Loss: 188687.1406





In [31]:
experiment.end()

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/artaasd95/tail-price/6f4c8d0eeff640c9b0a362f65b9c829b
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     Discriminator\Whole Train [100]   : (0.29707714915275574, 101331.14510902815)
[1;38;5;39mCOMET INFO:[0m     Main Generator\Whole Train [100]  : (47296.581695380235, 226844.73195617323)
[1;38;5;39mCOMET INFO:[0m     Noise Generator\Whole Train [100] : (61.17535220611991, 2672072.2927137655)
[1;38;5;39mCOMET INFO:[0m     Total\Whole Train [100]           : (18924.90200788024, 286126.013025

In [32]:
torch.save(main_gen, 'checkpoints/main_gen_xau.pth')
torch.save(noise_gen, 'checkpoints/noise_gen_xau.pth')
torch.save(discriminator, 'checkpoints/discriminator_xau.pth')