In [1]:
# -*- coding: utf-8 -*- 

import numpy as np
import pandas as pd
from tqdm import tqdm

import tensorflow as tf
from tensorflow import keras as K
from tensorflow.keras import Sequential, layers, losses, optimizers, datasets
from tensorflow.keras.layers import Dense, BatchNormalization
from tensorflow.keras.utils import Sequence

from pathlib import Path
from sklearn.model_selection import train_test_split

In [2]:
# data generator class
class DataGenerator(Sequence):
    def __init__(self, ids, niis_dir, target_path, batch_size=16, shuffle=True):
        self.id_names = ids
        self.indexes = np.arange(len(self.id_names))
        self.niis_dir = Path(niis_dir)
        self.target = np.load(target_path, allow_pickle=True)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    # for printing the statistics of the function
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.id_names))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation__(self, id_name): # 1 부터 시작
        'Generates data containing batch_size samples'
        # Initialization
        subject_id, r0 = divmod(id_name, 720)
        run, t = divmod(r0, 240)
        nii_path = self.niis_dir / f'sub-{subject_id+1:02d}/run-{run+1}/{t}.pkl' # 이미지 1개 경로
        nii = np.load(nii_path, allow_pickle=True).flatten()
        target = self.target[subject_id, run, t]
        
        return nii, target

    def __len__(self):
        "Denotes the number of batches per epoch"
        # self.id_names: 존재하는 전체 데이터 개수
        # self.batch_size: 배치사이즈를 의미합니다.
        return int(np.floor(len(self.id_names) / self.batch_size))

    def __getitem__(self, index):  # index : batch no.
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_ids = [self.id_names[k] for k in indexes]

        niis = []
        targets = []
        
        for id_name in batch_ids:
            nii, target = self.__data_generation__(id_name)
            niis.append(nii)
            targets.append(target)

        niis = np.array(niis)
        targets = np.array(targets)

        return niis, targets  # return batch

In [3]:
niis_path = '/data2/project_modelbasedMVPA/prepprep'
targets_path = './modulation.pkl'

ids = range(16 * 3 * 240 - 1)
train_ids, valid_ids = train_test_split(ids, test_size=0.2, random_state=42)

train_generator = DataGenerator(train_ids, niis_path, targets_path)
valid_generator = DataGenerator(valid_ids, niis_path, targets_path)

In [4]:
batch_size = 256
epochs = 10

print("total training batches: ", len(train_generator))
print("total validaton batches: ", len(valid_generator))
train_steps = len(train_ids) // batch_size
valid_steps = len(valid_ids) // batch_size

optimizer = optimizers.Adam(lr=0.003, decay=1e-5)
loss = losses.MeanSquaredError()

mirrored_strategy = tf.distribute.MirroredStrategy() # for multi-gpus

with mirrored_strategy.scope():
    model = Sequential()
    model.add(Dense(1024, activation='relu', input_shape=(902629,)))
    model.add(BatchNormalization())
    model.add(Dense(512, activation='relu'))
    model.add(BatchNormalization())
    model.add(Dense(512, activation='relu'))
    model.add(BatchNormalization())
    model.add(Dense(512, activation='relu'))
    model.add(BatchNormalization())
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='mse', optimizer='adam')

# fit model
model.fit_generator(generator=train_generator, validation_data=valid_generator,
                    steps_per_epoch=train_steps, validation_steps=valid_steps,
                    epochs=epochs)

total training batches:  575
total validaton batches:  144
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Instructions for updating:
Please use Model.fit, which supports generators.
Epoch 1/10
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f77900bf390>

In [5]:
model.save('mlp_model')

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: mlp_model/assets


In [7]:
weights_list = model.get_weights()

26


In [8]:
print(len(weights_list))
print(len(weights_list[0]))

26
902629


In [13]:
import pickle
with open('weight_first_layer.pkl', 'wb') as f:
    pickle.dump(weights_list[0], f, pickle.HIGHEST_PROTOCOL)

In [None]:
target = weights_list[0]