In [1]:
import numpy as np
import torch
import argparse
from torchmetrics import Accuracy

from torch.utils.data import DataLoader

from argparse import ArgumentParser

from change_detection_pytorch.datasets import UCMerced, build_transform

from dataclasses import dataclass

from matplotlib import pyplot as plt
from tqdm import tqdm

import train_classifier as tr_cls

cuda:0


In [2]:
@dataclass
class Args:
    device: int = 1
    experiment_name: str = 'tmp'
    
    backbone_name: str = 'ibot-B'
    encoder_weights: str = 'million_aid_fa'
    encoder_depth: int = 12
    in_features: int =768
    
    root: str = '/nfs/ap/mnt/sxtn/classification/datasets/'
    base_dir: str = 'UCMerced_LandUse/Images/'
    dataset_name: str = 'uc_merced'
    num_classes: int = 21
    
    fusion: str = 'diff'
    scale: str = None
    mode: str = 'vanilla'
    batch_size: int = 32
    image_size: int = 256

In [3]:
checkpoints = [
    # '/auto/home/ani/change_detection.pytorch/checkpoints/satlas_head_resisc_mixup/epoch=99.ckpt',
    # '/auto/home/ani/change_detection.pytorch/checkpoints/satlas_hr_head_resisc_mixup/epoch=99.ckpt',
    '/auto/home/ani/change_detection.pytorch/checkpoints/ibot_scale_ben/epoch=99.ckpt',
    # '/auto/home/ani/change_detection.pytorch/checkpoints/ibotIN_head_resisc_mixup/epoch=99.ckpt',
    # '/auto/home/ani/change_detection.pytorch/checkpoints/gfm_head_resisc_mixup/epoch=99.ckpt',
    # '/auto/home/ani/change_detection.pytorch/checkpoints/dino_head_resisc_mixup/epoch=99.ckpt',
]

In [None]:
results = {}
scales = ['1x', '2x', '4x', '8x']

for checkpoint_path in checkpoints:
    args = Args()
    if 'dino' in checkpoint_path:
        args.backbone_name = 'dino'
        args.encoder_weights = ''
        args.in_features = 768
        args.image_size = 252
        
    if 'gfm' in checkpoint_path:
        args.backbone_name = 'Swin-B'
        args.encoder_weights = 'geopile'
        args.in_features = 1024
        
    elif 'satlas' in checkpoint_path:
        args.backbone_name = 'Swin-B'
        args.encoder_weights = 'satlas'
        args.in_features = 1024

    if 'resisc' in checkpoint_path:
        args.root = "/nfs/ap/mnt/sxtn/classification/datasets/"
        args.base_dir = "NWPU-RESISC45/"
        args.dataset_name = "resisc45"
        args.num_classes = 45
            
    if 'satlas' in checkpoint_path:
        model = tr_cls.Classifier(backbone_name=args.backbone_name, backbone_weights=args.encoder_weights, 
                                  in_features=args.in_features, num_classes=args.num_classes, 
                                  lr=0.0, sched='', checkpoint_path=checkpoint_path, only_head='',
                                warmup_steps = '', eta_min = '', warmup_start_lr='', weight_decay= '', prefix='encoder', mixup=False)
    else:
        model = tr_cls.Classifier(backbone_name=args.backbone_name, backbone_weights=args.encoder_weights, 
                                  in_features=args.in_features, num_classes=args.num_classes, 
                                  lr=0.0, sched='', checkpoint_path='', only_head='',
                                     warmup_steps = '', eta_min = '', warmup_start_lr='', weight_decay= '', mixup=False)
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if 'satlas' not in checkpoint_path:
        model.load_state_dict(checkpoint['state_dict'])

    model.eval()
    model = model.to(device)
    
    results[checkpoint_path] = {}
    base_dir = args.base_dir
    for scale in scales:
        if scale != '1x':
            args.scale = scale
            args.base_dir = base_dir[:-1] + '_' + scale + args.base_dir[-1]
            
        print(args)

        test_transform = build_transform(split='test', image_size = args.image_size)
        test_dataset = UCMerced(root=args.root, base_dir=args.base_dir, split='test', transform=test_transform, dataset_name=args.dataset_name)
        test_dataloader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=True, num_workers=24)

        test_accuracy = Accuracy(task="multiclass", num_classes=args.num_classes).to(device)
        with torch.no_grad():
            correct_predictions = 0
            total_samples = 0
            for batch in tqdm(test_dataloader):
                x, y = batch
                x = x.to(device)
                y = y.to(device)
                logits = model(x)
                batch_accuracy = test_accuracy(torch.argmax(logits, dim=1), y)
                correct_predictions += batch_accuracy.item() * len(y)
                total_samples += len(y)
        
            overall_test_accuracy = correct_predictions / total_samples
        print(checkpoint_path)
        print(f'Test Accuracy: {overall_test_accuracy * 100:.2f}%')
        results[checkpoint_path][scale] = overall_test_accuracy * 100
