In [None]:
import matplotlib.pyplot as plt
import numpy as np
from numpy.fft import fftfreq, fft, ifft
from skimage.filters import gaussian
from skimage.io import imread
from skimage.transform import rescale, rotate
import warnings
from matplotlib.animation import FuncAnimation
%matplotlib notebook

img_name_root = "examples/"
images = [
    "CT_ScoutView.jpg",
    "CT_ScoutView-large.jpg",
    "Kolo.jpg",
    "Kropka.jpg",
    "Kwadraty2.jpg",
    "Paski2.jpg",
    "SADDLE_PE.JPG",
    "SADDLE_PE-large.JPG",
    "Shepp_logan.jpg"
]
images = [img_name_root + n for n in images]

image_indx = 8


class Parameters:

    def __init__(self, alpha, emitters_num, use_filter, image_name) -> None:
        self.alpha = alpha
        self.emitters_num = emitters_num
        self.use_filter = use_filter
        self.image_name = image_name

    def set_values(self, alpha, emitters_num, use_filter, image_name):
        if alpha <= 0:
            raise Exception("Alpha must be positive")
        if emitters_num <= 0:
            raise Exception("Emitters num must be positive")
        if image_name is None:
            raise Exception("Image was not selected")
        self.alpha = alpha
        self.emitters_num = emitters_num
        self.use_filter = use_filter
        self.image_name = image_name


params = Parameters(180 / 360, 80, True, images[image_indx])


def draw_image(i, img):
    plt.subplot(1, 2, i)
    plt.imshow(img, cmap=plt.cm.Greys_r)


def make_image_square(img):
    shape = img.shape
    if len(shape) != 2:
        raise Exception("Wrong shape " + str(shape))

    if shape[0] == shape[1]:
        return img

    max_shape = np.max(shape)
    result = pad_image(img, max_shape, shape)
    return result


def pad_image(img, required_size, shape):
    result = np.zeros((required_size, required_size))
    start_x = int((required_size - shape[0]) / 2)
    start_y = int((required_size - shape[1]) / 2)
    end_x = start_x + shape[0]
    end_y = start_y + shape[1]
    result[start_x:end_x, start_y:end_y] = img
    return result


def increase_image(img):
    shape = img.shape
    np_max_dim = np.max(shape)
    return pad_image(img, int(np_max_dim * np.sqrt(2)), shape)


def create_filter_at(target, start, distance, filter):
    half_len = len(filter) // 2
    for i in range(len(filter)):
        to_write = i - half_len,
        target[:, start + to_write[0]::distance] = filter[i]


def prepare_tomograph(emitters, dim):
    if emitters > dim:
        warnings.warn("emmiters num was bigger than image dim - used smaller")
    emitters = np.min((emitters, dim))
    zeros = np.zeros((dim, dim))
    distance = int(dim / emitters)
    start = int(np.ceil((dim % distance) / 2))
    create_filter_at(zeros, start, distance, [1])
    zeros = gaussian(zeros)
    zeros /= zeros.max()
    return zeros


def get_intersection(rotation, image, tomograph, real_dim):
    rotation += 90
    target_size = len(image)
    start = int((target_size - real_dim) / 2)
    rotated = rotate(tomograph, rotation)
    common_part = rotated * image
    common_rotated_again = rotate(common_part, -rotation)
    column_avg = [sum(q) / real_dim for q in common_rotated_again[start:(start + real_dim)]]
    return column_avg


def make_radon(increased_image, increased_tomograph, real_dim, theta, on_change=None):
    res = np.zeros((len(theta), real_dim))
    for i, rotation in enumerate(theta):
        res[i] = get_intersection(rotation, increased_image, increased_tomograph, real_dim)
        if on_change is not None:
            on_change(si=res)
    return res


def transform_sinogram(sinogram):
    sinogram = np.rot90(sinogram, k=1)
    freqs = sinogram.shape[0]
    projection_size_padded = \
        max(64, int(2 ** np.ceil(np.log2(2 * freqs))))
    pad_width = ((0, projection_size_padded - freqs), (0, 0))
    img = np.pad(sinogram, pad_width, mode='constant', constant_values=0)
    f = fftfreq(projection_size_padded).reshape(-1, 1)
    omega = 2 * np.pi * f
    fourier_filter = 2 * np.abs(f)  # * np.cos(omega)
    projection = fft(img, axis=0) * fourier_filter
    radon_filtered = np.real(ifft(projection, axis=0))
    radon_filtered = radon_filtered[:freqs, :]
    return np.rot90(radon_filtered, k=-1)


def inverse_radon(sigmoid, rotations, output_size, on_change=None):
    reconstructed = increase_image(np.zeros((output_size, output_size)))
    start = (len(reconstructed) - output_size) // 2
    end = start + output_size
    rotations_len = len(rotations)
    for i in range(rotations_len):
        temp = np.array([sigmoid[i], ] * len(reconstructed))
        temp = make_image_square(temp)
        reconstructed += rotate(temp, rotations[i])
        result = reconstructed[start:end, start:end]
        if on_change is not None:
            on_change(isi=result)
    result = reconstructed[start:end, start:end]

    if rotations_len > 0:
        result /= rotations_len
    return result


