In [13]:
import torch
from torch import nn


In [14]:
m = nn.AdaptiveAvgPool2d((5,7))
input = torch.randn(1, 64, 8, 9)
output = m(input)
print(output.shape)
# >>> # target output size of 7x7 (square)
# >>> m = nn.AdaptiveAvgPool2d(7)
# >>> input = torch.randn(1, 64, 10, 9)
# >>> output = m(input)
# >>> # target output size of 10x7
# >>> m = nn.AdaptiveAvgPool2d((None, 7))
# >>> input = torch.randn(1, 64, 10, 9)
# >>> output = m(input)

torch.Size([1, 64, 5, 7])


In [15]:
m = nn.AdaptiveAvgPool2d((1,1))
input = torch.randn(1, 64, 8, 9)
output = m(input)
print(output.shape)

torch.Size([1, 64, 1, 1])


## Dataset

In [1]:
from typing import Tuple
import numpy as np
import torch
from torchvision.transforms import ToTensor
from torch.utils import data
from skimage.transform import resize
from skimage import io, color
from scipy import io as spio
import pickle
import cv2
import albumentations as A
from utils.ffRemap import *
from utils.data_process import pad_image, match_histograms, normalize_min_max, normalize_mean_std

In [2]:
from matplotlib import pyplot as plt


In [9]:
class Dataset(data.Dataset):
    def __init__(self, image_sequences, image_keypoints,
                 im_size=(1, 256, 256),
                 train=True, shuffle=False, register_limit=5,
                 use_masks=True, use_crop=False):
        """
        
        :param: image_sequences:
        :param: image_keypoints:
        :param: im_size:
        :param: train:
        :param: register_limit:
        :param: shuffle:
        :param: use_masks:
        :param: use_crop:
        """
        self.image_sequences = []
        self.image_keypoints = []
        
        for sequence_path, keypoint_path in zip(image_sequences, image_keypoints):
            assert sequence_path.split('/')[-1].split('.')[0] == \
            keypoint_path.split('/')[-1].split('.')[0], 'Keypoint and sequence files must be ordered!'
            
            seq = io.imread(sequence_path)
            print('len(seq)', len(seq))
            print(seq[0].dtype, seq[0].max())
            if seq.shape[-1] == 3:
                seq = color.rgb2gray(seq)
                print(seq[0].dtype, seq[0].max())
            self.image_sequences.append(seq)
            
            poi = spio.loadmat(keypoint_path)
            bound = np.stack(poi['spotsB'][0].squeeze())
            inner = np.stack(poi['spotsI'][0].squeeze())

            bound = bound[:, :, :2]
            inner = inner[:, :, :2]

            line1 = np.stack(poi['lines'][:, 0])
            line2 = np.stack(poi['lines'][:, 1])
            line3 = np.stack(poi['lines'][:, 2])
            line4 = np.stack(poi['lines'][:, 3])

            len1 = len(line1[0])
            len2 = len(line2[0])
            len3 = len(line3[0])
            len4 = len(line4[0])

            lines = np.concatenate((line1, line2, line3, line4), axis=1)
            lines_lengths = np.array([len1, len2, len3, len4])
            self.image_keypoints.append({'inner': inner, 'bound': bound, 'lines': (lines, lines_lengths)})
            
            
            
        # with open(path, 'rb') as file:
        #     self.data = pickle.load(file)
        #     print(self.data.keys())
        # if use_masks:
        #     with open(path.split('.pkl')[0] + '_body.pkl', 'rb') as file:
        #         self.masks = pickle.load(file)
        #         print(self.masks.keys())

        self.length = sum([len(sequence) for sequence in self.image_sequences])
        print('Dataset length is ', self.length)
        self.seq_numeration = []

        for seq_idx, _ in enumerate(self.image_sequences):
            for i, _ in enumerate(self.image_sequences):
                self.seq_numeration.append((seq_idx, i))

        self.use_masks = use_masks
        self.use_crop = use_crop
        
        self.im_size = im_size[1:]
        self.train = train
        self.shuffle = shuffle
        
        if isinstance(register_limit, int):
            self.register_limit = [register_limit] * len(self.image_sequences)
        else:
            assert len(register_limit) == len(self.image_sequences), 'limit value must be assigned either \
            by integer or by the list of values for each mage sequenceaccordingly'
            
            self.register_limit = register_limit
        
        
        if self.shuffle:
            np.random.shuffle(self.seq_numeration)
