In [1]:
import os
import pandas as pd
import numpy as np
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from scipy import signal

from adversarial_autoencoder import Decoder, Encoder

In [2]:
data_folder = "../../data/william/"
fig_folder = "../../fig/william/"

data_file = f'{data_folder}/preprocessed_data.csv'
file_for_export = f'{data_folder}/generated_data.csv'

In [4]:
data = pd.read_csv(data_file, index_col=0)
data

Unnamed: 0,label,0,1,2,3,4,5,6,7,8,...,1990,1991,1992,1993,1994,1995,1996,1997,1998,1999
0,sitting,-0.416300,-0.417491,-0.420441,-0.429083,-0.420441,-0.414827,-0.414827,-0.425872,-0.434747,...,0.549077,0.555241,0.555241,0.539026,0.500123,0.539026,0.543436,0.543436,0.543436,0.541752
1,sitting,-0.975244,-0.975244,-0.975244,-0.956182,-0.958875,-0.958875,-0.958875,-0.939893,-0.934297,...,1.424881,1.448575,1.448575,1.456614,1.460817,1.460817,1.464875,1.485612,1.485612,1.419242
2,sitting,-0.075665,-0.075665,-0.031610,-0.031610,-0.044277,-0.054205,-0.054205,-0.038245,-0.038245,...,-0.450913,-0.458894,-0.458894,-0.437464,-0.429559,-0.429559,-0.465102,-0.465102,-0.437855,-0.390392
3,sitting,-1.344867,-1.358563,-1.358563,-1.358563,-1.360337,-1.360337,-1.400246,-1.406819,-1.400246,...,1.714682,1.714682,1.714712,1.742487,1.742487,1.713822,1.713822,1.737701,1.737701,1.730101
4,sitting,-0.972655,-1.007781,-1.018223,-1.018223,-1.061634,-1.061634,-1.059564,-1.026008,-1.026008,...,1.947539,1.947539,1.947539,1.883424,1.903985,1.903985,1.903985,1.918113,1.909350,1.909350
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
345,standing,0.362166,0.462288,0.467350,0.467350,0.441121,0.441121,0.441121,0.443292,0.443292,...,-0.600820,-0.605646,-0.622196,-0.622196,-0.637492,-0.597702,-0.637492,-0.597702,-0.642326,-0.506493
346,standing,-0.025810,-0.025810,-0.020282,-0.013312,-0.016767,-0.005698,-0.016767,0.022315,-0.003027,...,0.071983,0.046967,0.046967,0.055400,0.055400,0.060659,0.084731,0.098476,0.084731,-0.000364
347,standing,0.945325,0.952541,0.962438,0.962438,0.950656,0.927253,0.927253,0.976097,0.980584,...,-1.318418,-1.337065,-1.371594,-1.372040,-1.372942,-1.372040,-1.372942,-1.362027,-1.374201,-1.351678
348,standing,1.015584,1.015584,1.015823,1.027837,1.027837,1.011747,1.008753,1.011747,1.024149,...,-1.649722,-1.649722,-1.659491,-1.679571,-1.679571,-1.668293,-1.665690,-1.665690,-1.622155,-1.598536


In [5]:
root_folder = "../.."
bkp_folder = f"{root_folder}/bkp/william/generative_models"
conditions = [x[0].split("/")[-1] for x in os.walk(bkp_folder)][1:]

n = 10000

df_list = []

for cond in conditions:
    folder = f"{bkp_folder}/{cond}"
    encoder = Encoder.load(folder=folder)
    decoder = Decoder.load(folder=folder)
    z = torch.randn((n//2, encoder.latent_dim))
    samples = decoder(z).detach().numpy()
    
    df = pd.DataFrame(samples)
    df.insert(0, 'label', cond)
    
    df_list.append(df)
    
gen_data = pd.concat(df_list, axis=0, ignore_index=True, sort=False)
gen_data

Unnamed: 0,label,0,1,2,3,4,5,6,7,8,...,490,491,492,493,494,495,496,497,498,499
0,standing,0.998841,0.998791,0.998900,0.998916,0.998992,0.998892,0.998849,0.999002,0.998782,...,-0.982918,-0.983242,-0.979289,-0.982258,-0.983125,-0.985334,-0.984307,-0.985363,-0.983074,-0.983885
1,standing,0.975289,0.976977,0.979783,0.979124,0.977761,0.980229,0.982286,0.977204,0.978893,...,-0.997984,-0.998587,-0.997904,-0.997725,-0.998198,-0.998369,-0.998208,-0.997941,-0.998365,-0.998429
2,standing,0.028090,0.022696,0.045333,0.032829,0.022541,0.022338,0.027021,0.020599,0.017348,...,-0.401291,-0.385413,-0.381627,-0.410900,-0.375205,-0.379400,-0.391117,-0.387035,-0.382908,-0.392319
3,standing,0.911681,0.921721,0.931487,0.927215,0.932348,0.929968,0.921985,0.928674,0.933144,...,-0.998353,-0.998224,-0.998109,-0.998126,-0.997948,-0.998409,-0.998098,-0.998088,-0.997890,-0.998413
4,standing,0.928393,0.940146,0.938458,0.938733,0.943226,0.940740,0.935778,0.938224,0.942594,...,-0.983525,-0.984405,-0.983077,-0.982838,-0.982928,-0.983912,-0.983345,-0.984596,-0.982796,-0.984287
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,sitting,-0.335568,-0.374021,-0.373741,-0.373817,-0.371406,-0.374717,-0.401219,-0.391292,-0.374460,...,0.079913,0.087612,0.112122,0.113251,0.109776,0.104869,0.094402,0.093411,0.093990,0.090596
9996,sitting,-0.261506,-0.294556,-0.296993,-0.295482,-0.289226,-0.305838,-0.311839,-0.307460,-0.292914,...,0.469337,0.471598,0.485698,0.465470,0.467598,0.491054,0.484793,0.478743,0.474344,0.475569
9997,sitting,-0.721090,-0.725018,-0.703282,-0.707152,-0.708919,-0.714992,-0.701263,-0.700657,-0.722918,...,0.982453,0.981935,0.981704,0.982096,0.982008,0.981100,0.981698,0.980657,0.980936,0.981345
9998,sitting,-0.972290,-0.973930,-0.969861,-0.971826,-0.970991,-0.968090,-0.970479,-0.968067,-0.973491,...,0.980609,0.981121,0.982023,0.979871,0.980726,0.981528,0.981197,0.977977,0.979114,0.979803


In [6]:
gen_data.to_csv(file_for_export, index=True, header=True)