In [2]:
import os
from pathlib import Path
import glob
from datetime import datetime

import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
import torch.utils
import torchvision as tv
from torchvision import transforms as tf

from fastai.vision import *
from fastai.metrics import error_rate

import matplotlib.pyplot as plt
%reload_ext autoreload
%autoreload 2
# %matplotlib inline

In [3]:
import sys
sys.path.append('../')

In [4]:
from utils_ import *
from io_ import *
from data_ import *
from transforms import *
from models import *
from training import *
from visualizations import *
from monitoring import *
%reload_ext autoreload
%autoreload 2

In [5]:
CLASS_NAMES = ['building', 'tree', 'low-vegetation', 'clutter', 'car', 'pavement']
NCLASSES = len(CLASS_NAMES)

In [6]:
S3_BUCKET = 'raster-vision-ahassan'
S3_ROOT = f'potsdam/experiments/output/tmp'

In [7]:
io_handler = S3IoHandler(
    local_root='tmp', 
    s3_bucket=S3_BUCKET, 
    s3_root=S3_ROOT
)

# Data

In [None]:
with open('../../potsdam/data/potsdam.pkl', 'rb') as f:
    potsdam_dict = pickle.load(f)


## Prepare datasets

In [None]:
CHIP_SIZE = 400
STRIDE = 200
DOWNSAMPLING = 2

In [None]:
CHANNELS = [ch_R, ch_G, ch_B, ch_IR, ch_E]
_, val_transform, x_transform, y_transform = tfs_potsdam(channels=CHANNELS, downsampling=DOWNSAMPLING)
original_ds = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=val_transform)

In [None]:
CHANNELS = [ch_E]
_, val_transform, x_transform, y_transform = tfs_potsdam(channels=CHANNELS, downsampling=DOWNSAMPLING)
val_ds_e     = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=val_transform  , x_tf=x_transform, y_tf=y_transform)

CHANNELS = [ch_R, ch_G, ch_B]
_, val_transform, x_transform, y_transform = tfs_potsdam(channels=CHANNELS, downsampling=DOWNSAMPLING)
val_ds_rgb   = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=val_transform  , x_tf=x_transform, y_tf=y_transform)

CHANNELS = [ch_R, ch_G, ch_B, ch_E]
_, val_transform, x_transform, y_transform = tfs_potsdam(channels=CHANNELS, downsampling=DOWNSAMPLING)
val_ds_rgbe = Potsdam(potsdam_dict, chip_size=CHIP_SIZE, stride=STRIDE, tf=val_transform  , x_tf=x_transform, y_tf=y_transform)

