In [1]:
%matplotlib inline
import numpy as np
import time
import h5py
import keras
import pandas as pd
import math
import joblib
import os, cv2
import matplotlib.pyplot as plt

# from fuel.datasets.hdf5 import H5PYDataset

from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedShuffleSplit
from IPython.display import display

from keras.layers import (Input, Dense, Lambda, Flatten, Reshape, BatchNormalization, Activation, 
                          Dropout, Conv2D, Conv2DTranspose, MaxPooling2D)
from keras.layers.merge import concatenate
from keras.regularizers import l2
from keras.initializers import RandomUniform
from keras.optimizers import RMSprop, Adam, SGD
from keras.models import Model
from keras import metrics
from keras.utils import np_utils
from keras import backend as K
# from keras_tqdm import TQDMNotebookCallback
from keras.datasets import mnist
import tensorflow as tf
from tensorflow import Graph, Session
os.environ["CUDA_VISIBLE_DEVICES"]="0"

Using TensorFlow backend.


# Variational Autoencoder Parameters

In [2]:
img_rows, img_cols, img_chns = 20, 20, 3
print(K.image_data_format())
if K.image_data_format() == 'channels_first':
    original_img_size = (img_chns, img_rows, img_cols)
else:
    original_img_size = (img_rows, img_cols, img_chns)

batch_size = 600
latent_dim = 128
intermediate_dim = 512
epsilon_std = 1.0
epochs = 1500
activation = 'relu'
dropout = 0.5
learning_rate = 1e-4
decay = 0.0
num_classes = 2

image_name = '322.tif'
mask_name = '322_car_1'
data_path = './data/xview/'
unlabeled_imgs_path = './temp'
pos_folder_path = os.path.join(data_path, mask_name+'_samples', 'positive')
neg_folder_path = os.path.join(data_path, mask_name+'_samples', 'negative')

channels_last


# Load map dataset

In [3]:
def read_img_from_dir(path):
    data= []
    for root, dirs, files in os.walk(path, topdown=False):
        for name in files:
            img_path = os.path.join(root, name)
            data.append(cv2.imread(img_path))
    data = np.array(data)
    return data

# (X_train, y_train), (X_test, y_test) = mnist.load_data()
# X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1) / 255.
# X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1) / 255.

# print(X_train.shape, y_train.shape)
# print(X_test.shape, y_test.shape)

In [4]:
X_others = read_img_from_dir(unlabeled_imgs_path)
pos_samples = read_img_from_dir(pos_folder_path)
neg_samples = read_img_from_dir(neg_folder_path)
print(X_others.shape, pos_samples.shape, neg_samples.shape)

(1406, 20, 20, 3) (5, 20, 20, 3) (5, 20, 20, 3)


In [5]:
X_others = X_others.reshape(X_others.shape[0], img_rows, img_cols, img_chns) / 255.
pos_samples = list(pos_samples.reshape(pos_samples.shape[0], img_rows, img_cols, img_chns) / 255.)
neg_samples = list(neg_samples.reshape(neg_samples.shape[0], img_rows, img_cols, img_chns) / 255.)
X = []
y = []
for i in range(len(pos_samples)):
    X.append(pos_samples[i])
    X.append(neg_samples[i])
    y.append([0,1])
    y.append([1,0])
X, y = np.array(X), np.array(y)
X_others = X_others[:batch_size*2]
print(X_others.shape, X.shape, y.shape)

(1200, 20, 20, 3) (10, 20, 20, 3) (10, 2)


# Encoder Network

In [6]:
def create_enc_conv_layers(stage, **kwargs):
    conv_name = '_'.join(['enc_conv', str(stage)])
#     bn_name = '_'.join(['enc_bn', str(stage)])
    layers = [
        Conv2D(name=conv_name, **kwargs),
#         BatchNormalization(name=bn_name),
        Activation(activation),
    ]
    return layers

def create_dense_layers(stage, width):
    dense_name = '_'.join(['dense', str(stage)])
#     bn_name = '_'.join(['dense_bn', str(stage)])
#     dense_layer = Dense(width, name=dense_name)
# #     bn_layer = BatchNormalization(name=bn_name)
#     activation_layer = Activation(activation)
#     dropout_layer = Dropout(dropout)
    layers = [
        Dense(width, name=dense_name),
        Activation(activation),
        Dropout(dropout),
    ]
    return layers

def inst_layers(layers, in_layer):
    x = in_layer
    for layer in layers:
        if isinstance(layer, list):
            x = inst_layers(layer, x)
        else:
            print(layer.name, x, layer)
            x = layer(x)
        
    return x

In [7]:

enc_filters=64
enc_layers = [
    create_enc_conv_layers(stage=1, filters=enc_filters, kernel_size=3, strides=1, padding='same'),
    create_enc_conv_layers(stage=2, filters=enc_filters, kernel_size=3, strides=1, padding='same'),
    create_enc_conv_layers(stage=3, filters=enc_filters, kernel_size=3, strides=2, padding='same'),
    Flatten(),
    create_dense_layers(stage=4, width=intermediate_dim),
]

In [8]:
# Labeled encoder
# x_in = Input(batch_shape=(batch_size,) + original_img_size)
# y_in = Input(batch_shape=(batch_size, num_classes))
x_in = Input(shape=(img_rows, img_cols, img_chns))
y_in = Input(shape=(num_classes,))
_enc_dense = inst_layers(enc_layers, x_in)

_z_mean_1 = Dense(latent_dim)(_enc_dense)
_z_log_var_1 = Dense(latent_dim)(_enc_dense)

z_mean = _z_mean_1
z_log_var = _z_log_var_1
# print(z_mean)

enc_conv_1 Tensor("input_1:0", shape=(?, 20, 20, 3), dtype=float32) <keras.layers.convolutional.Conv2D object at 0x7fc96e9cfcc0>
activation_1 Tensor("enc_conv_1/BiasAdd:0", shape=(?, 20, 20, 64), dtype=float32) <keras.layers.core.Activation object at 0x7fc96e9cfe48>
enc_conv_2 Tensor("activation_1/Relu:0", shape=(?, 20, 20, 64), dtype=float32) <keras.layers.convolutional.Conv2D object at 0x7fc96e9cfef0>
activation_2 Tensor("enc_conv_2/BiasAdd:0", shape=(?, 20, 20, 64), dtype=float32) <keras.layers.core.Activation object at 0x7fc96e9cffd0>
enc_conv_3 Tensor("activation_2/Relu:0", shape=(?, 20, 20, 64), dtype=float32) <keras.layers.convolutional.Conv2D object at 0x7fc96e9c00f0>
activation_3 Tensor("enc_conv_3/BiasAdd:0", shape=(?, 10, 10, 64), dtype=float32) <keras.layers.core.Activation object at 0x7fc96e9c0240>
flatten_1 Tensor("activation_3/Relu:0", shape=(?, 10, 10, 64), dtype=float32) <keras.layers.core.Flatten object at 0x7fc96e9c02b0>
dense_4 Tensor("flatten_1/Reshape:0", shape=(?

### Reparameterization Trick

In [9]:
# def sampling(args, batch_size=-1, latent_dim=latent_dim, epsilon_std=epsilon_std):
#     z_mean, z_log_var = args
    
#     epsilon = K.random_normal(shape=(batch_size, latent_dim),
#                               mean=0., stddev=epsilon_std)
    
#     return z_mean + K.exp(z_log_var) * epsilon
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# Classifier Network 

In [10]:
classifier_layers = [
    Conv2D(32, (3, 3), padding='same'),
    Activation('relu'),
    Conv2D(32, (3, 3)),
    Activation('relu'),
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.25),
    Conv2D(64, (3, 3), padding='same'),
    Activation('relu'),
    Conv2D(64, (3, 3)),
    Activation('relu'),
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.25),
    Flatten(),
    Dense(512,name='dense_cls_'),
    Activation('relu'),
    Dropout(0.5),
    Dense(num_classes,name='dense_cls'),
#     Activation('softmax'),
    Activation(tf.nn.softmax),
]

