In [10]:
import json
import os
import numpy as np
import cv2
from tifffile import imread
import torch
from sklearn.model_selection import train_test_split

class FrameSegmenter:
    def __init__(self, json_folder, tif_folder, output_folder, test_size=0.3, random_seed=42):
        self.json_folder = json_folder
        self.tif_folder = tif_folder
        self.output_folder = output_folder
        self.test_size = test_size
        self.random_seed = random_seed

    def extract_sequence_number(self, json_filepath):
        filename = os.path.basename(json_filepath)
        sequence_number, _ = os.path.splitext(filename)
        return sequence_number

    def load_json_data(self, json_filepath):
        with open(json_filepath, 'r') as json_file:
            data = json.load(json_file)

        return data

    def convert_to_grayscale(self, frame):
        return np.dot(frame[..., :3], [0.2989, 0.5870, 0.1140])

    def process_frames(self, json_filepath, tif_filepath, output_folder, num_frames=30):
        json_data = self.load_json_data(json_filepath)

        # Check if the TIFF file exists
        if not os.path.exists(tif_filepath):
            print(f"TIFF file not found: {tif_filepath}. Skipping.")
            return

        # Loop through the specified number of frames (or all frames if fewer)
        for frame_key, frame_info in list(json_data.items())[:num_frames]:
            frame_number = int(frame_key.split('_')[-1]) - 1

            # Load the frame from the TIFF file
            frame = imread(tif_filepath, key=frame_number)

            # Process each box in the current frame
            for box_index, box_info in enumerate(frame_info.get('boxes', [])):
                # Convert x, y, w, h, and box_id to integers
                x, y, w, h, box_id = map(int, box_info)

                # Segment the frame based on the bounding box
                segmented_image = frame[y:y+h, x:x+w]

                # Convert the segmented frame to grayscale
                grayscale_image = self.convert_to_grayscale(segmented_image)

                # Save the grayscale image as a PyTorch tensor
                output_filename = f"{self.extract_sequence_number(json_filepath)}_{frame_number}_{box_id}.pt"
                output_filepath = os.path.join(output_folder, output_filename)

                # Count the number of cells for the current box_id in the cells data
                num_cells = sum(1 for cell_info in frame_info.get('cells', []) if cell_info[2] == box_id)

                # Save the PyTorch tensor with image and number of cells
                torch.save([torch.from_numpy(grayscale_image).float(), num_cells], output_filepath)

    def process_sequences(self, num_frames=30):
        # List to store sequence numbers
        sequence_numbers = []

        # Loop through all files in the JSON folder
        for json_filename in os.listdir(self.json_folder):
            if json_filename.endswith(".json"):
                # Extract sequence number from the JSON filename
                sequence_number = self.extract_sequence_number(json_filename)
                sequence_numbers.append(sequence_number)

        # Split sequence numbers into train and test sets
        train_sequence_numbers, val_sequence_numbers = train_test_split(
            sequence_numbers,
            test_size=self.test_size,
            random_state=self.random_seed,
            stratify=[int(seq.split('_')[0]) % 2 for seq in sequence_numbers]
        )

        print("\nTrain Sequences:", train_sequence_numbers)
        print("\nVal Sequences:", val_sequence_numbers)

        # Process frames for train sequences
        for sequence_number in train_sequence_numbers:
            json_filepath = os.path.join(self.json_folder, f"{sequence_number}.json")
            tif_filepath = os.path.join(self.tif_folder, f"{sequence_number.replace('_', '-')}.tif")
            self.process_frames(json_filepath, tif_filepath, os.path.join(self.output_folder, "train"), num_frames)

        # Process frames for val sequences
        for sequence_number in val_sequence_numbers:
            json_filepath = os.path.join(self.json_folder, f"{sequence_number}.json")
            tif_filepath = os.path.join(self.tif_folder, f"{sequence_number.replace('_', '-')}.tif")
            self.process_frames(json_filepath, tif_filepath, os.path.join(self.output_folder, "val"), num_frames)

# Example usage
json_folder = "../Data/all_json"
tif_folder = "../Data/All_Sequences"
output_folder = "../Data/input_tensors"

frame_segmenter = FrameSegmenter(json_folder, tif_folder, output_folder)
frame_segmenter.process_sequences(num_frames=30)


TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'



Train Sequences: ['15_1', '5_1', '2_1', '1_1', '6_3', '13_1', '12_3', '4_3', '1_2', '13_2', '5_2', '3_3', '10_1', '6_2', '12_1', '2_2', '10_3', '1_3', '9_3', '8_1', '9_1', '5_3', '7_2', '8_3', '11_3', '10_2', '13_3', '4_1', '8_2']

Val Sequences: ['6_1', '14_2', '11_1', '12_2', '7_3', '7_1', '3_1', '14_1', '2_3', '9_2', '11_2', '3_2', '4_2']
TIFF file not found: ../Data/All_Sequences\15-1.tif. Skipping.


TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'
TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offs