In [1]:




from dataset_aug_tf import IrisDataset




# In our UNet implementation the dims can be whatever you want.
# You could even change them between training iterations - but it might be a bad idea because all the weights had been learnt at the scale of the previous dims.
INPUT_DIMS = {
    "width" : 256,
    "height" : 128,
    "channels" : 5
}

# In our UNet the output width and height have to be the same as the input width and height. 
OUTPUT_DIMS = {
    "width" : INPUT_DIMS["width"],
    "height" : INPUT_DIMS["height"],
    "channels" : 2
}



model_parameters = {
    # layer sizes
    "output_y" : OUTPUT_DIMS["height"],
    "output_x" : OUTPUT_DIMS["width"],
    "n_channels" : INPUT_DIMS["channels"],
    "n_classes" : OUTPUT_DIMS["channels"],
    "starting_kernels" : 64,
    "expansion" : 2,
    "depth" : 6,
}



dataset_args = {

    "testrun" : True,
    "testrun_size" : 2,
   

    "input_width" : INPUT_DIMS["width"],
    "input_height" : INPUT_DIMS["height"],
    "output_width" : OUTPUT_DIMS["width"],
    "output_height" : OUTPUT_DIMS["height"],
    
    # iris dataset params
    "path_to_sclera_data" : "Data/vein_and_sclera_data",
    # "transform" : transform,
    "n_classes" : OUTPUT_DIMS["channels"],

    "zero_out_non_sclera" : True,
    "add_sclera_to_img" : False,
    "add_bcosfire_to_img" : True,
    "add_coye_to_img" : True
}


data_path = dataset_args["path_to_sclera_data"]
# n_classes = 4 if 'sip' in args.dataset.lower() else 2

train_dataset = IrisDataset(filepath=data_path, split='train', **dataset_args)


real_imgs = train_dataset[0]

print(real_imgs)


