In [1]:
import torch

from mmcv import Config
from mmcv.runner import load_checkpoint
from mmseg.models import build_segmentor

  from .autonotebook import tqdm as notebook_tqdm


# Explore MMSEG Pipeline

This notebook explores the MMSEG pipeline Jakubik et. al implemented for the Crop Segmentation Model trained based on Prithvi.

## some docs and references

- https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/overview.md
- https://mmsegmentation.readthedocs.io/en/latest/advanced_guides/index.html
- https://mmsegmentation.readthedocs.io/en/latest/notes/faq.html

## understanding the prithvi crop classification example model

In [2]:
cfg = Config.fromfile('./multi_temporal_crop_classification.py')

# Build the model
model = build_segmentor(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))

load from prithvi-demo-ft/prithvi/Prithvi_100M.pt
load checkpoint from local path: prithvi-demo-ft/prithvi/Prithvi_100M.pt
The model and loaded state dict do not match exactly

unexpected key in source state_dict: mask_token, decoder_pos_embed, decoder_embed.weight, decoder_embed.bias, decoder_blocks.0.norm1.weight, decoder_blocks.0.norm1.bias, decoder_blocks.0.attn.qkv.weight, decoder_blocks.0.attn.qkv.bias, decoder_blocks.0.attn.proj.weight, decoder_blocks.0.attn.proj.bias, decoder_blocks.0.norm2.weight, decoder_blocks.0.norm2.bias, decoder_blocks.0.mlp.fc1.weight, decoder_blocks.0.mlp.fc1.bias, decoder_blocks.0.mlp.fc2.weight, decoder_blocks.0.mlp.fc2.bias, decoder_blocks.1.norm1.weight, decoder_blocks.1.norm1.bias, decoder_blocks.1.attn.qkv.weight, decoder_blocks.1.attn.qkv.bias, decoder_blocks.1.attn.proj.weight, decoder_blocks.1.attn.proj.bias, decoder_blocks.1.norm2.weight, decoder_blocks.1.norm2.bias, decoder_blocks.1.mlp.fc1.weight, decoder_blocks.1.mlp.fc1.bias, decoder_block

In [3]:
model

