In [10]:
import numpy as np
from enum import Enum
import rasterio
from rasterio.windows import Window

class S1Bands(Enum):
    VV = 1
    VH = 2
    ALL = [VV, VH]
    NONE = None

class Sensor(Enum):
    s1 = "s1"
    s2 = "s2"
    lc = "lc"
    dfc = "dfc"

class S2Bands(Enum):
    B01 = aerosol = 1
    B02 = blue = 2
    B03 = green = 3
    B04 = red = 4
    B05 = re1 = 5
    B06 = re2 = 6
    B07 = re3 = 7
    B08 = nir1 = 8
    B08A = nir2 = 9
    B09 = vapor = 10
    B10 = cirrus = 11
    B11 = swir1 = 12
    B12 = swir2 = 13
    ALL = [B01, B02, B03, B04, B05, B06, B07, B08, B08A, B09, B10, B11, B12]
    RGB = [B04, B03, B02]
    NONE = None

class LCBands(Enum):
    LC = lc = 0
    DFC = dfc = 1
    ALL = [DFC]
    NONE = None

def get_patch(patch_path, bands, window=None):
    """
        Returns raster data and image bounds for the defined bands of a specific patch
        This method only loads a sinlge patch from a single sensor as defined by the bands specified
    """
    # season = Seasons(season).value
    sensor = None

    if not bands:
        return None, None

    if isinstance(bands, (list, tuple)):
        b = bands[0]
    else:
        b = bands
    
    if isinstance(b, S1Bands):
        sensor = Sensor.s1.value
        bandEnum = S1Bands
    elif isinstance(b, S2Bands):
        sensor = Sensor.s2.value
        bandEnum = S2Bands
    elif isinstance(b, LCBands):
        if LCBands(bands) == LCBands.LC:
            sensor = Sensor.lc.value 
        else:
            sensor = Sensor.dfc.value 

        bands = LCBands(1)
        bandEnum = LCBands
    else:
        raise Exception("Invalid bands specified")

    if isinstance(bands, (list, tuple)):
        bands = [b.value for b in bands]
    else:
        bands = bandEnum(bands).value

    print(bands)
    with rasterio.open(patch_path) as patch:
        if window is not None:
            data = patch.read(bands, window=window) 
        else:
            data = patch.read(bands)
        bounds = patch.bounds

    # Remap IGBP to DFC bands
    # if sensor  == "lc":
    #     data = IGBP2DFC[data]

    if len(data.shape) == 2:
        data = np.expand_dims(data, axis=0)

    return data, bounds

image_px_size = 224
if image_px_size != 256:
    # crop the data to image_px_size times image_px_size (e.g. 128x128)
    x_offset, y_offset = np.random.randint(0, 256 - image_px_size, 2)
    window = Window(x_offset, y_offset, image_px_size, image_px_size)
else:
    window = None

s1, bounds1 = get_patch(patch_path="../data/inference/s1_combined.tif", bands=S1Bands.ALL, window=window)
s2, bounds1 = get_patch(patch_path="../data/inference/S2 Composite.tif", bands=S2Bands.ALL, window=window)

[1, 2]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]


In [11]:
patch_path = "../data/inference/s1_combined.tif"
s1_bands = [1, 2]
s2_bands = [i for i in range(1, 14)]

image_px_size = 224
if image_px_size != 256:
    # crop the data to image_px_size times image_px_size (e.g. 128x128)
    x_offset, y_offset = np.random.randint(0, 256 - image_px_size, 2)
    window = Window(x_offset, y_offset, image_px_size, image_px_size)
else:
    window = None

window = None

bands = s1_bands
with rasterio.open(patch_path) as patch:
    if window is not None:
        data = patch.read(bands, window=window) 
    else:
        data = patch.read(bands)
    bounds = patch.bounds
data.shape

(2, 9632, 9583)

In [13]:
s1_bands = [1, 2]
s2_bands = [i for i in range(1, 14)]
patch_path = "../data/inference/S2 Composite.tif"
bands = s2_bands

def load_tif(patch_path, bands):
    with rasterio.open(patch_path) as patch:
        data = patch.read(bands)

(13, 10980, 10980)

In [None]:
import matplotlib.pyplot as plt


fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(s1[0])
axs[0].set_title("Sentinel-1 VV polarization")
axs[1].imshow(s1[1])
axs[1].set_title("Sentinel-1 VH polarization")

In [None]:
s1.shape

In [None]:
s1.shape 

In [None]:
import rasterio
from rasterio.merge import merge
from rasterio.plot import show

# Paths to your VV and VH images
vv_image_path = "../data/inference/s1_vv.tif"
vh_image_path = "../data/inference/s1_vh.tif"

# Read the VV and VH images
with rasterio.open(vv_image_path) as vv_src:
    vv_data = vv_src.read(1)