dataset_aug_tf.py do_log: False
img_augments.py do_log: False
helper_img_and_fig_tools.py do_log: False
summary for train
valid images: 88
{'images': tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  

In [32]:



from helper_img_and_fig_tools import smart_conversion, show_image, save_img_quick_figs, save_imgs_quick_figs

import numpy as np
import torch

import matplotlib.pyplot as plt
import os.path as osp
# from utils import one_hot2dist

np.random.seed(7)


def get_rising_tensor(shape):

    # Calculate the total number of elements needed
    num_elements = torch.prod(torch.tensor(shape)).item()

    # Create a 1D tensor with sequential numbers
    sequential_tensor = torch.arange(1, num_elements + 1)

    # Reshape the tensor to the desired shape
    reshaped_tensor = sequential_tensor.reshape(shape)

    return reshaped_tensor

tensor_img_shape = (4, 6)
d3_shape = (1, 3, 6)

tensor_img = get_rising_tensor(tensor_img_shape)
d3 = get_rising_tensor(d3_shape)

patch_shape = (2,2)
stride = (2,2)

print(f"{tensor_img=}")
print(f"{d3=}")

tensor_img_unf = tensor_img.unfold(patch_shape[0], stride[0])
# print(f"{tensor_img_unf=}")
# print(f"{tensor_img_unf.shape=}")
tensor_img_unf = tensor_img_unf.unfold(1, patch_shape[1], stride[1])
# print(f"{tensor_img_unf=}")
# print(f"{tensor_img_unf.shape=}")


# Calculate the number of patches
num_patches_h = (tensor_img.size(0) - patch_shape[0]) // stride[0] + 1
num_patches_w = (tensor_img.size(1) - patch_shape[1]) // stride[1] + 1

# Calculate starting indices for each patch
patch_indices = []
for i in range(num_patches_h):
    for j in range(num_patches_w):
        start_h = i * stride[0]
        start_w = j * stride[1]
        patch_indices.append((start_h, start_w))

print(f"{patch_indices=}")












# right_bottom_patches:

# right_patches
# bottom_patches
# right_bottom_corner

# with left up ixs for all of them. And then you concat all of that.
# And at test time you add that to that parts of the accumulator.
# Possibly you could count how many times each part of the accumulator was added to and devide by that,
# but I'm not sure we need that, because it's logits anyway.



# right_patches

right_slice = tensor_img[:, -patch_shape[1]:]
rs_patches = right_slice.unfold(patch_shape[0], stride[0])
x_ix = tensor_img.shape[1] - patch_shape[1]
rs_ixs = [(i*stride[0], x_ix) for i in range(rs_patches.size(0))]
print(f"{right_slice=}")
print(f"{rs_patches=}")
print(f"{rs_ixs=}")

bottom_slice = tensor_img[-patch_shape[0]:, :]
bs_patches = bottom_slice.unfold(1, patch_shape[1], stride[1])
y_ix = tensor_img.shape[0] - patch_shape[0]
bs_ixs = [(y_ix, j) for j in range(bs_patches.size(1))]
print(f"{bottom_slice=}")
print(f"{bs_patches=}")
print(f"{bs_ixs=}")

right_bottom_corner = tensor_img[-patch_shape[0]:, -patch_shape[1]:]
rbc_patch = right_bottom_corner.reshape(1, *patch_shape)
x_ix = tensor_img.shape[1] - patch_shape[1]
y_ix = tensor_img.shape[0] - patch_shape[0]
rbc_ixs = [(y_ix, x_ix)]
print(f"{right_bottom_corner=}")
print(f"{rbc_patch=}")
print(f"{rbc_ixs=}")


# bottom_patches
# right_bottom_corner





















def patchify(tensor_imgs_list, patch_shape, stride):
    """
    tensor_imgs_list: list of tensors at the end of __getitem__ function in IrisDataset
    patch_shape: tuple of 2 integers (H, W)
    stride: tuple of 2 integers (H, W)
    """
    # tensor_imgs_list is a list of tensors
    # each tensor is either a 3D tensor of shape (C, H, W)
    # or a 2D tensor of shape (H, W) (for grayscale images)
    # C is the number of channels
    # H is the height of the image
    # W is the width of the image

    # The function returns a tensor of patches
    # of shape (N, C, H, W)
    # N is the number of patches

    # The function also returns a tensor of patch coordinates (top left and bottom right corners)
    # each patch coordinate is a tuple of 4 integers
    # an element: (y1, x1, y2, x2)

    for tensor_img in tensor_imgs_list:
        assert tensor_img.dim() == 3 or tensor_img.dim() == 2
        patches = tensor_img.unfold(1, patch_shape[0], stride[0]).unfold(2, patch_shape[1], stride[1])

    






tensor_img=tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12],
        [13, 14, 15, 16, 17, 18],
        [19, 20, 21, 22, 23, 24]])
d3=tensor([[[ 1,  2,  3,  4,  5,  6],
         [ 7,  8,  9, 10, 11, 12],
         [13, 14, 15, 16, 17, 18]]])
patch_indices=[(0, 0), (0, 2), (0, 4), (2, 0), (2, 2), (2, 4)]
right_slice=tensor([[ 5,  6],
        [11, 12],
        [17, 18],
        [23, 24]])
rs_patches=tensor([[[ 5, 11],
         [ 6, 12]],

        [[17, 23],
         [18, 24]]])
rs_ixs=[(0, 4), (2, 4)]
bottom_slice=tensor([[13, 14, 15, 16, 17, 18],
        [19, 20, 21, 22, 23, 24]])
bs_patches=tensor([[[13, 14],
         [15, 16],
         [17, 18]],

        [[19, 20],
         [21, 22],
         [23, 24]]])
bs_ixs=[(2, 0), (2, 1), (2, 2)]
right_bottom_corner=tensor([[17, 18],
        [23, 24]])
rbc_patch=tensor([[[17, 18],
         [23, 24]]])
rbc_ixs=[(2, 4)]


In [22]:



from helper_img_and_fig_tools import smart_conversion, show_image, save_img_quick_figs, save_imgs_quick_figs

import numpy as np
import torch

