In [30]:
import glob
import itertools
import os
import shutil
from collections.abc import Generator

import numpy as np
import rasterio
import rasterio.windows

In [66]:
root_path = r"C:\Users\Dimit\Downloads\AIRS"

image_dirname = "image"
image_dirpath = os.path.join(root_path, image_dirname)

mask_dirname = "label"
mask_dirpath = os.path.join(root_path, mask_dirname)

In [67]:
os.makedirs(image_dirpath, exist_ok=True)
os.makedirs(mask_dirpath, exist_ok=True)

In [18]:
for split in ["train", "val", "test"]:
    image_paths = glob.glob(os.path.join(root_path, split, "image", "*.tif"))
    for path in image_paths:
        shutil.copy2(src=path, dst=image_dirpath)

    mask_paths = glob.glob(os.path.join(root_path, split, "label", "*[!_vis].tif"))
    for path in mask_paths:
        shutil.copy2(src=path, dst=mask_dirpath)

In [31]:
def get_chips(src: rasterio.io.DatasetReader, width: int = 512, height: int = 512) -> Generator[
    tuple[rasterio.windows.Window, "affine.Affine"], None, None]:
    num_cols, num_rows = src.width, src.height

    tile = rasterio.windows.Window(col_off=0, row_off=0, width=num_cols, height=num_rows)

    offsets = itertools.product(range(0, num_cols, width), range(0, num_rows, height))
    for col_off, row_off in offsets:
        chip = rasterio.windows.Window(col_off=col_off, row_off=row_off, width=width, height=height).intersection(tile)
        transform = rasterio.windows.transform(chip, src.transform)
        yield chip, transform

In [68]:
image_chip_dirpath = os.path.join(image_dirpath, "chips")
mask_chip_dirpath = os.path.join(mask_dirpath, "chips")

In [69]:
os.makedirs(image_chip_dirpath, exist_ok=True)
os.makedirs(mask_chip_dirpath, exist_ok=True)

In [70]:
image_paths = glob.glob(os.path.join(image_dirpath, "*.tif"))
image_paths.sort()

mask_paths = glob.glob(os.path.join(mask_dirpath, "*.tif"))
mask_paths.sort()

In [71]:
bad_tiles = set()
for i, src_path in enumerate(mask_paths):
    print(f"{i}/{len(mask_paths)}")
    src: rasterio.io.DatasetReader
    with rasterio.open(src_path) as src:
        src_data = src.read(1)
    if not np.any(src_data):
        bad_tiles.add(os.path.basename(src_path))

0/1046
1/1046
2/1046
3/1046
4/1046
5/1046
6/1046
7/1046
8/1046
9/1046
10/1046
11/1046
12/1046
13/1046
14/1046
15/1046
16/1046
17/1046
18/1046
19/1046
20/1046
21/1046
22/1046
23/1046
24/1046
25/1046
26/1046
27/1046
28/1046
29/1046
30/1046
31/1046
32/1046
33/1046
34/1046
35/1046
36/1046
37/1046
38/1046
39/1046
40/1046
41/1046
42/1046
43/1046
44/1046
45/1046
46/1046
47/1046
48/1046
49/1046
50/1046
51/1046
52/1046
53/1046
54/1046
55/1046
56/1046
57/1046
58/1046
59/1046
60/1046
61/1046
62/1046
63/1046
64/1046
65/1046
66/1046
67/1046
68/1046
69/1046
70/1046
71/1046
72/1046
73/1046
74/1046
75/1046
76/1046
77/1046
78/1046
79/1046
80/1046
81/1046
82/1046
83/1046
84/1046
85/1046
86/1046
87/1046
88/1046
89/1046
90/1046
91/1046
92/1046
93/1046
94/1046
95/1046
96/1046
97/1046
98/1046
99/1046
100/1046
101/1046
102/1046
103/1046
104/1046
105/1046
106/1046
107/1046
108/1046
109/1046
110/1046
111/1046
112/1046
113/1046
114/1046
115/1046
116/1046
117/1046
118/1046
119/1046
120/1046
121/1046
122/1046
123

In [72]:
bad_chips = set()
for i, src_path in enumerate(mask_paths):
    print(f"{i}/{len(mask_paths)}")
    if os.path.basename(src_path) in bad_tiles:
        continue
    src: rasterio.io.DatasetReader
    with rasterio.open(src_path) as src:
        meta: rasterio.profiles.Profile = src.meta.copy()
        for chip, transform in get_chips(src):
            if not (chip.width == chip.height==512):
                continue
            src_data = src.read(window=chip)
            dst_name = os.path.basename(src_path).removesuffix(".tif") + f"_{chip.col_off // chip.width}-{chip.row_off // chip.height}" + ".tif"
            if np.count_nonzero(src_data) < 0.2 * chip.width * chip.height:
                bad_chips.add(dst_name)
                continue
            meta.update(width=chip.width, height=chip.height, transform=transform)
            dst_path = os.path.join(mask_chip_dirpath, dst_name)
            dst: rasterio.io.DatasetWriter
            with rasterio.open(dst_path, mode="w", **meta) as dst:
                dst.write(src_data)

