In [47]:
import torch
from torchvision.models import vgg16
from torch.nn import Conv2d

import os, glob
from tqdm.notebook import tqdm
import rasterio
import numpy as np

# Utils

In [48]:
class sen12msDFC_labelTransform:
    def __init__(self):
        # only classes that do exist in sen12ms DFC dataset are
        # 1, 2, 4, 5, 6, 7, 9, 10
        # so map them from zero to seven
        #
        #                       0     1  2     3    4  5  6  7    8     9  10
        self.lut = np.array( [np.nan, 0, 1, np.nan, 2, 3, 4, 5, np.nan, 6, 7] )
    def __call__(self, x):
        return self.lut[x]

lookUpTable = sen12msDFC_labelTransform()

In [33]:
model = vgg16(num_classes=8)
model.features[0] = Conv2d(10, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

# load checkpoint

In [34]:
# adapt here to your path
path_to_checkpoint = "/home/user/results/data_centric_clustering/pretrain_on_dfc/2023-09-12_19:32:38/model_checkpoints/state_at_finalize.pt"

In [35]:
ckp = torch.load(path_to_checkpoint)

In [36]:
ckp.keys()

dict_keys(['epoch', 'global_step', 'state_dict', 'optimizer_state_dict', 'loss'])

In [37]:
model.load_state_dict(ckp["state_dict"])

<All keys matched successfully>

# predict the sen12msdfc dataset

### here just to calculate the confusion matrix. You ofc have to adapt to your spesific dataloader and experiment

In [38]:
def preprocess_s2(s2):
    # devide by 10k to get the 0-100% reflection
    # window, then clip to [0,1]
    return np.clip(s2/10000,0,1)

In [39]:
# adapt here to your path
path_to_data = "/home/user/data/sen12msDFC/s2_validation/"
all_dfc_sampels = glob.glob(os.path.join(path_to_data,"*.tif"))

len(all_dfc_sampels)

986

In [65]:
GT = []
PRED = []

for loc in tqdm(all_dfc_sampels):

    # get data
    with rasterio.open(loc,"r") as src:
        data = src.read((2,3,4,5,6,7,8,9,12,13)) # only 10 bands
        data = preprocess_s2(data)
        
    # get corresponding label
    with rasterio.open(loc.replace("s2","dfc"),"r") as src:
        label = src.read(1)
        label = lookUpTable(label)
        percentages = [100*np.sum(label==classindex)/256**2 for classindex in range(8)]
        max_value = max(percentages)
        max_value_index = percentages.index(max_value)
        GT.append(max_value_index)
    

    # predict with model
    pred = model(torch.Tensor(np.expand_dims(data,0)))
    PRED.append(pred.argmax().item())
    
    #break

  0%|          | 0/986 [00:00<?, ?it/s]

In [71]:
from sklearn.metrics import confusion_matrix

In [72]:
cm = confusion_matrix(GT, PRED)

In [73]:
cm

array([[ 70,   0,   5,   0,   1,   0,   0,   0],
       [  1,  27,   3,   0,   4,   0,   2,   0],
       [ 31,   5,  47,   0,  18,   0,   0,   0],
       [ 32,   0, 145,   0,   3,   0,   0,   0],
       [  3,   8,  45,   0, 100,   0,   0,   0],
       [  1,   1,  13,   0,   7,  40,   0,   0],
       [  0,   9,   0,   0,   5,   0,   9,   0],
       [ 21,   5,  24,   0,   4,   3,   2, 292]])