import matplotlib.pyplot as plt
import os.path as osp
# from utils import one_hot2dist

np.random.seed(7)




def get_rising_tensor(shape):

    # Calculate the total number of elements needed
    num_elements = torch.prod(torch.tensor(shape)).item()

    # Create a 1D tensor with sequential numbers
    sequential_tensor = torch.arange(1, num_elements + 1)

    # Reshape the tensor to the desired shape
    reshaped_tensor = sequential_tensor.reshape(shape)

    return reshaped_tensor

# tensor_img_shape = (4, 6)
d3_shape = (2, 3, 4)

# tensor_img = get_rising_tensor(tensor_img_shape)
d3 = get_rising_tensor(d3_shape)

patch_shape = (2,2)
stride = (2,2)




def unfold_3chan(tensor_img, patch_shape, stride: tuple):
    # tensor_img_unf = tensor_img.unfold(0, patch_shape[0], stride[0])
    # tensor_img_unf = tensor_img_unf.unfold(1, patch_shape[1], stride[1])

    tensor_img_unf = tensor_img.unfold(1, patch_shape[0], stride[0])
    tensor_img_unf = tensor_img_unf.unfold(2, patch_shape[1], stride[1])

    combined_shape = tensor_img_unf.size(0), -1, patch_shape[0], patch_shape[1]
    tensor_img_combined = tensor_img_unf.view(combined_shape)
    print(f"{tensor_img_combined=}")
    
    # Calculate the number of patches in each dim
    num_patches_h = (tensor_img.size(0) - patch_shape[0]) // stride[0] + 1
    num_patches_w = (tensor_img.size(1) - patch_shape[1]) // stride[1] + 1


    assert tensor_img_unf.size(1) == num_patches_h * num_patches_w

    # Calculate starting indices for each patch
    patch_indices = []
    for i in range(num_patches_h):
        for j in range(num_patches_w):
            start_h = i * stride[0]
            start_w = j * stride[1]
            patch_indices.append((start_h, start_w))
            

    return tensor_img_unf, patch_indices



def patchify(tensor_img, patch_shape, stride: tuple):

    # tensor_img_unf = tensor_img.unfold(0, patch_shape[0], stride[0])
    # print(f"{tensor_img_unf=}")
    # print(f"{tensor_img_unf.shape=}")
    # tensor_img_unf = tensor_img_unf.unfold(1, patch_shape[1], stride[1])
    
    tensor_img_unf, patch_indices = unfold_3chan(tensor_img, patch_shape, stride)
    print(f"{tensor_img_unf=}")
    print(f"{tensor_img_unf.shape=}")
    print(f"{patch_indices=}")







    # right_bottom_patches:

    # right_patches
    # bottom_patches
    # right_bottom_corner

    # with left up ixs for all of them. And then you concat all of that.
    # And at test time you add that to that parts of the accumulator.
    # Possibly you could count how many times each part of the accumulator was added to and devide by that,
    # but I'm not sure we need that, because it's logits anyway.



    # right_patches

    right_slice = tensor_img[:, -patch_shape[1]:]


    rs_patches = right_slice.unfold(0, patch_shape[0], stride[0])
    rs_patches.unfold(1, patch_shape[1], stride[1])
    x_ix = tensor_img.shape[1] - patch_shape[1]
    rs_ixs = [(i*stride[0], x_ix) for i in range(rs_patches.size(0))]
    # print(f"{right_slice=}")
    print(f"{rs_patches=}")
    print(f"{rs_ixs=}")
    # make it so the img has the Channel dim, which is 1 in this case
    rs_patches = rs_patches.unsqueeze(1)


    # bottom_patches

    bottom_slice = tensor_img[-patch_shape[0]:, :]
    bs_patches = bottom_slice.unfold(0, patch_shape[0], stride[0])
    bs_patches = bs_patches.unfold(1, patch_shape[1], stride[1])
    y_ix = tensor_img.shape[0] - patch_shape[0]
    bs_ixs = [(y_ix, j) for j in range(bs_patches.size(1))]
    # print(f"{bottom_slice=}")
    print(f"{bs_patches=}")
    print(f"{bs_ixs=}")
    # make it so the img has the Channel dim, which is 1 in this case
    bs_patches = bs_patches.unsqueeze(1)


    # right_bottom_corner

    right_bottom_corner = tensor_img[-patch_shape[0]:, -patch_shape[1]:]
    rbc_patch = right_bottom_corner
    x_ix = tensor_img.shape[1] - patch_shape[1]
    y_ix = tensor_img.shape[0] - patch_shape[0]
    rbc_ixs = [(y_ix, x_ix)]
    # print(f"{right_bottom_corner=}")
    print(f"{rbc_patch=}")
    print(f"{rbc_ixs=}")
    # give batch sie and channel dims to this patch
    rbc_patch = rbc_patch.unsqueeze(0).unsqueeze(0)
    
    
    patch_dict = {
        "main_patches" : tensor_img_unf,
        "main_lu_ixs" : patch_indices,
        "right_patches" : rs_patches,
        "right_lu_ixs" : rs_ixs,
        "bottom_patches" : bs_patches,
        "bottom_lu_ixs" : bs_ixs,
        "right_bottom_corner" : rbc_patch,
        "right_bottom_corner_lu_ixs" : rbc_ixs
    }

    return patch_dict




