In [23]:
%%capture
%load_ext autoreload
%autoreload 2
#Basic Imports
import os,sys
os.chdir('/home/asebaq/MSMatch')

from tqdm import tqdm,trange
import numpy as np
import matplotlib.pyplot as plt
import torch
import pandas

from datasets.ssl_dataset import SSL_Dataset
from datasets.data_utils import get_data_loader
from utils import get_model_checkpoints
from utils import net_builder
from utils import clean_results_df

In [24]:
# Load necessary visualization code
# Original code from https://github.com/utkuozbulak/pytorch-cnn-visualizations
# slightly modified it to fit our needs
from external.visualizations.guided_backprop import GuidedBackprop
from external.visualizations.misc_functions import convert_to_grayscale,get_positive_negative_saliency
from external.visualizations.smooth_grad import generate_smooth_grad

In [25]:
#Path to the runs to load
folder = "/home/asebaq/MSMatch/trained_models/rgb/saved_models_rgb/fixmatch/eurosat_rgb/FixMatch_archefficientnet-b0_batch16_confidence0.95_lr0.03_uratio7_wd0.0005_wu1.0_seed0_numlabels4000_optSGD" #Path to the runs to load 

sort_criterion = "numlabels" # Accepted net, numlabels
seed_wanted = 0 # Seed wanted (the others will be filtered)

## Initialize parameters

In [26]:
checkpoints, run_args = get_model_checkpoints(folder)
if os.name == 'nt':
       [print(_.split("\\")[1]) for _ in checkpoints];
else:
       [print(_.split("/")[1]) for _ in checkpoints];

home


In [27]:
print(checkpoints, run_args)

['/home/asebaq/MSMatch/trained_models/rgb/saved_models_rgb/fixmatch/eurosat_rgb/FixMatch_archefficientnet-b0_batch16_confidence0.95_lr0.03_uratio7_wd0.0005_wu1.0_seed0_numlabels4000_optSGD/model_best.pth'] [{'dataset': 'eurosat_rgb', 'net': 'efficientnet-b0', 'batch': 16, 'confidence': 0.95, 'lr': 0.03, 'uratio': 7, 'wd': 0.0005, 'wu': 1.0, 'seed': 0, 'numlabels': 4000, 'opt': 'SGD', 'iterations': 40000}]


## Run all models

In [28]:
_eval_dset = SSL_Dataset("eurosat_rgb", train=False,  data_dir="/home/asebaq/MSMatch/data/", seed=seed_wanted)
eval_dset = _eval_dset.get_dset()

100%|███████████████████████████████████████████████████████████████| 11/11 [00:11<00:00,  1.02s/it]


In [29]:

saliency = {} #will contain saliency maps for all runs
correct_prediction = {} # will contain predictions
numbers_to_skip  = []
N = 2700 # how many images should be looked at

#Start with num labels = 50, will break if smallest not 50!
runs = list(zip(checkpoints,run_args))
runs.reverse()

