In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Change to yours src folder path

In [2]:
cd /content/drive/MyDrive/mvi-sp/src

/content/drive/MyDrive/mvi-sp/src


In [3]:
!pip install -q -U keras-tuner

## Change below flags to tune or train NNs

In [4]:
TUNE_UNET  = False
TRAIN_UNET = False
TRAIN_GAN  = True

In [5]:
import gc
import os

import keras_tuner as kt
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Adam

from config import *
from models import small_test, u_net, u_net_bn, DCGAN

## Ran to some memory issues while using generator, callback below could help

In [6]:
class MemoryCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()
        tf.keras.backend.clear_session()

In [7]:
def get_unet(hp):
    model = u_net_bn()
    optimizer = Adam(learning_rate=hp.Choice('learning_rate', values=[1e-3, 1e-4, 1e-5]))
    model.compile(optimizer=optimizer, loss='mse', metrics=['accuracy'])
    return model

## Data normalization for U-Net

In [8]:
X = None
Y = None
flag = False
for part in os.listdir(PARTITIONS_PATH):
    arr = np.load(os.path.join(PARTITIONS_PATH, part))
    if not flag:
        flag = True
        X = arr['X']/255.0
        Y = arr['Y']/255.0
    else:
        X = np.concatenate((X,arr['X'] / 255.0))
        Y = np.concatenate((Y,arr['Y'] / 255.0))

del arr

X = X.astype('float32')
Y = Y.astype('float32')

In [9]:
if TUNE_UNET:
    BATCH_SIZE = 16
    EPOCHS = 50
    tuner = kt.RandomSearch(
        hypermodel=get_unet,
        objective='val_loss',
        overwrite=True,
        directory='.',
        project_name='unettuner'
    )

    mc = MemoryCallback()
    es = EarlyStopping(monitor='val_loss', patience = 2, min_delta = 0.01)
    tuner.search(    
        x = X,
        y = Y,
        epochs = EPOCHS,
        batch_size = BATCH_SIZE,
        callbacks = [mc,es],
        validation_split=0.2
    )
    tuner.results_summary(1)
    del tuner



In [11]:
if TRAIN_UNET:
    BATCH_SIZE = 16
    EPOCHS = 100
    callbacks = [
        MemoryCallback(),
        ModelCheckpoint(
            filepath=os.path.join(WEIGHTS_PATH,'unetbn-{val_loss:.4f}.hdf5'),
            monitor = 'val_loss',
            save_best_only = True,
            mode = 'min',
            save_freq = 'epoch'
        ),
        EarlyStopping(
            monitor = 'val_loss',
            patience = 3,
            min_delta = 0.01
        )
    ]

    model = u_net_bn()
    optimizer = Adam(learning_rate=0.0001)
    model.compile(optimizer=optimizer, loss='mse', metrics=['accuracy'])

    history = model.fit(
        x = X,
        y = Y,
        batch_size = BATCH_SIZE,
        epochs = EPOCHS,
        validation_split = 0.2,
        callbacks = callbacks,
        verbose = 1
    )

    del model

In [12]:
if TRAIN_GAN:
    del X
    del Y
    X = None
    Y = None
    flag = False
    for part in os.listdir(PARTITIONS_PATH):
        arr = np.load(os.path.join(PARTITIONS_PATH, part))
        if not flag:
            flag = True
            X = (arr['X'] - 127.5)/127.5
            Y = (arr['Y'] - 127.5)/127.5
        else:
            X = np.concatenate((X,(arr['X'] - 127.5)/127.5))
            Y = np.concatenate((Y,(arr['Y'] - 127.5)/127.5))

    del arr

    X = X.astype('float32')
    Y = Y.astype('float32')

    BATCH_SIZE = 8
    EPOCHS = 50
    callbacks = [
        MemoryCallback(),
        ModelCheckpoint(
            filepath=os.path.join(WEIGHTS_PATH,'gan-{val_g_loss:.4f}.hdf5'),
            monitor = 'val_g_loss',
            save_weights_only = True,
            mode = 'min',
            save_freq = 'epoch'
        )
    ]

    model = DCGAN()
    model.compile()
    history = model.fit(
        x = X,
        y = Y,
        batch_size = BATCH_SIZE,
        epochs = EPOCHS,
        validation_split = 0.2,
        callbacks = callbacks,
        verbose = 1
    )

    del model

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50

KeyboardInterrupt: ignored