In [4]:
import os
import sys
parent_dir = "/Users/aradinka/Documents/GitHub/koltiva/SSLTransformerRS"
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import numpy as np
import rasterio
from rasterio.windows import Window
from enum import Enum
import albumentations as A
from utils import AlbumentationsToTorchTransform
from albumentations.pytorch import ToTensorV2

import json
import torch
from utils import dotdictify
from Transformer_SSL.models import build_model

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
class S1Bands(Enum):
    VV = 1
    VH = 2
    ALL = [VV, VH]
    NONE = None

class Sensor(Enum):
    s1 = "s1"
    s2 = "s2"
    lc = "lc"
    dfc = "dfc"

class S2Bands(Enum):
    B01 = aerosol = 1
    B02 = blue = 2
    B03 = green = 3
    B04 = red = 4
    B05 = re1 = 5
    B06 = re2 = 6
    B07 = re3 = 7
    B08 = nir1 = 8
    B08A = nir2 = 9
    B09 = vapor = 10
    B10 = cirrus = 11
    B11 = swir1 = 12
    B12 = swir2 = 13
    ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12]
    RGB = [B04, B03, B02]
    NONE = None

class LCBands(Enum):
    LC = lc = 0
    DFC = dfc = 1
    ALL = [DFC]
    NONE = None

In [6]:
def get_patch(patch_path, bands, window=None):
    """
        Returns raster data and image bounds for the defined bands of a specific patch
        This method only loads a sinlge patch from a single sensor as defined by the bands specified
    """
    # season = Seasons(season).value
    sensor = None

    if not bands:
        return None, None

    if isinstance(bands, (list, tuple)):
        b = bands[0]
    else:
        b = bands
    
    if isinstance(b, S1Bands):
        sensor = Sensor.s1.value
        bandEnum = S1Bands
    elif isinstance(b, S2Bands):
        sensor = Sensor.s2.value
        bandEnum = S2Bands
    elif isinstance(b, LCBands):
        if LCBands(bands) == LCBands.LC:
            sensor = Sensor.lc.value 
        else:
            sensor = Sensor.dfc.value 

        bands = LCBands(1)
        bandEnum = LCBands
    else:
        raise Exception("Invalid bands specified")

    if isinstance(bands, (list, tuple)):
        bands = [b.value for b in bands]
    else:
        bands = bandEnum(bands).value

    with rasterio.open(patch_path) as patch:
        if window is not None:
            data = patch.read(bands, window=window) 
        else:
            data = patch.read(bands)
        bounds = patch.bounds

    # Remap IGBP to DFC bands
    # if sensor  == "lc":
    #     data = IGBP2DFC[data]

    if len(data.shape) == 2:
        data = np.expand_dims(data, axis=0)

    return data, bounds

In [7]:
class DoubleSwinTransformerClassifier(torch.nn.Module):
    def __init__(self, encoder1, encoder2, out_dim, device, freeze_layers=True):
        super(DoubleSwinTransformerClassifier, self).__init__()
        
        # If you're only using one of the two backbones, just comment the one you don't need
        self.backbone1 = encoder1
        self.backbone2 = encoder2

        self.device = device

        # add final linear layer
        self.fc = torch.nn.Linear(
            self.backbone2.num_features + self.backbone1.num_features,
            out_dim,
            bias=True,
        )

        # freeze all layers but the last fc
        if freeze_layers:
            for name, param in self.named_parameters():
                if name not in ["fc.weight", "fc.bias"]:
                    param.requires_grad = False

    def forward(self, x):
        x1, _, _ = self.backbone1.forward_features(x["s1"].to(self.device))
        x2, _, _ = self.backbone2.forward_features(x["s2"].to(self.device))

        z = torch.cat([x1, x2], dim=1)
        z = self.fc(z)
        
        # If you're only using one of the two backbones, you may comment the lines above and use the following:
        # x1, _, _ = self.backbone1.forward_features(x["s1"].to(self.device))
        # z = self.fc(x1)

        return z

In [14]:
DFC_map_clean = {
    0: "Forest",
    1: "Shrubland",
    2: "Grassland",
    3: "Wetlands",
    4: "Croplands",
    5: "Urban/Built-up",
    6: "Barren",
    7: "Water",
    255: "Invalid",
}

data_config = {
    'train_dir': '../data/data_disini', # path to the training directory,  
    'val_dir': '../data/data_disini', # path to the validation directory,
    'train_mode': 'validation', # can be one of the following: 'test', 'validation'
    'val_mode': 'test', # can be one of the following: 'test', 'validation'
    'num_classes': 8, # number of classes in the dataset.
    'clip_sample_values': True, # clip (limit) values
    'train_used_data_fraction': 1, # fraction of data to use, should be in the range [0, 1]
    'val_used_data_fraction': 1,
    'image_px_size': 224, # image size (224x224)
    'cover_all_parts_train': True, # if True, if image_px_size is not 224 during training, we use a random crop of the image
    'cover_all_parts_validation': True, # if True, if image_px_size is not 224 during validation, we use a non-overlapping sliding window to cover the entire image
    'seed': 42,
}

