# Model validation notebook

## 1. - Settings and imports

Export CUDA_VISIBLE_DEVICES.

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="4,5"

Imports.

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
#Basic Imports

os.chdir("..")
os
import sys
sys.path.insert(1, "..")
from utils import decode_parameters_from_path

from tqdm import tqdm,trange
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report,confusion_matrix, ConfusionMatrixDisplay
import torch
import pandas as pd

from datasets.ssl_dataset import SSL_Dataset
from datasets.data_utils import get_data_loader
from utils import get_model_checkpoints
from utils import net_builder
import random
from utils import clean_results_df


Dictionary vs class names dictionary.

In [None]:
class_names_dict={'eurosat_rgb' : ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway',
       'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River',
       'SeaLake'], 'eurosat_ms' : ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway',
       'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River',
       'SeaLake'], 'ucm' : ["agricultural", "airplane", "baseballdiamond", "beach", "buildings","chaparral","denseresidential","forest", "freeway", "golfcourse","harbor", "intersection", "mediumresidential", "mobilehomepark","overpass","parkinglot","river", "runway", "sparseresidential", "storagetanks", "tenniscourt"],
                 'thraws_swir' : ['event', 'notevent']}

Set checkpoint file.

In [None]:
checkpoint_path = "/home/gabrielemeoni/project/end2end/END2END/MSMatch/checkpoints/iter1/thraws_swir/FixMatch_archefficientnet-b0_batch16_confidence0.95_lr0.03_uratio4_wd0.00075_wu1.0_seed1_numlabels600_optSGD" #Checkpoint

Folder for exported CSV files containing results.

In [None]:
csv_folder="."

## 2. - Parse checkpoint file and run the model

In [None]:
results = []

args = decode_parameters_from_path(os.path.join(checkpoint_path,""))
print("------------ RUNNING ", checkpoint_path, " -----------------")
print(args)
args["batch_size"] = 256
args["data_dir"] = "./data/"
args["use_train_model"] = False
args["load_path"] = checkpoint_path

checkpoint_model_path = os.path.join(checkpoint_path, "model_best.pth")
if torch.cuda.is_available():
    checkpoint = torch.load(checkpoint_model_path,map_location='cuda:0')
else:
    checkpoint = torch.load(checkpoint_model_path,map_location='cpu')
    
load_model = (checkpoint["train_model"] if args["use_train_model"] else checkpoint["eval_model"])
_net_builder = net_builder(args["net"],False,{})
_eval_dset = SSL_Dataset(name=args["dataset"], train=False, data_dir=args["data_dir"], seed=args["seed"])
eval_dset = _eval_dset.get_dset()
net = _net_builder(num_classes=_eval_dset.num_classes, in_channels=_eval_dset.num_channels)
net.load_state_dict(load_model)
if torch.cuda.is_available():
    net.cuda()
net.eval()

eval_loader = get_data_loader(eval_dset, args["batch_size"], num_workers=1)
label_encoding = _eval_dset.label_encoding
inv_transf = _eval_dset.inv_transform


print("------------ PREDICTING TESTSET -----------------")

images, labels, preds = [],[],[]
with torch.no_grad():
    for image, target in tqdm(eval_loader):
        image = image.type(torch.FloatTensor).cuda()
        logit = net(image)
        for idx,img in enumerate(image):
            images.append(inv_transf(img.transpose(0,2).cpu().numpy()).transpose(0,2).numpy())
        preds.append(logit.cpu().max(1)[1])
        labels.append(target)
labels = torch.cat(labels).numpy()
preds = torch.cat(preds).numpy()
test_report = classification_report(labels, preds, target_names=label_encoding, output_dict=True)
test_report["params"] = args
results.append(test_report)

In [None]:
big_df = pd.DataFrame()
pd.set_option('display.max_columns', None)
for result in results:
    params = result["params"]
    df = pd.DataFrame(result)
    df.drop(list(params.keys()),inplace=True)
    df.drop(["support","recall","precision"],inplace=True)
    for key,val in params.items():
        df[key] = val
    df = df.set_index("dataset")
    big_df = big_df.append(df)
# print(big_df)
small_df = clean_results_df(big_df, ".","numlabels", keep_per_class=True)
small_df.to_csv(csv_folder + "_test_results.csv")

In [None]:
small_df = small_df.drop(labels=["pretrained","supervised","net","accuracy","batch","confidence","lr","uratio","wd","wu","opt","iterations","load_path"],axis=1)
small_df = small_df.groupby('numlabels').mean().reset_index()
small_df = small_df.reindex(sorted(small_df.columns), axis=1)
small_df = small_df.reset_index()

Adding info on numlabels per class.

In [None]:
for n in range(len(small_df["numlabels"])):
    small_df["numlabels"][n]=str(small_df["numlabels"][n]) + " (" + str(small_df["numlabels"][n]//len(class_names_dict[args['dataset']]))+")"

In [None]:
l = pd.melt(small_df, id_vars='numlabels', value_vars=class_names_dict[args['dataset']])
l.columns = ["# of labels \n(per class)", "Class", "F1 Score"]

# 3. - Visualize results 

Visualize accuracy on subclasses.

In [None]:
small_df

Print F1 scores.

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
sns.set()
sns.set(font_scale=3)  # crazy big
with sns.plotting_context():
    p = sns.catplot(x="F1 Score", y="Class", hue="# of labels \n(per class)", data=l, kind="bar",palette="crest",height=10,aspect=1.25)
    # p.set_xticklabels(rotation=90)
    p.set(xlim=[0.3,1.01])

    # p.set(xticks=[0.4,0.6,0.8,1.0])
plt.savefig("class_f1.pdf")

In [None]:
label_enc_dict={0:"E", 1: "NE"}
idxs=random.sample([n for n in range(len(preds))], 9)
images_to_plot=[]
preds_to_plot=[]
labels_to_plot=[]


for idx in idxs:
    images_to_plot.append(images[idx])
    preds_to_plot.append(preds[idx])
    labels_to_plot.append(labels[idx])
    
fig, ax=plt.subplots(3,3)
plt.subplots_adjust(left=0.1,
                    bottom=0.1,
                    right=0.9,
                    top=0.9,
                    wspace=0.4,
                    hspace=0.4)
k=0
for n in range(3):
    for m in range(3):
        ax[n,m].imshow(images_to_plot[k])
        ax[n,m].set_title("GT:"+str(label_enc_dict[labels_to_plot[k]])+"\nPR:"+str(label_enc_dict[preds_to_plot[k]]), fontsize=10)
        ax[n,m].imshow(images_to_plot[k])
        ax[n,m].axis('off')
        k+=1