In [None]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
import seaborn as sns
import matplotlib.pyplot as plt

import res.fnn.training as training
import res.fnn.functions as func
from res.fnn.generator import Generator
from res.fnn.discriminator import Discriminator

import res.process_data.process_raw_data as prd
import res.process_data.dire_and_coor as dc
import res.process_data.process_output as out
from res.process_data.dataset import tensor_dataset

In [None]:
# prepare the raw data
coordinates_input = prd.read_coordinate(16, '../../data/Coordinates.dat')
directions_input = dc.coor_direction_four(coordinates_input)
one_hot_vector = dc.one_hot_four(directions_input)

In [None]:
n_epochs = 5
z_dim = 8
display_step = 50
batch_size = 128
lr = 0.0003
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
disc_repeats = 5
device = 'cpu'
shuffle = True
num_worker = 4
pin_memory = True

In [None]:
input_tensor = torch.Tensor(one_hot_vector)
dataset = tensor_dataset(input_tensor, 15, 4)
dataloader = DataLoader(dataset= dataset,
                            shuffle=shuffle,
                            batch_size=batch_size,
                            num_workers=num_worker,
                            pin_memory=pin_memory)

In [None]:
gen = Generator(z_dim, im_dim=60, hidden_dim=16).to(device)
disc = Discriminator(im_dim=60, hidden_dim=16).to(device)

gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

In [None]:
training.training_wloss(n_epochs, dataloader, device, disc_repeats, gen, gen_opt,
                   disc, disc_opt, z_dim, c_lambda, display_step)

In [None]:
out.process_model('test', n_epochs, gen, disc, z_dim, 60, 16, 'onehot', coordinates_input)


