In [None]:
from sklearn import metrics
import numpy as np
import torch
import argparse
from collections import namedtuple
# import wandb
import os

from matplotlib import pyplot as plt

import change_detection_pytorch as cdp
from change_detection_pytorch.datasets import LEVIR_CD_Dataset
from torch.utils.data import DataLoader

# from change_detection_pytorch.utils.lr_scheduler import GradualWarmupScheduler

from change_detection_pytorch.datasets import ChangeDetectionDataModule
from argparse import ArgumentParser
from tqdm import tqdm

In [None]:
checkpoint_path = '/auto/home/ani/change_detection.pytorch/checkpoints/gfm_oscd192_diff_e_real_160_b32_multisteplr/best_model.pth'
checkpoint_path = '/auto/home/ani/change_detection.pytorch/checkpoints/gfm_cdd/best_model.pth'
checkpoint_path = '/auto/home/ani/change_detection.pytorch/checkpoints/gfm/best_model.pth'
checkpoint_path = '/auto/home/ani/change_detection.pytorch/checkpoints/gfm_aug/best_model.pth'
checkpoint_path = '/auto/home/ani/change_detection.pytorch/checkpoints/gfm_oscd_norm/best_model.pth'
checkpoint_path = '/auto/home/ani/change_detection.pytorch/checkpoints/ibot_oscd_norm/best_model.pth'
checkpoint_path = '/auto/home/ani/change_detection.pytorch/checkpoints/gfm_oscd_norm_e400_lr4e-4/best_model.pth'
checkpoint_path = '/auto/home/ani/change_detection.pytorch/checkpoints/ibot_fa_oscd_norm/best_model.pth'
checkpoint_path = '/auto/home/ani/change_detection.pytorch/checkpoints/gfm_oscd_norm_e300_lr2e-4/best_model.pth'

In [None]:
  Args = namedtuple('Args', ['experiment_name', 'backbone', 'encoder_weights', 'encoder_depth',
                             'dataset_name', 'dataset_path', 'fusion', 'scale',
                             'tile_size', 'mode', 'batch_size'])

In [None]:
  args = Args(experiment_name='tmp', fusion='diff', tile_size=192,
              backbone='Swin-B', encoder_weights='geopile', encoder_depth=12,
              dataset_name='OSCD', dataset_path='/mnt/sxtn/aerial/change/OSCD/', batch_size=116//4,
              mode='vanilla', scale=None,
              # mode='wo_train_aug', scale='4x'
             )

In [None]:
  # args = Args(experiment_name='tmp', fusion='diff', tile_size=192,
  #             backbone='Swin-B', encoder_weights='geopile', encoder_depth=5,
  #             dataset_name='CDD', dataset_path='/mnt/sxtn/aerial/change/CDD/Real/subset/', batch_size=2,
  #             mode='vanilla', scale=None,
  #             # mode='wo_train_aug', scale='8x'
  #            )

In [None]:
  # args = Args(experiment_name='tmp', fusion='diff', tile_size=192, 
  #             backbone='ibot-B', encoder_weights='million_aid', encoder_depth=12,
  #             dataset_name='OSCD', dataset_path='/mnt/sxtn/aerial/change/OSCD/', batch_size=116//4,
  #             mode='vanilla', scale=None,
  #             # mode='wo_train_aug', scale='8x'
  #            )

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = cdp.UPerNet(
    encoder_depth=args.encoder_depth,
    encoder_name=args.backbone, # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=args.encoder_weights, # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=2, # model output channels (number of classes in your datasets)
    siam_encoder=True, # whether to use a siamese encoder
    fusion_form=args.fusion, # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.
)

In [None]:
checkpoint = torch.load(checkpoint_path)

In [None]:
model.load_state_dict(checkpoint.state_dict())

In [None]:
if 'oscd' in args.dataset_name.lower():
    datamodule = ChangeDetectionDataModule(args.dataset_path, patch_size=args.tile_size, mode=args.mode, scale=args.scale, batch_size=args.batch_size)
    datamodule.setup()

    valid_loader = datamodule.val_dataloader()
    print(len(valid_loader))
    data = {
        'p': np.empty((0, 192, 192)),
        't': np.empty((0, 192, 192)),
        'f': []
    }
else:
    print('CCD', args.dataset_name)
    valid_dataset = LEVIR_CD_Dataset(f'{args.dataset_path}/val',
                                    sub_dir_1='A',
                                    sub_dir_2='B',
                                    img_suffix='.jpg',
                                    ann_dir=f'{args.dataset_path}/val/OUT',
                                    debug=False,
                                    seg_map_suffix='.jpg')

    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    data = {
        'p': np.empty((0, 256, 256)),
        't': np.empty((0, 256, 256)),
        'f': []
    }
