In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from matplotlib.collections import LineCollection

from sklearn.metrics import accuracy_score

%matplotlib inline

import torch
import torch.nn as nn
import torch.autograd as autograd
import torchvision
import torchvision.models as models
import torchvision.transforms as T
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules.activation import LeakyReLU

from torchsummary import summary

from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import glob
import random
import tensorflow as tf
from shutil import rmtree
from keras.models import load_model
from os.path import join, getctime, basename
from models import *
from utils import *
from data_preprocess import *

%load_ext autoreload
%autoreload 2

In [2]:
aug_data = load_prep_data(time=500, debug=False, pooling = True, subsample = 2, average = 2, normalization = False, noise_level = 0.5)

x_train = aug_data['x_train']
y_train = aug_data['y_train']
x_valid = aug_data['x_valid']
y_valid = aug_data['y_valid']
x_test = aug_data['x_test']
y_test = aug_data['y_test']
person_train_valid = aug_data['person_train_valid']
person_test = aug_data['person_test']
X_train_valid = aug_data['X_train_valid']
y_train_valid = aug_data['y_train_valid']

X_test = x_test
input_shape = x_train.shape
print("x_train.shape: {}", format(x_train.shape))

x_train.shape: {} (7191, 22, 250)


In [3]:
x_train = np.swapaxes(x_train, 1,2)
x_valid = np.swapaxes(x_valid, 1,2)
x_test = np.swapaxes(X_test, 1,2)
print('Shape of training set after dimension reshaping:',x_train.shape)
print('Shape of validation set after dimension reshaping:',x_valid.shape)
print('Shape of test set after dimension reshaping:',x_test.shape)


# Normalize each channel to have mean 0 and std 1
def standardize(x):
    mean = np.mean(x, axis=1)
    var = np.var(x, axis=1)

    return (x - mean[:, None]) / np.sqrt(var)[:, None]

x_train = standardize(x_train)
x_valid = standardize(x_valid)
x_test = standardize(x_test)

class LoadData(Dataset):
    def __init__(self, data, labels):
        self.data = torch.Tensor(data)
        self.labels = labels

    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, index: int):
        input_data = self.data[index].permute(1,0)
        input_label = self.labels[index]
        return (input_data, input_label)
    
    
# train_dataset = LoadData(data=x_train, labels=y_train,)
# val_dataset = LoadData(data=x_valid, labels=y_valid,)
# train_dl = DataLoader(train_dataset, batch_size=16, num_workers=1, pin_memory=True, shuffle=True)
# val_dl = DataLoader(val_dataset, batch_size=16, num_workers=1, pin_memory=True)

# data = next(iter(train_dl))
# input_data, input_labels = data
# print(input_data.size())

Shape of training set after dimension reshaping: (7191, 250, 22)
Shape of validation set after dimension reshaping: (1269, 250, 22)
Shape of test set after dimension reshaping: (1772, 250, 22)