train_config = {
    's1_input_channels': 2,
    's2_input_channels': 13,
    'finetuning': True, # If false, backbone layers is frozen and only the head is trained
    'classifier_lr': 3e-6,
    'learning_rate': 0.00001,
    'adam_betas': (0.9, 0.999), 
    'weight_decay': 0.001,
    'dataloader_workers': 4,
    'batch_size': 16,
    'epochs': 5, 
    'target': 'dfc_label'
}

In [91]:
image_px_size = 224
if image_px_size != 256:
    # crop the data to image_px_size times image_px_size (e.g. 128x128)
    x_offset, y_offset = np.random.randint(0, 256 - image_px_size, 2)
    window = Window(x_offset, y_offset, image_px_size, image_px_size)
else:
    window = None

s1, bounds1 = get_patch(patch_path="../data/data_disini/ROIs0000_test/s1_0/ROIs0000_test_s1_0_p4548.tif", bands=S1Bands.ALL, window=window)
# s1, bounds1 = get_patch(patch_path="../data/inference/s1_combined.tif", bands=S1Bands.ALL, window=window)
s1, bounds1 = get_patch(patch_path="../data/inference/s1_vv.tif", bands=S1Bands.VV, window=window)
s2, bounds2 = get_patch(patch_path="../data/data_disini/ROIs0000_test/s2_0/ROIs0000_test_s2_0_p4548.tif", bands=S2Bands.ALL, window=window)

print(s1.min(), s1.max())

41 1076


In [None]:
clip_sample_values = True
if clip_sample_values:
    s1 = np.clip(s1, a_min=-25, a_max=0)
    s1 = (
        s1 + 25
    )  # go from [-25,0] to [0,25] interval to make normalization easier
    s2 = np.clip(s2, a_min=0, a_max=1e4)

base_aug = A.Compose([ToTensorV2()])
base_transform = AlbumentationsToTorchTransform(base_aug)

s1 = base_transform(np.moveaxis(s1, 0, -1))
s2 = base_transform(np.moveaxis(s2, 0, -1))

s1_maxs = []
for ch_idx in range(s1.shape[0]):
    s1_maxs.append(
        torch.ones((s1.shape[-2], s1.shape[-1])) * s1[ch_idx].max().item()
        + 1e-5
    )
s1_maxs = torch.stack(s1_maxs)
s2_maxs = []
for b_idx in range(s2.shape[0]):
    s2_maxs.append(
        torch.ones((s2.shape[-2], s2.shape[-1])) * s2[b_idx].max().item() + 1e-5
    )
s2_maxs = torch.stack(s2_maxs)
normalize = True
if normalize:
    s1 = s1 / s1_maxs
    s2 = s2 / s2_maxs

s1

In [33]:
s1.max()

tensor(1.0000, dtype=torch.float64)

In [16]:
from dfc_dataset import DFCDataset

val_dataset = DFCDataset(
    data_config['val_dir'],
    mode=data_config['val_mode'],
    clip_sample_values=data_config['clip_sample_values'],
    used_data_fraction=data_config['val_used_data_fraction'],
    image_px_size=data_config['image_px_size'],
    cover_all_parts=data_config['cover_all_parts_validation'],
    seed=data_config['seed'],
)

In [19]:
accelerator = 'cpu'
checkpoint = torch.load("../checkpoints/swin_t.pth", map_location=torch.device(accelerator))
weights = checkpoint["state_dict"]

s1_weights = {k[len("backbone1."):]: v for k, v in weights.items() if "backbone1" in k}
s2_weights = {k[len("backbone2."):]: v for k, v in weights.items() if "backbone2" in k}

input_channels = train_config['s1_input_channels'] + train_config['s2_input_channels']

with open("../configs/backbone_config.json", "r") as fp:
    swin_conf = dotdictify(json.load(fp))

s1_backbone = build_model(swin_conf.model_config)
swin_conf.model_config.MODEL.SWIN.IN_CHANS = 13
s2_backbone = build_model(swin_conf.model_config)
s1_backbone.load_state_dict(s1_weights)
s2_backbone.load_state_dict(s2_weights)


device = torch.device(accelerator)
model = DoubleSwinTransformerClassifier(s1_backbone, s2_backbone, out_dim=data_config['num_classes'], device=device)
model = model.to(device)
model.load_state_dict(torch.load("../checkpoints/classifier-epoch-4.pth"))


