In [1]:
from fid import calculate_fid
import numpy as np
from library.dataset import get_pytorch_datataset

In [2]:
architectures = ['TCN', 'LSTM', 'GRU']

In [3]:
cfid_values = {arch: {} for arch in architectures}

In [None]:
from library.constants import DEVICE
from library.gan_train_loop import load_gan
from library.gan import Generator as TCN_Generator
from library.gan_LSTM import Generator as LSTM_Generator
from library.gan_GRU import Generator as GRU_Generator

from library.generation import generate_fake_returns as TCN_generate_fake_returns
from library.generation_LSTM import generate_fake_returns as LSTM_generate_fake_returns
from library.generation_GRU import generate_fake_returns as GRU_generate_fake_returns

GENERATIONS_AMOUNT = 10

df_returns_real = get_pytorch_datataset()[0]

tcn_generator = TCN_Generator(2).to(DEVICE)
load_gan('TCN', tcn_generator, epoch=800)
tcn_df_returns_fake = [TCN_generate_fake_returns(tcn_generator, df_returns_real, seed=i) for i in range(GENERATIONS_AMOUNT)]

lstm_generator = LSTM_Generator().to(DEVICE)
load_gan('LSTM', lstm_generator, epoch=800)
lstm_df_returns_fake = [LSTM_generate_fake_returns(lstm_generator, df_returns_real, seed=i) for i in range(GENERATIONS_AMOUNT)]

gru_generator = GRU_Generator().to(DEVICE)
load_gan('GRU', gru_generator, epoch=800)
gru_df_returns_fake = [GRU_generate_fake_returns(gru_generator, df_returns_real, seed=i) for i in range(GENERATIONS_AMOUNT)]


In [7]:
fids = []
for traj in tcn_df_returns_fake:
    fids.append(calculate_fid(df_returns_real, traj))

cfid_values['TCN']['mean'] = np.mean(fids)
cfid_values['TCN']['std'] = np.std(fids)

In [9]:
fids = []
for traj in lstm_df_returns_fake:
    fids.append(calculate_fid(df_returns_real, traj))

cfid_values['LSTM']['mean'] = np.mean(fids)
cfid_values['LSTM']['std'] = np.std(fids)

In [10]:
fids = []
for traj in gru_df_returns_fake:
    fids.append(calculate_fid(df_returns_real, traj))

cfid_values['GRU']['mean'] = np.mean(fids)
cfid_values['GRU']['std'] = np.std(fids)

In [11]:
cfid_values

{'TCN': {'mean': 0.000224809737867745, 'std': 8.750088205863424e-06},
 'LSTM': {'mean': 0.0011258203285543058, 'std': 2.1268350228962416e-05},
 'GRU': {'mean': 0.0009432847844564321, 'std': 2.8437548273114842e-05}}