with rasterio.open(vh_image_path) as vh_src:
    vh_data = vh_src.read(1)

# Check that both images have the same shape and CRS
assert vv_src.shape == vh_src.shape
assert vv_src.crs == vh_src.crs

# Create a new multi-band raster file
output_path = "../data/inference/s1_combined.tif"
with rasterio.open(
    output_path, 
    'w', 
    driver='GTiff',
    height=vv_src.height, 
    width=vv_src.width, 
    count=2,  # Number of bands
    dtype=vv_data.dtype,
    crs=vv_src.crs,
    transform=vv_src.transform
) as dst:
    dst.write(vv_data, 1)  # Write VV data to band 1
    dst.write(vh_data, 2)  # Write VH data to band 2


In [None]:
import torch

In [None]:
class DoubleSwinTransformerClassifier(torch.nn.Module):
    def __init__(self, encoder1, encoder2, out_dim, device, freeze_layers=True):
        super(DoubleSwinTransformerClassifier, self).__init__()
        
        # If you're only using one of the two backbones, just comment the one you don't need
        self.backbone1 = encoder1
        self.backbone2 = encoder2

        self.device = device
        
        # add final linear layer
        self.fc = torch.nn.Linear(
            self.backbone2.num_features + self.backbone1.num_features,
            out_dim,
            bias=True,
        )

        # freeze all layers but the last fc
        if freeze_layers:
            for name, param in self.named_parameters():
                if name not in ["fc.weight", "fc.bias"]:
                    param.requires_grad = False

    def forward(self, x):
        x1, _, _ = self.backbone1.forward_features(x["s1"].to(self.device))
        x2, _, _ = self.backbone2.forward_features(x["s2"].to(self.device))

        z = torch.cat([x1, x2], dim=1)
        z = self.fc(z)
        
        # If you're only using one of the two backbones, you may comment the lines above and use the following:
        # x1, _, _ = self.backbone1.forward_features(x["s1"].to(self.device))
        # z = self.fc(x1)

        return z

In [None]:
from Transformer_SSL.models import build_model
from utils import dotdictify
import json


with open("configs/backbone_config.json", "r") as fp:
    swin_conf = dotdictify(json.load(fp))
s1_backbone = build_model(swin_conf.model_config)
s2_backbone = build_model(swin_conf.model_config)

# Data configurations:
data_config = {
    'train_dir': '/data/grss-dfc-20', # path to the training directory,  
    'val_dir': '/data/grss-dfc-20', # path to the validation directory,
    'train_mode': 'validation', # can be one of the following: 'test', 'validation'
    'val_mode': 'test', # can be one of the following: 'test', 'validation'
    'num_classes': 8, # number of classes in the dataset.
    'clip_sample_values': True, # clip (limit) values
    'train_used_data_fraction': 1, # fraction of data to use, should be in the range [0, 1]
    'val_used_data_fraction': 1,
    'image_px_size': 224, # image size (224x224)
    'cover_all_parts_train': True, # if True, if image_px_size is not 224 during training, we use a random crop of the image
    'cover_all_parts_validation': True, # if True, if image_px_size is not 224 during validation, we use a non-overlapping sliding window to cover the entire image
    'seed': 42,
}
device = torch.device('cpu')

In [None]:
# path to the checkpoint
checkpoint = torch.load(
    "checkpoints/swin_t.pth",
    map_location=torch.device('cpu')
) 
weights = checkpoint["state_dict"]

# Sentinel-1 stream weights
s1_weights = {
    k[len("backbone1."):]: v for k, v in weights.items() if "backbone1" in k
}

# Sentinel-2 stream weights
s2_weights = {
    k[len("backbone2."):]: v for k, v in weights.items() if "backbone2" in k
}

In [None]:
from Transformer_SSL.models.swin_transformer import DoubleSwinTransformerDownstream
from utils import save_checkpoint_single_model, dotdictify
from Transformer_SSL.models import build_model


# if torch.cuda.is_available():
#     device = torch.device("cuda")
# else:
#     device = torch.device("cpu:0")

device = torch.device('cpu')


# Training configurations
train_config = {
    's1_input_channels': 2,
    's2_input_channels': 13,
    'finetuning': True, # If false, backbone layers is frozen and only the head is trained
    'classifier_lr': 3e-6,
    'learning_rate': 0.00001,
    'adam_betas': (0.9, 0.999), 
    'weight_decay': 0.001,
    'dataloader_workers': 4,
    'batch_size': 16,
    'epochs': 5, 
    'target': 'dfc_label'
}

# Input channel size
input_channels = train_config['s1_input_channels'] + train_config['s2_input_channels']

# If you are using a uni-modal dataset, you can un-comment one of these lines, and comment the one above:
# input_channels = train_config['s1_input_channels']
# input_channels = train_config['s2_input_channels']