In [11]:
_cls_output = inst_layers(classifier_layers, x_in)
_y_output = _cls_output

conv2d_1 Tensor("input_1:0", shape=(?, 20, 20, 3), dtype=float32) <keras.layers.convolutional.Conv2D object at 0x7fc96e8d7c88>
activation_5 Tensor("conv2d_1/BiasAdd:0", shape=(?, 20, 20, 32), dtype=float32) <keras.layers.core.Activation object at 0x7fc96e8d7dd8>
conv2d_2 Tensor("activation_5/Relu:0", shape=(?, 20, 20, 32), dtype=float32) <keras.layers.convolutional.Conv2D object at 0x7fc96e8d7e48>
activation_6 Tensor("conv2d_2/BiasAdd:0", shape=(?, 18, 18, 32), dtype=float32) <keras.layers.core.Activation object at 0x7fc96e8d7f98>
max_pooling2d_1 Tensor("activation_6/Relu:0", shape=(?, 18, 18, 32), dtype=float32) <keras.layers.pooling.MaxPooling2D object at 0x7fc96e86a048>
dropout_2 Tensor("max_pooling2d_1/MaxPool:0", shape=(?, 9, 9, 32), dtype=float32) <keras.layers.core.Dropout object at 0x7fc96e86a0f0>
conv2d_3 Tensor("dropout_2/cond/Merge:0", shape=(?, 9, 9, 32), dtype=float32) <keras.layers.convolutional.Conv2D object at 0x7fc96e86a128>
activation_7 Tensor("conv2d_3/BiasAdd:0", sh

# Decoder Network

In [12]:
def create_dec_trans_conv_layers(stage, **kwargs):
    conv_name = '_'.join(['dec_trans_conv', str(stage)])
#     bn_name = '_'.join(['dec_bn', str(stage)])
    layers = [
        Conv2DTranspose(name=conv_name, **kwargs),
#         BatchNormalization(name=bn_name),
        Activation(activation),
    ]
    return layers

In [13]:


dec_filters = 64

decoder_layers = [
    create_dense_layers(stage=10, width=10 * 10 * 64),
    Reshape((10, 10, 64)),
    create_dec_trans_conv_layers(11, filters=dec_filters, kernel_size=3, strides=1, padding='same'),
    create_dec_trans_conv_layers(12, filters=dec_filters, kernel_size=3, strides=1, padding='same'),
    create_dec_trans_conv_layers(13, filters=dec_filters, kernel_size=3, strides=2, padding='same'),
    Conv2DTranspose(name='x_decoded', filters=img_chns, kernel_size=1, strides=1, activation='sigmoid'),
]

In [14]:

# Labeled decoder
_merged = concatenate([y_in, z])

_dec_out = inst_layers(decoder_layers, _merged)
_x_output = _dec_out



dense_10 Tensor("concatenate_1/concat:0", shape=(?, 130), dtype=float32) <keras.layers.core.Dense object at 0x7fc96da4f7b8>
activation_11 Tensor("dense_10/BiasAdd:0", shape=(?, 6400), dtype=float32) <keras.layers.core.Activation object at 0x7fc96da4f908>
dropout_5 Tensor("activation_11/Relu:0", shape=(?, 6400), dtype=float32) <keras.layers.core.Dropout object at 0x7fc96da4f940>
reshape_1 Tensor("dropout_5/cond/Merge:0", shape=(?, 6400), dtype=float32) <keras.layers.core.Reshape object at 0x7fc96da4f978>
dec_trans_conv_11 Tensor("reshape_1/Reshape:0", shape=(?, 10, 10, 64), dtype=float32) <keras.layers.convolutional.Conv2DTranspose object at 0x7fc96da4f9b0>
activation_12 Tensor("dec_trans_conv_11/BiasAdd:0", shape=(?, ?, ?, 64), dtype=float32) <keras.layers.core.Activation object at 0x7fc96da4fb38>
dec_trans_conv_12 Tensor("activation_12/Relu:0", shape=(?, ?, ?, 64), dtype=float32) <keras.layers.convolutional.Conv2DTranspose object at 0x7fc96da4fb70>
activation_13 Tensor("dec_trans_conv

In [15]:

# Unlabeled decoder
print(z, _y_output)
u_merged = concatenate([_y_output, z])
# for layer in decoder_layers:
#     if isinstance(layer, list):
#         for sub_layer in layer:
#             print(sub_layer.name)
#     else:
#         print(layer.name)
# print('******')
u_dec_out = inst_layers(decoder_layers, u_merged)
u_x_output = u_dec_out

Tensor("lambda_1/add:0", shape=(?, 128), dtype=float32) Tensor("activation_10/Softmax:0", shape=(?, 2), dtype=float32)
dense_10 Tensor("concatenate_2/concat:0", shape=(?, 130), dtype=float32) <keras.layers.core.Dense object at 0x7fc96da4f7b8>
activation_11 Tensor("dense_10_1/BiasAdd:0", shape=(?, 6400), dtype=float32) <keras.layers.core.Activation object at 0x7fc96da4f908>
dropout_5 Tensor("activation_11_1/Relu:0", shape=(?, 6400), dtype=float32) <keras.layers.core.Dropout object at 0x7fc96da4f940>
reshape_1 Tensor("dropout_5_1/cond/Merge:0", shape=(?, 6400), dtype=float32) <keras.layers.core.Reshape object at 0x7fc96da4f978>
dec_trans_conv_11 Tensor("reshape_1_1/Reshape:0", shape=(?, 10, 10, 64), dtype=float32) <keras.layers.convolutional.Conv2DTranspose object at 0x7fc96da4f9b0>
activation_12 Tensor("dec_trans_conv_11_1/BiasAdd:0", shape=(?, ?, ?, 64), dtype=float32) <keras.layers.core.Activation object at 0x7fc96da4fb38>
dec_trans_conv_12 Tensor("activation_12_1/Relu:0", shape=(?, ?

# Loss Function

In [16]:
def kl_loss(x, x_decoded_mean, z_mean=z_mean, z_log_var=z_log_var):
    kl_loss = - 0.5 * K.sum(1. + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
   
    return K.mean(kl_loss)

def logxy_loss(x, x_decoded_mean):
    x = K.flatten(x)
    x_decoded_mean = K.flatten(x_decoded_mean)
    xent_loss = img_rows * img_cols * img_chns * metrics.binary_crossentropy(x, x_decoded_mean)
   
    # p(y) for observed data is equally distributed
    logy = np.log(1. / num_classes)
    
    return xent_loss - logy

def labeled_vae_loss(x, x_decoded_mean):
    return logxy_loss(x, x_decoded_mean) + kl_loss(x, x_decoded_mean)

def cls_loss(y, y_pred, N=1000):
    alpha = 0.1 * N
    return alpha * metrics.categorical_crossentropy(y, y_pred)

def unlabeled_vae_loss(x, x_decoded_mean):
    entropy = metrics.categorical_crossentropy(_y_output, _y_output)
    # This is probably not correct, see discussion here: https://github.com/bjlkeng/sandbox/issues/3
    labeled_loss = logxy_loss(x, x_decoded_mean) + kl_loss(x, x_decoded_mean)
    
    return K.mean(K.sum(_y_output * labeled_loss, axis=-1)) + entropy

# Compile Model

In [17]:
label_vae = Model(inputs=[x_in, y_in], outputs=[_x_output, _y_output])
# optimizer = Adam(lr=learning_rate, decay=decay)
optimizer = SGD(lr=learning_rate)
label_vae.compile(optimizer=optimizer, loss=[labeled_vae_loss, cls_loss])
label_vae.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 20, 20, 3)    0                                            
__________________________________________________________________________________________________
enc_conv_1 (Conv2D)             (None, 20, 20, 64)   1792        input_1[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 20, 20, 64)   0           enc_conv_1[0][0]                 
__________________________________________________________________________________________________
enc_conv_2 (Conv2D)             (None, 20, 20, 64)   36928       activation_1[0][0]               
__________________________________________________________________________________________________
activation

In [18]:
unlabeled_vae = Model(inputs=x_in, outputs=u_x_output)
# optimizer = Adam(lr=learning_rate, decay=decay)
optimizer = SGD(lr=learning_rate)
unlabeled_vae.compile(optimizer=optimizer, loss=unlabeled_vae_loss)
unlabeled_vae.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 20, 20, 3)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 20, 20, 32)   896         input_1[0][0]                    
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 20, 20, 32)   0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 18, 18, 32)   9248        activation_5[0][0]               
__________________________________________________________________________________________________
activation

