-
Notifications
You must be signed in to change notification settings - Fork 16
/
data.py
91 lines (65 loc) · 2.29 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Sat May 6 19:42:16 2017
@author: ldy
"""
from os.path import exists, join, basename
from os import makedirs, remove
from six.moves import urllib
import tarfile
from torchvision.transforms import Compose, CenterCrop, ToTensor, Scale, RandomCrop, RandomHorizontalFlip
from dataset import DatasetFromFolder
crop_size =128
def download_bsd300(dest="dataset"):
output_image_dir = join(dest, "BSDS300/images")
if not exists(output_image_dir):
makedirs(dest)
url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
print("downloading url ", url)
data = urllib.request.urlopen(url)
file_path = join(dest, basename(url))
with open(file_path, 'wb') as f:
f.write(data.read())
print("Extracting data")
with tarfile.open(file_path) as tar:
for item in tar:
tar.extract(item, dest)
remove(file_path)
return output_image_dir
def LR_transform(crop_size):
return Compose([
Scale(crop_size//8),
ToTensor(),
])
def HR_2_transform(crop_size):
return Compose([
Scale(crop_size//4),
ToTensor(),
])
def HR_4_transform(crop_size):
return Compose([
Scale(crop_size//2),
ToTensor(),
])
def HR_8_transform(crop_size):
return Compose([
RandomCrop((crop_size, crop_size)),
RandomHorizontalFlip(),
])
def get_training_set():
root_dir = download_bsd300()
train_dir = join(root_dir, "train")
return DatasetFromFolder(train_dir,
LR_transform=LR_transform(crop_size),
HR_2_transform=HR_2_transform(crop_size),
HR_4_transform=HR_4_transform(crop_size),
HR_8_transform=HR_8_transform(crop_size))
def get_test_set():
root_dir = download_bsd300()
test_dir = join(root_dir, "test")
return DatasetFromFolder(test_dir,
LR_transform=LR_transform(crop_size),
HR_2_transform=HR_2_transform(crop_size),
HR_4_transform=HR_4_transform(crop_size),
HR_8_transform=HR_8_transform(crop_size))