loss = cdp.utils.losses.CrossEntropyLoss()

In [None]:
from change_detection_pytorch.base.modules import Activation
from change_detection_pytorch.utils import base
from change_detection_pytorch.utils import functional as F

class CustomMetric(base.Metric):
    __name__ = 'custom'

    def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.threshold = threshold
        self.activation = Activation(activation)
        self.ignore_channels = ignore_channels

    def forward(self, y_pr, y_gt):
        y_pr = self.activation(y_pr)
        data['p'] = np.concatenate([data['p'], y_pr.cpu().numpy()])
        data['t'] = np.concatenate([data['t'], y_gt.cpu().numpy()])
        
        fscores = torch.tensor([F.f_score(p, g) for p, g in zip(y_pr, y_gt)])
        # plt.figure(figsize=(4,2))
        # plt.imshow((y_pr[0]*2+y_gt[0]).cpu().numpy(), cmap='nipy_spectral', vmax=4)
        # plt.title(f"F-score={fscores[0]:.3f}")
        # print("\n", y_pr.shape, y_gt.shape)
        # print((y_pr*y_gt).sum(), y_pr.sum(), y_gt.sum())
        return fscores.mean()

In [None]:
our_metrics = [
    cdp.utils.metrics.Fscore(activation='argmax2d'),
    cdp.utils.metrics.Precision(activation='argmax2d'),
    cdp.utils.metrics.Recall(activation='argmax2d'),
    CustomMetric(activation='argmax2d'),
]

In [None]:
valid_epoch = cdp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=our_metrics,
    device=DEVICE,
    verbose=True,
)

valid_logs = valid_epoch.run(valid_loader)


In [None]:
data['f'] = [y for x in valid_logs['filenames'] for y in x]

In [None]:
cities = []
coords = []
for name in data['f']:
    name = name.split('/')[-1]
    _parts = name.split('_')
    city = '_'.join(_parts[:-1])
    coord = [int(t) for t in _parts[-1][1:-1].split(', ')]
    cities.append(city)
    coords.append(coord)

In [None]:
unique_cities = set(cities)

In [None]:
maps = {city: {
    't': np.zeros((1000, 1000)),
    'p': np.zeros((1000, 1000)),
} for city in unique_cities}

In [None]:
for city, coord, p, t in zip(cities, coords, data['p'], data['t']):
    x1,y1,x2,y2 = coord
    maps[city]['t'][y1:y2,x1:x2] = t
    maps[city]['p'][y1:y2,x1:x2] = p

In [None]:
for city in tqdm(maps.keys()):
    maps[city]['fscore'] = metrics.f1_score(maps[city]['t'].flatten(), maps[city]['p'].flatten())

In [None]:
micro_f1 = metrics.f1_score(
    np.concatenate([maps[city]['t'].flatten() for city in maps]),
    np.concatenate([maps[city]['p'].flatten() for city in maps]), 
)

In [None]:
np.mean([maps[city]['fscore'] for city in maps]), micro_f1

In [None]:
for city in maps.keys():
    plt.figure(figsize=(8,8))
    plt.imshow(maps[city]['p']*2 + maps[city]['t'], cmap='nipy_spectral', vmax=4)
    plt.title(f"{city} F-score: {maps[city]['fscore']:.2f}")

In [None]:
import torch
import torchvision

In [None]:
checkpoint_path = '/mnt/sxtn/cd/satlas_model/sentinel2_swinb_si_rgb.pth'

In [None]:
checkpoint = torch.load(checkpoint_path)

In [None]:
new_state_dict = {}
prefix='backbone.'
needed='backbone'
for key, value in checkpoint.items():
    # Assure we're only keeping keys that we need for the current model component. 
    if not needed in key:
        continue

    # Update the key prefixes to match what the model expects.
    if prefix is not None:
        while key.count(prefix) > 0:
            key = key.replace(prefix, '', 1)

    new_state_dict[key] = value

In [None]:
model = torchvision.models.swin_v2_b()

In [None]:
model.load_state_dict(new_state_dict)

In [None]:
new_state_dict

In [None]:
model.features

In [None]:
model.__dict__

In [None]:
linear = torch.nn.Linear(1024, 21)

In [None]:
model.head = linear

In [None]:
model