In [19]:
# from tqdm import tnrange, tqdm_notebook
def fit_model(X_unlabeled, X_labeled, y_labeled, epochs):
#     assert len(X_unlabeled) % len(X_labeled) == 0, \
#             (len(X_unlabeled), batch_size, len(X_labeled))
    start = time.time()
    history = []
    
    for epoch in range(epochs):
        unlabeled_index = np.arange(len(X_unlabeled))
        np.random.shuffle(unlabeled_index)
        
        # Repeat the labeled data to match length of unlabeled data
#         labeled_index = []
#         for i in range(len(X_unlabeled) // len(X_labeled)):
        l = np.arange(len(X_labeled))
        np.random.shuffle(l)
        labeled_index = l
#         labeled_index.append(l)
#         labeled_index = np.concatenate(labeled_index)
        
        batches = len(X_unlabeled) // batch_size
        
        for i in range(batches):
            # Labeled
            index_range =  labeled_index[i * batch_size:(i+1) * batch_size]
            
            loss = label_vae.train_on_batch([X_labeled[labeled_index], y_labeled[labeled_index]], 
                                            [X_labeled[labeled_index], y_labeled[labeled_index]])

#             loss = label_vae.train_on_batch([X_labeled[index_range], y_labeled[index_range]], 
#                                             [X_labeled[index_range], y_labeled[index_range]])

            # Unlabeled
            index_range =  unlabeled_index[i * batch_size:(i+1) * batch_size]
            loss += [unlabeled_vae.train_on_batch(X_unlabeled[index_range],  X_unlabeled[index_range])]
            print('epoch '+str(epoch)+'   loss = ' +str(loss))
            history.append(loss)
                

    
   
    done = time.time()
    elapsed = done - start
    print("Elapsed: ", elapsed)
    
    return history

In [20]:
sample_size = 10
start = time.time()
print('Fitting with sample_size: {}'.format(sample_size))

# if sample_size < len(X_train):
#     sss = StratifiedShuffleSplit(n_splits=2, test_size=sample_size / len(X_train), random_state=0)
#     _, index = sss.split(X_train, y_train)
#     X, y = X_train[index[1]], y_train[index[1]]
#     X_others, _ = X_train[index[0]], y_train[index[0]]
# else:
#     X, y = X_train, y_train

# y = np_utils.to_categorical(y)

history = fit_model(X_others, X, y, epochs=epochs)



