In [14]:
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import config

from model_structure_3d_2 import UNet3D

In [6]:
# Load the PyTorch model
model = UNet3D()
model.load_state_dict(torch.load("./saved_models/model_for_vasc_3d2839818.pth", map_location="cpu"))

# Set the model to evaluation mode
model.eval()

# Directory paths for input and output data
input_dir = "./preparations/data/indata"
output_dir = "./preparations/data/outdata"

# Get a list of image files in the input directory
image_files = os.listdir(input_dir)

def get_prediction(img, margin):
    if margin != 0:
        raise ValueError
    with torch.no_grad():
        input_tensor = torch.from_numpy(img / 255).unsqueeze(0).float()
        output = model(input_tensor)
    output_np = output.numpy()
    return output_np


In [23]:
trueColor = (0, 255, 0)
falseColor = (0, 0, 255)

image_stacks = 10

path_in = "./preparations/data/indata"
path_out = "./preparations/data/outdata"

paths_in = os.listdir(path_in)
paths_out = os.listdir(path_out)

first_image_index = 100
index = first_image_index

for k in range(first_image_index, first_image_index + image_stacks * config.NUM_PICS, config.NUM_PICS):
    images_in = []
    images_out = []
    for i in range(config.NUM_PICS):
        img_in = cv2.imread(os.path.join(path_in, paths_in[k + i]), cv2.IMREAD_GRAYSCALE)
        img_out = cv2.imread(os.path.join(path_out, paths_out[k + i]), cv2.IMREAD_GRAYSCALE)
        images_in.append(img_in)
        images_out.append(img_out)
    images_in = np.array([images_in])
    images_out = np.array([images_out])

    step = 64
    shape = images_in.shape
    count_times = np.zeros(shape)
    total_times = np.zeros(shape)

    for i in tqdm(range(0, shape[2] - shape[2] % step, step), desc="Processing"):
        for j in range(0, shape[3] - shape[3] % step, step):
            slice_tmp = images_in[:, :,i : i + config.HEIGHT, j : j + config.WIDTH]
            array_tmp = get_prediction(slice_tmp, 0)[0]
            count_times[:, :, i : (i + config.HEIGHT), j : (j + config.WIDTH)] += (array_tmp > 0.5)
            total_times[:, :, i : (i + config.HEIGHT), j : (j + config.WIDTH)] += 1
    confidence = (count_times + 0.0001) / (total_times + 0.0001)
    images_predict = (np.floor(confidence * 255)).astype(np.uint8)[0]
    
    # Initialize img_blend with img_in
    img_blend = np.zeros(images_in.shape + (3,), dtype=np.uint8)

    # Apply blending rules
    # Rule 1: Where img_out and img_predict are both 0, put the color that was in img_in
    img_blend[(images_out == 0) & (images_predict == 0)] = np.stack([images_in] * 3, axis=-1)[(images_out == 0) & (images_predict == 0)]

    # Rule 2: Where img_out and img_predict are both 255, multiply trueColor with the number in img_in and put it back
    img_blend[(images_out == 255) & (images_predict == 255)] = np.array(trueColor) * np.stack([images_in] * 3, axis=-1)[(images_out == 255) & (images_predict == 255)] / 255

    # Rule 3: Where img_out and img_predict are different, multiply by falseColor and put it back
    img_blend[(images_out != images_predict)] = np.array(falseColor) * np.stack([images_in] * 3, axis=-1)[(images_out != images_predict)] / 255

    for x in img_blend[0]:
        cv2.imwrite(os.path.join("./blend_images", str(1000000 + index) + ".png"), x)
        index += 1



    


Processing: 100%|██████████| 21/21 [01:16<00:00,  3.63s/it]
Processing: 100%|██████████| 21/21 [01:14<00:00,  3.55s/it]
Processing: 100%|██████████| 21/21 [01:16<00:00,  3.67s/it]
Processing: 100%|██████████| 21/21 [01:12<00:00,  3.44s/it]
Processing: 100%|██████████| 21/21 [01:11<00:00,  3.42s/it]
Processing: 100%|██████████| 21/21 [01:17<00:00,  3.69s/it]
Processing: 100%|██████████| 21/21 [01:15<00:00,  3.59s/it]
Processing: 100%|██████████| 21/21 [01:16<00:00,  3.63s/it]
Processing: 100%|██████████| 21/21 [01:15<00:00,  3.58s/it]
Processing: 100%|██████████| 21/21 [01:18<00:00,  3.76s/it]


In [22]:
print(img_blend.shape)
cv2.imshow("img_blend", img_blend[0][0])
cv2.waitKey(0)
cv2.destroyAllWindows()

(1, 16, 1356, 1356, 3)


In [None]:
def create_combined_image(img_in, img_out):
    n = img_in.shape[0]
    combined_img = np.zeros((n, n, 3), dtype=np.uint8)
    true_color = np.array([0, 255, 0], dtype=np.uint8)

    for i in range(n):
        for j in range(n):
            if img_out[i, j] == 255:
                combined_img[i, j] = img_in[i, j] * true_color
            else:
                combined_img[i, j] = img_in[i, j]

    return combined_img

trueColor = (0, 255, 0)
falseColor = (0, 0, 255)

image_stacks = 10

path_in = "./preparations/data/indata"
path_out = "./preparations/data/outdata"

paths_in = os.listdir(path_in)
paths_out = os.listdir(path_out)

first_image_index = 100
index = first_image_index

for k in range(first_image_index, first_image_index + image_stacks * config.NUM_PICS):
    img_in = cv2.imread(os.path.join(path_in, paths_in[k]), cv2.IMREAD_GRAYSCALE)
    img_out = cv2.imread(os.path.join(path_out, paths_out[k]), cv2.IMREAD_GRAYSCALE)
    
    