<a href="https://colab.research.google.com/github/Ryan-Qiyu-Jiang/KidneyStones/blob/master/abcd_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/Ryan-Qiyu-Jiang/rloss.git
%cd /content/rloss/
!git checkout color_query

In [None]:
%cd /content/rloss/data/VOC2012/
!./fetchVOC2012.sh
%cd /content/rloss/data/pascal_scribble/
! ./fetchPascalScribble.sh

In [None]:
%cd /content/rloss/pytorch/wrapper/bilateralfilter
!apt-get install swig -y
!swig -python -c++ bilateralfilter.i
!python setup.py install

In [None]:
!pip install -qqq tensorboardX
!pip install -qqq wandb
!pip install -qqq pytorch-lightning
# !pip install -qqq kornia
!pip install -qqq higher

In [None]:
%cd /content/rloss/pytorch/pytorch_deeplab_v3_plus 

In [None]:
import argparse
import os
import numpy as np
from tqdm import tqdm

from mypath import Path
from dataloaders import make_data_loader
from dataloaders.custom_transforms import denormalizeimage
from modeling.sync_batchnorm.replicate import patch_replication_callback
from modeling.deeplab import *
from utils.loss import SegmentationLosses
from utils.calculate_weights import calculate_weigths_labels
from utils.lr_scheduler import LR_Scheduler
from utils.saver import Saver
from utils.summaries import TensorboardSummary
from utils.metrics import Evaluator

import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.nn import functional as F
import pytorch_lightning as pl
from torch import nn
import torch
from argparse import Namespace
from dataloaders.utils import decode_seg_map_sequence

from color_query import BaseModel, get_args, SingleDataset, RepeatDataset, DeepLabEncoder, BaseColorModel

segmentation_classes = [
    'background','aeroplane','bicycle','bird','boat','bottle',
    'bus','car','cat','chair','cow','diningtable','dog','horse',
    'motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'
]

def labels():
  l = {}
  for i, label in enumerate(segmentation_classes):
    l[i] = label
  return l

def wb_mask(bg_img, pred_mask, true_mask):
  return wandb.Image(bg_img, masks={
    "prediction" : {"mask_data" : pred_mask, "class_labels" : labels()},
    "ground truth" : {"mask_data" : true_mask, "class_labels" : labels()}})

In [None]:
from argparse import Namespace
from dataloaders import make_data_loader

args_dict = get_args()
args_dict['cuda'] = True
args_dict['checkname'] = 'ignore'
args_dict['epochs'] = 1
args_dict['shuffle'] = False # True for real training
args_dict['batch_size'] = 10
args_dict['lr'] = 1e-2
args_dict['inner_lr'] = 1e-2
args_dict['test_steps'] = 50
args_dict['full_gt'] = True # True for gt
args_dict['limit_dataset'] = False
# args_dict['rloss_scale'] = 1
args = Namespace(**args_dict)

kwargs = {'num_workers': 6, 'pin_memory': True}
train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs)
gt_loader = val_loader

seeds_args = Namespace(**dict(args_dict, full_gt=False) )
seeds_loader, _, _, _ = make_data_loader(seeds_args, **kwargs)

In [None]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader

batch = iter(seeds_loader).next()
batch_sample = {k:v for k,v in batch.items()}

# w = 513
# circle = torch.zeros((3, w, w))
# gt = torch.zeros((w,w)).long()
# gt[:,:] = 0 #255

# for y in range(w):
#   for x in range(w):
#     if ( (y-w/2)**2 + (x-w/2)**2 ) < (w//3)**2:
#       circle[:,y,x] = 1.
#       # gt[y,x] = 1

# # gt[:40, :40] = 0
# # gt[w//2-25:w//2+25, w//2-25:w//2+25] = 1
# gt[:w//2,:] = 1
# circle.to(batch['image'].device)
# gt.to(batch['label'].device)