def get_moves(a):
    return [a * i for i in range(int(np.ceil(180 / a)))]


def show_images(original, sinogram, reconstructed):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8, 8))
    ax1.imshow(original, cmap=plt.cm.Greys_r)
    ax2.imshow(sinogram, cmap=plt.cm.Greys_r)
    ax3.imshow(reconstructed, cmap=plt.cm.Greys_r)
    plt.show()


def transform_sinogram_if_enabled(params, sinogram):
    if params.use_filter:
        return transform_sinogram(sinogram)
    else:
        return sinogram


def read_image(name):
    image = imread(name, as_grey=True)
    image = rescale(image, scale=0.4, mode='reflect')
    max_image_value = np.max(image)
    if max_image_value > 1:
        image /= 255
    return image


def prepare_instance(params):
    theta = get_moves(params.alpha)
    image = read_image(params.image_name)
    image = make_image_square(image)
    return image, theta


class Scanner:
    update_time = 0.2

    def __init__(self, params) -> None:
        self.params = params
        self.image, self.theta = prepare_instance(params)
        self.increased_image = increase_image(self.image)
        self.tomograph = prepare_tomograph(emitters=params.emitters_num, dim=np.max(self.image.shape))
        self.increased_tomograph = increase_image(self.tomograph)
        self.sinogram = None
        self.sinogram_transformed = None
        self.i_sin = None
        self.fig = None
        self.ax1 = self.ax2 = self.ax3 = None
        self.last_time = 0
        self.im2 = self.im3 = None
        self.im2c = self.im3c = False
        self.ani1 = self.ani2 = None

    def assign(self, si=None, isi=None, tisi=None):
        if si is not None:
            self.sinogram = si
            self.im2c = True
            if self.im2 is None:
                self.im2 = self.ax2.imshow(self.sinogram, cmap=plt.cm.Greys_r, animated=True)
                self.ani1 = FuncAnimation(self.fig, self.update_sin, blit=True, interval=50)
        if isi is not None:
            self.i_sin = isi
            self.im3c = True
            if self.im3 is None:
                self.im3 = self.ax3.imshow(self.i_sin, cmap=plt.cm.Greys_r, animated=True)
                self.ani2 = FuncAnimation(self.fig, self.update_isin, blit=True, interval=50)
        if tisi is not None:
            self.sinogram_transformed = tisi

    def watch_changes(self):
        self.init_chart()
        sinogram = make_radon(self.increased_image, self.increased_tomograph,
                              len(self.image), self.theta, on_change=self.assign)
        sinogram_transformed = transform_sinogram_if_enabled(self.params, sinogram)
        self.assign(tisi=sinogram_transformed)
        i_sin = inverse_radon(sinogram_transformed, self.theta, len(self.image), on_change=self.assign)

    def update_sin(self, *f):
        print("update !")
        if self.im2c:
            self.im2c = False
            self.im2.set_array(self.sinogram)
        return self.im2,

    def update_isin(self, *f):
        print("update 2")
        if self.im3c:
            self.im3c = False
            self.im3.set_array(self.i_sin)
        return self.im3,

    def init_chart(self):
        self.fig, (self.ax1, self.ax2, self.ax3) = plt.subplots(1, 3, figsize=(8, 8))
        self.ax1.imshow(self.image, cmap=plt.cm.Greys_r)
        print("ini")
        plt.show() 

In [None]:
from utils import SelectFilesButton
%matplotlib notebook
import ipywidgets as widgets
import traceback

In [None]:
file_select = SelectFilesButton()
emitters_btn = widgets.IntText(
    value=params.emitters_num,
    description='Emitters number',
    disabled=False
)

alpha_btn = widgets.FloatText(
    value=params.alpha,
    description='Alpha value',
    disabled=False
)

use_filter_cbx = widgets.Checkbox(
    value=params.use_filter,
    description='Use sinogram filter',
    disabled=False
)
interactive_btn = widgets.Checkbox(
    value=False,
    description='Use interactive mode',
    disabled=False
)
run_btn = widgets.Button(
    description='Load',
    disabled=False,
    button_style='info',
    tooltip='Run',
    icon='check',
)


def run_task(e):
    try:
        run_btn.disabled = True
        params.set_values(alpha_btn.value, emitters_btn.value, use_filter_cbx.value, file_select.files[0])
        scanner = Scanner(params)
        scanner.watch_changes()
    except Exception as e:
        traceback.print_exc()
    run_btn.disabled = False


run_btn.on_click(run_task)


In [None]:
widgets.VBox([
    widgets.HBox([file_select, alpha_btn, emitters_btn]),
    widgets.HBox([use_filter_cbx, interactive_btn, run_btn]),
]) 