with open("configs/backbone_config.json", "r") as fp:
    swin_conf = dotdictify(json.load(fp))

s1_backbone = build_model(swin_conf.model_config)

swin_conf.model_config.MODEL.SWIN.IN_CHANS = 13
s2_backbone = build_model(swin_conf.model_config)

s1_backbone.load_state_dict(s1_weights)
s2_backbone.load_state_dict(s2_weights)

In [None]:
class DoubleSwinTransformerClassifier(torch.nn.Module):
    def __init__(self, encoder1, encoder2, out_dim, device, freeze_layers=True):
        super(DoubleSwinTransformerClassifier, self).__init__()
        
        # If you're only using one of the two backbones, just comment the one you don't need
        self.backbone1 = encoder1
        self.backbone2 = encoder2

        self.device = device

        # add final linear layer
        self.fc = torch.nn.Linear(
            self.backbone2.num_features + self.backbone1.num_features,
            out_dim,
            bias=True,
        )

        # freeze all layers but the last fc
        if freeze_layers:
            for name, param in self.named_parameters():
                if name not in ["fc.weight", "fc.bias"]:
                    param.requires_grad = False

    def forward(self, x):
        x1, _, _ = self.backbone1.forward_features(x["s1"].to(self.device))
        x2, _, _ = self.backbone2.forward_features(x["s2"].to(self.device))

        z = torch.cat([x1, x2], dim=1)
        z = self.fc(z)
        
        # If you're only using one of the two backbones, you may comment the lines above and use the following:
        # x1, _, _ = self.backbone1.forward_features(x["s1"].to(self.device))
        # z = self.fc(x1)

        return z

In [None]:
model = DoubleSwinTransformerClassifier(
        s1_backbone, s2_backbone, out_dim=data_config['num_classes'], device=device
    )

model = model.to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=255, reduction="mean").to(device)
if train_config['finetuning']:
    # train all parameters (backbone + classifier head)
    param_backbone = []
    param_head = []
    for p in model.parameters():
        if p.requires_grad:
            param_head.append(p)
        else:
            param_backbone.append(p)
        p.requires_grad = True
    # parameters = model.parameters()
    parameters = [
        {"params": param_backbone},  # train with default lr
        {
            "params": param_head,
            "lr": train_config['classifier_lr'],
        },  # train with classifier lr
    ]
    print("Finetuning")

else:
    # train only final linear layer for SSL methods
    print("Frozen backbone")
    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))

In [None]:
from dfc_dataset import DFCDataset

# Create Training Dataset
train_dataset = DFCDataset(
    data_config['train_dir'],
    mode=data_config['train_mode'],
    clip_sample_values=data_config['clip_sample_values'],
    used_data_fraction=data_config['train_used_data_fraction'],
    image_px_size=data_config['image_px_size'],
    cover_all_parts=data_config['cover_all_parts_train'],
    seed=data_config['seed'],
)

# Create Validation Dataset
val_dataset = DFCDataset(
    data_config['val_dir'],
    mode=data_config['val_mode'],
    clip_sample_values=data_config['clip_sample_values'],
    used_data_fraction=data_config['val_used_data_fraction'],
    image_px_size=data_config['image_px_size'],
    cover_all_parts=data_config['cover_all_parts_validation'],
    seed=data_config['seed'],
)

In [None]:
optimizer = torch.optim.Adam(
    parameters,
    lr=train_config['learning_rate'],
    betas=train_config['adam_betas'],
    weight_decay=train_config['weight_decay'],
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_config['batch_size'],
    shuffle=True,
    pin_memory=True,
    num_workers=train_config['dataloader_workers'],
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=train_config['batch_size'],
    shuffle=False,
    num_workers=train_config['dataloader_workers'],
)

In [None]:
# create a new model's instance
model = DoubleSwinTransformerClassifier(
        s1_backbone, s2_backbone, out_dim=data_config['num_classes'], device=device)

model = model.to(device)

# load checkpoints weights
model.load_state_dict(torch.load("checkpoints/classifier-epoch-4.pth"))

# prepare input and feed it to model for evaluation
img = {"s1": torch.unsqueeze(val_dataset[2]['s1'], 0), "s2": torch.unsqueeze(val_dataset[2]['s2'], 0)} # adding an extra dimension for batch information
model.eval()
output = model(img)

# display predicted class:
print(f'Predicted class: {DFC_map_clean[torch.argmax(output).item()]}')

# display ground-truth label:
print('Ground-truth class: ', DFC_map_clean[val_dataset[2][train_config['target']]])

# display image
val_dataset.visualize_observation(2)

# DFC SEN12MS

In [None]:
import geopandas as gpd

gdf = gpd.read_file("data/dfc_sen12ms_dataset.py/Data_for_training_zoom_031123.shp")
gdf.shape

In [None]:
gdf.columns