In [None]:
#!/usr/bin/env python
# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import numpy as np
import random
from tqdm import tqdm
import copy
import pickle
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, BatchNormalization, Activation
from tensorflow.keras.models import Model

random.seed(42)
from sklearn.model_selection import train_test_split

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

scale = 224
    
with open('Dataset_0908.pkl','rb') as r:
    data = pickle.load(r)
    
with open('HMDB_BMRB_v5.pkl','rb') as r:
    raw = pickle.load(r)    
    
encode = dict((c,n) for n,c in enumerate(raw))
keys = list(data.keys())

In [None]:
labels = np.zeros((len(data)),int)
for n, i in tqdm(enumerate(data)):
    labels[n] = data[i]['Target']

In [None]:
train_list, val_list, _, _, = train_test_split(keys,labels,test_size=1/10,random_state=42,stratify=labels)
train_list, test_list, _, _, = train_test_split(keys,labels,test_size=1/9,random_state=42,stratify=labels)

In [None]:
def get_unet(img_size):
    inputs = Input(shape=img_size + (2,))
    conv1 = Conv2D(64, (3, 3), padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    conv1 = Conv2D(64, (3, 3), padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    conv1 = Conv2D(64, (3, 3), padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = Conv2D(128, (3, 3), padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = Conv2D(128, (3, 3), padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(256, (3, 3), padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = Conv2D(256, (3, 3), padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = Conv2D(256, (3, 3), padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Conv2D(512, (3, 3), padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Conv2D(512, (3, 3), padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(1024, (3, 3), padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Conv2D(1024, (3, 3), padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Conv2D(1024, (3, 3), padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    pool5 = MaxPooling2D(pool_size=(2, 2))(conv5)

    up6 = concatenate([Conv2DTranspose(512, (2, 2), strides=(2,2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(512, (3, 3), padding='same')(up6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)
    conv6 = Conv2D(512, (3, 3), padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)
    conv6 = Conv2D(512, (3, 3), padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)
    

    up7 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(256, (3, 3), padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)
    conv7 = Conv2D(256, (3, 3), padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)
    conv7 = Conv2D(256, (3, 3), padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)
    

    up8 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(128, (3, 3), padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)
    conv8 = Conv2D(128, (3, 3), padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)
    conv8 = Conv2D(128, (3, 3), padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation('relu')(conv8)


    up9 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(64, (3, 3), padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)
    conv9 = Conv2D(64, (3, 3), padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)
    conv9 = Conv2D(64, (3, 3), padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)    

    conv10 = Conv2D(2, (1, 1), activation='softmax',name='hsqc')(conv9)
    
    
    gpool1 = layers.GlobalMaxPooling2D()(conv5)
    
    drop1 = layers.Dropout(0.2)(gpool1)
    dense1 = layers.Dense(len(encode), activation = 'softmax',name='mb')(drop1)
    
    model = Model(inputs=inputs, outputs=[conv10,dense1])


    return model

img_size = (scale,scale)

# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
model = get_unet(img_size)
model.compile(optimizer=keras.optimizers.Adam(lr=0.00001), loss={'mb':'sparse_categorical_crossentropy','hsqc':'sparse_categorical_crossentropy'},
              metrics={'mb':(tf.keras.metrics.SparseTopKCategoricalAccuracy(k=2, name="top2", dtype=None),
                             tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1, name="top1", dtype=None)),
                       'hsqc':'accuracy'})

In [None]:
# Datagenerator setup
class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, batch_size=32, dim=(scale,scale), n_channels=2,
                 n_classes=len(encode), shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X,y1,y2 = self.__data_generation(list_IDs_temp)
        
        return X, [y1,y2]

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.zeros((self.batch_size, *self.dim, 2), int)
        y1 = np.zeros((self.batch_size, *self.dim,1), int)
        y2 = np.zeros((self.batch_size,1), dtype=int)
        
        for n,i in enumerate(list_IDs_temp):
            qc = np.zeros((scale,scale,3),int)
            mat, label = data[i]['HSQC'],data[i]['Target']
            for j in range(len(mat[0])):
                qc[mat[0][j],mat[1][j],mat[2][j]] = 1
            X[n], y1[n], y2[n] = qc[:,:,:2],qc[:,:,2:3],label


        return X, y1, y2

In [None]:
#Generate training and validation set generator
params = {'dim': (scale,scale),
          'batch_size': 16,
          'n_classes': len(encode),
          'n_channels': 2,
          'shuffle': True}



training_generator = DataGenerator(train_list, **params)
validation_generator = DataGenerator(val_list, **params)
test_generator = DataGenerator(test_list, **params)

In [None]:
#Training model
model.fit(training_generator,validation_data=validation_generator,
            epochs=10,use_multiprocessing=False, verbose=1)


In [None]:
#Evaluation
model.evaluate(test_generator)