In [1]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

from tensorflow.keras.layers import Dense, Flatten, Input, Conv2D, MaxPooling2D, Dropout
from tensorflow.keras.models import Model, Sequential
from osgeo import osr, ogr, gdal
from tqdm import tqdm
from PIL import Image

import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import datetime, os
import numpy as np

import tensorflow.keras

In [2]:
# Параметры
path_train = "files 3/train_3.png" 
path_target = "files 3/target_3.png" 

size = 32

In [3]:
def split_image(path, size, save_path=None):
    img = Image.open(path).convert('L') 
    x, y = img.size  # ширина (x) и высота (y) изображения
    
    x = x // size
    y = y // size

    all_data = [] 

    for i in tqdm(range(x)):
        for j in range(y): 
            mini_image = img.crop(box=(size * i, size * j, size * (i + 1), size * ( j + 1)))
            if save_path:
                mini_image.save(save_path + 'image{}.png'.format(str(j + i*y)))

            mini_image = np.array(mini_image.getdata())
            mini_image = np.reshape(mini_image, (1, size, size, 1)) 

            if len(all_data) == 0:
                all_data =  mini_image
            else:
                all_data = np.append(all_data, mini_image, axis=0)
        
    print("Data shape is", np.shape(all_data))
    return all_data, [x, y]

In [4]:
# Split target data
target, target_size = split_image(path_target, size)

100%|██████████| 27/27 [00:00<00:00, 192.76it/s]

Data shape is (378, 32, 32, 1)





In [5]:
# Split train data
train, train_size = split_image(path_train, size)

100%|██████████| 27/27 [00:00<00:00, 263.91it/s]

Data shape is (378, 32, 32, 1)





In [6]:
print(train.min(), train.max())
print(target.min(), target.max())

0 192
0 255


In [7]:
target = target / 255.0 
train = train / train.max()

In [8]:
target = np.reshape(target, (-1, size*size)) 

In [9]:
unique, counts = np.unique(target, return_counts=True)
all = dict(zip(unique, counts))
all[1.0] / all[0.0]

0.01605760472094104

In [10]:
inp = Input(shape=(size, size, 1))
x = Conv2D(256, (14,14), padding='same', activation=tf.nn.relu) (inp)
x = MaxPooling2D((2, 2), strides=2)(x)
x = Conv2D(128, (9,9), padding='same', activation=tf.nn.relu)(x)
x = MaxPooling2D((2, 2), strides=2)(x)
x = Conv2D(90, (7,7), padding='same', activation=tf.nn.relu)(x)
x = MaxPooling2D((2, 2), strides=2)(x)
x = Flatten()(x)
x = Dense(130, activation=tf.nn.relu)(x) 
x = Dropout(0.5)(x)
x = Dense(130, activation=tf.nn.relu)(x) 
x = Dropout(0.5)(x)
out = Dense(size*size,  activation=tf.nn.softmax)(x)

model = Model(inputs=inp, outputs=out)

In [11]:
model = Sequential([
    Conv2D(512, (7,7), padding='same', activation=tf.nn.relu,
                           input_shape=(size, size, 1)),
    MaxPooling2D((2, 2), strides=2),
    Conv2D(256, (5,5), padding='same', activation=tf.nn.relu), 
    MaxPooling2D((2, 2), strides=2), 
    Flatten(),
    Dense(130, activation=tf.nn.relu),
#     Linear(),
    Dense(130, activation=tf.nn.relu),
    Dense(size*size,  activation=tf.nn.softmax, dtype='float64')
])


In [12]:
from tensorflow.keras.optimizers import Adam
opt = Adam(learning_rate=0.1)

In [13]:
model.compile(optimizer=opt,
              loss='binary_crossentropy',       # categorical_crossentropy, mean_squared_error, binary_crossentropy
              metrics=['binary_accuracy'])

In [14]:
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

In [15]:
his = model.fit(train, target, epochs=30, shuffle=True, 
            callbacks=[tensorboard_callback])

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [16]:
%tensorboard --logdir logs

Reusing TensorBoard on port 6006 (pid 663765), started 0:09:19 ago. (Use '!kill 663765' to kill it.)