In [1]:
import safetensors.torch
import torch
import os
import math
from safetensors import safe_open
from zipfile import ZipFile
import plotly.graph_objects as go
import plotly.express as px
from torchvision.io import decode_image

In [2]:
def can_tile_tensor(tensor: torch.Tensor, tile_width: int, tile_height: int):
    tensor_shape = tensor[0].shape
    assert tensor_shape[1] % tile_width == 0, "Tile width is not divisible by image width"
    assert tensor_shape[0] % tile_height == 0, "Tile height is not divisible by image width"

def get_all_tiles(tensor: torch.Tensor, tile_width: int, tile_height: int):
    """ Returns a tensor in format batch, tile_number, tile_height, tile_width
    tensor: A torch tensor in with shape: batch, width, height
    """
    can_tile_tensor(tensor, tile_width, tile_height)
    
    tiles = []
    rows = tensor.split(tile_height, dim=1)
    for row in rows: tiles.extend(row.split(tile_width, dim=2))
    
    return torch.stack(tiles, dim=1)

In [3]:
file_path = os.path.expanduser('~/jpeters/histograms/idr0001/idr0001-plate_1A-converted.safetensors')
image = 'idr0001-plate_1A-converted.zip/Meas_01(2012-07-31_10-41-12)_001001001_series-0_z-0_t-0_channel-1.png'
with safe_open(file_path, framework="torch", device="cpu") as f:
        tensor:torch.Tensor = f.get_tensor(image)
tensor.shape
# px.imshow(tensor.numpy())

torch.Size([26, 43])

In [4]:
raw_filepath = os.path.expanduser('~/dataset/images/idr0001/idr0001-plate_1A-converted.zip')
with ZipFile(raw_filepath) as myzip:
        image_tensor_name = image.split('/')[-1]
        with myzip.open(image_tensor_name) as myfile:
            img_bytes = bytearray(myfile.read())
            torch_buffer = torch.frombuffer(img_bytes, dtype=torch.uint8)
            image_tensor = decode_image(torch_buffer)[0]
raw_fig = px.imshow(image_tensor.numpy(), binary_string=True)
# raw_fig.show()            

In [23]:
(1040, 1376) == tuple(image_tensor.shape)
100 % 2

0

In [5]:
tiles = get_all_tiles(image_tensor.unsqueeze(0), 32, 40)[0]
tile_maxes = tiles.amax(dim=(1,2))
tile_mins = tiles.amin(dim=(1,2))
diffs = tile_maxes - tile_mins

num_tiles_wide = image_tensor.shape[1] // 32
num_tiles_tall = image_tensor.shape[0] // 40

diffs_hist = diffs.reshape((num_tiles_tall, num_tiles_wide)).numpy()

px.imshow(diffs_hist)

In [6]:
k_type = "top"
top_k = 10

heatmap_data = diffs_hist.tolist()
tiled_width = len(heatmap_data[0])
tiled_height = len(heatmap_data)

img_width = image_tensor.shape[1]
img_height = image_tensor.shape[0]

tile_width = img_width/tiled_width
tile_height = img_height/tiled_height    

values_with_coords = []
for i, row in enumerate(heatmap_data):
    for j, tile_val in enumerate(row):
        values_with_coords.append((i, j, tile_val))

reverse = True
if k_type == "bottom":    
    reverse = False
    
values_with_coords.sort(key=lambda x: x[-1], reverse=reverse)

study = "idr0001"

percent = False

top_k = math.ceil(tiled_height * tiled_width * (top_k / 100))
percent = True

top_coordinates = values_with_coords[:top_k]
    
fig = go.Figure(raw_fig)
shapes = []
for i, j, _ in top_coordinates:
    raw_y = i * tile_height
    raw_x = j * tile_width
    
    shapes.append(
        dict(
            type="rect",
            xref="x", yref="y",
            x0=raw_x, y0=raw_y,
            x1=raw_x+tile_width, y1=raw_y+tile_height,
            line=dict(
                color="Red",
                width=1.5,
            )
        )
    )
fig.update_layout(shapes=shapes)

fig.show()

In [7]:
tensor

tensor([[ 6.3383,  4.4070,  7.3930,  ..., 30.8461, 18.1273, 16.3961],
        [12.3172,  9.4047, 18.8352,  ..., 42.2797, 33.8914, 27.9297],
        [17.8781, 14.2648, 23.6938,  ..., 49.8727, 50.2227, 40.1820],
        ...,
        [47.4445, 56.3008, 65.7148,  ..., 79.8156, 62.5281, 67.9430],
        [37.7031, 50.9422, 58.6133,  ..., 74.1719, 67.3602, 60.6391],
        [28.5781, 43.1391, 51.6742,  ..., 67.0969, 59.5195, 53.2609]])

In [8]:
normal_t = tensor.clone()
normal_t.mean()

tensor(137.3148)

In [11]:
def get_patch_image_dims(tile_width, tile_height, image_width, image_height):
    max_height = 45
    max_width = 45
    temp_tile_height = tile_height
    temp_tile_width = tile_width
    
    while True:
        temp_tile_height = tile_height
        temp_tile_width = tile_width
    
        if image_height % temp_tile_height != 0 or image_width % temp_tile_width != 0:
            while image_height % temp_tile_height != 0:
                temp_tile_height += 1
            while image_width % temp_tile_width != 0:
                temp_tile_width += 1
                
        if temp_tile_height > max_height:
            image_height -= 1
        if temp_tile_width > max_width:
            image_width -= 1
        
        if temp_tile_width <= max_width and temp_tile_height <= max_height:
            break
        
    return temp_tile_width, temp_tile_height, image_width, image_height

get_patch_image_dims(32, 32, 1233, 491)

(44, 35, 1232, 490)