<a href="https://colab.research.google.com/github/Ryan-Qiyu-Jiang/abcd/blob/main/abcd_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/Ryan-Qiyu-Jiang/rloss.git
%cd /content/rloss/
!git checkout color_query

In [None]:
%cd /content/rloss/data/VOC2012/
!./fetchVOC2012.sh
%cd /content/rloss/data/pascal_scribble/
! ./fetchPascalScribble.sh

In [None]:
!pip install -qqq tensorboardX
!pip install -qqq wandb
!pip install -qqq pytorch-lightning

In [None]:
%cd /content/rloss/pytorch/pytorch_deeplab_v3_plus

In [None]:
import argparse
import os
import numpy as np
from tqdm import tqdm

from mypath import Path
from dataloaders import make_data_loader
from dataloaders.custom_transforms import denormalizeimage
from modeling.sync_batchnorm.replicate import patch_replication_callback
from modeling.deeplab import *
from utils.loss import SegmentationLosses
from utils.calculate_weights import calculate_weigths_labels
from utils.lr_scheduler import LR_Scheduler
from utils.saver import Saver
from utils.summaries import TensorboardSummary
from utils.metrics import Evaluator

import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.nn import functional as F
import pytorch_lightning as pl
from torch import nn
import torch
from argparse import Namespace
from dataloaders.utils import decode_seg_map_sequence

from color_query import BaseModel, get_args, SingleDataset, RepeatDataset

segmentation_classes = [
    'background','aeroplane','bicycle','bird','boat','bottle',
    'bus','car','cat','chair','cow','diningtable','dog','horse',
    'motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'
]

def labels():
  l = {}
  for i, label in enumerate(segmentation_classes):
    l[i] = label
  return l

def wb_mask(bg_img, pred_mask, true_mask):
  return wandb.Image(bg_img, masks={
    "prediction" : {"mask_data" : pred_mask, "class_labels" : labels()},
    "ground truth" : {"mask_data" : true_mask, "class_labels" : labels()}})

In [None]:
from argparse import Namespace
from dataloaders import make_data_loader

args_dict = get_args()
args_dict['cuda'] = True
args_dict['checkname'] = 'ignore'
args_dict['epochs'] = 1
args_dict['shuffle'] = False # True for real training
args_dict['batch_size'] = 10
args_dict['lr'] = 1e-3
args_dict['full_gt'] = False # True for gt
args_dict['limit_dataset'] = False
# args_dict['rloss_scale'] = 1
args = Namespace(**args_dict)

kwargs = {'num_workers': 6, 'pin_memory': True}
train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs)

In [None]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader

batch = iter(train_loader).next()
batch_sample = {k:v for k,v in batch.items()}
single_dataset = RepeatDataset(batch_sample,  100*10)
single_train_loader = DataLoader(single_dataset, batch_size=10, shuffle=True, num_workers=4)
single_val_loader = DataLoader(single_dataset, batch_size=10, shuffle=False, num_workers=4)

plt.imshow(batch['image'][0].numpy().transpose(1,2,0));

In [None]:
import wandb
from pytorch_lightning.loggers import WandbLogger
wandb.login()

In [None]:
import pdb
from utils.metrics import Evaluator

def to_img(img_tensor, **kwargs):
  if img_tensor.size(0) == 3:
    img_tensor = img_tensor.permute(1, 2, 0) # 3 h w -> h w 3
  return wandb.Image((img_tensor.cpu().detach().numpy() * 255).astype(np.uint8), **kwargs)
log_num = 0