TemporalEncoderDecoder(
  (backbone): TemporalViTEncoder(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(6, 768, kernel_size=(1, 16, 16), stride=(1, 16, 16))
      (norm): Identity()
    )
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((768,)

### Dataset

we can load the dataset based on the config. this will allow us to get samples and runs the pipeline on them automatically. 

In [4]:
from mmseg.datasets import build_dataset

# Build the training dataset
train_dataset = build_dataset(cfg.data.train)

# Build the validation dataset
val_dataset = build_dataset(cfg.data.val)

# Build the testing dataset, if needed
test_dataset = build_dataset(cfg.data.test)

2024-04-17 16:53:35,952 - mmseg - INFO - Loaded 3083 images
2024-04-17 16:53:35,981 - mmseg - INFO - Loaded 771 images
2024-04-17 16:53:36,011 - mmseg - INFO - Loaded 771 images


In [5]:
print('train dataset keys: ', train_dataset[0].keys())
print('img shape', train_dataset[0]['img'].shape)
print('gt_semantic_seg shape', train_dataset[0]['gt_semantic_seg'].shape)

print('len train:', len(train_dataset))

train dataset keys:  dict_keys(['img_metas', 'img', 'gt_semantic_seg'])
img shape torch.Size([6, 3, 224, 224])
gt_semantic_seg shape torch.Size([1, 224, 224])
len train: 3083


we get lables with the train dataset

In [6]:
print('val dataset keys: ', val_dataset[0].keys())
print('img shape', val_dataset[0]['img'][0].shape)

print('len val:', len(val_dataset))

val dataset keys:  dict_keys(['img_metas', 'img'])
img shape torch.Size([6, 3, 224, 224])
len val: 771


val & test dataset provides only images and meta info. img is encapsulated in an array, thus second [0] is needed

In [7]:
train_dataset.get_gt_seg_map_by_idx(0)

array([[ 2,  2,  2, ...,  2,  2,  2],
       [ 2,  2,  2, ...,  2,  2,  2],
       [ 2,  2,  2, ...,  2,  2,  2],
       ...,
       [ 0,  0,  0, ...,  6,  6,  6],
       [ 0,  0,  0, ...,  0,  1,  4],
       [ 5,  5,  5, ..., 12,  0,  8]], dtype=uint8)

In [8]:
test_dataset.get_ann_info(0)

{'seg_map': 'chip_002_060.mask.tif'}

### pipeline

the pipeline can be built and ran separately, but needs a correctly structured dict as input. this became obsolote with the dataset class above.

In [9]:
from mmseg.datasets.pipelines import Compose

def run_pipeline(results):
    train_pipeline = cfg.data.train.pipeline
    data = Compose(train_pipeline)(results)
    return data

results = {
    "img_info": {"filename": "chip_003_062_merged.tif"},
    "img_prefix": "./data/training_chips",
    "ann_info": {"seg_map": "chip_003_062.mask.tif"},
    "seg_prefix": "./data/training_chips",
    "seg_fields": []
}

data = run_pipeline(results)
data.keys()

dict_keys(['img_metas', 'img', 'gt_semantic_seg'])

In [10]:
data['img_metas']

DataContainer({'filename': './data/training_chips/chip_003_062_merged.tif', 'ori_filename': 'chip_003_062_merged.tif', 'ori_shape': (224, 224, 18), 'img_shape': (224, 224, 18), 'pad_shape': (224, 224, 18), 'scale_factor': 1.0, 'flip': False, 'flip_direction': 'horizontal', 'img_norm_cfg': {'mean': [494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962, 1739.579917], 'std': [284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808, 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808, 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808]}})

In [11]:
print('img shape', data['img'].shape)
print('gt_semantic_seg shape', data['gt_semantic_seg'].shape)

img shape torch.Size([6, 3, 224, 224])
gt_semantic_seg shape torch.Size([1, 224, 224])


### passing a sample through model

In [22]:
model.eval()

# Get a sample from the dataset
sample = train_dataset[0]

# Add a batch dimension to the data
img = sample['img'].unsqueeze(0)
gt_semantic_seg = sample['gt_semantic_seg'].unsqueeze(0)
img_metas = [sample['img_metas']]

# Pass the batched data to the model
with torch.no_grad():
    results = model.forward(img=img, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)

TemporalViTEncoder IN: torch.Size([1, 6, 3, 224, 224])
TemporalViTEncoder EMBED: torch.Size([1, 588, 768])
TemporalViTEncoder OUT: torch.Size([1, 589, 768])
ConvTransformerTokensToEmbeddingNeck IN: torch.Size([1, 589, 768])
ConvTransformerTokensToEmbeddingNeck RSHP: torch.Size([1, 2304, 224, 224])
FCNHead INP:  torch.Size([1, 2304, 224, 224])
FCNHead OUT:  torch.Size([13, 224, 224])
FCNHead INP:  torch.Size([1, 2304, 224, 224])
FCNHead OUT:  torch.Size([13, 224, 224])


In [40]:
img_metas[0]

{'filename': './data/training_chips/chip_003_062_merged.tif',
 'ori_filename': 'chip_003_062_merged.tif',
 'ori_shape': (224, 224, 18),
 'img_shape': (224, 224, 18),
 'pad_shape': (224, 224, 18),
 'scale_factor': 1.0,
 'flip': False,
 'flip_direction': 'horizontal',
 'img_norm_cfg': {'mean': [494.905781,
   815.239594,
   924.335066,
   2968.881459,
   2634.621962,
   1739.579917,
   494.905781,
   815.239594,
   924.335066,
   2968.881459,
   2634.621962,
   1739.579917,
   494.905781,
   815.239594,
   924.335066,
   2968.881459,
   2634.621962,
   1739.579917],
  'std': [284.925432,
   357.84876,
   575.566823,
   896.601013,
   951.900334,
   921.407808,
   284.925432,
   357.84876,
   575.566823,
   896.601013,
   951.900334,
   921.407808,
   284.925432,
   357.84876,
   575.566823,
   896.601013,
   951.900334,
   921.407808]}}

In [19]:
len(cfg.CLASSES)

13

In [20]:
results

{'decode.loss_ce': tensor(1.9787),
 'decode.acc_seg': tensor([6.3237]),
 'aux.loss_ce': tensor(1.8801),
 'aux.acc_seg': tensor([14.2379])}

In [15]:
model

TemporalEncoderDecoder(
  (backbone): TemporalViTEncoder(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(6, 768, kernel_size=(1, 16, 16), stride=(1, 16, 16))
      (norm): Identity()
    )
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
        (norm1): LayerNorm((768,)