#         TODO add aug
#         if self.train:
#             self.aug_pipe = A.Compose([A.HorizontalFlip(p=1),# A.VerticalFlip(p=0.3),
#                                   #A.ShiftScaleRotate(shift_limit=0.0225, scale_limit=0.1, rotate_limit=15, p=0.2)
#                                   ], additional_targets={'image2': 'image', 'keypoints2': 'keypoints'},
#                                       keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))
        
        self.to_tensor = ToTensor()
        self.resize = A.Compose([A.Resize(*self.im_size)],
                                keypoint_params=A.KeypointParams(format='xy', remove_invisible=False))
        
    def __len__(self):
        return self.length

    def __getitem__(self, index):
        seq_idx, it = self.seq_numeration[index]
        current_seq_len = len(self.image_sequences[seq_idx])
        
        if self.train:
            it2 = np.random.randint(max(it-self.register_limit[seq_idx], 0),
                                    min(it+self.register_limit[seq_idx] + 1, current_seq_len - 1), size=1)
        else:
            it2 = min(it + 1, current_seq_len - 1)

        image1 = self.image_sequences[seq_idx][it].squeeze()
        image2 = self.image_sequences[seq_idx][it2].squeeze()
        
        inner1 = self.image_keypoints[seq_idx]['inner'][it]
        inner2 = self.image_keypoints[seq_idx]['inner'][it2]
        bound1 = self.image_keypoints[seq_idx]['bound'][it]
        bound2 = self.image_keypoints[seq_idx]['bound'][it2]
        lines1 = self.image_keypoints[seq_idx]['lines'][0][it]
        lines_len = self.image_keypoints[seq_idx]['lines'][1]
        lines2 = self.image_keypoints[seq_idx]['lines'][0][it2]
        
        image1 = normalize_mean_std(image1)
        image2 = normalize_mean_std(image2)
        
        print(image1.shape, image2.shape)
        image1, image2 = match_histograms(image1, image2, random_switch=self.train)
        h, w = image1.shape
        
        if self.use_crop:
            x0 = np.random.randint(0, w - self.im_size[1])
            y0 = np.random.randint(0, h - self.im_size[0])
            image1 = image1[y0: y0 + self.im_size[0], x0:x0 + self.im_size[1]]
            image2 = image2[y0: y0 + self.im_size[0], x0:x0 + self.im_size[1]]
            inner1 -= np.array([x0, y0]).reshape(1, 2)
            inner2 -= np.array([x0, y0]).reshape(1, 2)
            bound1 -= np.array([x0, y0]).reshape(1, 2)
            bound2 -= np.array([x0, y0]).reshape(1, 2)
            lines1[0] -= np.array([x0, y0]).reshape(1, 2)
            lines2[0] -= np.array([x0, y0]).reshape(1, 2)
        else:
            if h < w:
                image1 = pad_image(image1, (0, w - h, 0, 0))
                image2 = pad_image(image2, (0, w - h, 0, 0))
            else:
                image1 = pad_image(image1, (0, 0, 0, h - w))
                image2 = pad_image(image2, (0, 0, 0, h - w))
            
            inner_len1 = len(inner1)
            bound_len1 = len(bound1)
            points_len = np.array([inner_len1, bound_len1, *lines_len])
            points1 = np.concatenate([inner1, bound1, lines1], axis=0)
            data1 = self.resize(image=image1, keypoints=points1)
            image1, points1 =  data1['image'], np.array(data1['keypoints'])
            
            inner_len2 = len(inner2)
            
            bound_len2 = len(bound2)
            assert inner_len2 == inner_len1
            assert bound_len2 == bound_len1
            
            points2 = np.concatenate([inner2, bound2, lines2], axis=0)
            data2 = self.resize(image=image1, keypoints=points1)
            image2, points2 =  data2['image'], np.array(data2['keypoints'])

        if self.train:
            if np.random.rand() < 0.5:
                image1 = image1[:, ::-1].copy()
                image2 = image2[:, ::-1].copy()
                points1[:, 0] = self.im_size[1] - points1[:, 0]
                points2[:, 0] = self.im_size[1] - points2[:, 0]
                

            if np.random.rand() < 0.5:
                image1 = image1[::-1].copy()
                image2 = image2[::-1].copy()
                points1[:, 1] = self.im_size[0] - points1[:, 1]
                points2[:, 1] = self.im_size[0] - points2[:, 1]
             
        image1 = self.to_tensor(image1).float()
        image2 = self.to_tensor(image2).float()
        
        # points1[:, 0] /= self.im_size[1]
        # points1[:, 1] /= self.im_size[0]
        # points2[:, 0] /= self.im_size[1]
        # points2[:, 1] /= self.im_size[0]
        
        print(image1.max(), image1.min())
        
        return image1, image2, points1, points2, points_len