img = {"s1": torch.unsqueeze(s1, 0), "s2": torch.unsqueeze(s2, 0)} # adding an extra dimension for batch information
model.eval()
output = model(img)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


FileNotFoundError: [Errno 2] No such file or directory: '../checkpoints/classifier-epoch-4.pth'

In [11]:
img_index = 13
img_dataset = {"s1": torch.unsqueeze(val_dataset[img_index]['s1'], 0), "s2": torch.unsqueeze(val_dataset[img_index]['s2'], 0)} # adding an extra dimension for batch information
img_dataset['s2']

data/data_disini/ROIs0000_test/s1_0/ROIs0000_test_s1_0_p4548.tif
data/data_disini/ROIs0000_test/s2_0/ROIs0000_test_s2_0_p4548.tif
data/data_disini/ROIs0000_test/lc_0/ROIs0000_test_lc_0_p4548.tif
data/data_disini/ROIs0000_test/dfc_0/ROIs0000_test_dfc_0_p4548.tif
data/data_disini/ROIs0000_test/s1_0/ROIs0000_test_s1_0_p4548.tif
data/data_disini/ROIs0000_test/s2_0/ROIs0000_test_s2_0_p4548.tif
data/data_disini/ROIs0000_test/lc_0/ROIs0000_test_lc_0_p4548.tif
data/data_disini/ROIs0000_test/dfc_0/ROIs0000_test_dfc_0_p4548.tif


tensor([[[[0.7967, 0.7967, 0.7967,  ..., 0.8181, 0.8181, 0.8181],
          [0.7967, 0.7967, 0.7967,  ..., 0.8181, 0.8181, 0.8181],
          [0.7967, 0.7967, 0.7967,  ..., 0.8181, 0.8181, 0.8181],
          ...,
          [0.8388, 0.8388, 0.8388,  ..., 0.8213, 0.8213, 0.8213],
          [0.8388, 0.8388, 0.8388,  ..., 0.8173, 0.8173, 0.8173],
          [0.8388, 0.8388, 0.8388,  ..., 0.8173, 0.8173, 0.8173]],

         [[0.4442, 0.4442, 0.4542,  ..., 0.4802, 0.4833, 0.4796],
          [0.4641, 0.4641, 0.4442,  ..., 0.4665, 0.4907, 0.4777],
          [0.4492, 0.4492, 0.4591,  ..., 0.4734, 0.4734, 0.4814],
          ...,
          [0.4994, 0.5056, 0.4994,  ..., 0.4703, 0.4572, 0.4597],
          [0.4876, 0.4988, 0.5025,  ..., 0.4783, 0.4647, 0.4603],
          [0.4808, 0.4845, 0.4839,  ..., 0.4684, 0.4690, 0.4641]],

         [[0.2981, 0.2981, 0.3189,  ..., 0.3255, 0.3489, 0.3408],
          [0.3052, 0.3052, 0.3144,  ..., 0.3205, 0.3489, 0.3316],
          [0.3073, 0.3073, 0.3195,  ..., 0

In [14]:
img_dataset['s2'] == s2

tensor([[[[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],

         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]],

         [[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False,

In [None]:
img_dataset['s2'].shape

In [None]:
img['s2'].shape

# Inference data loader

In [1]:
import numpy as np
import rasterio
from rasterio.windows import Window
from enum import Enum
import albumentations as A
from utils import AlbumentationsToTorchTransform
from albumentations.pytorch import ToTensorV2

import json
import torch
from utils import dotdictify
from Transformer_SSL.models import build_model

  warn(f"Failed to load image Python extension: {e}")
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from dfc_dataset import DFCDataset

data_config = {
    'train_dir': 'data/data_disini', # path to the training directory,  
    'val_dir': 'data/data_disini', # path to the validation directory,
    'train_mode': 'validation', # can be one of the following: 'test', 'validation'
    'val_mode': 'test', # can be one of the following: 'test', 'validation'
    'num_classes': 8, # number of classes in the dataset.
    'clip_sample_values': True, # clip (limit) values
    'train_used_data_fraction': 1, # fraction of data to use, should be in the range [0, 1]
    'val_used_data_fraction': 1,
    'image_px_size': 224, # image size (224x224)
    'cover_all_parts_train': True, # if True, if image_px_size is not 224 during training, we use a random crop of the image
    'cover_all_parts_validation': True, # if True, if image_px_size is not 224 during validation, we use a non-overlapping sliding window to cover the entire image
    'seed': 42,
}

val_dataset = DFCDataset(
    data_config['val_dir'],
    mode=data_config['val_mode'],
    clip_sample_values=data_config['clip_sample_values'],
    used_data_fraction=data_config['val_used_data_fraction'],
    image_px_size=data_config['image_px_size'],
    cover_all_parts=data_config['cover_all_parts_validation'],
    seed=data_config['seed'],
)