print(results)

In [None]:
satlas_results = list(results[checkpoints[0]].values())
satlas_hr_results = list(results[checkpoints[1]].values())
ibot_results = list(results[checkpoints[2]].values())
ibot_IN_results = list(results[checkpoints[3]].values())
gfm_results = list(results[checkpoints[4]].values())
dino_results = list(results[checkpoints[5]].values())

x_values = ['1/8', '1/4', '1/2', '1']

# Plotting the data
plt.plot(x_values, satlas_results[::-1], marker='o', color='red', label='Satlas')
plt.plot(x_values, satlas_hr_results[::-1], marker='o', color='blue', label='Satlas HR')
plt.plot(x_values, ibot_results[::-1], marker='o', color='orange', label='Ibot')
plt.plot(x_values, ibot_IN_results[::-1], marker='o', color='green', label='Ibot_IN')
plt.plot(x_values, gfm_results[::-1], marker='o', color='m', label='GFM')
plt.plot(x_values, dino_results[::-1], marker='o', color='brown', label='DINOv2')


# Adding labels and title
plt.xlabel('Linear Probing with Mixup RES45')
plt.ylabel('Accuracy')
# plt.title('Accuracy Comparison between ibot and gfm')
plt.legend()

# Display the plot
plt.show()
                 

In [None]:
formatted_resuls = {k: {} for k in results.keys()}
for k in results.keys():
    formatted_resuls[k] = {x: round(y, 2) for x, y in results[k].items()}

In [None]:
formatted_resuls

In [None]:
from PIL import Image

In [None]:
import os
import numpy


In [None]:
os.listdir('/nfs/ap/mnt/sxtn/classification/datasets/UCMerced_LandUse/Images/')

In [None]:
im1 = Image.open('/nfs/ap/mnt/sxtn/classification/datasets/UCMerced_LandUse/Images/harbor/harbor05.tif')
im2 = Image.open('/nfs/ap/mnt/sxtn/classification/datasets/UCMerced_LandUse/Images/agricultural/agricultural04.tif')
im3 = Image.open('/nfs/ap/mnt/sxtn/classification/datasets/UCMerced_LandUse/Images/baseballdiamond/baseballdiamond15.tif')
im4 = Image.open('/nfs/ap/mnt/sxtn/classification/datasets/UCMerced_LandUse/Images/golfcourse/golfcourse12.tif')

In [None]:
im4

In [None]:
os.listdir('/nfs/ap/mnt/sxtn/classification/datasets/UCMerced_LandUse/Images/buildings/buildings04.tif')

In [None]:
im3

In [None]:
im4 = Image.open('/mnt/sxtn/aerial/change/LEVIR_CD/test/B/test_119.png')

In [None]:
im4

In [None]:
gt1 = Image.open('/mnt/sxtn/aerial/change/LEVIR_CD/test/OUT/test_119.png')

In [None]:
gt1

In [None]:
import matplotlib.pyplot as plt


In [None]:
imgs = [im1, im2, gt, im3, im4, gt1]
_, axs = plt.subplots(2, 3, figsize=(5, 3.5))
axs = axs.flatten()
for img, ax in zip(imgs, axs):
    ax.set_axis_off()
    ax.imshow(img, cmap='gray')
plt.subplots_adjust(wspace=0.02, hspace=0)

plt.show()