In [10]:
dataset = Dataset(['./data/SeqB/SeqB1.tif'], ['./data/SeqB/SeqB1.mat'], train=True)

image1, image2, points1, points2, total_len = dataset[0]

print(points1.shape, points2.shape)

image1 = image1.numpy().squeeze()
image2 = image2.numpy().squeeze()
points2[:, 0] += dataset.im_size[0]


plt.figure(figsize=(20, 10))
plt.imshow(np.concatenate([image1, image2], 1))
plt.scatter(points1[:, 0], points1[:, 1])
plt.scatter(points2[:, 0], points2[:, 1])



len(seq) 42
uint8 173
Dataset length is  42
(356, 287) (356, 287)


AssertionError: 

In [None]:
if __name__ == '__main__':
    path = '/data/sim/Notebooks/VM/dataset/train_set.pkl'
    dataset = Dataset(path, (1, 256, 256), size_file='../src_old/sizes.txt',
                      smooth=True, train=True, shuffle=True, use_masks=True, use_mul=False)
    fixed, moving, deform = dataset[0]
#    fixed = fixed[0][None]
#    moving = moving[0][None]
    # deform *= 10.
    print(deform.min(), deform.max())
    from voxelmorph2d import SpatialTransformation

    SP = SpatialTransformation()
    print(deform.shape, moving.shape)
#    plt.imshow(np.concatenate([deform[0], deform[1]], axis=1), cmap='gray')
#    plt.waitforbuttonpress()
    cv2.imwrite('test0.jpg', np.concatenate([deform[0], deform[1]], axis=1))
    movingN = SP(moving[None, :, :, :], deform[None])
    movingN = np.uint8(movingN.numpy() * 255).squeeze()
    print(movingN.shape)
    fixed = np.uint8(fixed.numpy().squeeze() * 255)
    moving = np.uint8(moving.numpy().squeeze() * 255)

    print(fixed.shape, moving.shape, deform.shape)
    print(fixed.max())
#    plt.imshow(np.stack([fixed, moving, np.zeros(fixed.shape, dtype='int')], axis=-1), cmap='gray')
#    plt.waitforbuttonpress()
    cv2.imwrite('test1.jpg', np.stack([fixed[0], moving[0], np.zeros(fixed[0].shape, dtype='int')], axis=-1))
    cv2.imwrite('masktest1.jpg', np.stack([fixed[1], moving[1], np.zeros(fixed[1].shape, dtype='int')], axis=-1))

#    plt.figure()
#    plt.imshow(np.stack([fixed, movingN, np.zeros(fixed.shape, dtype='int')], axis=-1), cmap='gray')
#    plt.waitforbuttonpress()
#    plt.close()

    cv2.imwrite('test2.jpg', np.stack([fixed[0], movingN[0], np.zeros(fixed[0].shape, dtype='int')], axis=-1))
    cv2.imwrite('masktest2.jpg', np.stack([fixed[1], movingN[1], np.zeros(fixed[1].shape, dtype='int')], axis=-1))