In [2]:
import os
import sys
sys.path.insert(0, '..')
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import TensorDataset, DataLoader
from models.wgan_gp import WGANGPBaseDiscriminator, WGANGPBaseGenerator
from training import Trainer
from modules import DBlock, GBlock

class Discriminator(WGANGPBaseDiscriminator):
    def __init__(self, **kwargs):
        super().__init__(channels=64)

        self.block1 = DBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, downsample=True)
        self.block2 = DBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, downsample=True)
        self.block3 = DBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, downsample=True)
        self.conv = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0)
        self.end = nn.Linear(394, 1)
        
        nn.init.normal_(self.conv.weight.data, 0.0, 0.02)
        nn.init.normal_(self.end.weight.data, 0.0, 0.02)

    def forward(self, x):
        x = x.float()
        h = self.block1(x)
        h = self.block2(h)
        h = self.block3(h)
        h = self.conv(h)
        h = self.end(h)
        return h.view(h.shape[0], 64)

class Generator(WGANGPBaseGenerator):
    def __init__(self, **kwargs):
        super().__init__(channels=64, nz=3152) # noise shape will start off as real_data.shape[0] x channels x nz

        # Build the layers
        self.block1 = GBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, upsample=False)
        self.block2 = GBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, upsample=False)
        self.block3 = GBlock(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, upsample=False)
        self.conv = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0)

        # Initialise the weights
        nn.init.normal_(self.conv.weight.data, 0.0, 0.02)

    def forward(self, x):
        x = x.float()
        h = self.block1(x)
        h = self.block2(h)
        h = self.block3(h)
        h = self.conv(h)
        return h
    

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
os.chdir('c:\\Users\\joshua.park\\Desktop\\repro-gan')
data = torch.tensor(np.load("./examples/test_data.npy")).detach() # torch.Size([2, 64, 3152])
dataloader = DataLoader(
    TensorDataset(data),
    batch_size=1,
    shuffle=True
)

netD = Discriminator().to(device)
netG = Generator().to(device)

optD = optim.Adam(netD.parameters(), 0.0001, (0.5, 0.99))
optG = optim.Adam(netG.parameters(), 0.0001, (0.5, 0.99))

trainer = Trainer(
    netD=netD, # netD=netD.module to use GPU
    netG=netG, # netD=netD        to use CPU
    optD=optD,
    optG=optG,
    n_dis=1,
    num_steps=3,
    dataloader=dataloader,
    save_steps=1,
    print_steps=1,
    log_dir='./examples/saved_states',
    device=device)
trainer.train()

INFO: Starting training from global step 0...
tensor(-2.8117e-06, grad_fn=<AddBackward0>) errD
INFO: [Epoch 1/2][Global Step: 1/5] 
| D(G(z)): 0.5037
| D(x): 0.5037
| errD: -0.0
| errG: -0.0146
| (1.2729 sec/idx)
INFO: Saving checkpoints...
tensor(-9.8748e-06, grad_fn=<AddBackward0>) errD
INFO: [Epoch 1/2][Global Step: 2/5] 
| D(G(z)): 0.5037
| D(x): 0.5037
| errD: -0.0
| errG: -0.0146
| (1.2202 sec/idx)
INFO: Saving checkpoints...
tensor(-7.4180e-06, grad_fn=<AddBackward0>) errD
INFO: [Epoch 1/2][Global Step: 3/5] 
| D(G(z)): 0.5037
| D(x): 0.5037
| errD: -0.0
| errG: -0.0146
| (1.2090 sec/idx)
INFO: Saving checkpoints...
tensor(-2.2119e-06, grad_fn=<AddBackward0>) errD
INFO: [Epoch 2/2][Global Step: 4/5] 
| D(G(z)): 0.5037
| D(x): 0.5037
| errD: -0.0
| errG: -0.0146
| (1.2790 sec/idx)
INFO: Saving checkpoints...
tensor(-4.2189e-06, grad_fn=<AddBackward0>) errD
INFO: [Epoch 2/2][Global Step: 5/5] 
| D(G(z)): 0.5037
| D(x): 0.5037
| errD: -0.0
| errG: -0.0146
| (1.4138 sec/idx)
INFO: S

In [4]:
! tensorboard --logdir=./examples/saved_states