i = 5
# 0 plane
# 5 sheep
single_sample = {
    'image': [batch['image'][i]] * len(batch['image']),
    'label': [batch['label'][i]] * len(batch['label'])
}
# batch['image'][i]
# batch['label'][i]

# batch_sample = single_sample # for single image overfit
bs = 5
single_dataset = RepeatDataset(batch_sample,  100*10)
single_train_loader = DataLoader(single_dataset, batch_size=bs, shuffle=True, num_workers=4)
single_dataset = RepeatDataset(batch_sample, 10)
single_val_loader = DataLoader(single_dataset, batch_size=bs, shuffle=False, num_workers=4)

plt.imshow(batch_sample['image'][0].numpy().transpose(1,2,0));

In [None]:
def elementwise_entropy(logits=None, probs=None):
  assert(not (logits is None and probs is None) ), "entropy() is missing input"
  eps = 1e-16

  probs = nn.Softmax(dim=1)(logits) if logits is not None else probs

  bs, num_class, h, w = probs.shape
  entropy = -(probs* (probs + eps).log()).sum(1)
  return entropy

def visualize_entropy(logits, return_img=False):
  logits = logits[0].unsqueeze(0) # pick first example
  entropy = elementwise_entropy(logits).detach().cpu()[0]
  wb_img = wandb.Image(entropy, caption=f"mean entropy={entropy.mean():9.4f}")
  if return_img:
    return wb_img
  wandb.log({'train/entropy_map': wb_img}, commit=False)

def visualize_c_entropy(logits, x):
  logits = logits[0].unsqueeze(0) # pick first example
  x = x[0].unsqueeze(0)
  probs = nn.Softmax(dim=1)(logits)
  c_entropy = compute_conditional_entropy(probs, x).detach().cpu()[0]
  wb_img = wandb.Image(c_entropy, caption=f"mean conditional entropy={c_entropy.mean():9.4f}")
  wandb.log({'train/conditional_entropy_map': wb_img}, commit=False)

def visualize_local_mi():
  wb_img = wandb.Image(last_patch_mi[0], caption=f"mean local MI={last_patch_mi.mean():9.4f}")
  wandb.log({'train/local_mi_map': wb_img}, commit=False)

def visualize_global_mi():
  wb_img = wandb.Image(last_global_mi[0], caption=f"mean global MI={last_global_mi.mean():9.4f}")
  wandb.log({'train/global_mi_map': wb_img}, commit=False)

def yuv(x):
  T = torch.tensor( [[ 0.299,     0.587,    0.114],
                     [-0.14713,  -0.28886,  0.436],
                     [0.615,    -0.51499, -0.10001],
                     ]
                   , device=x.device)
  return torch.einsum("cc, bchw -> bchw", T, x)