#Iterate over runs
for path, args in runs:
    print("------------ RUNNING ", path, " -----------------")
    print(args)
    args["data_dir"] = "/home/asebaq/MSMatch/data/"
    args["use_train_model"] = False
    args["load_path"] = path
    saliency[args["numlabels"]] = []
    correct_prediction[args["numlabels"]] = []
    
    if args["seed"] != seed_wanted:
        continue
    
    # Load the model and dataset
    checkpoint_path = os.path.join(args["load_path"])
    checkpoint = torch.load(checkpoint_path,map_location='cuda:0')
    load_model = (checkpoint["train_model"] if args["use_train_model"] else checkpoint["eval_model"])
    _net_builder = net_builder(args["net"],False,{})
    
    net = _net_builder(num_classes=_eval_dset.num_classes, in_channels=_eval_dset.num_channels)
    net.load_state_dict(load_model)
    
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    eval_loader = get_data_loader(eval_dset, 1, num_workers=1) #note batchsize is manually set to 1 here
    label_encoding = _eval_dset.label_encoding
    inv_transf = _eval_dset.inv_transform
    
    # Init saliency computation algorithm
    cam = GuidedBackprop(net)
    
    idx = 0 #current image index
    image_original = [] # to store original images
    
    
    for image, target in tqdm(eval_loader):
        image = image.type(torch.FloatTensor).cuda()
        
        # Check prediction
        logit = net(image)
        correct = logit.cpu().max(1)[1].eq(target).sum().numpy()
                        
        # Check if correct result for num_labels 50
        if args["numlabels"] == 50:
            if correct:
                numbers_to_skip.append(idx)
                idx = idx + 1
                continue
            else:
                idx = idx + 1
                
        if idx in numbers_to_skip:
            idx = idx + 1
            continue
        else:
            idx = idx + 1
            
        correct_prediction[args["numlabels"]].append(correct)
        
        image_original.append(inv_transf(image[0].transpose(0,2).cpu().numpy()).transpose(0,2).numpy())
        

        # Use smooth grad by sampling the gradients with some noise added to image to get a smoother output
        param_n = 100 #nr of images to sample
        param_sigma_multiplier = 2 #noise strength
        result = generate_smooth_grad(cam,  # ^This parameter
                                           image,
                                           target,
                                           param_n,
                                           param_sigma_multiplier)

        result = result[:,0:64,0:64] #some padding happens in the network, we discard
        result = convert_to_grayscale(result)
        result, _ = get_positive_negative_saliency(result) #we only use positive saliency maps
        saliency[args["numlabels"]].append(result[0])

        if idx > N:
            break

------------ RUNNING  /home/asebaq/MSMatch/trained_models/rgb/saved_models_rgb/fixmatch/eurosat_rgb/FixMatch_archefficientnet-b0_batch16_confidence0.95_lr0.03_uratio7_wd0.0005_wu1.0_seed0_numlabels4000_optSGD/model_best.pth  -----------------
{'dataset': 'eurosat_rgb', 'net': 'efficientnet-b0', 'batch': 16, 'confidence': 0.95, 'lr': 0.03, 'uratio': 7, 'wd': 0.0005, 'wu': 1.0, 'seed': 0, 'numlabels': 4000, 'opt': 'SGD', 'iterations': 40000}
Using not pretrained model efficientnet-b0 ...


  0%|                                                                      | 0/2700 [00:00<?, ?it/s]


ValueError: operands could not be broadcast together with shapes (66,66) (3,65,65) 

In [None]:
#Save results
np.save("saliency.npy",saliency)
np.save("image_original.npy",image_original)
np.save("correct_prediction.npy",correct_prediction)

In [None]:
#Load results
saliency = np.load("saliency.npy")
image_original = np.load("image_original.npy")
correct_prediction = np.load("correct_prediction.npy")

In [None]:
def plot_examples(images,saliency,numlabels=[50,100,500,1000,2000,3000],indices=[2]):
    """ Small function to plot the results
    """
    fig = plt.figure(figsize=(15 * len(numlabels) / 6, 1.5*len(indices)), dpi=300)
    offset = len(numlabels) + 1
    images = np.asarray(images)
    for plot_nr,idx in enumerate(indices):
        ax = fig.add_subplot(len(indices), offset, offset*plot_nr+1, xticks=[], yticks=[])
        img = images[idx]
        if np.max(img) > 1.5:
            img = img / 255
        plt.imshow(img)

        for nl_idx,nl in enumerate(numlabels):
            ax = fig.add_subplot(len(indices), offset, offset*plot_nr+2+nl_idx, xticks=[], yticks=[])
            sal = np.flipud(saliency[nl][idx])
            plt.contourf(sal,cmap="gnuplot2")
    plt.savefig("saliency.png")

In [None]:
indices_to_plot = np.arange(110,120)
# print(correct_prediction[3000][20:20])
plot_examples(image_original,saliency,numlabels=[50,3000],indices=indices_to_plot)