# TODO: Make this a layer in a sequence of decoding layers
class ColorDecoder(nn.Module):
  def __init__(self, num_classes=21, feature_dim=256):
    super().__init__()
    self.num_classes = num_classes
    self.feature_dim = feature_dim
    self.softmax = nn.Softmax(dim=1)
    self.coarse_cls = nn.Conv2d(feature_dim, num_classes, kernel_size=1, stride=1)
    
  def log(self, d, commit=False):
    if log_num % 10 == 0:
      wandb.log(d, commit=commit)

  def forward(self, feature_map, x, low=None):
    image = x
    if low is not None:
      low = F.interpolate(low, size=image.size()[2:], mode='bilinear', align_corners=True) # bs, num_channels, h, w
      x = torch.concat([x, low], dim=1) # bs, (num_channels+3 = d), h, w
    
    x_dim = x.size(1)
    logits_map = self.coarse_cls(feature_map) # c
    coarse_segments = self.softmax(logits_map) # bs, num_classes, h, w,  dim(x) = bs, h, w, d
    coarse_segments = F.interpolate(coarse_segments, size=image.size()[2:], mode='bilinear', align_corners=True) # s, dim(s) = bs, num_classes, h, w
    # sanity check
    with torch.no_grad():
      mask = torch.max(coarse_segments[:1],1)[1].detach()
      mask = wandb.Image(image[0].cpu().numpy().transpose([1,2,0]), masks={
        "prediction" : {"mask_data" : mask[0].cpu().numpy(), "class_labels" : labels()}})
      self.log({'debug/coarse_segments': mask}, commit=False)
    
    image_segments_masked =  [  x * coarse_segments[:,i].unsqueeze(1).expand(-1,x_dim,-1,-1) for i in range(self.num_classes) ] # num_classes x (bs, d, h, w)
    q = [ torch.mean(s, dim=(2,3)) for s in image_segments_masked ] # mean color of the segment, num_classes x (bs, d)
    
    # # sanity check
    # with torch.no_grad():
    #   self.log({'debug/query': [to_img(q[i][0][:3].unsqueeze(0).unsqueeze(0).expand(50, 50,-1)/coarse_segments[0,i].sum()*(image.size(-1)**2), caption=segmentation_classes[i]) for i in range(self.num_classes)]}, commit=False)

    attn_maps = [ torch.sum(x * q[i].unsqueeze(-1).unsqueeze(-1).expand(-1, -1, x.size(2), x.size(3)), dim=1) for i in range(self.num_classes) ] # num_classes x  (bs, h, w)
    # # sanity check
    # with torch.no_grad():
    #   self.log({'debug/attn_maps': [to_img(attn_maps[i][0], caption=segmentation_classes[i]) for i in range(self.num_classes)] }, commit=False)

    segments_by_color = torch.cat([a.unsqueeze(1) for a in attn_maps],  dim=1) # bs, num_classes, h, w
    # sanity check
    with torch.no_grad():
      mask = torch.max(segments_by_color[:1],1)[1].detach()
      mask = wandb.Image(image[0].cpu().numpy().transpose([1,2,0]), masks={
        "prediction" : {"mask_data" : mask[0].cpu().numpy(), "class_labels" : labels()}})
      self.log({'debug/finer_segments': mask}, commit=False)

    global log_num
    log_num += 1
    return segments_by_color


class DeepLabEncoder(nn.Module):
  def __init__(self, hparams):
    super().__init__()
    self.hparams = hparams
    model = DeepLab(num_classes=self.hparams.nclass,
                                backbone=self.hparams.backbone,
                                output_stride=self.hparams.out_stride,
                                sync_bn=self.hparams.sync_bn,
                                freeze_bn=self.hparams.freeze_bn)
    self.encoder = model.backbone
    self.aspp = model.aspp

  def forward(self, x):
    x, low_level = self.encoder(x)
    return self.aspp(x), low_level


from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d

