In [None]:
from sklearn import metrics
import numpy as np
import torch
import argparse
from collections import namedtuple
import os

from matplotlib import pyplot as plt

import change_detection_pytorch as cdp
from change_detection_pytorch.datasets import LEVIR_CD_Dataset
from torch.utils.data import DataLoader


from change_detection_pytorch.datasets import ChangeDetectionDataModule
from argparse import ArgumentParser
from tqdm import tqdm

In [None]:
results = {}

In [None]:

from dataclasses import dataclass

@dataclass
class Args:
    experiment_name: str = 'tmp'
    backbone: str = 'Swin-B'
    encoder_weights: str = 'geopile'
    encoder_depth: int = 12
    dataset_name: str = 'OSCD'
    dataset_path: str = '/mnt/sxtn/aerial/change/OSCD/'
    fusion: str = 'diff'
    scale: str = None
    tile_size: int = 192
    mode: str = 'vanilla'
    batch_size: int = 116 // 4


In [None]:
scales = ['1x', '2x', '4x', '8x']

In [None]:
checkpoints = [] #path_to_finetuned models

In [None]:
def f1_bitwise(y_true, y_pred):
    TP = np.bitwise_and(y_true, y_pred).sum()
    FP = np.bitwise_and(y_pred, np.logical_not(y_true)).sum()
    FN = np.bitwise_and(np.logical_not(y_pred), y_true).sum()

    precision = TP / (TP + FP + 1e-10)
    recall = TP / (TP + FN + 1e-10)
    F1 = 2 * (precision * recall) / (precision + recall + 1e-10)
    return F1

In [None]:
print(checkpoints)

In [None]:
from change_detection_pytorch.base.modules import Activation
from change_detection_pytorch.utils import base
from change_detection_pytorch.utils import functional as F

class CustomMetric(base.Metric):
            __name__ = 'custom'
        
            def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
                super().__init__(**kwargs)
                self.eps = eps
                self.threshold = threshold
                self.activation = Activation(activation)
                self.ignore_channels = ignore_channels
        
            def forward(self, y_pr, y_gt):
                y_pr = self.activation(y_pr)
                data['p'] = np.concatenate([data['p'], y_pr.cpu().numpy().astype('uint8')])
                data['t'] = np.concatenate([data['t'], y_gt.cpu().numpy().astype('uint8')])
                
                fscores = torch.tensor([F.f_score(p, g) for p, g in zip(y_pr, y_gt)])
                return fscores.mean()

