In [1]:
from parse_data import get_data
from generator import Generator
from discriminator import Discriminator
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device: ", device)

df = get_data()

def make_data(df, device):
    
    x_train, y_train = [], []
    prev = -1
    count = 1
    m = df.max()[0]
    maxtime = df.max()[1]
    print("Max value: ", m)
    
    for row in df.values:
        x_train.append([prev/m, count/maxtime])
        y_train.append([row[0]/m])
        prev = row[0]
        count += 1
    return torch.tensor(x_train, dtype=torch.float).to(device),torch.tensor(y_train, dtype=torch.float).to(device)

make_data(df, device)

Using device:  cuda
Max value:  1000


(tensor([[-1.0000e-03,  7.7509e-08],
         [ 1.0000e+00,  1.5502e-07],
         [ 4.6000e-01,  2.3253e-07],
         ...,
         [ 1.0000e+00,  6.3170e-04],
         [ 4.6000e-01,  6.3177e-04],
         [ 5.9000e-01,  6.3185e-04]], device='cuda:0'),
 tensor([[1.0000],
         [0.4600],
         [1.0000],
         ...,
         [0.4600],
         [0.5900],
         [0.5900]], device='cuda:0'))

### *Make this cell "code" to enable warnings*
### *Only used for debugging*
import warnings
warnings.filterwarnings('error')

In [2]:
import numpy as np
import torch.optim as optim
import torch.utils.data as data
from IPython.display import clear_output
import torch.nn as nn


generator = Generator(device=device).to(device)
discriminator = Discriminator(device=device).to(device)

optimizer_g = optim.Adam(generator.parameters(), lr=0.002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.002)

loss = nn.BCELoss()

batch_size = 100
x_d, y_d = make_data(get_data(), device)
loader = data.DataLoader(data.TensorDataset(x_d,y_d), batch_size=batch_size)
epochs = 1000


true = torch.tensor([1.0]*batch_size).unsqueeze(1).to(device)
false = torch.tensor([0.0]*batch_size).unsqueeze(1).to(device)
for e in range(epochs):

    #clear_output()
    # train generator
    clear_output()

    for _ in range(20):

        generator.lstm_zero()
        
        res = []
        for x, _ in loader:
            x.to(device)
            if x.size()[0] < batch_size:
                continue
            generator.init_state()
            discriminator.random_state()
            y_pred = generator(x)
            d = discriminator(y_pred)
            res.append(loss(d, true))

        l = res[0]
        for i in res[1:]:
            l += i
        print("Generator loss: ", l.item())
        optimizer_g.zero_grad()
        l.backward()
        optimizer_g.step()


    # Train discriminator
    for _ in range(10):
    
        optimizer_d.zero_grad()
        generator.lstm_zero()
        res = []

        for x, y in loader:
            x.to(device)
            y.to(device)
            if x.size()[0] < batch_size:
                continue
            generator.init_state()
            discriminator.random_state()
            y_pred = generator(x).detach()
            res.append(loss(discriminator(y_pred), false))
            discriminator.random_state()
            res.append(loss(discriminator(y), true))

        l = res[0]
        for i in res[1:]:
            l += i
        l /= 2
        print("Discriminator loss: ", l.item())
        optimizer_d.zero_grad()
        l.backward()
        optimizer_d.step()
    
    
        


Generator loss:  58.63193130493164
Generator loss:  56.964508056640625
Generator loss:  54.67265701293945
Generator loss:  49.8618049621582
Generator loss:  45.79183578491211
Generator loss:  41.790687561035156
Generator loss:  38.5559196472168
Generator loss:  36.848323822021484
Generator loss:  32.73257827758789
Generator loss:  31.593767166137695
Generator loss:  31.580915451049805
Generator loss:  30.950395584106445
Generator loss:  30.83224868774414
Generator loss:  28.861591339111328
Generator loss:  27.94122314453125
Generator loss:  28.087387084960938
Generator loss:  27.483898162841797
Generator loss:  27.495128631591797
Generator loss:  26.805212020874023
Generator loss:  26.874853134155273
Discriminator loss:  74.4173355102539
Discriminator loss:  75.04899597167969
Discriminator loss:  74.17759704589844
Discriminator loss:  71.91226959228516
Discriminator loss:  69.0453109741211
Discriminator loss:  67.18856048583984
Discriminator loss:  65.24742889404297
Discriminator loss: