In [7]:
import os
import pandas as pd
import numpy as np
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from scipy import signal

from adversarial_autoencoder import Decoder, Encoder, Dataset

In [8]:
data_folder = "../../data/william/"
bkp_folder = f"../../bkp/william/generative_models"
fig_folder = "../../fig/william/reconstructed"

data_file = f'{data_folder}/preprocessed_data.csv'
file_for_export = f'{data_folder}/generated_data.csv'

In [9]:
df = pd.read_csv(data_file, index_col=0)
df

Unnamed: 0,label,0,1,2,3,4,5,6,7,8,...,30,31,32,33,34,35,36,37,38,39
0,standing,0.335113,0.347596,0.336049,0.343032,0.328762,0.348014,0.346694,0.335331,0.364545,...,-0.433690,-0.434881,-0.431574,-0.428831,-0.428773,-0.421384,-0.416518,-0.423027,-0.428537,-0.429517
1,standing,0.000895,0.000467,-0.003566,-0.011361,-0.005202,0.001991,0.005929,0.006656,0.013977,...,-0.631254,-0.651125,-0.444244,-0.494571,-0.538317,-0.572548,-0.607047,-0.612838,-0.616735,-0.618311
2,standing,-0.096764,-0.089463,-0.098124,-0.099384,-0.097621,-0.103314,-0.101170,-0.109625,-0.105213,...,-0.404683,-0.323787,-0.348436,-0.349691,-0.376399,-0.381099,-0.373210,-0.379673,-0.375099,-0.372539
3,standing,-0.263370,-0.263706,-0.261728,-0.261677,-0.260057,-0.260121,-0.259701,-0.249794,-0.254735,...,-0.134556,-0.030455,-0.165365,-0.275227,-0.200201,-0.212030,-0.184061,-0.165587,-0.180345,-0.183996
4,standing,0.081559,0.095439,0.095855,0.100744,0.106001,0.106589,0.106274,0.097211,0.100753,...,-0.615612,-0.453607,-0.536615,-0.560950,-0.631372,-0.667639,-0.671894,-0.700172,-0.704910,-0.703084
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
345,sitting,-0.521345,-0.509091,-0.511413,-0.513730,-0.511035,-0.511041,-0.504692,-0.509402,-0.505744,...,0.091892,0.162558,0.135246,0.112067,0.137026,0.133732,0.147023,0.141268,0.130998,0.129877
346,sitting,-0.271551,-0.284299,-0.279301,-0.273808,-0.278717,-0.267460,-0.267893,-0.267599,-0.267934,...,-0.135700,-0.217707,-0.177913,-0.177706,-0.157223,-0.152605,-0.149855,-0.139461,-0.140686,-0.142098
347,sitting,-0.198927,-0.214663,-0.213983,-0.217474,-0.224734,-0.222663,-0.229432,-0.231022,-0.226253,...,-0.234768,-0.180221,-0.189550,-0.203341,-0.245324,-0.248175,-0.248354,-0.248519,-0.240840,-0.238807
348,sitting,-0.448367,-0.447859,-0.450201,-0.448652,-0.449198,-0.445936,-0.444985,-0.449783,-0.445769,...,0.210265,0.146054,0.188480,0.218161,0.158905,0.224519,0.204968,0.173396,0.200847,0.210803


In [4]:
df_list = []
decimate_factor = 50 
datasets = Dataset.load_datasets(data_file, decimate_factor)
for cond, dataset in datasets.items():
    batch_size = len(dataset)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    folder = f"{bkp_folder}/{cond}"
    encoder = Encoder.load(folder=folder)
    decoder = Decoder.load(folder=folder)
    
    for _, (x, _) in enumerate(dataloader):
    
        encoded = encoder(x)
        x_prime = decoder(encoded).detach().numpy()
        _df = pd.DataFrame(x_prime)
        _df.insert(0, 'label', cond)
        df_list.append(_df)
    
df_rec = pd.concat(df_list, axis=0, ignore_index=True, sort=False)
df_rec    