In [4]:
class WGANGen(nn.Module):
    def __init__(self, noise_dim=100):
        super(WGANGen, self).__init__()
        # self.scaler = nn.Linear(64 * 50, 64 * 352 * 14)
        self.label_emb = nn.Embedding(4, 4)

        self.fc1 = nn.Linear(in_features=noise_dim + 4, out_features=352 * 14)
        self.bnorm1 = nn.BatchNorm1d(352 * 14)
        self.relu1 = nn.LeakyReLU(0.3)
        self.dropout1 = nn.Dropout(p=0.5)

        self.deconv1 = nn.ConvTranspose1d(352, 176, 5, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm2 = nn.BatchNorm1d(176)
        self.relu2 = nn.LeakyReLU(0.3)
        self.dropout2 = nn.Dropout(p=0.5)

        self.deconv2 = nn.ConvTranspose1d(176, 88, 5, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm3 = nn.BatchNorm1d(88)
        self.relu3 = nn.LeakyReLU(0.3)
        self.dropout3 = nn.Dropout(p=0.5)

        self.deconv3 = nn.ConvTranspose1d(88, 44, 5, stride=2, padding=2, output_padding=1, bias=False)
        self.bnorm4 = nn.BatchNorm1d(44)
        self.relu4 = nn.LeakyReLU(0.3)
        self.dropout4 = nn.Dropout(p=0.5)

        self.deconv4 = nn.ConvTranspose1d(44, 22, 5, stride=2, padding=1, output_padding=1, bias=False)
        self.tanh = nn.Tanh()

    def forward(self, x, labels):
        batch_size = x.size(0)  # 动态获取批次大小

        h = torch.cat((self.label_emb(labels), x), -1)
        # print(h.shape)

        h = self.fc1(h)
        h = self.bnorm1(h)
        h = self.relu1(h)
        # h = torch.reshape(h, (64, 352, 14))
        h = h.view(batch_size, 352, 14)  # 使用动态批次大小
        h = self.dropout1(h)

        # print(h.shape)

        h = self.deconv1(h)
        h = self.bnorm2(h)
        h = self.relu2(h)
        h = self.dropout2(h)
        # print(h.shape)

        h = self.deconv2(h)
        h = self.bnorm3(h)
        h = self.relu3(h)
        h = self.dropout3(h)
        # print(h.shape)

        h = self.deconv3(h)
        h = self.bnorm4(h)
        h = self.relu4(h)
        h = self.dropout4(h)
        # print(h.shape)
        
        h = self.deconv4(h)
        h = self.tanh(h)

        # print(h.shape)

        return h

class WGANDis(nn.Module):
    def __init__(self):
        super(WGANDis, self).__init__()
        self.scaler = nn.Linear(4, 250 * 22)
        self.label_emb = nn.Embedding(4, 4)

        self.conv1 = nn.Conv1d(44, 64, 2, stride=2, padding=0)
        self.bnorm1 = nn.BatchNorm1d(64)
        self.relu1 = LeakyReLU(0.3)

        self.conv2 = nn.Conv1d(64, 128, 2, stride=2, padding=0)
        self.bnorm2 = nn.BatchNorm1d(128)
        self.relu2 = LeakyReLU(0.3)

        self.conv3 = nn.Conv1d(128, 256, 2, stride=2, padding=0)
        self.bnorm3 = nn.BatchNorm1d(256)
        self.relu3 = LeakyReLU(0.3)

        self.conv4 = nn.Conv1d(256, 512, 2, stride=2, padding=0)
        self.bnorm4 = nn.BatchNorm1d(512)
        self.relu4 = LeakyReLU(0.3)

        self.dropout1 = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(7680, 1)
    
    def forward(self, x, labels):
        batch_size = x.size(0)  # 动态获取批次大小
        labels = labels.long()
        li = self.label_emb(labels)
        li = self.scaler(li)
        # li = torch.reshape(li, (64, 22, 250))
        # 使用动态批次大小进行reshape
        li = li.view(batch_size, 22, 250)


        h = torch.cat((x, li), 1)

        h = self.conv1(h)
        h = self.bnorm1(h)
        h = self.relu1(h)

        h = self.conv2(h)
        h = self.bnorm2(h)
        h = self.relu2(h)

        h = self.conv3(h)
        h = self.bnorm3(h)
        h = self.relu3(h)

        h = self.conv4(h)
        h = self.bnorm4(h)
        h = self.relu4(h)

        h = torch.flatten(h, 1)

        h = self.dropout1(h)
        h = self.fc1(h)

        return h

In [5]:
# cuda = True if torch.cuda.is_available() else False
# Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
# lambda_gp = 10
# n_critic = 5

In [None]:
model_test = WGANGen()
PATH = 'data/gan_checkpoint_600.pth'
checkpoint = torch.load(PATH)
model_test.load_state_dict(checkpoint['generator_state_dict'])
model_test.eval()