In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import res.fnn.training as training
import res.process_data.process_raw_data as prd
import res.process_data.dire_and_coor as dc
import res.process_data.process_output as out
from res.process_data.dataset import tensor_dataset

In [None]:
coordinates_input = prd.read_coordinate(16, '../../data/Coordinates.dat')
directions_input = dc.coordinates_directions_four(coordinates_input)
sincos_input = dc.sin_cos(directions_input)

In [None]:
z_dim = 8
im_dim = 30
hidden_dim = 16
display_step = 50
lr = 0.0003
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
disc_repeats = 5
batch_size = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
shuffle = True
num_worker = 0
pin_memory = True
input_tensor = torch.Tensor(sincos_input)
dataset = tensor_dataset(input_tensor, 15, 2)
dataloader = DataLoader(dataset= dataset,
                            shuffle=shuffle,
                            batch_size=batch_size,
                            num_workers=num_worker,
                            pin_memory=pin_memory)

In [None]:
gen, disc, gen_opt, disc_opt = training.initialize_model(z_dim, im_dim, hidden_dim,
                                                         device, lr, beta_1, beta_2)

In [None]:
n_epochs = 10
training.training_wloss(n_epochs, dataloader, device, disc_repeats, gen, gen_opt,
                        disc, disc_opt, z_dim, c_lambda, display_step)

In [None]:
# out.save_model(gen, disc, 'wgan_sincos_real', n_epochs)

In [None]:
n_epochs = 10
out.check_models("wgan_sincos_real", n_epochs, z_dim, im_dim, hidden_dim, "sincos", coordinates_input)

In [None]:
from res.fnn.generator import Generator
from res.fnn.discriminator import Discriminator

gen = Generator(z_dim, im_dim, hidden_dim)
disc = Discriminator(im_dim, hidden_dim)

gen_check_point = torch.load('wgan_sincos_real_gan_10.pth.tar', map_location='cpu')
disc_check_point = torch.load('wgan_sincos_real_disc_10.pth.tar', map_location='cpu')
gen.load_state_dict(gen_check_point['gen_state_dict'])
disc.load_state_dict(disc_check_point['disc_state_dict'])

In [None]:
coordinates_output, output_list = out.get_output_coordinate(gen, 'sincos', z_dim, iteration=1000, noise_num=16)