def mutual_info(logits, x):
  """MI(X,Y) = H(Y) - H(Y|X) = entropy(average(p_y)) - average(entropy(p_y|x))"""
  r = 4
  logits = F.interpolate(logits, size=(x.size(2)//r, x.size(3)//r), mode='bilinear', align_corners=True)
  x = F.interpolate(x, size=(x.size(2)//r, x.size(3)//r), mode='bilinear', align_corners=True)
  # import pdb;pdb.set_trace()

  eps = 1e-16
  bs, num_classes, h, w = logits.shape

  probs = nn.Softmax(dim=1)(logits)

  p_avg = probs.reshape(bs, num_classes,-1).mean(dim=-1)
  HY = -(p_avg* (p_avg + eps).log()).sum(dim=-1)
  # HYX = elementwise_entropy(probs=probs).reshape(bs,-1).mean(dim=-1)
  HYX = compute_conditional_entropy2(probs, x).reshape(bs,-1).mean(dim=-1) #(-(probs* (probs + eps).log()).sum(1)).reshape(bs,-1).mean(dim=-1)
  return (HY-HYX).mean()

def compute_conditional_entropy(probs, x):
  eps = 1e-16
  bs, num_classes, h, w = probs.shape
  _, in_channels, _, _ = x.shape

  probs = probs.reshape(bs, num_classes, -1)
  # entropy = -(probs* (probs + eps).log()).sum(1) # bs, num_pixels
  x = yuv(x)
  x = x.reshape(bs, in_channels, -1)
  
  result = torch.zeros((bs, h*w), device=x.device)
  for i in range(x.size(-1)):
    q = x[:,:,i]
    atten = -(x-q.unsqueeze(-1)).abs().sum(1)
    # atten = torch.einsum("bcn, bc -> bn", x, q) # do better than dot product
    atten = (atten-atten.mean(dim=-1, keepdim=True))*100
    atten = nn.functional.softmax(atten, dim=-1)
  
    p_avg  = torch.einsum("bcn, bn -> bc", probs, atten)
    ce = -(p_avg* (p_avg + eps).log()).sum(-1)
    result[:, i] = ce
  return result.reshape(bs, h, w)

def compute_conditional_entropy2(probs, x):
  eps = 1e-16
  bs, num_classes, h, w = probs.shape
  _, in_channels, _, _ = x.shape

  probs = probs.reshape(bs, num_classes, -1)
  x = yuv(x)
  x = x.reshape(bs, in_channels, -1)
  
  # atten = torch.einsum("bcm, bcn -> bmn", x, x)
  _x = x.permute(0,2,1)
  atten = -torch.cdist( _x, _x, p=2 )
  atten -= atten.mean(dim=-1, keepdim=True)
  atten *= 5
  atten = nn.functional.softmax(atten, dim=-1)
  values = torch.einsum("bmn, bkn -> bkm", atten, probs)
  conditional_entropy = -(values* (values + eps).log()).sum(1)
  return conditional_entropy.reshape(bs, h, w)


last_patch_mi = None
last_global_mi = None

def MI(logits, x, bias=1., r=1, return_map=False):
  """MI(X,Y) = H(Y) - H(Y|X) = entropy(average(p_y)) - average(entropy(p_y|x))"""
  eps = 1e-16
  bs, num_classes, h, w = logits.shape
  probs = nn.Softmax(dim=1)(logits)

  p_avg = probs.reshape(bs, num_classes,-1).mean(dim=-1)
  HY = -(p_avg* (p_avg + eps).log()).sum(dim=-1)

  if r > 1:
    logits = F.interpolate(logits, size=(h//r, w//r), mode='bilinear', align_corners=True)
    probs = nn.Softmax(dim=1)(logits)
    x = F.interpolate(x, size=(h//r, w//r), mode='bilinear', align_corners=True)
    HYX = compute_conditional_entropy2(probs, x)
    HYX = F.interpolate(HYX.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=True)
  else:
    HYX = compute_conditional_entropy2(probs, x)

  if return_map:
    global last_global_mi
    last_global_mi = HYX.detach().cpu()

  HYX = HYX.reshape(bs,-1).mean(dim=-1)
  return bias*HY-HYX


def MI_efficient(logits, x, local_weight=0.5, local_bias=1, local_reduction=5):
  """MI estimation is quadratic in respect to number of pixels due to self attention.
  If the image is hxw, global MI estimation takes O(h^2 w^2) time and memory.
  This MI estimator downsamples the image and estimates coarse MI,
  as well as local MI on image patches.
  """

  patch_size = 50
  bs, num_classes, h, w = logits.shape
  _, in_channels, _, _ = x.shape
  r = h//patch_size

  coarse_logits = F.interpolate(logits, size=(h//r, w//r), mode='bilinear', align_corners=True)
  coarse_x = F.interpolate(x, size=(h//r, w//r), mode='bilinear', align_corners=True)
  coarse_global_mi = MI(coarse_logits, coarse_x, return_map=True).mean()

  patch_mi = torch.zeros((bs, h//patch_size, w//patch_size), device=x.device)
  for r in range(h//patch_size):
    for c in range(w//patch_size):
      R = r*patch_size
      C = c*patch_size
      logit_patch = logits[:,:,R:R+patch_size, C:C+patch_size]
      x_patch = x[:,:,R:R+patch_size, C:C+patch_size]
      patch_mi[:,r,c] = MI(logit_patch, x_patch, bias=local_bias, r=local_reduction)

  global last_patch_mi, last_global_mi
  last_patch_mi = patch_mi.detach().cpu()

  local_mi = patch_mi.mean()
  return (1-local_weight)*coarse_global_mi + local_weight*local_mi




$$
\begin{aligned}
C &= \{ c_p| p \in \Omega \}\ \text{grid of label probabilities} \\
X &= \{ x_p| p \in \Omega \} \ \text{image colors} \\
c_p &= \text{label rv for some point} \\
x_p &= \text{color rv for some point} \\
I(c_p|x_p) &= H(c_p) - H(c_p|x_p) \\
I(c_p|x_p)\ &\tilde{=}\ H(\overline{C}) - \overline{H(C\ \text{softmax} (\ norm(X X^T) ) )} \\ 
\end{aligned}
$$

In [None]:
from DenseCRFLoss import DenseCRFLoss

class SegModelDebug(BaseModel):
    def __init__(self, hparams, model=None):
        super().__init__(hparams)
        self.model = DeepLab(num_classes=self.hparams.nclass,
                            backbone=self.hparams.backbone,
                            output_stride=self.hparams.out_stride,
                            sync_bn=self.hparams.sync_bn,
                            freeze_bn=self.hparams.freeze_bn)
        # self.densecrflosslayer = DenseCRFLoss(weight=1, sigma_rgb=self.hparams.sigma_rgb, sigma_xy=self.hparams.sigma_xy, scale_factor=self.hparams.rloss_scale)

    def forward(self, x):
      return self.model(x)
    
    def get_loss(self, batch, batch_idx):
            i = batch_idx
            epoch = self.current_epoch
            image, target = batch['image'], batch['label']
            target[target==254]=255
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            output = self.forward(image)

            celoss = self.criterion(output, target.long())

            denormalized_image = denormalizeimage(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            croppings = (target!=254).float() if target is not None else torch.ones((image.size(0), image.size(2), image.size(3))).to(self.device)
            probs = nn.Softmax(dim=1)(output)
            # r_loss = self.hparams.densecrfloss*self.densecrflosslayer(denormalized_image, probs, croppings).to(self.device)
            r_loss = -MI_efficient(output, image, local_weight=self.hparams.mi_local_weight, local_bias=self.hparams.mi_local_bias)

            if i % 10 == 0:
              with torch.no_grad():
                mask = torch.max(output[0].unsqueeze(0),1)[1].detach()
                visualize_entropy(output)
                # visualize_c_entropy(output, image)
                visualize_global_mi()
                visualize_local_mi()
                self.logger.experiment.log({'train/img': wb_mask(image[0].cpu().numpy().transpose([1,2,0]), mask[0].cpu().numpy(), target[0].cpu().numpy())}, commit=False) 

            self.log('train/ce', celoss.item())
            self.log('train/-MI', r_loss.item())
            self.log('train/entropy', elementwise_entropy(probs=probs).mean().item())
            return celoss + 2*r_loss #if epoch == 0 else 2*r_loss
    
    def get_loss_val(self, batch, batch_idx):
            # self.model.train()
            image, target = batch['image'], batch['label']
            target[target==254]=255
            i= batch_idx % len(batch['image'])
            output = self.forward(image)
            celoss = self.criterion(output, target.long())
            for i in range(len(batch['image'])):
              mask = torch.max(output[i].unsqueeze(0),1)[1].detach()
              self.val_img_logs += [wb_mask(image[i].cpu().numpy().transpose([1,2,0]), mask[0].cpu().numpy(), target[i].cpu().numpy()), visualize_entropy(output[i].unsqueeze(0), return_img=True)]

            # denormalized_image = denormalizeimage(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            croppings = (target!=254).float() if target is not None else torch.ones((image.size(0), image.size(2), image.size(3))).to(self.device)
            probs = nn.Softmax(dim=1)(output)
            # r_loss = self.hparams.densecrfloss*self.densecrflosslayer(denormalized_image, probs, croppings)
            r_loss = -mutual_info(output, image)
            
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = target.cpu().numpy()
            self.evaluator.add_batch(target, pred)
            result = {
              'ce_loss': celoss,
              'r_loss': r_loss,
            }
            return result

    def validation_summary(self, outputs):
      test_loss = 0.0
      masks = self.val_img_logs
      self.val_img_logs = []
      for output in outputs:
        test_loss += output['ce_loss']

      # Fast test during the training
      Acc = self.evaluator.Pixel_Accuracy()
      Acc_class = self.evaluator.Pixel_Accuracy_Class()
      mIoU = self.evaluator.Mean_Intersection_over_Union()
      FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
      # if len(masks)>10:
      self.logger.experiment.log({'val/Examples':masks[:50]}, commit=False)
      self.logger.experiment.log({'val/mIoU': mIoU}, commit=False)
      self.logger.experiment.log({'val/Acc': Acc}, commit=False)
      self.logger.experiment.log({'val/Acc_class': Acc_class}, commit=False)
      self.logger.experiment.log({'val/fwIoU': FWIoU}, commit=False)
      self.logger.experiment.log({'val/loss_epoch': test_loss.item()})
      print('Validation:')
      print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
      print('Loss: %.3f' % test_loss)
      self.evaluator.reset()

    def get_10x_lr_params(self):
        modules = [self.model.decoder, self.model.aspp, self.coarse_cls]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                        or isinstance(m[1], nn.BatchNorm2d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

In [None]:
import wandb
from pytorch_lightning.loggers import WandbLogger
wandb.login()

In [None]:
args.batch_size = 5
args.limit_dataset = False
args.nclass = nclass
args.lr = 1e-2
args.inner_lr = 1e-2
args.densecrfloss = 2e-9
args.epochs = 5
args.mi_local_weight = 0.5
args.mi_local_bias = 0.5
# args.backbone = "efficientnet"
# args.backbone = "mobilenet"
# train_loader, val_loader, test_loader, nclanss = make_data_loader(args, **kwargs)
train_loader, val_loader = single_train_loader, single_val_loader

args.num_img_tr=len(train_loader)
model = SegModelDebug(args)
# model = SegModelDebug(args)
# model.configure_optimizers()
# for param in model.encoder.model.backbone.parameters():
#   param.requires_grad = False

wandb_logger = WandbLogger(project='Color-Query', name='seeds_batch_efficient_mi') # proto_dlv3_rand-space  deeplabv3+_seeds  prototype_aspp_seeds  deeplabv3+_color-low-feats

trainer = pl.Trainer(gpus=1, max_epochs=args.epochs, logger=wandb_logger, log_every_n_steps=10, num_sanity_val_steps=0, progress_bar_refresh_rate=0, accumulate_grad_batches=2)
results = trainer.fit(model, train_loader, val_loader)
wandb.finish()

# plt.imshow(atten[0].reshape(128,128).cpu().numpy());plt.pause(1)
# plt.imshow((mi[0] > 2.8792).detach().cpu().numpy());plt.pause(1)
# (mi[0][mi[0] > 2.8792]).mean()

In [None]:
# plt.imshow(atten[0,0].reshape(h,w).detach().cpu().numpy());plt.pause(1)

# import gc
# for obj in gc.get_objects():
#     try:
#         if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
#             print(type(obj), obj.size())
#     except:
#         pass