In [None]:
import torch.nn as nn
import torch
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import pickle
import os

import model as m
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

data_path = 'AMIGOS/preprocessed/'

In [None]:
def train_autoencoder(train_loader, test_loader, model, epochs, optimizer='Adam', lr=1e-4, autoencoder_weight_path='final'):
    train_loss = []
    test_loss = []
    
    loss_fct=nn.BCELoss()
    optimizer=torch.optim.Adam(model.parameters(), lr)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for step, (batch_x, batch_y) in enumerate(train_loader):
            X_hat=model(batch_x)
            
            loss = loss_fct(X_hat, batch_y)
            total_loss += loss.item() * batch_x.size(0)
            
            loss.backward()
        
            optimizer.step()
            optimizer.zero_grad()
            
        train_loss.append(total_loss / len(train_loader.dataset))
        test_loss.append(test_autoencoder(test_loader, model, loss_fct).item())
        
        if epoch != 0 and min(train_loss) == (total_loss / len(train_loader.dataset)):
            torch.save(model.state_dict(), autoencoder_weight_path)
                
        if (epoch+1) % 10 == 0:
            print("Epoch: ", epoch + 1, "| Loss: ", train_loss[-1])
            print("Test Loss: ", test_loss[-1])

        if (epoch+1) % 100 ==0:
            torch.save(model.state_dict(), 'modelweight/AutoEncoder_' + str(epoch + 1) + name_tag)
            
    return train_loss, test_loss  

def test_autoencoder(test_loader, model, loss_fct):
    total_loss=0
    
    model.eval()
    
    for x_batch, y_batch in test_loader:
        with torch.no_grad():
            X_hat = model(x_batch)
            total_loss += loss_fct(X_hat , y_batch) * x_batch.size(0)
            
    return total_loss / len(test_loader.dataset)

In [None]:
epochs=2000
batch_size=16

In [None]:
for name in [x for x in os.listdir(data_path) if '_dwt_' in x and 'func_data.pkl' in x]:
    with open(data_path + name, 'rb') as f:
        df = pickle.load(f)
    
    name_tag = '_'.join(name.split('.')[0].split('_')[-6:])
    autoencoder_weight_path = 'modelweight/AutoEncoder_lowestloss_' + name_tag
    print(name_tag)
    
    df.dwt = df.dwt.apply(np.nan_to_num)

    X = np.vstack(df.dwt.apply(lambda x: np.expand_dims(x, 0)).to_numpy())
    
    X_train, X_test, _, _ = train_test_split(X, X, test_size=0.2, random_state=3)

    X_train = torch.as_tensor(X_train, dtype=torch.float)
    X_test = torch.as_tensor(X_test, dtype=torch.float)

    train_loader = DataLoader(TensorDataset(X_train, X_train), 
                          batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(TensorDataset(X_test, X_test), 
                         batch_size=batch_size, shuffle=True)        

    autoencoder=m.AutoEncoderEEGNet()
    train_loss, test_loss = train_autoencoder(train_loader, test_loader, 
                                              autoencoder, epochs, 
                                              autoencoder_weight_path=autoencoder_weight_path)
    
    activation = {}
    
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output[0].detach().reshape(output[0].shape[0],-1)
            
        return hook
    
    handle = autoencoder.encoder.register_forward_hook(get_activation('emb'))
    
    X_embs = []
    X = torch.as_tensor(X, dtype=torch.float)
    xloader = DataLoader(TensorDataset(X, X), 
                      batch_size=batch_size, shuffle=True)
    
    autoencoder.eval()
    
    for x_batch, y_batch in xloader:
        with torch.no_grad():
            X_hat = autoencoder(x_batch)
            X_embs.append(activation['emb'])
            
    X_embs = torch.cat([x for x in X_embs]).numpy()
    
    df = pd.concat([df, pd.DataFrame(X_embs, columns = ['dwt_emb_' + str(i) for i in range(X_embs.shape[1])])], axis=1)
    
    with open(data_path + name.replace('.pkl', '_emb.pkl'), 'wb') as f:
        pickle.dump(df, f)
        
    os.remove(data_path + name)
    
    handle.remove()
    del autoencoder, activation

In [None]:
X.shape