<a href="https://colab.research.google.com/github/ariaberlian/rbm_sr/blob/main/rbm_sr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Python v3.6.12

##### Run This to Install #####

# ! pip install requirements.txt

## !git clone https://github.com/albertbup/deep-belief-network.git
## !pip install -r "deep-belief-network/requirements.txt"
## !mv "deep-belief-network" "deep_belief_network"

## """
## - add this to dbn/tensorflow/models.py:
##     import tensorflow._api.v2.compat.v1 as tf
##     tf.disable_v2_behavior()
## """

In [None]:
from deep_belief_network.dbn.tensorflow.models import UnsupervisedDBN # use "from dbn.tensorflow import SupervisedDBNClassification" for computations on TensorFlow
from utils.data_processing import DataProcessing
from utils.image_file_util import *
from utils.scoring import *
from utils.visualizer import *
import tracemalloc

## DBN

In [None]:
## Training DBN
def train(model, input : np.ndarray, model_name : str, repetitions: int, beta:float, interpolation_factor:float):
    dp = DataProcessing()

    visualize_histogram(input, title=f"Training Image")

    curr_rep = 1
    print("Repetisi saat ini: ", curr_rep)

    model.fit(input)
    model.save(f"model/{model_name}_{curr_rep}.h5")

    while curr_rep < repetitions:
        model = model.load(f"model/{model_name}_{curr_rep}.h5")
        r = model.transform(input)
        print("Shape Transformed: ", r.shape)
        input = dp.proccess_output(u=input, r=r, beta=beta, s=interpolation_factor)
        visualize_histogram(input, title=f"Training Image: After transform ({curr_rep})")
        r = None

        curr_rep += 1
        print("Repetisi saat ini: ", curr_rep )

        model.fit(input)
        model.save(f"model/{model_name}_{curr_rep}.h5")

## Test DBN
def test(test_image, test_reference_image, model, patch_size:tuple, stride:tuple, interpolation_factor:int):
    dp = DataProcessing()

    desired_shape = test_reference_image.shape
    
    test_image = dp.interpolate(test_image, interpolation_factor=interpolation_factor)
    visualize_image(test_image, title="Test Image after interpolation")
    visualize_histogram(test_image, title="Test Image after interpolation", range=(0,256))

    test_patches = dp.get_patches(test_image,patch_size=patch_size, stride=stride)

    # Process data for rbm
    test_patches = dp.preprocess_for_rbm(test_patches)

    # Infer test to model
    result = model.transform(test_patches)

    test_patches = dp.inverse_preprocess(
        dp.proccess_output(test_patches, result, 1, interpolation_factor),
        original_patch_shape=(patch_size[0], patch_size[1], 3)
        )
    
    result = None

    visualize_patches(test_patches, title="Test Patches Example", visualize_size=(6,6))

    reconstruct_image = dp.reconstruct_from_patches(test_patches, original_shape=desired_shape, patch_size=patch_size, stride=stride)

    test_patches = None

    visualize_histogram_compare(original_image=test_reference_image, reconstruct_image=reconstruct_image)
    psnr_value = calculate_psnr(test_reference_image, reconstruct_image)
    ssim_value = calculate_ssim(test_reference_image, reconstruct_image)*100

    psnr_print = f"PSNR value: {psnr_value:,.3f} dB"
    ssim_print = f"SSIM value: {ssim_value:,.3f}%"

    print(psnr_print.replace(".", ","))
    print(ssim_print.replace(".", ","))


In [None]:
tracemalloc.start()

#### Patch size and Stride ####
patch_size = (4,4)
stride = (1,1)

#### Load Data ####
dp = DataProcessing()

# Load Training Data
training_image = load_image("train/128/ct_lung1_128.png")
visualize_image(training_image, "Training Image")

train_patches = dp.get_patches(training_image, patch_size=patch_size, stride=stride)
visualize_patches(train_patches, "Train patches example")

X_train = dp.preprocess_for_rbm(train_patches)

train_patches = None
training_image = None

#### Training parameter ###
model_name = "model_128_x2_p4_s1_(24_12_48)"
interpolation_factor = 2
beta = 1
Repetitions = 1
lr = 0.01
epoch = 50

layers = [24_12_48]
batch_size = 128
activation_function = 'sigmoid'

# Models we will use
dbn = UnsupervisedDBN(hidden_layers_structure=layers,
                      batch_size=batch_size,
                      learning_rate_rbm=lr,
                      n_epochs_rbm=epoch,
                      activation_function=activation_function,
                      optimization_algorithm='sgd',)

#### Train and Test ####

snapshot1 = tracemalloc.take_snapshot()
# comment to test only
train(
    model=dbn, 
    input=X_train, 
    model_name=model_name,
    beta=beta,
    interpolation_factor=interpolation_factor,
    repetitions=Repetitions, 
)

snapshot2 = tracemalloc.take_snapshot()
top_stats = snapshot2.compare_to(snapshot1, 'lineno')

print("[ Top 10 differences in memory allocation ]")
for stat in top_stats[:10]:
    print(stat)

# Load Testing Data
print("== Test 1 ==")
test_image1 = load_image("test/64/ct_lung2_64.png")
visualize_image(test_image1, title="Test Image")

test_reference_image = load_image("test/128/ct_lung2_128.png")
dbn = dbn.load(f"model/{model_name}_{Repetitions}.h5")

test(
    test_image=test_image1,
    test_reference_image=test_reference_image,
    model=dbn,
    patch_size=patch_size,
    stride=stride,
    interpolation_factor=interpolation_factor
)

print("== Test 2 ==")
test_image1 = load_image("test/64/ct_lung3_64.png")
visualize_image(test_image1, title="Test Image")

test_reference_image = load_image("test/128/ct_lung3_128.png")
dbn = dbn.load(f"model/{model_name}_{Repetitions}.h5")

test(
    test_image=test_image1,
    test_reference_image=test_reference_image,
    model=dbn,
    patch_size=patch_size,
    stride=stride,
    interpolation_factor=interpolation_factor
)

print("== Test 3 ==")
test_image1 = load_image("test/64/ct_lung4_64.png")
visualize_image(test_image1, title="Test Image")

test_reference_image = load_image("test/128/ct_lung4_128.png")
dbn = dbn.load(f"model/{model_name}_{Repetitions}.h5")

test(
    test_image=test_image1,
    test_reference_image=test_reference_image,
    model=dbn,
    patch_size=patch_size,
    stride=stride,
    interpolation_factor=interpolation_factor
)