def accumulate_patches(prediction_tensor_shape, patch_shape, stride: tuple, patch_dict):
    
    accumulating_tensor = torch.zeros(prediction_tensor_shape)
    print(f"{accumulating_tensor=}")

    for i, (lu_ix, patch) in enumerate(zip(patch_dict["main_lu_ixs"], patch_dict["main_patches"])):

        y1, x1 = lu_ix
        y2, x2 = y1 + patch_shape[0], x1 + patch_shape[1]
        accumulating_tensor[:, y1:y2, x1:x2] += patch
    
    for i, (lu_ix, patch) in enumerate(zip(patch_dict["right_lu_ixs"], patch_dict["right_patches"])):
        
        y1, x1 = lu_ix
        y2, x2 = y1 + patch_shape[0], x1 + patch_shape[1]
        accumulating_tensor[:, y1:y2, x1:x2] += patch

    for i, (lu_ix, patch) in enumerate(zip(patch_dict["bottom_lu_ixs"], patch_dict["bottom_patches"])):

        y1, x1 = lu_ix
        y2, x2 = y1 + patch_shape[0], x1 + patch_shape[1]
        accumulating_tensor[:, y1:y2, x1:x2] += patch
    
    for i, (lu_ix, patch) in enumerate(zip(patch_dict["right_bottom_corner_lu_ixs"], patch_dict["right_bottom_corner"])):
        
        y1, x1 = lu_ix
        y2, x2 = y1 + patch_shape[0], x1 + patch_shape[1]
        accumulating_tensor[:, y1:y2, x1:x2] += patch
    
    return accumulating_tensor





print(f"{d3=}")
patch_dict = patchify(d3, patch_shape, stride)

# simulate the patches going through the model and becoming 2-channel

print(f"{patch_dict['main_patches'].shape=}")
print(f"{patch_dict['right_patches'].shape=}")
print(f"{patch_dict['bottom_patches'].shape=}")
print(f"{patch_dict['right_bottom_corner'].shape=}")

# make tensors have 2 channels as if they went through the model


concated_patches = torch.cat([patch_dict["main_patches"], patch_dict["right_patches"], patch_dict["bottom_patches"], patch_dict["right_bottom_corner"]], dim=0)
print(f"{concated_patches=}")

# make tensor 2-channel, by repeating the same data in the second channel. This is for dim=1
concated_patches = concated_patches.repeat(1, 2, 1, 1)
print(f"{concated_patches=}")

# deconcat patches
num_main = patch_dict["main_patches"].size(0)
num_right = patch_dict["right_patches"].size(0)
num_bottom = patch_dict["bottom_patches"].size(0)
num_rbc = patch_dict["right_bottom_corner"].size(0)