In [None]:
for checkpoint_path in checkpoints:
    results[checkpoint_path] = {}
    for scale in scales:
        args = Args()  # Create an instance with default values
        if scale != '1x':
            args.scale = scale
            args.mode = 'wo_train_aug'
    
        if 'ibot' in checkpoint_path:
            args.backbone = 'ibot-B'
            args.encoder_weights = 'million_aid'
            args.encoder_depth = 12
        elif 'satlas' in checkpoint_path:
            args.backbone = 'Swin-B'
            args.encoder_weights = 'satlas'
            args.encoder_depth = 12

        if 'cdd' in checkpoint_path:
            args.dataset_name = 'CDD'
            args.dataset_path = #path_to_dataset
            args.batch_size = 32
            args.tile_size = 256 # it doesn't use
        elif 'levir' in checkpoint_path:
            args.dataset_name = 'LEVIR_CD'
            args.dataset_path = #path_to_dataset
            args.batch_size = 32
            args.tile_size = 256 # it doesn't use
            
        print(args)
        
        # DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'
        
        # args.fusion = 'concat' # comment this for other models
        
        model = cdp.UPerNet(
            encoder_depth=args.encoder_depth,
            encoder_name=args.backbone, # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights=args.encoder_weights, # use `imagenet` pre-trained weights for encoder initialization
            in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=2, # model output channels (number of classes in your datasets)
            siam_encoder=True, # whether to use a siamese encoder
            fusion_form=args.fusion, # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.
        )

        
        checkpoint = torch.load(checkpoint_path)
        
        model.load_state_dict(checkpoint.state_dict())

        
        if 'cdd' in args.dataset_name.lower():
            print('CDD', args.dataset_name)
            valid_dataset = LEVIR_CD_Dataset(f'{args.dataset_path}/test',
                                            sub_dir_1='A',
                                            sub_dir_2='B' if scale == '1x' else f'B_{scale}',
                                            img_suffix='.jpg',
                                            ann_dir=f'{args.dataset_path}/test/OUT',
                                            debug=False,
                                            seg_map_suffix='.jpg')
        
            valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
            data = {
                'p': np.empty((0, 256, 256), dtype='uint8'),
                't': np.empty((0, 256, 256), dtype='uint8'),
                'f': []
            }

        elif 'levir' in args.dataset_name.lower():
            print('LEVIR', args.dataset_name)
            valid_dataset = LEVIR_CD_Dataset(f'{args.dataset_path}/test',
                                            sub_dir_1='A_cut',
                                            sub_dir_2='B_cut' if scale == '1x' else f'B_cut_{scale}',
                                            img_suffix='.png',
                                            #ann_dir=f'{args.dataset_path}/test/OUT',
                                            ann_dir=f'{args.dataset_path}/test/OUT_cut',
                                            debug=False,
                                            test_mode=True,
                                            seg_map_suffix='.png')
        
            valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
            data = {
                'p': np.empty((0, 256, 256), dtype='uint8'),
                't': np.empty((0, 256, 256), dtype='uint8'),
                'f': []
            }
            
        loss = cdp.utils.losses.CrossEntropyLoss()
        
        our_metrics = [
            cdp.utils.metrics.Fscore(activation='argmax2d'),
            cdp.utils.metrics.Precision(activation='argmax2d'),
            cdp.utils.metrics.Recall(activation='argmax2d'),
            CustomMetric(activation='argmax2d'),
        ]
        
        valid_epoch = cdp.utils.train.ValidEpoch(
            model,
            loss=loss,
            metrics=our_metrics,
            device=DEVICE,
            verbose=True,
        )
        
        valid_logs = valid_epoch.run(valid_loader)
                
        if 'cdd' in args.dataset_name.lower() or 'levir' in args.dataset_name.lower():
            fscores = []
            maps_t = []
            maps_p = []
            for p, t in tqdm(zip(data['p'], data['t'])):
                if p.sum() + t.sum() == 0:
                    fscores.append(0)
                else:
                    f1_real = metrics.f1_score(t.flatten(), p.flatten())
                    # f1_ours = f1_bitwise(t.flatten(), p.flatten())
                    fscores.append(
                        f1_real
                    )
                maps_t.append(t)
                maps_p.append(p)
            macro_f1 = np.mean(fscores)
            maps_t = np.vstack(maps_t)
            maps_p = np.vstack(maps_p)
            print(maps_t.shape, maps_p.shape)

            print(maps_t)
            print(maps_t.dtype)
            micro_f1 = f1_bitwise(maps_t, maps_p)
            maps = {'t':maps_t, 'p':maps_p}
            
        
        print(checkpoint_path, scale)
        print(macro_f1, micro_f1)
    
        results[checkpoint_path][scale] = {
            'maps': maps,
            'micro_f1': micro_f1,
            'macro_f1': macro_f1
        }

In [None]:
for checkpoint in checkpoints:
    print(checkpoint)
    for scale in scales:
        print(f"{scale} macro F1 = {results[checkpoint][scale]['macro_f1']:.3f}  micro-F1 = {results[checkpoint][scale]['micro_f1']:.3f}")
    print(' ')

In [None]:
for checkpoint in checkpoints:
    print(checkpoint)
    for scale in ['8x', '4x', '2x', '1x']:
        print(f",{results[checkpoint][scale]['micro_f1']:.3f}", end='')
    print(' ')