Unnamed: 0,label,0,1,2,3,4,5,6,7,8,...,30,31,32,33,34,35,36,37,38,39
0,sitting,-0.561582,-0.548078,-0.545102,-0.522403,-0.514699,-0.499084,-0.521699,-0.480395,-0.533008,...,0.150560,0.085568,-0.028287,-0.106042,-0.036338,-0.034473,0.020582,0.047462,0.006331,0.026881
1,sitting,-0.999938,-0.999955,-0.999956,-0.999948,-0.999941,-0.999935,-0.999924,-0.999906,-0.999912,...,0.991876,0.995181,0.996613,0.996220,0.996445,0.996107,0.996279,0.995726,0.996170,0.996140
2,sitting,-0.852998,-0.847709,-0.835885,-0.826442,-0.825064,-0.828233,-0.817710,-0.812382,-0.777748,...,0.993987,0.997815,0.998617,0.998459,0.998290,0.998673,0.999046,0.998903,0.998591,0.998358
3,sitting,-0.998900,-0.999191,-0.999313,-0.999271,-0.999318,-0.998962,-0.999124,-0.999005,-0.999136,...,0.999734,0.999876,0.999955,0.999981,0.999984,0.999988,0.999982,0.999981,0.999979,0.999980
4,sitting,-0.708514,-0.722383,-0.722868,-0.707767,-0.708915,-0.693779,-0.703643,-0.679233,-0.682641,...,0.695661,0.649244,0.555211,0.484111,0.507474,0.497638,0.531238,0.541869,0.559448,0.553519
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
345,standing,0.999894,0.999899,0.999916,0.999918,0.999903,0.999915,0.999915,0.999892,0.999869,...,-0.999403,-0.999376,-0.998172,-0.998412,-0.998351,-0.998181,-0.998029,-0.998128,-0.998052,-0.997959
346,standing,0.999894,0.999901,0.999909,0.999903,0.999894,0.999904,0.999919,0.999897,0.999858,...,-0.999633,-0.999646,-0.998688,-0.998953,-0.998654,-0.998877,-0.998716,-0.999027,-0.998802,-0.998818
347,standing,0.987350,0.988147,0.987847,0.987627,0.989501,0.988891,0.991419,0.989810,0.993246,...,-0.999207,-0.999540,-0.999985,-0.999994,-0.999992,-0.999992,-0.999991,-0.999995,-0.999996,-0.999995
348,standing,0.927524,0.946348,0.942657,0.945399,0.950647,0.947871,0.952451,0.948971,0.961397,...,-0.983811,-0.995301,-0.999682,-0.999779,-0.999712,-0.999695,-0.999564,-0.999697,-0.999617,-0.999623


In [6]:
x = df.iloc[:, 1:].values
x_dec = signal.decimate(x, decimate_factor, axis=1)

    
x_rec = df_rec.iloc[:, 1:].values

y = df.label.values


for i in tqdm(range(x.shape[0])):
    
    fig, axes = plt.subplots(ncols=3)
    ax = axes[0]
    ax.plot(x[i, :])
    ax = axes[1]
    ax.plot(x_dec[i, :])
    ax = axes[2]
    ax.plot(x_rec[i, :])
    label = y[i]

    plt.tight_layout()

    os.makedirs(fig_folder, exist_ok=True)
    plt.savefig(f"{fig_folder}/{label}_{i}.png")
    plt.close()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 350/350 [00:32<00:00, 10.68it/s]


In [29]:
x.shape

(178, 125)

In [31]:
x = df[idx_data].values
x.shape

(178, 500)

In [32]:
df

Unnamed: 0,label,0,1,2,3,4,5,6,7,8,...,490,491,492,493,494,495,496,497,498,499
0,standing,1.672478,1.699251,1.653529,1.675447,1.701357,1.666121,1.673655,1.666940,1.666660,...,-0.943382,-0.901028,-0.921282,-1.002052,-0.906315,-0.909447,-0.941414,-0.907625,-0.898057,-0.867120
1,standing,1.353121,1.336306,1.325868,1.330737,1.333798,1.348461,1.343726,1.333880,1.337275,...,-1.545159,-1.554109,-1.548504,-1.547196,-1.554705,-1.550402,-1.530438,-1.542497,-1.552186,-1.536785
2,standing,-0.015325,-0.044372,-0.045755,-0.042678,-0.057733,-0.041835,-0.036314,-0.045428,-0.032134,...,-0.365798,-0.359208,-0.359462,-0.351151,-0.380013,-0.384817,-0.391396,-0.374895,-0.370487,-0.394166
3,standing,0.837277,0.805751,0.803815,0.797416,0.802943,0.808979,0.797999,0.796282,0.823370,...,-0.492374,-0.478440,-0.472203,-0.472054,-0.480722,-0.450856,-0.482687,-0.468123,-0.456670,-0.488786
4,standing,1.108809,1.168114,1.169325,1.180292,1.182882,1.169371,1.182611,1.188081,1.171493,...,-1.139195,-1.117813,-1.158414,-1.140875,-1.146082,-1.150172,-1.149547,-1.153580,-1.147293,-1.160851
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
173,standing,0.043477,0.055490,0.055713,0.048546,0.062730,0.059257,0.048316,0.072245,0.057259,...,-0.508125,-0.492149,-0.503229,-0.501296,-0.473725,-0.479683,-0.464368,-0.443531,-0.491752,-0.490495
174,standing,0.316889,0.326908,0.341556,0.325362,0.324445,0.332067,0.336173,0.339676,0.346671,...,-0.638636,-0.609874,-0.606061,-0.625323,-0.624185,-0.652241,-0.640237,-0.610047,-0.627170,-0.631413
175,standing,1.533256,1.550071,1.553509,1.536497,1.542901,1.556756,1.525776,1.517910,1.535335,...,-1.225621,-1.224379,-1.251184,-1.234858,-1.223370,-1.240235,-1.233826,-1.239359,-1.260614,-1.271496
176,standing,-0.280300,-0.264563,-0.301838,-0.270662,-0.256652,-0.297598,-0.302181,-0.275497,-0.266088,...,-0.325860,-0.317736,-0.356090,-0.341584,-0.325998,-0.341557,-0.334343,-0.342708,-0.354456,-0.339314
