In [None]:
from sklearn import metrics
import numpy as np
import torch
import argparse
from collections import namedtuple
import os
from torchmetrics import Accuracy

import change_detection_pytorch as cdp
from change_detection_pytorch.datasets import LEVIR_CD_Dataset

from change_detection_pytorch.encoders._utils import adjust_state_dict_prefix

from torch.utils.data import DataLoader

from change_detection_pytorch.datasets import ChangeDetectionDataModule
from argparse import ArgumentParser

from change_detection_pytorch.datasets import UCMerced, build_transform

from dataclasses import dataclass

from matplotlib import pyplot as plt
from tqdm import tqdm

import train_classifier as tr_cls

In [None]:
@dataclass
class Args:
    device: int = 1
    experiment_name: str = 'tmp'
    
    backbone_name: str = 'ibot-B'
    encoder_weights: str = 'million_aid_fa'
    encoder_depth: int = 12
    in_features: int =768
    
    root: str = '/nfs/ap/mnt/sxtn/classification/datasets/'
    base_dir: str = 'UCMerced_LandUse/Images/'
    dataset_name: str = 'uc_merced'
    num_classes: int = 21
    
    fusion: str = 'diff'
    scale: str = None
    mode: str = 'vanilla'
    batch_size: int = 32

In [None]:
checkpoints = [] #paths to finetuned models

In [None]:
results = {}
scales = ['1x', '2x', '4x', '8x']

for checkpoint_path in checkpoints:
    results[checkpoint_path] = {}
    for scale in scales:
        
        args = Args()
    
        if 'gfm' in checkpoint_path:
            args.backbone_name = 'Swin-B'
            args.encoder_weights = 'geopile'
            args.in_features = 1024
        elif 'satlas' in checkpoint_path:
            args.backbone_name = 'Swin-B'
            args.encoder_weights = 'satlas'
            args.in_features = 1024

        if '_resisc' in checkpoint_path:
            args.root = #path to datasets
            args.base_dir = "NWPU-RESISC45/"
            args.dataset_name = "resisc45"
            args.num_classes = 45

        if scale != '1x':
            args.scale = scale
            args.base_dir = args.base_dir[:-1] + '_' + scale + args.base_dir[-1]
            
        print(args)

        test_transform = build_transform(split='test')
        test_dataset = UCMerced(root=args.root, base_dir=args.base_dir, split='test', transform=test_transform, dataset_name=args.dataset_name)
        test_dataloader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True, num_workers=8)

        if 'satlas' in checkpoint_path:
            model = tr_cls.Classifier(backbone_name=args.backbone_name, backbone_weights=args.encoder_weights, 
                                  in_features=args.in_features, num_classes=args.num_classes, 
                                  lr=0.0, sched='', checkpoint_path=checkpoint_path, only_head='', prefix='encoder')
        else:
            model = tr_cls.Classifier(backbone_name=args.backbone_name, backbone_weights=args.encoder_weights, 
                                  in_features=args.in_features, num_classes=args.num_classes, 
                                  lr=0.0, sched='', checkpoint_path='', only_head='')

        checkpoint = torch.load(checkpoint_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
        if 'satlas' not in checkpoint_path:
            model.load_state_dict(checkpoint['state_dict'])


        model.eval()
        
        test_accuracy = Accuracy(task="multiclass", num_classes=args.num_classes)

        with torch.no_grad():
            correct_predictions = 0
            total_samples = 0
            for batch in tqdm(test_dataloader):
                x, y = batch
                logits = model(x)
                batch_accuracy = test_accuracy(torch.argmax(logits, dim=1), y)
                correct_predictions += batch_accuracy.item() * len(y)
                total_samples += len(y)
        
            overall_test_accuracy = correct_predictions / total_samples
        print(checkpoint_path)
        print(f'Test Accuracy: {overall_test_accuracy * 100:.2f}%')
        results[checkpoint_path][scale] = overall_test_accuracy * 100
print(results)