pred_patch_dict = {
    "main_patches" : concated_patches[:num_main],
    "right_patches" : concated_patches[num_main:num_main + num_right],
    "bottom_patches" : concated_patches[num_main + num_right:num_main + num_right + num_bottom],
    "right_bottom_corner" : concated_patches[num_main + num_right + num_bottom:],

    "main_lu_ixs" : patch_dict["main_lu_ixs"],
    "right_lu_ixs" : patch_dict["right_lu_ixs"],
    "bottom_lu_ixs" : patch_dict["bottom_lu_ixs"],
    "right_bottom_corner_lu_ixs" : patch_dict["right_bottom_corner_lu_ixs"]
}

prediction_tensor_shape = (2, tensor_img_shape[0], tensor_img_shape[1])

accumulating_tensor = accumulate_patches(prediction_tensor_shape, patch_shape, stride, pred_patch_dict)


print(f"{accumulating_tensor=}")
























# def patchify(tensor_imgs_list, patch_shape, stride):
#     """
#     tensor_imgs_list: list of tensors at the end of __getitem__ function in IrisDataset
#     patch_shape: tuple of 2 integers (H, W)
#     stride: tuple of 2 integers (H, W)
#     """
#     # tensor_imgs_list is a list of tensors
#     # each tensor is either a 3D tensor of shape (C, H, W)
#     # or a 2D tensor of shape (H, W) (for grayscale images)
#     # C is the number of channels
#     # H is the height of the image
#     # W is the width of the image

#     # The function returns a tensor of patches
#     # of shape (N, C, H, W)
#     # N is the number of patches

#     # The function also returns a tensor of patch coordinates (top left and bottom right corners)
#     # each patch coordinate is a tuple of 4 integers
#     # an element: (y1, x1, y2, x2)

#     for tensor_img in tensor_imgs_list:
#         assert tensor_img.dim() == 3 or tensor_img.dim() == 2
#         patches = tensor_img.unfold(1, patch_shape[0], stride[0]).unfold(2, patch_shape[1], stride[1])

    






d3=tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],

        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]])
tensor_img_combined=tensor([[[[ 1,  2],
          [ 5,  6]],

         [[ 3,  4],
          [ 7,  8]]],


        [[[13, 14],
          [17, 18]],

         [[15, 16],
          [19, 20]]]])
tensor_img_unf=tensor([[[[[ 1,  2],
           [ 5,  6]],

          [[ 3,  4],
           [ 7,  8]]]],



        [[[[13, 14],
           [17, 18]],

          [[15, 16],
           [19, 20]]]]])
tensor_img_unf.shape=torch.Size([2, 1, 2, 2, 2])
patch_indices=[(0, 0)]
rs_patches=tensor([[[[ 5, 17],
          [ 6, 18],
          [ 7, 19],
          [ 8, 20]],

         [[ 9, 21],
          [10, 22],
          [11, 23],
          [12, 24]]]])