In [None]:
TRAIN_SPLIT = 0.85
train_split_size = int((len(original_ds) * TRAIN_SPLIT) // 1)
val_split_size = len(original_ds) - train_split_size
train_split_size, val_split_size

print('train_split_size', train_split_size)
print('val_split_size', val_split_size)

inds = np.arange(len(original_ds))

In [None]:
train_sampler = torch.utils.data.SubsetRandomSampler(inds[:train_split_size])
val_sampler = torch.utils.data.SubsetRandomSampler(inds[train_split_size:])

assert len(set(train_sampler.indices) & set(val_sampler.indices)) == 0

In [None]:
val_iter = iter(val_sampler)

# Models

## RGB-E, merge after backbone

In [None]:
model_rgbe_bb = get_deeplab_custom(NCLASSES, in_channels=3, pretrained=True)
model_e_tmp = get_deeplab_custom(NCLASSES, in_channels=1, pretrained=True)

model_rgbe_bb.m.backbone = DeeplabDoublePartialBackbone(model_rgbe_bb.m.backbone, model_e_tmp.m.backbone, 4)
model_rgbe_bb = model_rgbe_bb.cuda()

In [None]:
model_rgbe_bb.m.aux_classifier = None

In [None]:
name = 'ss_rgbp_ep_deeplab_resnet101p_merge_after_backbone_partial_4'

In [None]:
io_handler.load_model_weights(
    model_rgbe_bb, 
    s3_path=f'potsdam/experiments/output/{name}/best_model/best_acc', 
    tgt_path=f'models/{name}'
)

In [None]:
viz_conv_layer_filters(model_rgbe_bb.m.backbone.head[1][0][0].weight.data, show=True, normalize=True, scale_each=True, figsize=(6, 6))
viz_conv_layer_filters(model_rgbe_bb.m.backbone.head[1][1][0].weight.data, show=True, normalize=True, scale_each=True, figsize=(6, 6))
# viz_conv_layer_filters(model_rgbe_1x1[0][0].original_conv.weight.data, show=True, normalize=True, scale_each=True, figsize=(6, 6))


# RGB only

In [None]:
model_rgb = get_deeplab_custom(NCLASSES, in_channels=3, pretrained=True).cuda()

In [None]:
name = 'ss_rgb_deeplab_resnet101p'
io_handler.load_model_weights(
    model_rgb, 
    s3_path=f'potsdam/experiments/output/{name}/best_model/best_acc', 
    tgt_path=f'models/{name}'
)

# KL Div

In [None]:
f = plt.figure(figsize=(20, 6))
ax = f.add_subplot(141)
ax.imshow(batch[0, :3].permute(1, 2, 0))
ax.set_title(f'RGB, {val_idx}')
# plt.imshow(batch[0, -1] * (labels == 4).float())
ax = f.add_subplot(142)
ax.imshow(batch[0, -1].log())
ax.set_title('elevation (log)')

ax = f.add_subplot(143)
ax.imshow(labels)
ax.set_title('Ground Truth')
with torch.no_grad():
    model_rgbe_bb.eval()
    out_rgb = model_rgbe_bb(batch.cuda())
ax = f.add_subplot(144)
ax.imshow(out_rgb.permute(0, 2, 3, 1).argmax(dim=-1).squeeze())
ax.set_title('Model Prediction')
plt.show()

g = gbp.gradients[0].permute(0, 2, 3, 1).squeeze().detach().cpu()
gpos = g.clamp(min=0) / g.max()
gneg = (-g).clamp(min=0) / (-g).max()

f = plt.figure(figsize=(12, 5))
f.suptitle('Elevation')
f.tight_layout()
f.subplots_adjust(top=0.88)
ax = f.add_subplot(121)
g = gbp.gradients[0].permute(0, 2, 3, 1).squeeze().detach().cpu()
im = ax.imshow(gpos[:, :, -1], cmap='hot')
ax.set_title('Positive saliency')
# plt.colorbar(im, ax=ax)

ax = f.add_subplot(122)
im = ax.imshow(gneg[:, :, -1], cmap='hot')
ax.set_title('Negative saliency')
# plt.colorbar(im, ax=ax)
plt.show()

f = plt.figure(figsize=(12, 5))
f.tight_layout()
f.subplots_adjust(top=0.88)
f.suptitle('RGB')
ax = f.add_subplot(121)
ax.imshow(gpos[:, :, :3])
ax.set_title('Positive saliency')

ax = f.add_subplot(122)
ax.imshow(gpos[:, :, :3])
ax.set_title('Negative saliency')
plt.show()

29

In [None]:
batches_rgb = []
batches_rgbe = []
labels = []
h, w = val_ds_rgb.h, val_ds_rgb.w

In [None]:
model_rgb.eval()
model_rgbe_bb.eval()

with torch.no_grad():
    for batch, label in val_ds_rgbe[len(val_ds_rgbe) - h*w :: 2]:
        batch = batch.unsqueeze(0).cuda()

        out1 = model_rgb(batch[:, :3, :, :])
        out2 = model_rgbe_bb(batch)
        pred1 = out1.permute(0, 2, 3, 1)
        pred2 = out2.permute(0, 2, 3, 1)
        batches_rgb.append(pred1)
        batches_rgbe.append(pred2)
        labels.append(label)


In [None]:
preds_rgb = torch.cat(batches_rgb, dim=0).view(-1, NCLASSES)
preds_rgbe = torch.cat(batches_rgbe, dim=0).view(-1, NCLASSES)
labels_flat = torch.cat([l.unsqueeze(0) for l in labels], dim=0).view(-1)

probs_rgb = F.softmax(preds_rgb, dim=-1).cpu()
probs_rgbe = F.softmax(preds_rgbe, dim=-1).cpu()

logprobs_rgb = F.log_softmax(preds_rgb, dim=-1).cpu()
logprobs_rgbe = F.log_softmax(preds_rgbe, dim=-1).cpu()

In [45]:
for i in range(NCLASSES):
    mask = (labels_flat == i).view(-1)
    print(i, 
      F.kl_div(logprobs_rgb[mask, i], probs_rgbe[mask, i], reduction='batchmean').item(), 
      F.kl_div(logprobs_rgbe[mask, i], probs_rgb[mask, i], reduction='batchmean').item()
    )


0 0.05024977773427963 0.017487837001681328
1 0.08461654186248779 0.05247887596487999
2 0.045692700892686844 0.08882671594619751
3 0.05830514058470726 0.04818468913435936
4 0.022100672125816345 0.010693208314478397
5 0.041541676968336105 0.009151924401521683


In [53]:
for i in range(NCLASSES):
    mask = (labels_flat == i).view(-1)
    
    print(i)
    P = probs_rgbe[mask, i]
    Q = probs_rgb[mask, i]
    print('P: rgbe, Q: rgb', (P * (P / Q).log()).mean().item())

    P = probs_rgb[mask, i]
    Q = probs_rgbe[mask, i]
    print('P: rgb, Q: rgbe', (P * (P / Q).log()).mean().item())



0
P: rgbe, Q: rgb 0.05024978891015053
P: rgb, Q: rgbe 0.017487844452261925
1
P: rgbe, Q: rgb 0.08461654186248779
P: rgb, Q: rgbe 0.05247887223958969
2
P: rgbe, Q: rgb 0.04569270834326744
P: rgb, Q: rgbe 0.0888267308473587
3
P: rgbe, Q: rgb 0.05830514058470726
P: rgb, Q: rgbe 0.04818468913435936
4
P: rgbe, Q: rgb 0.02210068143904209
P: rgb, Q: rgbe 0.010693217627704144
5
P: rgbe, Q: rgb 0.041541688144207
P: rgb, Q: rgbe 0.009151934646070004


In [1]:
for i in range(NCLASSES):
    mask = (labels_flat == i).float()
    print(i, 
      F.kl_div(logprobs_rgb[:, i], mask, reduction='batchmean').item(), 
      F.kl_div(logprobs_rgbe[:, i], mask, reduction='batchmean').item()
     )


NameError: name 'NCLASSES' is not defined