In [1]:
import os
os.environ["PAT_MAXIMUM_BATCH_SIZE"] = "5"

import unittest

from os.path import join

import numpy as np
from matplotlib import pyplot as plt

import patato.data
from patato.data import get_msot_time_series_example
from patato.io.msot_data import PAData
from patato.processing.preprocessing_algorithm import DefaultMSOTPreProcessor
from patato.recon import OpenCLBackprojection
from patato.recon.backprojection_reference import ReferenceBackprojection

In [2]:
class BackprojectionTest(unittest.TestCase):
    def setUp(self) -> None:
        self.pa = get_msot_time_series_example("so2")[0:1, 0:1]

        self.preproc = DefaultMSOTPreProcessor(time_factor=1, detector_factor=2)
        self.filtered_time_series, self.new_settings, _ = self.preproc.run(self.pa.get_time_series(), self.pa)

    def _test_backprojector(self, reconstructor_class):
        reconstructor = reconstructor_class([333, 334, 1], [0.025, 0.025, 1.])
        r, _, _ = reconstructor.run(self.filtered_time_series, self.pa, **self.new_settings)
        self.assertEqual(r.shape, (1, 1, 1, 334, 333))
        self.assertAlmostEqual(np.mean(r.values), 315.78, 2)
        
        reconstructor = reconstructor_class([1, 333, 334], [1., 0.025, 0.025])
        r, _, _ = reconstructor.run(self.filtered_time_series, self.pa, **self.new_settings)
        self.assertEqual(r.shape, (1, 1, 334, 333, 1))
        
        reconstructor = reconstructor_class([334, 1, 333], [0.025, 1., 0.025])
        r, _, _ = reconstructor.run(self.filtered_time_series, self.pa, **self.new_settings)
        self.assertEqual(r.shape, (1, 1, 333, 1, 334))
        
    
    def test_reference_reconstruction(self):
        self._test_backprojector(ReferenceBackprojection)
        

    def test_opencl_reconstruction(self):
        try:
            import pyopencl
        except ImportError:
            return  # Skip test if pyopencl is not installed
        
        self._test_backprojector(OpenCLBackprojection)

In [3]:
t = BackprojectionTest()
t.setUp()
t.test_opencl_reconstruction()
t.test_reference_reconstruction()