Fitting with sample_size: 10
epoch 0   loss = [905.68115, 835.38727, 70.2939, 833.1175]
epoch 0   loss = [900.1317, 829.41956, 70.71213, 831.63885]
epoch 1   loss = [893.89307, 823.3495, 70.5436, 830.13824]
epoch 1   loss = [887.5386, 817.7822, 69.75632, 828.979]
epoch 2   loss = [883.86597, 812.3882, 71.47779, 828.02466]
epoch 2   loss = [876.8726, 807.1199, 69.75275, 826.38745]
epoch 3   loss = [871.61304, 802.0248, 69.588264, 825.88715]
epoch 3   loss = [866.04584, 797.1687, 68.87714, 824.8247]
epoch 4   loss = [860.34064, 792.2245, 68.116135, 824.78656]
epoch 4   loss = [856.2073, 787.50684, 68.70041, 823.0294]
epoch 5   loss = [851.20844, 783.04553, 68.1629, 823.80554]
epoch 5   loss = [849.02295, 778.83325, 70.18973, 821.94824]
epoch 6   loss = [846.58716, 774.4974, 72.08977, 822.6978]
epoch 6   loss = [839.6262, 770.5437, 69.08253, 821.7988]
epoch 7   loss = [835.8156, 766.68384, 69.131775, 821.1939]
epoch 7   loss = [831.99585, 762.98914, 69.00673, 822.69165]
epoch 8   loss = [

epoch 69   loss = [734.58075, 687.9529, 46.62786, 844.39575]
epoch 70   loss = [727.69434, 688.09906, 39.59525, 843.7599]
epoch 70   loss = [730.24164, 687.99207, 42.249596, 846.417]
epoch 71   loss = [729.7383, 687.9387, 41.79957, 845.318]
epoch 71   loss = [727.2033, 687.8539, 39.349422, 844.8418]
epoch 72   loss = [729.1055, 687.95667, 41.148857, 846.4151]
epoch 72   loss = [723.7664, 688.13525, 35.63114, 843.7866]
epoch 73   loss = [725.6157, 688.0072, 37.60851, 845.6786]
epoch 73   loss = [725.3819, 688.1116, 37.27033, 844.4916]
epoch 74   loss = [723.82666, 687.8978, 35.92885, 846.6067]
epoch 74   loss = [728.4239, 688.1155, 40.308395, 843.55817]
epoch 75   loss = [726.41565, 687.9231, 38.492523, 844.64966]
epoch 75   loss = [724.52563, 688.0438, 36.481834, 845.5821]
epoch 76   loss = [721.6672, 687.8531, 33.814095, 843.2691]
epoch 76   loss = [725.0048, 687.98413, 37.02069, 846.9811]
epoch 77   loss = [724.3791, 687.9033, 36.475765, 846.4118]
epoch 77   loss = [724.70844, 687.86

epoch 137   loss = [702.26337, 687.53906, 14.724302, 844.71515]
epoch 138   loss = [697.7184, 687.6802, 10.038232, 847.83105]
epoch 138   loss = [701.897, 687.66077, 14.236236, 842.31104]
epoch 139   loss = [699.3947, 687.7156, 11.67911, 847.3434]
epoch 139   loss = [699.29553, 687.60754, 11.687974, 842.7703]
epoch 140   loss = [701.1212, 687.60693, 13.514307, 846.4676]
epoch 140   loss = [700.7551, 687.58496, 13.170183, 843.7001]
epoch 141   loss = [702.99915, 687.72797, 15.271161, 847.0158]
epoch 141   loss = [695.9971, 687.6592, 8.337906, 843.10345]
epoch 142   loss = [699.136, 687.56384, 11.572172, 845.18536]
epoch 142   loss = [696.6221, 687.55786, 9.064186, 844.982]
epoch 143   loss = [697.6594, 687.6517, 10.007776, 847.98865]
epoch 143   loss = [696.2428, 687.6157, 8.627061, 842.1955]
epoch 144   loss = [696.7378, 687.61365, 9.124143, 843.5424]
epoch 144   loss = [700.5634, 687.572, 12.991413, 846.74646]
epoch 145   loss = [703.0403, 687.61096, 15.429342, 844.6247]
epoch 145   l

epoch 205   loss = [690.4643, 687.493, 2.971301, 845.41327]
epoch 205   loss = [690.89844, 687.4862, 3.412211, 844.64825]
epoch 206   loss = [693.30975, 687.52966, 5.7800894, 842.2196]
epoch 206   loss = [694.1003, 687.4219, 6.6784267, 847.8594]
epoch 207   loss = [690.12756, 687.42053, 2.7070527, 846.89056]
epoch 207   loss = [692.3822, 687.53406, 4.8481145, 843.0766]
epoch 208   loss = [690.2253, 687.5012, 2.7240329, 844.3916]
epoch 208   loss = [691.81976, 687.60913, 4.21065, 845.6372]
epoch 209   loss = [690.5631, 687.50476, 3.058322, 846.84717]
epoch 209   loss = [691.25745, 687.6301, 3.6273088, 843.15204]
epoch 210   loss = [691.6166, 687.61584, 4.000761, 844.2383]
epoch 210   loss = [691.93115, 687.5719, 4.359261, 845.7916]
epoch 211   loss = [690.2057, 687.495, 2.7106674, 842.963]
epoch 211   loss = [692.0797, 687.60034, 4.479345, 847.0728]
epoch 212   loss = [691.22046, 687.5106, 3.7098594, 842.9724]
epoch 212   loss = [690.24744, 687.5237, 2.7237377, 847.09436]
epoch 213   lo

epoch 273   loss = [689.3506, 687.5923, 1.7583237, 844.52264]
epoch 273   loss = [689.46936, 687.5239, 1.945421, 845.3466]
epoch 274   loss = [689.6905, 687.51196, 2.1784983, 846.12964]
epoch 274   loss = [688.97284, 687.54364, 1.4292192, 843.7099]
epoch 275   loss = [689.3098, 687.44336, 1.86645, 842.0051]
epoch 275   loss = [688.74646, 687.3541, 1.3923328, 847.9076]
epoch 276   loss = [688.8406, 687.4674, 1.3731458, 845.50073]
epoch 276   loss = [689.0528, 687.57556, 1.4772127, 844.3508]
epoch 277   loss = [688.2608, 687.4325, 0.82830095, 844.8757]
epoch 277   loss = [690.53705, 687.57056, 2.9665127, 844.9938]
epoch 278   loss = [689.43915, 687.41724, 2.0219336, 843.0097]
epoch 278   loss = [688.4389, 687.39795, 1.0409606, 846.8707]
epoch 279   loss = [688.9086, 687.57153, 1.3371, 843.62427]
epoch 279   loss = [689.1436, 687.49744, 1.6461549, 846.2691]
epoch 280   loss = [688.7594, 687.5115, 1.247936, 845.37463]
epoch 280   loss = [689.7074, 687.5725, 2.1349165, 844.4653]
epoch 281  

epoch 341   loss = [689.4616, 687.4791, 1.9824817, 844.9945]
epoch 341   loss = [687.7815, 687.47424, 0.30724955, 844.79004]
epoch 342   loss = [687.8345, 687.40393, 0.43054444, 846.2944]
epoch 342   loss = [688.0389, 687.5906, 0.4483011, 843.48175]
epoch 343   loss = [688.4617, 687.38904, 1.0726466, 847.88074]
epoch 343   loss = [688.1246, 687.44415, 0.6804284, 841.8747]
epoch 344   loss = [688.07477, 687.4192, 0.6555933, 848.77167]
epoch 344   loss = [688.72516, 687.53625, 1.1888943, 840.9891]
epoch 345   loss = [689.6318, 687.4093, 2.2224755, 845.19806]
epoch 345   loss = [688.76294, 687.3883, 1.3746369, 844.5779]
epoch 346   loss = [688.88745, 687.53467, 1.3527663, 842.506]
epoch 346   loss = [688.0071, 687.41345, 0.59362245, 847.3242]
epoch 347   loss = [688.8036, 687.3583, 1.4452903, 846.86316]
epoch 347   loss = [688.513, 687.4297, 1.083328, 842.93286]
epoch 348   loss = [688.01404, 687.30566, 0.70836717, 845.6697]
epoch 348   loss = [688.13086, 687.42224, 0.7085916, 844.14404]


epoch 408   loss = [687.98926, 687.3874, 0.6018728, 843.6262]
epoch 408   loss = [688.7224, 687.3606, 1.361808, 846.1867]
epoch 409   loss = [687.7567, 687.29846, 0.45826983, 842.85657]
epoch 409   loss = [687.8308, 687.2899, 0.54088074, 846.93005]
epoch 410   loss = [687.65344, 687.44946, 0.20400423, 844.19214]
epoch 410   loss = [688.2469, 687.33765, 0.9092703, 845.603]
epoch 411   loss = [688.0411, 687.3778, 0.6632864, 844.38354]
epoch 411   loss = [688.78723, 687.33386, 1.453348, 845.40845]
epoch 412   loss = [687.797, 687.35486, 0.44213656, 844.7019]
epoch 412   loss = [687.7463, 687.3718, 0.37447068, 845.09717]
epoch 413   loss = [690.06415, 687.3922, 2.6719117, 840.9852]
epoch 413   loss = [687.6149, 687.3396, 0.27531716, 848.84827]
epoch 414   loss = [688.05457, 687.3988, 0.65575624, 845.38525]
epoch 414   loss = [687.7698, 687.40875, 0.36102554, 844.3871]
epoch 415   loss = [687.8528, 687.3998, 0.45302784, 846.3537]
epoch 415   loss = [688.20575, 687.3491, 0.8566496, 843.382]


epoch 474   loss = [687.583, 687.32776, 0.25526363, 844.978]
epoch 474   loss = [687.3936, 687.18005, 0.21355233, 844.755]
epoch 475   loss = [687.68964, 687.4492, 0.24043018, 847.2141]
epoch 475   loss = [687.5378, 687.3562, 0.18158644, 842.49774]
epoch 476   loss = [687.5468, 687.2948, 0.25202894, 842.9481]
epoch 476   loss = [687.9136, 687.40845, 0.50510305, 846.81354]
epoch 477   loss = [687.5304, 687.23145, 0.2989614, 844.2435]
epoch 477   loss = [703.2962, 687.3739, 15.922327, 845.4932]
epoch 478   loss = [687.91406, 687.2728, 0.64125156, 845.3854]
epoch 478   loss = [688.20026, 687.4136, 0.7866687, 844.28735]
epoch 479   loss = [687.58185, 687.34656, 0.23526973, 844.17004]
epoch 479   loss = [688.1053, 687.3561, 0.74921966, 845.54333]
epoch 480   loss = [687.59033, 687.2997, 0.29064676, 845.79443]
epoch 480   loss = [687.6668, 687.35254, 0.31426343, 843.89734]
epoch 481   loss = [688.2455, 687.3711, 0.8744162, 845.0974]
epoch 481   loss = [689.1946, 687.32025, 1.8742995, 844.610

epoch 541   loss = [687.94055, 687.36365, 0.57690215, 848.0058]
epoch 541   loss = [687.54767, 687.3259, 0.22173005, 841.69336]
epoch 542   loss = [687.6918, 687.3662, 0.32555416, 844.7531]
epoch 542   loss = [687.5309, 687.2987, 0.23220252, 844.9932]
epoch 543   loss = [687.52167, 687.2761, 0.24552225, 846.8559]
epoch 543   loss = [687.68024, 687.3906, 0.289615, 842.8715]
epoch 544   loss = [687.585, 687.34485, 0.24017298, 844.54626]
epoch 544   loss = [687.4338, 687.2468, 0.18693773, 845.1831]
epoch 545   loss = [687.59216, 687.2953, 0.29687873, 843.1321]
epoch 545   loss = [687.4574, 687.27075, 0.18662211, 846.6395]
epoch 546   loss = [687.68555, 687.3281, 0.3574003, 847.9553]
epoch 546   loss = [687.5764, 687.29016, 0.28626287, 841.694]
epoch 547   loss = [688.08484, 687.2814, 0.8034718, 846.36145]
epoch 547   loss = [687.66095, 687.3456, 0.3153487, 843.3431]
epoch 548   loss = [687.8525, 687.34924, 0.5032467, 846.24817]
epoch 548   loss = [687.85266, 687.29663, 0.556046, 843.47034

epoch 607   loss = [687.6607, 687.3075, 0.35318884, 846.6174]
epoch 607   loss = [687.5368, 687.349, 0.18779776, 843.05035]
epoch 608   loss = [687.68024, 687.294, 0.38620406, 841.3672]
epoch 608   loss = [687.294, 687.177, 0.11703086, 848.36707]
epoch 609   loss = [687.5014, 687.22754, 0.27389237, 842.57495]
epoch 609   loss = [687.47003, 687.26636, 0.20366582, 847.1517]
epoch 610   loss = [687.68945, 687.308, 0.38150024, 843.85175]
epoch 610   loss = [687.3884, 687.2429, 0.14549866, 845.84766]
epoch 611   loss = [687.47577, 687.26086, 0.21489868, 843.20605]
epoch 611   loss = [687.5153, 687.30493, 0.210401, 846.5122]
epoch 612   loss = [687.72284, 687.3513, 0.37153542, 842.15625]
epoch 612   loss = [687.4332, 687.2751, 0.15813254, 847.57434]
epoch 613   loss = [687.5112, 687.30945, 0.2017928, 843.4031]
epoch 613   loss = [687.4933, 687.3245, 0.16876158, 846.30475]
epoch 614   loss = [687.5823, 687.2627, 0.31958807, 843.4369]
epoch 614   loss = [687.6637, 687.306, 0.3576585, 846.2581]

epoch 673   loss = [687.46875, 687.2527, 0.21604288, 843.68616]
epoch 673   loss = [687.2919, 687.20886, 0.0829989, 846.03107]
epoch 674   loss = [687.2891, 687.1892, 0.09989348, 845.1337]
epoch 674   loss = [687.5941, 687.25024, 0.34388286, 844.5759]
epoch 675   loss = [687.51056, 687.2374, 0.2731596, 845.4833]
epoch 675   loss = [687.68774, 687.199, 0.4887548, 844.2096]
epoch 676   loss = [687.6041, 687.2588, 0.34534594, 845.57935]
epoch 676   loss = [687.4677, 687.297, 0.1707142, 844.10986]
epoch 677   loss = [687.48737, 687.3502, 0.13717587, 842.0703]
epoch 677   loss = [687.33673, 687.271, 0.06573997, 847.66565]
epoch 678   loss = [687.32104, 687.2252, 0.09583078, 845.03625]
epoch 678   loss = [687.4493, 687.2499, 0.1994237, 844.6513]
epoch 679   loss = [687.38635, 687.2555, 0.13082987, 843.7082]
epoch 679   loss = [687.426, 687.3276, 0.098423555, 845.9955]
epoch 680   loss = [687.4956, 687.3473, 0.14831814, 845.08344]
epoch 680   loss = [687.30566, 687.2268, 0.07884105, 844.5939]

epoch 739   loss = [687.46326, 687.26184, 0.20141485, 843.3181]
epoch 739   loss = [687.49304, 687.17554, 0.31748706, 846.3589]
epoch 740   loss = [687.37085, 687.3314, 0.03942219, 843.9258]
epoch 740   loss = [687.85583, 687.37805, 0.47775906, 845.7554]
epoch 741   loss = [687.4495, 687.26184, 0.1876806, 844.32135]
epoch 741   loss = [687.1632, 687.11694, 0.0462585, 845.35657]
epoch 742   loss = [687.60657, 687.20544, 0.40115264, 846.17413]
epoch 742   loss = [687.3753, 687.25574, 0.119567096, 843.47906]
epoch 743   loss = [687.3784, 687.2295, 0.14892814, 842.6501]
epoch 743   loss = [687.3204, 687.27795, 0.04239857, 847.0543]
epoch 744   loss = [687.2927, 687.1898, 0.10290603, 848.7254]
epoch 744   loss = [687.32605, 687.2731, 0.052984834, 840.8811]
epoch 745   loss = [687.1692, 687.1343, 0.034939513, 843.0808]
epoch 745   loss = [687.3886, 687.23535, 0.15326515, 846.60333]
epoch 746   loss = [687.3409, 687.2615, 0.07939561, 847.1014]
epoch 746   loss = [687.47327, 687.3314, 0.141869

epoch 806   loss = [687.4071, 687.1604, 0.24670772, 845.4219]
epoch 806   loss = [687.20886, 687.1455, 0.063348815, 844.2449]
epoch 807   loss = [687.3053, 687.2246, 0.08071753, 844.8481]
epoch 807   loss = [687.43207, 687.27893, 0.1531348, 844.8266]
epoch 808   loss = [687.2988, 687.23035, 0.06850253, 840.2848]
epoch 808   loss = [687.19055, 687.0519, 0.1386571, 849.4705]
epoch 809   loss = [687.4764, 687.2158, 0.2605565, 843.30615]
epoch 809   loss = [687.18524, 687.10767, 0.07757946, 846.39276]
epoch 810   loss = [687.26105, 687.1353, 0.1257181, 844.6259]
epoch 810   loss = [687.2284, 687.1362, 0.092145, 845.05835]
epoch 811   loss = [687.3507, 687.2466, 0.104132235, 847.8705]
epoch 811   loss = [687.4188, 687.3186, 0.1002049, 841.7785]
epoch 812   loss = [687.3078, 687.11145, 0.19638008, 844.17755]
epoch 812   loss = [687.36163, 687.1033, 0.25836104, 845.52356]
epoch 813   loss = [687.2539, 687.1592, 0.09471756, 845.1411]
epoch 813   loss = [687.2712, 687.19446, 0.07674482, 844.546

epoch 871   loss = [687.2617, 687.0981, 0.16360915, 845.80176]
epoch 871   loss = [687.2287, 687.1349, 0.09382449, 843.8518]
epoch 872   loss = [687.2421, 687.14355, 0.098593295, 842.15796]
epoch 872   loss = [687.2362, 687.16797, 0.06822487, 847.5371]
epoch 873   loss = [687.2924, 687.24133, 0.051089924, 845.4225]
epoch 873   loss = [687.22327, 687.11945, 0.10382163, 844.233]
epoch 874   loss = [687.5234, 687.1422, 0.38114852, 845.4454]
epoch 874   loss = [687.2685, 687.20276, 0.06574806, 844.2192]
epoch 875   loss = [687.18604, 687.09827, 0.08777877, 842.1396]
epoch 875   loss = [687.23566, 687.1408, 0.09486591, 847.5702]
epoch 876   loss = [687.4241, 687.26025, 0.16383243, 848.15076]
epoch 876   loss = [687.30206, 687.22864, 0.0734427, 841.47534]
epoch 877   loss = [687.2287, 687.1654, 0.063274354, 845.3511]
epoch 877   loss = [687.21625, 687.15234, 0.06388512, 844.3164]
epoch 878   loss = [687.2333, 687.1588, 0.07445377, 846.15204]
epoch 878   loss = [687.23193, 687.199, 0.03297712

epoch 936   loss = [687.3109, 687.2012, 0.1097235, 845.60034]
epoch 936   loss = [687.3571, 687.1876, 0.16947663, 844.05554]
epoch 937   loss = [689.0526, 687.1937, 1.858897, 846.0004]
epoch 937   loss = [687.2759, 687.19226, 0.08363401, 843.6564]
epoch 938   loss = [687.35, 687.26514, 0.084810376, 843.7454]
epoch 938   loss = [687.1683, 687.0932, 0.075055376, 845.9256]
epoch 939   loss = [687.24524, 687.1616, 0.083590835, 845.4841]
epoch 939   loss = [687.25885, 687.18933, 0.069538325, 844.1713]
epoch 940   loss = [687.21356, 687.1002, 0.11333231, 845.21387]
epoch 940   loss = [687.3735, 687.2301, 0.1433785, 844.4188]
epoch 941   loss = [687.13824, 687.1157, 0.022540864, 844.48584]
epoch 941   loss = [687.13293, 687.0935, 0.039450984, 845.1506]
epoch 942   loss = [687.2495, 687.1329, 0.11666764, 840.36395]
epoch 942   loss = [687.1769, 687.1095, 0.06736719, 849.3338]
epoch 943   loss = [687.24585, 687.1575, 0.08836753, 843.8235]
epoch 943   loss = [687.304, 687.16113, 0.1428549, 845.8

epoch 1001   loss = [687.3981, 687.0873, 0.31076312, 844.56995]
epoch 1002   loss = [687.2188, 687.1051, 0.113680065, 844.3235]
epoch 1002   loss = [687.26685, 687.09766, 0.169206, 845.33276]
epoch 1003   loss = [687.2705, 687.13135, 0.13916934, 844.1158]
epoch 1003   loss = [687.3929, 687.1858, 0.20710917, 845.5377]
epoch 1004   loss = [687.3497, 687.0637, 0.28592646, 846.70074]
epoch 1004   loss = [687.2648, 687.17664, 0.08811376, 842.9181]
epoch 1005   loss = [687.56793, 687.0654, 0.5025064, 843.7101]
epoch 1005   loss = [687.26337, 687.1367, 0.12665802, 845.92914]
epoch 1006   loss = [687.2393, 687.0676, 0.17168461, 845.5816]
epoch 1006   loss = [687.3282, 687.2042, 0.12399222, 844.0401]
epoch 1007   loss = [687.36774, 687.11475, 0.25299084, 844.75745]
epoch 1007   loss = [687.1369, 687.0492, 0.08771093, 844.85815]
epoch 1008   loss = [687.3606, 687.2301, 0.13052282, 849.22626]
epoch 1008   loss = [687.3155, 687.21344, 0.10205406, 840.35614]
epoch 1009   loss = [687.3662, 687.1478,

epoch 1066   loss = [687.1478, 687.08386, 0.063994296, 842.3558]
epoch 1067   loss = [687.17914, 687.0797, 0.099444695, 843.2773]
epoch 1067   loss = [687.0946, 687.0399, 0.054703444, 846.38116]
epoch 1068   loss = [687.1421, 687.10425, 0.037834056, 845.0992]
epoch 1068   loss = [687.2339, 687.20044, 0.033468902, 844.5295]
epoch 1069   loss = [687.2434, 687.0814, 0.1620081, 843.6154]
epoch 1069   loss = [687.26776, 687.1355, 0.13226007, 846.0365]
epoch 1070   loss = [687.22174, 687.1239, 0.09781593, 847.61957]
epoch 1070   loss = [687.2478, 687.1981, 0.04968358, 841.9709]
epoch 1071   loss = [687.2781, 687.2163, 0.061740015, 842.9236]
epoch 1071   loss = [687.1988, 687.15564, 0.043179013, 846.7456]
epoch 1072   loss = [687.33136, 687.2551, 0.07625137, 846.03754]
epoch 1072   loss = [687.2775, 687.20325, 0.074268304, 843.57556]
epoch 1073   loss = [687.22174, 687.1564, 0.065395445, 842.4898]
epoch 1073   loss = [687.0652, 687.0462, 0.01899787, 847.18115]
epoch 1074   loss = [687.2221, 6

epoch 1130   loss = [687.2047, 687.177, 0.027691783, 844.75214]
epoch 1131   loss = [687.15704, 687.106, 0.051007412, 842.62915]
epoch 1131   loss = [687.07587, 686.9923, 0.08355798, 847.05273]
epoch 1132   loss = [687.2184, 687.0602, 0.15820664, 847.142]
epoch 1132   loss = [687.18555, 687.10583, 0.07972831, 842.4498]
epoch 1133   loss = [687.09436, 687.0724, 0.02194216, 844.3479]
epoch 1133   loss = [687.15955, 687.13, 0.029519383, 845.2953]
epoch 1134   loss = [687.1522, 687.10474, 0.04751038, 843.96954]
epoch 1134   loss = [687.2288, 687.08276, 0.14606577, 845.67737]
epoch 1135   loss = [687.06616, 687.0449, 0.02123269, 846.40234]
epoch 1135   loss = [687.293, 687.1467, 0.14627808, 843.2133]
epoch 1136   loss = [687.2761, 687.1858, 0.09032581, 843.41626]
epoch 1136   loss = [687.1848, 687.1001, 0.08472767, 846.2428]
epoch 1137   loss = [687.03345, 686.9813, 0.052117635, 840.6642]
epoch 1137   loss = [687.27747, 687.0962, 0.18127604, 849.0355]
epoch 1138   loss = [687.25995, 687.128

epoch 1195   loss = [687.6406, 687.11597, 0.5246435, 842.3658]
epoch 1196   loss = [687.2415, 687.2168, 0.024722239, 843.9968]
epoch 1196   loss = [687.1447, 687.1221, 0.02261541, 845.67535]
epoch 1197   loss = [687.2817, 687.2374, 0.044226736, 844.3922]
epoch 1197   loss = [687.64575, 686.96625, 0.67950547, 845.2881]
epoch 1198   loss = [687.2388, 687.1168, 0.12197012, 847.14246]
epoch 1198   loss = [687.1924, 687.1738, 0.018526927, 842.5063]
epoch 1199   loss = [687.1911, 687.152, 0.03914962, 850.0733]
epoch 1199   loss = [687.33765, 687.19714, 0.14052802, 839.51404]
epoch 1200   loss = [687.21173, 687.08765, 0.124059275, 844.19763]
epoch 1200   loss = [687.2338, 687.1987, 0.0350944, 845.46246]
epoch 1201   loss = [687.23157, 687.10095, 0.13059545, 847.15594]
epoch 1201   loss = [687.18634, 687.1559, 0.030471446, 842.497]
epoch 1202   loss = [687.08813, 687.0134, 0.07471329, 844.27747]
epoch 1202   loss = [687.12714, 687.0957, 0.03144775, 845.4137]
epoch 1203   loss = [687.0313, 686.

epoch 1259   loss = [687.07965, 687.02356, 0.05611638, 844.9233]
epoch 1260   loss = [687.14764, 687.07855, 0.06907252, 846.5132]
epoch 1260   loss = [687.20154, 687.16406, 0.037493438, 843.11426]
epoch 1261   loss = [687.0901, 687.05554, 0.034542397, 844.96893]
epoch 1261   loss = [687.0137, 686.97864, 0.035023987, 844.67474]
epoch 1262   loss = [687.1389, 687.0724, 0.066539705, 843.7065]
epoch 1262   loss = [687.15045, 687.079, 0.07147986, 845.95374]
epoch 1263   loss = [687.2492, 687.0746, 0.17462112, 844.87744]
epoch 1263   loss = [687.0401, 686.99677, 0.04333014, 844.77075]
epoch 1264   loss = [687.1939, 687.17883, 0.015086816, 840.7761]
epoch 1264   loss = [687.1566, 687.11365, 0.04298611, 848.9316]
epoch 1265   loss = [687.07495, 687.047, 0.027946353, 845.1206]
epoch 1265   loss = [687.18915, 687.1554, 0.033756677, 844.5231]
epoch 1266   loss = [687.18524, 687.08203, 0.10323658, 845.1715]
epoch 1266   loss = [687.2047, 687.1666, 0.038056538, 844.5028]
epoch 1267   loss = [687.11

epoch 1323   loss = [687.10156, 686.99066, 0.11091833, 848.0006]
epoch 1324   loss = [687.12585, 686.98596, 0.13986932, 845.70197]
epoch 1324   loss = [687.13916, 687.1096, 0.029514918, 843.94086]
epoch 1325   loss = [687.05707, 687.01526, 0.041805632, 846.5046]
epoch 1325   loss = [687.13116, 687.08813, 0.043020412, 843.1047]
epoch 1326   loss = [687.02344, 687.00793, 0.015488672, 843.8178]
epoch 1326   loss = [687.0544, 686.9866, 0.06779992, 845.8115]
epoch 1327   loss = [687.124, 687.0995, 0.024525551, 839.9188]
epoch 1327   loss = [687.0787, 687.0338, 0.04484602, 849.7644]
epoch 1328   loss = [687.1144, 687.09644, 0.017931167, 844.5544]
epoch 1328   loss = [687.0762, 687.0596, 0.01657105, 845.0765]
epoch 1329   loss = [687.124, 687.1007, 0.02333222, 842.5935]
epoch 1329   loss = [687.0811, 687.0647, 0.016396333, 847.06976]
epoch 1330   loss = [687.0837, 687.05676, 0.02689286, 845.3453]
epoch 1330   loss = [687.13165, 687.1017, 0.029973729, 844.2828]
epoch 1331   loss = [687.0943, 6

epoch 1387   loss = [686.9781, 686.91785, 0.060212027, 847.02625]
epoch 1388   loss = [687.10254, 687.09265, 0.009897826, 844.644]
epoch 1388   loss = [687.3577, 687.01807, 0.3396575, 844.9507]
epoch 1389   loss = [687.20197, 687.14624, 0.055725176, 845.4434]
epoch 1389   loss = [686.9514, 686.92285, 0.028556632, 844.15906]
epoch 1390   loss = [687.08545, 687.045, 0.04045897, 844.9022]
epoch 1390   loss = [686.9489, 686.9009, 0.04804199, 844.68604]
epoch 1391   loss = [687.1347, 687.0608, 0.07392631, 842.44257]
epoch 1391   loss = [687.0965, 687.05493, 0.041537724, 847.18463]
epoch 1392   loss = [687.2218, 687.07764, 0.14418295, 842.3877]
epoch 1392   loss = [686.99097, 686.9811, 0.009858586, 847.2438]
epoch 1393   loss = [687.1782, 687.14465, 0.03357062, 841.097]
epoch 1393   loss = [686.90765, 686.89294, 0.014697759, 848.53455]
epoch 1394   loss = [687.0972, 686.9386, 0.15862852, 846.40137]
epoch 1394   loss = [687.1787, 687.1456, 0.03308884, 843.1555]
epoch 1395   loss = [687.1719, 

epoch 1451   loss = [687.1142, 687.0647, 0.04947777, 844.2199]
epoch 1452   loss = [687.0255, 686.9833, 0.042252548, 845.0449]
epoch 1452   loss = [687.0351, 686.9524, 0.08268846, 844.5685]
epoch 1453   loss = [687.03595, 686.9878, 0.04814167, 844.087]
epoch 1453   loss = [686.89197, 686.80634, 0.08561824, 845.55554]
epoch 1454   loss = [686.98834, 686.91943, 0.06893226, 842.3832]
epoch 1454   loss = [686.95966, 686.91785, 0.041815173, 847.2928]
epoch 1455   loss = [687.08295, 687.0248, 0.058181033, 843.8759]
epoch 1455   loss = [687.0806, 687.0724, 0.008231196, 845.7666]
epoch 1456   loss = [687.11505, 687.1044, 0.010686637, 847.67267]
epoch 1456   loss = [687.019, 686.9127, 0.10624631, 841.90546]
epoch 1457   loss = [687.0911, 687.07385, 0.01724255, 842.28955]
epoch 1457   loss = [686.9899, 686.9656, 0.024347072, 847.3961]
epoch 1458   loss = [687.1437, 687.1207, 0.022956803, 842.2542]
epoch 1458   loss = [687.6399, 686.97327, 0.6666032, 847.42303]
epoch 1459   loss = [687.0809, 687.

## load test images

In [21]:
# load test images with index
stride = 10
X_test = []
X_test_idx = []
image = cv2.imread(os.path.join(data_path, image_name))
mask = cv2.imread(os.path.join(data_path, mask_name+'_label.jpeg'), 0) / 255
print(image.shape, mask.shape)

for row in range(0,mask.shape[0]-img_rows,stride):
    for col in range(0,mask.shape[1]-img_rows,stride):
        sub_mask = mask[row:row+img_rows, col:col+img_rows]
        if np.count_nonzero(sub_mask) == img_rows*img_rows:
            sub_img = image[row:row+img_rows, col:col+img_rows]
            X_test.append(sub_img)
            X_test_idx.append([row,col])
            
X_test = np.array(X_test) / 255.0
print(X_test.shape, len(X_test_idx))

(3063, 4764, 3) (3063, 4764)
(3666, 20, 20, 3) 3666


In [22]:
classifier = Model(inputs=[x_in], outputs=[_y_output])

y_pred = np.argmax(classifier.predict(X_test), axis=-1)

print(y_pred.shape)

(3666,)


In [23]:
pred_pos_idx, pred_neg_idx = [], []

for i in range(y_pred.shape[0]):
    if y_pred[i] == 0:
        pred_neg_idx.append(X_test_idx[i])
    else:
        pred_pos_idx.append(X_test_idx[i])
        
print(len(pred_pos_idx), len(pred_neg_idx))

850 2816


## generate point-level prediction results

In [24]:
# save the positive point-level prediction resuls as a txt file
save_file_name = mask_name+'_points_pred.txt'
save_file_path = os.path.join(data_path, save_file_name)
pred_center_points = []
for i in pred_pos_idx:
    x, y = i[0], i[1]
    x_c, y_c = int(x+img_rows/2), int(y+img_rows/2)
    pred_center_points.append([x_c,y_c])

pred_center_points = np.array(pred_center_points)
np.savetxt(save_file_path, pred_center_points, fmt='%d', delimiter=',')
print('save the point-level prediction results in %s',save_file_path)

save the point-level prediction results in %s ./data/xview/322_car_1_points_pred.txt


## generate patch-level prediction map

In [25]:
pred_map = np.zeros(mask.shape)

for i in pred_pos_idx:
    x, y = i[0], i[1]
    pred_map[x:x+img_rows, y:y+img_rows] = 255

pred_path = os.path.join(data_path, mask_name+'_pred.jpeg')
cv2.imwrite(pred_path, pred_map)

True

## generate point-level prediction map for visualization

In [26]:
import copy

radius = 1
# Blue color in BGR
color = (0, 0, 255)
# Line thickness of 2 px
thickness=2

pred_points_map = copy.deepcopy(image)
for i in pred_pos_idx:
    x, y = i[0], i[1]
    x_c, y_c = int(x+img_rows/2), int(y+img_rows/2)
    pred_points_map = cv2.circle(pred_points_map, (y_c,x_c), radius, color, thickness)

pred_points_path = os.path.join(data_path, mask_name+'_pred_points.jpeg')
cv2.imwrite(pred_points_path, pred_points_map)

True

## group points

In [27]:
# from sklearn.cluster import DBSCAN

# buffer_size = 10

# pred_points = []

# for i in pred_pos_idx:
#     x, y = i[0], i[1]
#     x_c, y_c = int(x+img_rows/2), int(y+img_rows/2)
#     pred_points.append([x_c, y_c])

# # grouping points
# points_dict = {}
# for i in pred_points:
#     x_i, y_i = i[0], i[1]
#     points_dict[str(i)] = [i]
#     for j in pred_points:
#         x_j, y_j = j[0], j[1]
#         if abs(x_j-x_i)<=buffer_size and abs(y_j-y_i)<=buffer_size:
#             points_dict[str(i)].append([x_j, y_j])

# grouped_points = []
# # method 1
# # visited = []
# # for p in points_dict.keys():
# #     i = eval(p)
# #     for j in points_dict.keys():
# #         q = eval(j)
# #         if p != j and points_dict[p] == points_dict[j] and q not in visited:
# #             if i not in visited:
# #                 grouped_points.append(i)
# #                 visited.append(i)
# #             visited.append(q)
# #     if i not in visited:
# #         grouped_points.append(i)

# # method 2
# for p in points_dict.keys():
#     pts_in_buffer =points_dict[p]
#     xs = [i[0] for i in pts_in_buffer]
#     ys = [i[1] for i in pts_in_buffer]
#     x_mean = int(sum(xs)*1.0 / len(xs))
#     y_mean = int(sum(ys)*1.0 / len(ys))
#     grouped_points.append([x_mean, y_mean])

# # method 3 DBSCAN
# # points_dict = {}
# # pred_points_np = np.array(pred_points)
# # clustering = DBSCAN(eps=buffer_size+5, min_samples=1).fit(pred_points_np)
# # print(len(set(clustering.labels_)))

# # for i in range(len(clustering.labels_)):
# #     label = clustering.labels_[i]
# #     if label not in points_dict:
# #         points_dict[label] = [pred_points[i]]
# #     else:
# #         points_dict[label].append(pred_points[i])
        
# # for p in points_dict.keys():
# #     if p == -1:
# #         for i in points_dictp[p]:
# #             grouped_points.append(i)   
# #     else:
# #         pts_in_buffer =points_dict[p]
# #         xs = [i[0] for i in pts_in_buffer]
# #         ys = [i[1] for i in pts_in_buffer]
# #         x_mean = int(sum(xs)*1.0 / len(xs))
# #         y_mean = int(sum(ys)*1.0 / len(ys))
# #         grouped_points.append([x_mean, y_mean])
    
# print(len(pred_points), len(grouped_points))

In [28]:
# grouped_points_np = np.array(grouped_points)
# clustering = DBSCAN(eps=4, min_samples=2).fit(grouped_points_np)
# print(len(set(clustering.labels_)))

# dbscan_points = []
# for p in set(clustering.labels_):
#     if p == -1:
#         for i in range(len(clustering.labels_)):
#             if clustering.labels_[i] == -1:
#                 dbscan_points.append(grouped_points[i])
#     else:
#         temp = []
#         for i in range(len(clustering.labels_)):
#             if clustering.labels_[i] == p:
#                 temp.append(grouped_points[i])
#         xs = [j[0] for j in temp]
#         ys = [j[1] for j in temp]
#         x_mean = int(sum(xs)*1.0 / len(xs))
#         y_mean = int(sum(ys)*1.0 / len(ys))
#         dbscan_points.append([x_mean, y_mean])

In [29]:
# grouped_points_map = copy.deepcopy(image)
# # for i in grouped_points:
# for i in dbscan_points:
#     x_c, y_c = i[0], i[1]
#     pred_points_map = cv2.circle(grouped_points_map, (y_c,x_c), radius, color, thickness)

# grouped_points_path = os.path.join(data_path, mask_name+'_grouped_points.jpeg')
# cv2.imwrite(grouped_points_path, grouped_points_map)