0/1046
1/1046
2/1046
3/1046
4/1046
5/1046
6/1046
7/1046
8/1046
9/1046
10/1046
11/1046
12/1046
13/1046
14/1046
15/1046
16/1046
17/1046
18/1046
19/1046
20/1046
21/1046
22/1046
23/1046
24/1046
25/1046
26/1046
27/1046
28/1046
29/1046
30/1046
31/1046
32/1046
33/1046
34/1046
35/1046
36/1046
37/1046
38/1046
39/1046
40/1046
41/1046
42/1046
43/1046
44/1046
45/1046
46/1046
47/1046
48/1046
49/1046
50/1046
51/1046
52/1046
53/1046
54/1046
55/1046
56/1046
57/1046
58/1046
59/1046
60/1046
61/1046
62/1046
63/1046
64/1046
65/1046
66/1046
67/1046
68/1046
69/1046
70/1046
71/1046
72/1046
73/1046
74/1046
75/1046
76/1046
77/1046
78/1046
79/1046
80/1046
81/1046
82/1046
83/1046
84/1046
85/1046
86/1046
87/1046
88/1046
89/1046
90/1046
91/1046
92/1046
93/1046
94/1046
95/1046
96/1046
97/1046
98/1046
99/1046
100/1046
101/1046
102/1046
103/1046
104/1046
105/1046
106/1046
107/1046
108/1046
109/1046
110/1046
111/1046
112/1046
113/1046
114/1046
115/1046
116/1046
117/1046
118/1046
119/1046
120/1046
121/1046
122/1046
123

In [74]:
for i, src_path in enumerate(image_paths):
    print(f"{i}/{len(image_paths)}")
    if os.path.basename(src_path) in bad_tiles:
        continue
    src: rasterio.io.DatasetReader
    with rasterio.open(src_path) as src:
        meta: rasterio.profiles.Profile = src.meta.copy()
        for chip, transform in get_chips(src):
            if not (chip.width == chip.height==512):
                continue
            src_data = src.read(window=chip)
            dst_name = os.path.basename(src_path).removesuffix(".tif") + f"_{chip.col_off // chip.width}-{chip.row_off // chip.height}" + ".tif"
            if dst_name in bad_chips:
                continue
            meta.update(width=chip.width, height=chip.height, transform=transform)
            dst_path = os.path.join(image_chip_dirpath, dst_name)
            dst: rasterio.io.DatasetWriter
            with rasterio.open(dst_path, mode="w", **meta) as dst:
                dst.write(src_data)

0/1046
1/1046
2/1046
3/1046
4/1046
5/1046
6/1046
7/1046
8/1046
9/1046
10/1046
11/1046
12/1046
13/1046
14/1046
15/1046
16/1046
17/1046
18/1046
19/1046
20/1046
21/1046
22/1046
23/1046
24/1046
25/1046
26/1046
27/1046
28/1046
29/1046
30/1046
31/1046
32/1046
33/1046
34/1046
35/1046
36/1046
37/1046
38/1046
39/1046
40/1046
41/1046
42/1046
43/1046
44/1046
45/1046
46/1046
47/1046
48/1046
49/1046
50/1046
51/1046
52/1046
53/1046
54/1046
55/1046
56/1046
57/1046
58/1046
59/1046
60/1046
61/1046
62/1046
63/1046
64/1046
65/1046
66/1046
67/1046
68/1046
69/1046
70/1046
71/1046
72/1046
73/1046
74/1046
75/1046
76/1046
77/1046
78/1046
79/1046
80/1046
81/1046
82/1046
83/1046
84/1046
85/1046
86/1046
87/1046
88/1046
89/1046
90/1046
91/1046
92/1046
93/1046
94/1046
95/1046
96/1046
97/1046
98/1046
99/1046
100/1046
101/1046
102/1046
103/1046
104/1046
105/1046
106/1046
107/1046
108/1046
109/1046
110/1046
111/1046
112/1046
113/1046
114/1046
115/1046
116/1046
117/1046
118/1046
119/1046
120/1046
121/1046
122/1046
123

In [75]:
[os.path.basename(path) for path in glob.glob(os.path.join(image_chip_dirpath, ".*tif"))] == [
    os.path.basename(path) for path in glob.glob(os.path.join(mask_chip_dirpath, ".*tif"))]

True