In [2]:
import segmentation_models_pytorch as smp
import torch
import yaml

## Planet based RTS v6 notcvis

In [2]:
# Load config and checkpoint
path = "../models/old/RTS_v6_notcvis/checkpoints/41.pt"
checkpoint = torch.load(path, map_location="cpu")

with open("../models/old/RTS_v6_notcvis/config.yml") as f:
    config = yaml.safe_load(f)
config["model"]

{'architecture': 'UnetPlusPlus',
 'encoder': 'resnet34',
 'encoder_weights': 'random',
 'input_channels': 7}

In [3]:
# Try loading
model = smp.create_model(
    arch=config["model"]["architecture"],
    encoder_name=config["model"]["encoder"],
    in_channels=config["model"]["input_channels"],
)
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [None]:
# Merge config and checkpoint
torch.save(
    {
        "config": {
            "model": {
                "arch": config["model"]["architecture"],
                "encoder_name": config["model"]["encoder"],
                "in_channels": config["model"]["input_channels"],
                "classes": 1,
            },
            "input_combination": [
                "ndvi",
                "blue",
                "green",
                "red",
                "nir",
                "relative_elevation",
                "slope",
            ],
            "norm_factors": {
                "red": 1 / 3000,
                "green": 1 / 3000,
                "blue": 1 / 3000,
                "nir": 1 / 3000,
                "ndvi": 1,
                "relative_elevation": 1 / 30000,
                "slope": 1 / 90,
            },
            "patch_size": 1024,
            "model_framework": "smp",
        },
        "statedict": model.module.state_dict(),
    },
    "../models/RTS_v6_notcvis.pt",
)

In [5]:
# Test it
checkpoint = torch.load("../models/RTS_v6_notcvis.pt")
model = smp.create_model(**checkpoint["config"]["model"])
model.load_state_dict(checkpoint["statedict"])

<All keys matched successfully>

## Planet based RTS v6 tcvis

In [None]:
# Load config and checkpoint
path = "../models/old/RTS_v6_tcvis/checkpoints/14.pt"
checkpoint = torch.load(path, map_location="cpu")

with open("../models/old/RTS_v6_tcvis/config.yml") as f:
    config = yaml.safe_load(f)
print(config["model"])

# Try loading
model = smp.create_model(
    arch=config["model"]["architecture"],
    encoder_name=config["model"]["encoder"],
    in_channels=config["model"]["input_channels"],
)
model = torch.nn.DataParallel(model)
print(model.load_state_dict(checkpoint))

# Merge config and checkpoint
torch.save(
    {
        "config": {
            "model": {
                "arch": config["model"]["architecture"],
                "encoder_name": config["model"]["encoder"],
                "in_channels": config["model"]["input_channels"],
                "classes": 1,
            },
            "input_combination": [
                "ndvi",
                "blue",
                "green",
                "red",
                "nir",
                "relative_elevation",
                "slope",
                "tc_brightness",
                "tc_greenness",
                "tc_wetness",
            ],
            "norm_factors": {
                "red": 1 / 3000,
                "green": 1 / 3000,
                "blue": 1 / 3000,
                "nir": 1 / 3000,
                "ndvi": 1,
                "relative_elevation": 1 / 30000,
                "slope": 1 / 90,
                "tc_brightness": 1 / 255,
                "tc_greenness": 1 / 255,
                "tc_wetness": 1 / 255,
            },
            "patch_size": 1024,
            "model_framework": "smp",
        },
        "statedict": model.module.state_dict(),
    },
    "../models/RTS_v6_tcvis.pt",
)

# Test it
checkpoint = torch.load("../models/RTS_v6_tcvis.pt")
model = smp.create_model(**checkpoint["config"]["model"])
model.load_state_dict(checkpoint["statedict"])

{'architecture': 'UnetPlusPlus', 'encoder': 'resnet34', 'encoder_weights': 'random', 'input_channels': 10}
<All keys matched successfully>


<All keys matched successfully>