class Decoder(nn.Module):
    def __init__(self, num_classes, low_level_inplanes=21, BatchNorm=nn.BatchNorm2d):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
        self.bn1 = BatchNorm(48)
        self.relu = nn.ReLU()
        self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.5),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.1),
                                       nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
        self._init_weight()


    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x, low_level_feat), dim=1)
        x = self.last_conv(x)

        return x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ColorModel(BaseModel):
    def __init__(self, hparams, encoder=None):
        super().__init__(hparams)
        if encoder is None:
          model = DeepLab(num_classes=self.hparams.nclass,
                              backbone=self.hparams.backbone,
                              output_stride=self.hparams.out_stride,
                              sync_bn=self.hparams.sync_bn,
                              freeze_bn=self.hparams.freeze_bn)
          encoder = model.backbone
        self.encoder = encoder
        self.decoder = ColorDecoder(num_classes=self.hparams.nclass, feature_dim=256) # resnet feature map dim 320, aspp=256
        
    def forward(self, x):
      feature_map, low_level_feats = self.encoder(x)
      return self.decoder(feature_map, x, low=low_level_feats)

    def get_loss(self, batch, batch_idx):
            i = batch_idx
            epoch = self.current_epoch
            image, target = batch['image'], batch['label']
            target[target==254]=255
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            output = self.forward(image)
            celoss = self.criterion(output, target.long())

            x, _ = self.encoder(image)
            coarse_output = self.decoder.coarse_cls(x)
            coarse_output = F.interpolate(coarse_output, size=image.size()[2:], mode='bilinear', align_corners=True)
            coarse_celoss = self.criterion(coarse_output, target.long())

            self.log('train/ce', celoss.item())
            self.log('train/course_ce', coarse_celoss.item())
            return celoss + coarse_celoss*0.1
    
    def get_loss_val(self, batch, batch_idx):
            # import pdb;pdb.set_trace()
            image, target = batch['image'], batch['label']
            target[target==254]=255
            i= batch_idx % len(batch['image'])
            output = self.forward(image)
            celoss = self.criterion(output, target.long())
            mask = torch.max(output[i].unsqueeze(0),1)[1].detach()
            self.val_img_logs += [wb_mask(image[i].cpu().numpy().transpose([1,2,0]), mask[0].cpu().numpy(), target[i].cpu().numpy())]

            x, _ = self.encoder(image)
            coarse_output = self.decoder.coarse_cls(x)
            coarse_output = F.interpolate(coarse_output, size=image.size()[2:], mode='bilinear', align_corners=True)
            mask = torch.max(coarse_output[i].unsqueeze(0),1)[1].detach()
            self.val_img_logs += [wb_mask(image[i].cpu().numpy().transpose([1,2,0]), mask[0].cpu().numpy(), target[i].cpu().numpy())]

            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = target.cpu().numpy()
            self.evaluator.add_batch(target, pred)
            result = {
              'ce_loss': celoss
            }
            return result

    def configure_optimizers(self):
        train_params = [{'params': self.get_1x_lr_params(), 'lr': self.hparams.lr},
                        {'params': self.get_10x_lr_params(), 'lr': self.hparams.lr * 10}]
        self.optimizer = torch.optim.SGD(train_params, momentum=self.hparams.momentum, 
                                                weight_decay=self.hparams.weight_decay, 
                                                nesterov=self.hparams.nesterov)
        self.scheduler = LR_Scheduler(self.hparams.lr_scheduler, self.hparams.lr,
                                            self.hparams.epochs, self.hparams.num_img_tr)
        return self.optimizer

    def get_1x_lr_params(self):
        modules = [self.encoder.encoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                        or isinstance(m[1], nn.BatchNorm2d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_10x_lr_params(self):
        modules = [self.decoder, self.encoder.aspp]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                        or isinstance(m[1], nn.BatchNorm2d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p


In [None]:
class SegModelDebug(BaseModel):
    def __init__(self, hparams, model=None):
        super().__init__(hparams)
        self.model = DeepLab(num_classes=self.hparams.nclass,
                            backbone=self.hparams.backbone,
                            output_stride=self.hparams.out_stride,
                            sync_bn=self.hparams.sync_bn,
                            freeze_bn=self.hparams.freeze_bn)
        self.coarse_cls = nn.Conv2d(256, self.hparams.nclass, kernel_size=1, stride=1)

    def forward(self, x):
      return self.model(x)
    
    def get_loss(self, batch, batch_idx):
            i = batch_idx
            epoch = self.current_epoch
            image, target = batch['image'], batch['label']
            target[target==254]=255
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            # self.optimizer.zero_grad()
            output = self.forward(image)
            celoss = self.criterion(output, target.long())

            x, low = self.model.backbone(image)
            x = self.model.aspp(x)
            coarse_output = self.coarse_cls(x)
            coarse_output = F.interpolate(coarse_output, size=image.size()[2:], mode='bilinear', align_corners=True)
            coarse_celoss = self.criterion(coarse_output, target.long())
            self.log('train/ce', celoss.item())
            self.log('train/course_ce', coarse_celoss.item())
            return celoss + coarse_celoss*0.1
    
    def get_loss_val(self, batch, batch_idx):
            image, target = batch['image'], batch['label']
            target[target==254]=255
            i= batch_idx % len(batch['image'])
            output = self.forward(image)
            celoss = self.criterion(output, target.long())
            mask = torch.max(output[i].unsqueeze(0),1)[1].detach()
            self.val_img_logs += [wb_mask(image[i].cpu().numpy().transpose([1,2,0]), mask[0].cpu().numpy(), target[i].cpu().numpy())]

            x, low = self.model.backbone(image) 
            x = self.model.aspp(x)
            coarse_output = self.coarse_cls(x)
            coarse_output = F.interpolate(coarse_output, size=image.size()[2:], mode='bilinear', align_corners=True)
            mask = torch.max(coarse_output[i].unsqueeze(0),1)[1].detach()
            self.val_img_logs += [wb_mask(image[i].cpu().numpy().transpose([1,2,0]), mask[0].cpu().numpy(), target[i].cpu().numpy())]

            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = target.cpu().numpy()
            self.evaluator.add_batch(target, pred)
            result = {
              'ce_loss': celoss
            }
            return result

    def get_10x_lr_params(self):
        modules = [self.model.decoder, self.model.aspp, self.coarse_cls]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                        or isinstance(m[1], nn.BatchNorm2d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p


class SegModel(BaseModel):
    def __init__(self, hparams, model=None):
        super().__init__(hparams)
        if model is None:
          model = DeepLab(num_classes=self.hparams.nclass,
                            backbone=self.hparams.backbone,
                            output_stride=self.hparams.out_stride,
                            sync_bn=self.hparams.sync_bn,
                            freeze_bn=self.hparams.freeze_bn)
        self.model = model
        
    def forward(self, x):
      return self.model(x)

    def get_loss(self, batch, batch_idx):
        i = batch_idx
        epoch = self.current_epoch
        image, target = batch['image'], batch['label']
        target[target==254]=255
        self.scheduler(self.optimizer, i, epoch, self.best_pred)
        # self.optimizer.zero_grad()
        output = self.forward(image)
        celoss = self.criterion(output, target.long())
        self.log('train/ce', celoss.item())
        return celoss

In [None]:
args.batch_size = 10
args.limit_dataset = False
args.nclass = nclass
args.lr = 0.01
args.epochs = 1
# train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs)
train_loader, val_loader = single_train_loader, single_val_loader

args.num_img_tr=len(train_loader)
# model = ColorModel(args, encoder=DeepLabEncoder(args))
model = SegModelDebug(args)
# model.configure_optimizers()
# for param in model.encoder.model.backbone.parameters():
#   param.requires_grad = False

wandb_logger = WandbLogger(project='Color-Query', name='deeplabv3+_seeds') # prototype_aspp_seeds  deeplabv3+_color-low-feats

trainer = pl.Trainer(gpus=1, max_epochs=1, logger=wandb_logger, log_every_n_steps=10, num_sanity_val_steps=0, progress_bar_refresh_rate=0, accumulate_grad_batches=1)
results = trainer.fit(model, train_loader, val_loader)
wandb.finish()

In [None]:
# import gc
# for obj in gc.get_objects():
#     try:
#         if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
#             print(type(obj), obj.size())
#     except:
#         pass