/
ssl_transforms.py
73 lines (53 loc) · 2.62 KB
/
ssl_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import torch.nn.functional as F
from PIL import Image
class RandomTranslateWithReflect:
'''
Translate image randomly
Translate vertically and horizontally by n pixels where
n is integer drawn uniformly independently for each axis
from [-max_translation, max_translation].
Fill the uncovered blank area with reflect padding.
'''
def __init__(self, max_translation):
self.max_translation = max_translation
def __call__(self, old_image):
xtranslation, ytranslation = np.random.randint(-self.max_translation,
self.max_translation + 1,
size=2)
xpad, ypad = abs(xtranslation), abs(ytranslation)
xsize, ysize = old_image.size
flipped_lr = old_image.transpose(Image.FLIP_LEFT_RIGHT)
flipped_tb = old_image.transpose(Image.FLIP_TOP_BOTTOM)
flipped_both = old_image.transpose(Image.ROTATE_180)
new_image = Image.new("RGB", (xsize + 2 * xpad, ysize + 2 * ypad))
new_image.paste(old_image, (xpad, ypad))
new_image.paste(flipped_lr, (xpad + xsize - 1, ypad))
new_image.paste(flipped_lr, (xpad - xsize + 1, ypad))
new_image.paste(flipped_tb, (xpad, ypad + ysize - 1))
new_image.paste(flipped_tb, (xpad, ypad - ysize + 1))
new_image.paste(flipped_both, (xpad - xsize + 1, ypad - ysize + 1))
new_image.paste(flipped_both, (xpad + xsize - 1, ypad - ysize + 1))
new_image.paste(flipped_both, (xpad - xsize + 1, ypad + ysize - 1))
new_image.paste(flipped_both, (xpad + xsize - 1, ypad + ysize - 1))
new_image = new_image.crop((xpad - xtranslation,
ypad - ytranslation,
xpad + xsize - xtranslation,
ypad + ysize - ytranslation))
return new_image
class Patchify(object):
def __init__(self, patch_size, overlap_size):
self.patch_size = patch_size
self.overlap_size = self.patch_size - overlap_size
def __call__(self, x):
x = x.unsqueeze(0)
b, c, h, w = x.size()
# patch up the images
# (b, c, h, w) -> (b, c*patch_size, L)
x = F.unfold(x, kernel_size=self.patch_size, stride=self.overlap_size)
# (b, c*patch_size, L) -> (b, nb_patches, width, height)
x = x.transpose(2, 1).contiguous().view(b, -1, self.patch_size, self.patch_size)
# reshape to have (b x patches, c, h, w)
x = x.view(-1, c, self.patch_size, self.patch_size)
x = x.squeeze(0)
return x