In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import data_loader
import numpy as np
from torch.utils.data import DataLoader
import torchvision


import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [33]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 25).to(torch.float64),
            nn.LeakyReLU(0.1),
            nn.Linear(25, 1).to(torch.float64),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self, z_dim, out_features): # z_dim is the dimension of the noise (latent_dimension)
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 15).to(torch.float64),
            nn.LeakyReLU(0.1),
            nn.Linear(15, out_features).to(torch.float64),
        )

    def forward(self, x):
        return self.gen(x)
    
device = "cpu"
lr = 3e-4
z_dim = 6
seq_dimension = 300
batch_size = 32
num_epochs = 50

disc = Discriminator(seq_dimension).to(device)
gen = Generator(z_dim, seq_dimension).to(device)

fixed_noise = torch.randn(batch_size, z_dim, dtype=torch.float64).to(device)

destination, dataset, time = data_loader.load_data("data.json")
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.MSELoss()
writer_fake = SummaryWriter(f"runs/GAN_seq/fake")
writer_real = SummaryWriter(f"runs/GAN_seq/real")
step = 0

for epoch in range(num_epochs):
    # for batch_idx, (real, _) in enumerate(loader):
    for batch_idx, real in enumerate(loader):
        # real = real.view(-1, 784).to(device)
        real = real.view(-1, 300)
        batch_size = real.shape[0]

        ### Train discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim, dtype=torch.float64).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train the Generator min log(1 - D(G(z))) <-> max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 100, 3)
                data = real.reshape(-1, 1, 100, 3)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1





Epoch [0/50] Batch 0/670                       Loss D: 0.4178, loss G: 0.2444
Epoch [1/50] Batch 0/670                       Loss D: 0.0748, loss G: 0.3806
Epoch [2/50] Batch 0/670                       Loss D: 0.0260, loss G: 0.6049
Epoch [3/50] Batch 0/670                       Loss D: 0.0075, loss G: 0.7837
Epoch [4/50] Batch 0/670                       Loss D: 0.0030, loss G: 0.8644
Epoch [5/50] Batch 0/670                       Loss D: 0.0010, loss G: 0.9165
Epoch [6/50] Batch 0/670                       Loss D: 0.0006, loss G: 0.9380
Epoch [7/50] Batch 0/670                       Loss D: 0.0004, loss G: 0.9436
Epoch [8/50] Batch 0/670                       Loss D: 0.0002, loss G: 0.9639
Epoch [9/50] Batch 0/670                       Loss D: 0.0001, loss G: 0.9674
Epoch [10/50] Batch 0/670                       Loss D: 0.0001, loss G: 0.9748
Epoch [11/50] Batch 0/670                       Loss D: 0.0001, loss G: 0.9807
Epoch [12/50] Batch 0/670                       Loss D: 0.0000

In [15]:
destination, path, time = data_loader.load_data("data.json")
# print(len(destination))
# print(len(path[0]))
# print(len(time))
batch_size = 32

loader = DataLoader(path, batch_size=batch_size, shuffle=True)
for index, real in enumerate(loader):
    real = real.view(-1, 300)
    if index == 0:
        print(real.shape[0])
        print(len(real[0]))
        print(real[0])
# for i in range(len(path)):
#     if i == 0:
#         print(path[i])


32
300
tensor([ 0.0000e+00,  1.0000e+00,  0.0000e+00,  9.7278e-04,  2.0000e+00,
         0.0000e+00,  1.9455e-03,  7.0000e+00,  0.0000e+00,  2.9183e-03,
         8.0000e+00,  0.0000e+00,  3.8911e-03,  9.0000e+00,  0.0000e+00,
         4.8638e-03,  1.2000e+01, -5.2083e-04,  6.8094e-03,  1.4000e+01,
        -5.2083e-04,  7.7821e-03,  1.7000e+01, -5.2083e-04,  8.7549e-03,
         1.8000e+01, -5.2083e-04,  9.7276e-03,  1.8000e+01, -5.2083e-04,
         1.0700e-02,  2.2000e+01, -5.2083e-04,  1.1673e-02,  2.3000e+01,
        -5.2083e-04,  1.2646e-02,  2.5000e+01, -5.2083e-04,  1.3619e-02,
         2.6000e+01, -5.2083e-04,  1.4591e-02,  2.6000e+01, -5.2083e-04,
         1.5564e-02,  3.1000e+01, -5.2083e-04,  1.6537e-02,  3.2000e+01,
        -5.2083e-04,  1.7510e-02,  3.3000e+01, -5.2083e-04,  1.8483e-02,
         3.6000e+01, -5.2083e-04,  1.9455e-02,  3.7000e+01, -5.2083e-04,
         2.0428e-02,  4.2000e+01, -1.5625e-03,  2.0428e-02,  4.3000e+01,
        -2.0833e-03,  2.1401e-02,  4.4000e+0

In [3]:

fixed_noise = torch.randn((10, 3))
print(fixed_noise)


tensor([[-0.7320,  1.3747, -0.3393],
        [-0.4281,  0.0908, -0.1422],
        [-0.9475,  0.2882, -0.9761],
        [ 0.2710,  1.5331, -1.3416],
        [ 0.1415, -1.1231,  0.8345],
        [ 0.7772, -0.1877, -0.7764],
        [-0.2638, -0.4558,  0.0333],
        [ 0.1946, -0.0061,  1.1337],
        [-0.9534,  1.3972, -1.8519],
        [ 0.5256,  0.3138, -1.4106]])
