/
pytorch_dataset_example_with_masks.py
55 lines (41 loc) · 1.57 KB
/
pytorch_dataset_example_with_masks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import torch.utils.data
from allegroai import DataView, FrameGroup, Task
from PIL import Image
from torch.utils.data import DataLoader
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self, dv):
# automatically adjust dataset to balance all queries
self.frames = dv.to_list()
def __getitem__(self, idx):
frame_group = self.frames[idx] # type: FrameGroup
img_path = frame_group["000002"].get_local_source()
img = Image.open(img_path).convert("RGB").resize((256, 256))
mask_path = frame_group["000002"].get_local_mask_source()
mask = Image.open(mask_path).resize((256, 256))
return np.array(img), np.array(mask),
def __len__(self):
return len(self.frames)
task = Task.init(project_name='examples', task_name='PyTorch Sample Dataset with Masks')
# Create DataView with example query
dataview = DataView()
dataview.add_query(dataset_name='sample-dataset-masks', version_name='Current')
# dataview.add_query(dataset_name='sample-dataset', version_name='Current', roi_query=["aeroplane"])
# if we want all files to be downloaded in the background, we can call prefetch
# dataview.prefetch_files()
# create PyTorch Dataset
dataset = ExampleDataset(dataview)
# do your thing here :)
print('Fake PyTorch stuff below:')
print('Dataset length', len(dataset))
torch.manual_seed(0)
data_loader = DataLoader(
dataset,
batch_size=2,
num_workers=1,
pin_memory=True,
prefetch_factor=2,
)
for i, data in enumerate(data_loader):
print('{}] {}'.format(i, data))
print('done')