rs_ixs=[(0, 1)]
bs_patches=tensor([[[[[ 1,  5],
           [13, 17]],

          [[ 2,  6],
           [14, 18]],

          [[ 3,  7],
           [15, 19]],

          [[ 4,  8],
           [16,

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 2 but got size 4 for tensor number 1 in the list.

In [23]:
# Naive patchification









from helper_img_and_fig_tools import smart_conversion, show_image, save_img_quick_figs, save_imgs_quick_figs

import numpy as np
import torch

import matplotlib.pyplot as plt
import os.path as osp
# from utils import one_hot2dist

np.random.seed(7)




def get_rising_tensor(shape):

    # Calculate the total number of elements needed
    num_elements = torch.prod(torch.tensor(shape)).item()

    # Create a 1D tensor with sequential numbers
    sequential_tensor = torch.arange(1, num_elements + 1)

    # Reshape the tensor to the desired shape
    reshaped_tensor = sequential_tensor.reshape(shape)

    return reshaped_tensor

# tensor_img_shape = (4, 6)
d3_shape = (2, 3, 4)

# tensor_img = get_rising_tensor(tensor_img_shape)
d3 = get_rising_tensor(d3_shape)

patch_shape = (2,2)
stride = (2,2)




def unfold_3chan(tensor_img, patch_shape, stride: tuple):
    # tensor_img_unf = tensor_img.unfold(0, patch_shape[0], stride[0])
    # tensor_img_unf = tensor_img_unf.unfold(1, patch_shape[1], stride[1])

    y_ix = 0
    x_ix = 0

    if y_ix + patch_shape[0] < tensor_img.size(1) and x_ix + patch_shape[1] < tensor_img.size(2):
        patches = tensor_img[:, y_ix:y_ix + patch_shape[0], x_ix:x_ix + patch_shape[1]]
        patches = patches.unsqueeze(0)
        left_upper_ixs = [(y_ix, x_ix)]
        y_ix += stride[0]
    else:
        return None
    

    while x_ix + patch_shape[1] < tensor_img.size(2):
        while y_ix + patch_shape[0] < tensor_img.size(1):
            patch = tensor_img[:, y_ix:y_ix + patch_shape[0], x_ix:x_ix + patch_shape[1]]
            patch = patch.unsqueeze(0)
            patches = torch.cat([patches, patch], dim=0)
            left_upper_ixs.append((y_ix, x_ix))
            y_ix += stride[0]
        x_ix += stride[1]
    
    return patches, left_upper_ixs





def patchify(tensor_img, patch_shape, stride: tuple):


    main_patches, main_lu_ixs = unfold_3chan(tensor_img, patch_shape, stride)
    print(f"{main_patches=}")
    print(f"{main_patches.shape=}")
    print(f"{main_lu_ixs=}")




    # right_bottom_patches:

    # right_patches
    # bottom_patches
    # right_bottom_corner

    # with left up ixs for all of them. And then you concat all of that.
    # And at test time you add that to that parts of the accumulator.
    # Possibly you could count how many times each part of the accumulator was added to and devide by that,
    # but I'm not sure we need that, because it's logits anyway.




    # right_patches

    right_slice = tensor_img[:, :, -patch_shape[1]:]

    right_patches, right_lu_ixs = unfold_3chan(right_slice, patch_shape, stride)

    print(f"{right_slice=}")
    print(f"{right_patches=}")
    print(f"{right_patches.shape=}")
    print(f"{right_lu_ixs=}")




    # bottom_patches

    bottom_slice = tensor_img[:, -patch_shape[0]:, :]
    bottom_patches, bottom_lu_ixs = unfold_3chan(bottom_slice, patch_shape, stride)
    print(f"{bottom_slice=}")
    print(f"{bottom_patches=}")
    print(f"{bottom_patches.shape=}")
    print(f"{bottom_lu_ixs=}")




    # right_bottom_corner

    right_bottom_corner = tensor_img[:, -patch_shape[0]:, -patch_shape[1]:]
    right_bottom_patches, right_bottom_lu_ixs = unfold_3chan(right_bottom_corner, patch_shape, stride)
    print(f"{right_bottom_corner=}")
    print(f"{right_bottom_patches=}")
    print(f"{right_bottom_patches.shape=}")
    print(f"{right_bottom_lu_ixs=}")


    
    # patch_dict = {
    #     "main_patches" : tensor_img_unf,
    #     "main_lu_ixs" : patch_indices,
    #     "right_patches" : rs_patches,
    #     "right_lu_ixs" : rs_ixs,
    #     "bottom_patches" : bs_patches,
    #     "bottom_lu_ixs" : bs_ixs,
    #     "right_bottom_corner" : rbc_patch,
    #     "right_bottom_corner_lu_ixs" : rbc_ixs
    # }

    patch_dict = {
        "main_patches" : main_patches,
        "main_lu_ixs" : main_lu_ixs,
        "right_patches" : right_patches,
        "right_lu_ixs" : right_lu_ixs,
        "bottom_patches" : bottom_patches,
        "bottom_lu_ixs" : bottom_lu_ixs,
        "right_bottom_corner" : right_bottom_patches,
        "right_bottom_corner_lu_ixs" : right_bottom_lu_ixs
    }

    return patch_dict




def accumulate_patches(prediction_tensor_shape, patch_shape, stride: tuple, patch_dict):
    
    accumulating_tensor = torch.zeros(prediction_tensor_shape)
    print(f"{accumulating_tensor=}")

    for i, (lu_ix, patch) in enumerate(zip(patch_dict["main_lu_ixs"], patch_dict["main_patches"])):

        y1, x1 = lu_ix
        y2, x2 = y1 + patch_shape[0], x1 + patch_shape[1]
        accumulating_tensor[:, y1:y2, x1:x2] += patch
    
    for i, (lu_ix, patch) in enumerate(zip(patch_dict["right_lu_ixs"], patch_dict["right_patches"])):
        
        y1, x1 = lu_ix
        y2, x2 = y1 + patch_shape[0], x1 + patch_shape[1]
        accumulating_tensor[:, y1:y2, x1:x2] += patch

    for i, (lu_ix, patch) in enumerate(zip(patch_dict["bottom_lu_ixs"], patch_dict["bottom_patches"])):

        y1, x1 = lu_ix
        y2, x2 = y1 + patch_shape[0], x1 + patch_shape[1]
        accumulating_tensor[:, y1:y2, x1:x2] += patch
    
    for i, (lu_ix, patch) in enumerate(zip(patch_dict["right_bottom_corner_lu_ixs"], patch_dict["right_bottom_corner"])):
        
        y1, x1 = lu_ix
        y2, x2 = y1 + patch_shape[0], x1 + patch_shape[1]
        accumulating_tensor[:, y1:y2, x1:x2] += patch
    
    return accumulating_tensor





print(f"{d3=}")
patch_dict = patchify(d3, patch_shape, stride)

# The img is 2-channel from the start so the patches are also 2 channel, so they are the same as they would come out of the model

print(f"{patch_dict['main_patches'].shape=}")
print(f"{patch_dict['right_patches'].shape=}")
print(f"{patch_dict['bottom_patches'].shape=}")
print(f"{patch_dict['right_bottom_corner'].shape=}")

# make tensors have 2 channels as if they went through the model


concated_patches = torch.cat([patch_dict["main_patches"], patch_dict["right_patches"], patch_dict["bottom_patches"], patch_dict["right_bottom_corner"]], dim=0)
print(f"{concated_patches=}")


# deconcat patches
num_main = patch_dict["main_patches"].size(0)
num_right = patch_dict["right_patches"].size(0)
num_bottom = patch_dict["bottom_patches"].size(0)
num_rbc = patch_dict["right_bottom_corner"].size(0)

pred_patch_dict = {
    "main_patches" : concated_patches[:num_main],
    "right_patches" : concated_patches[num_main:num_main + num_right],
    "bottom_patches" : concated_patches[num_main + num_right:num_main + num_right + num_bottom],
    "right_bottom_corner" : concated_patches[num_main + num_right + num_bottom:],

    "main_lu_ixs" : patch_dict["main_lu_ixs"],
    "right_lu_ixs" : patch_dict["right_lu_ixs"],
    "bottom_lu_ixs" : patch_dict["bottom_lu_ixs"],
    "right_bottom_corner_lu_ixs" : patch_dict["right_bottom_corner_lu_ixs"]
}

prediction_tensor_shape = (2, d3_shape[1], d3_shape[2])

accumulating_tensor = accumulate_patches(prediction_tensor_shape, patch_shape, stride, pred_patch_dict)


print(f"{accumulating_tensor=}")
print(f"{accumulating_tensor.shape=}")
















d3=tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],

        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]])
main_patches=tensor([[[[ 1,  2],
          [ 5,  6]],

         [[13, 14],
          [17, 18]]]])
main_patches.shape=torch.Size([1, 2, 2, 2])
main_lu_ixs=[(0, 0)]


TypeError: cannot unpack non-iterable NoneType object