In [1]:
import math
import os
import glob
import xml.etree.ElementTree as ET
import array
from pathlib import Path

import numpy as np

In [2]:
import utils2p

In [3]:
data_dir = "/home/jbraun/data/210216_J1xCI9/Fly1/001_xz/2p"
metadata_dir = utils2p.find_metadata_file(data_dir)
raw_dir = utils2p.find_raw_file(data_dir)

In [4]:
metadata = utils2p.Metadata(metadata_dir)

In [23]:
path = os.path.expanduser(os.path.expandvars(raw_dir))
n_time_points = metadata.get_n_time_points()
width = metadata.get_num_x_pixels()
height = metadata.get_num_y_pixels()
n_channels = metadata.get_n_channels()
byte_size = os.stat(path).st_size

In [25]:
assert not byte_size % 1, "File does not have an integer byte length."
byte_size = int(byte_size)

In [26]:
n_z = (
        byte_size / 2 / width / height / n_time_points / n_channels
    )  # divide by two because the values are of type short (16bit = 2byte)
assert (
        not n_z % 1
    ), "Size given in metadata does not match the size of the raw file."
n_z = int(n_z)
meta_n_z = metadata.get_n_z()

In [42]:
stacks = np.zeros(
            (n_channels, n_time_points, meta_n_z, height, width), dtype="uint16"
        )
off_stacks = stacks
image_size = width * height
t_size = (
            width * height * n_z * n_channels
        )  # number of values stored for a given time point (this includes images for all channels)

In [111]:
 with open(path, "rb") as f:
    for t in range(n_time_points):
        if t == 0:
            print('{}/{}'.format(t,n_time_points))
        a = array.array("H")
        a.fromfile(f, t_size)
        a = np.array(a).reshape(
            (-1, image_size)
        )  # each row is an image alternating between channels
        if t == 0:
            print(a.shape)
        for c in range(n_channels):
            stacks[c, t, :, :, :] = a[c::n_channels, :].reshape(
                (n_z, height, width)
            )[:meta_n_z, :, :]
            off_stacks[c, t, :, :, :] = a[c::n_channels, :].reshape(
                (n_z, height, width)
            )[meta_n_z:, :, :]
            if t == 0 and c == 0:
                print(a[c::n_channels, :].shape)
                print(a[c::n_channels, :].reshape((n_z, height, width)).shape)
                print(a[c::n_channels, :].reshape((n_z, height, width))[:meta_n_z, :, :].shape)
stacks = tuple(np.squeeze(stacks))
off_stacks = tuple(np.squeeze(off_stacks))

0/4100
(4, 353280)
(2, 353280)
(2, 480, 736)
(1, 480, 736)


In [116]:
np.sum(stacks[0] != off_stacks[0])

0

In [178]:
np.arange(9).reshape((3,3))

array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])

In [112]:
stacks_utils2p = utils2p.load_raw(raw_dir, metadata)

In [117]:
np.sum(stacks_utils2p[0] != stacks[0])

1409138544

In [39]:
import matplotlib.pyplot as plt

In [120]:
%matplotlib notebook

fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(9.5, 10), sharex=True, sharey=True)

axs[0,0].imshow(stacks[0][0, :, :])
axs[0,1].imshow(stacks[1][0, :, :])
axs[1,0].imshow(off_stacks[0][0, :, :])
axs[1,1].imshow(off_stacks[1][0, :, :])
axs[2,0].imshow(stacks_utils2p[0][0, :, :])
axs[2,1].imshow(stacks_utils2p[1][0, :, :])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f7912601f90>

In [158]:
import time

In [159]:
class FrameFromRaw:
    def __init__(self, path, width, height, n_channels, meta_n_z=1, n_z=2):
        self.path = os.path.expanduser(os.path.expandvars(path))
        self.width = width
        self.height = height
        self.n_channels = n_channels
        self.meta_n_z = meta_n_z
        self.n_z = n_z
        
        self.init_read_sizes()
        
    def init_read_sizes(self):
        self.image_size = self.width * self.height
        self.t_size = self.width * self.height * self.n_channels * self.n_z
        
    
    def read_last_frame(self):
        # TODO: in online mode, we don't want to read the fly-backframes, but only the real data
        # right now, this function assumes that readl data + flyback frame are being read at once
        # start = time.process_time()
        with open(path, "rb") as f:
            f.seek(-2 * self.t_size, os.SEEK_END)  # 2* because it is uint16
            a = array.array("H")  # "H" correspond to uint16
            a.fromfile(f, self.t_size)
        # print(time.process_time() - start)
        # --> takes in the order of 4ms to open, read, and close file
        return self.reshape_byte_array(a)
    
    def read_nth_frame(self, n_frame):
        with open(path, "rb") as f:
            f.seek(n_frame * 2 * self.t_size, os.SEEK_SET)  # 2* because it is uint16
            a = array.array("H")  # "H" correspond to uint16
            a.fromfile(f, self.t_size)
        return self.reshape_byte_array(a)
            
    def reshape_byte_array(self, a):
        out = ()
        a = np.array(a).reshape((-1, self.image_size))  # e.g. 4 x 353280 for a 480x736 frame
        if self.n_z > 1:
            for c in range(self.n_channels):
                out += (np.squeeze(a[c::self.n_channels, :].reshape(
                                   (self.n_z, self.height, self.width)
                                   )[:self.meta_n_z, :, :]),
                       )
        else:
            for c in range(self.n_channels):
                out += (np.array(a[c * image_size : (c + 1) * image_size]
                                ).reshape((height, width)),
                       )
        return out
            
            

            
class FrameFromRawMetadata(FrameFromRaw):
    def __init__(self, path, metadata, n_z=2):
        super(FrameFromRawMetadata, self).__init__(path=path, 
            width=metadata.get_num_x_pixels(),
            height=metadata.get_num_y_pixels(),
            n_channels=metadata.get_n_channels(),
            meta_n_z=metadata.get_n_z(),
            n_z=n_z)

In [160]:
myFrameFromRaw = FrameFromRawMetadata(raw_dir, metadata, n_z=2)
# myFrameFromRaw = FrameFromRaw(raw_dir, width=736, height=480, 
#                               n_channels=2, meta_n_z=1, n_z=2)

myFrameFromRaw.t_size

1413120

In [177]:
frames = myFrameFromRaw.read_last_frame()
print(frames[0].shape, frames[1].shape)

0.00429674999998042
(480, 736) (480, 736)


In [155]:
%matplotlib notebook

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9.5, 5), sharex=True, sharey=True)

axs[0,0].imshow(stacks_utils2p[0][-1, :, :])
axs[0,1].imshow(stacks_utils2p[1][-1, :, :])
axs[1,0].imshow(frames[0])
axs[1,1].imshow(frames[1])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f7910ead810>

In [156]:
print(np.sum(frames[0] != stacks_utils2p[0][-1, :, :]))
print(np.sum(frames[1] != stacks_utils2p[1][-1, :, :]))

0
0


In [157]:
frames100 = myFrameFromRaw.read_nth_frame(100)
print(frames100[0].shape, frames100[1].shape)
print(np.sum(frames100[0] != stacks_utils2p[0][100, :, :]))
print(np.sum(frames100[1] != stacks_utils2p[1][100, :, :]))

(480, 736) (480, 736)
0
0
