In [6]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.decoders.deeplabv3.decoder import DeepLabV3PlusDecoder
import torch

In [26]:
small_encoder = smp.encoders.get_encoder("resnet18", in_channels=12, output_stride=16)
big_encoder = smp.encoders.get_encoder("resnet18", in_channels=12, output_stride=16)

decoder = DeepLabV3PlusDecoder(
    encoder_channels=[c*2 for c in small_encoder.out_channels],
    out_channels=256,
    atrous_rates=(12, 24, 36),
    output_stride=16,
)
head = smp.base.SegmentationHead(
    in_channels=decoder.out_channels,
    out_channels=2,
    kernel_size=1,
    activation=None,
    upsampling=4,
)

In [27]:
# Create random input with 12 channels and batch size 2 and image size 512x512
x = torch.rand(2, 12, 512, 512)
# Forward pass
out_0, out_1 = small_encoder(x), big_encoder(x)
out = [torch.cat(o, dim=1) for o in zip(out_0, out_1)]
out = decoder(*out)
out = head(out)

In [25]:
small_encoder.out_channels*2

(12, 64, 64, 128, 256, 512, 12, 64, 64, 128, 256, 512)

In [1]:
from lightning_modules.california_datamodule import BurnedOnlyDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = BurnedOnlyDataset(
    "../california_burned_areas/only_burned/burned_512x512.hdf5",
    "post",
    set(range(5)),
    [],
)

In [2]:
import h5py
import hdf5plugin
from collections import Counter

folds = []
with h5py.File("data/512x512.hdf5", "r") as f:
    for id, values in f.items():
        folds += [(values.attrs["fold"], len(values))]
sum(Counter(folds).values())

534

In [4]:
import numpy as np

(np.random.randint((2, 12, 512, 512)) / 10000).dtype

dtype('float64')

In [3]:
import h5py
import numpy as np

post = np.empty((512, 512, 12), dtype=np.int32)
extras = {}
with h5py.File("data/512x512.hdf5", "r") as f:
    for id, values in f.items():
        if 6 not in values.attrs["comments"]:
            values["post_fire"].read_direct(post)
            extras[id] = (post > 10000).sum()

In [4]:
# Get max of extras by value
max(extras, key=extras.get)

'd22e40c3-da6f-4202-bb50-02d303839d87_0'

In [None]:
with h5py.File("data/512x512.hdf5", "r") as f:
    f["d22e40c3-da6f-4202-bb50-02d303839d87_0"]["post_fire"].read_direct(post)

In [None]:
# Plot post as histogram
import matplotlib.pyplot as plt

plt.hist(post.flatten(), bins=100)

In [1]:
import h5py

In [19]:
with h5py.File("data/europe.hdf5") as f:
   l = f["coral/0/mask"][...]
l.min()

0.0

In [13]:
with h5py.File("data/512x512.hdf5") as f:
    post_c = f["012b8863-976c-44e1-a491-9adf19c1cbba_0/post_fire"][...]

In [15]:
post_c.shape, post.shape

((512, 512, 12), (512, 512, 12))