<a href="https://colab.research.google.com/github/Ryan-Qiyu-Jiang/KidneyStones/blob/master/abcd_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/Ryan-Qiyu-Jiang/rloss.git
%cd /content/rloss/
!git checkout color_query

In [None]:
%cd /content/rloss/data/VOC2012/
!./fetchVOC2012.sh
%cd /content/rloss/data/pascal_scribble/
! ./fetchPascalScribble.sh

In [None]:
!pip install -qqq tensorboardX
!pip install -qqq wandb
!pip install -qqq pytorch-lightning
!pip install -qqq kornia

In [None]:
%cd /content/rloss/pytorch/pytorch_deeplab_v3_plus

In [None]:
import argparse
import os
import numpy as np
from tqdm import tqdm

from mypath import Path
from dataloaders import make_data_loader
from dataloaders.custom_transforms import denormalizeimage
from modeling.sync_batchnorm.replicate import patch_replication_callback
from modeling.deeplab import *
from utils.loss import SegmentationLosses
from utils.calculate_weights import calculate_weigths_labels
from utils.lr_scheduler import LR_Scheduler
from utils.saver import Saver
from utils.summaries import TensorboardSummary
from utils.metrics import Evaluator

import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.nn import functional as F
import pytorch_lightning as pl
from torch import nn
import torch
from argparse import Namespace
from dataloaders.utils import decode_seg_map_sequence

from color_query import BaseModel, get_args, SingleDataset, RepeatDataset, DeepLabEncoder, BaseColorModel

segmentation_classes = [
    'background','aeroplane','bicycle','bird','boat','bottle',
    'bus','car','cat','chair','cow','diningtable','dog','horse',
    'motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'
]

def labels():
  l = {}
  for i, label in enumerate(segmentation_classes):
    l[i] = label
  return l

def wb_mask(bg_img, pred_mask, true_mask):
  return wandb.Image(bg_img, masks={
    "prediction" : {"mask_data" : pred_mask, "class_labels" : labels()},
    "ground truth" : {"mask_data" : true_mask, "class_labels" : labels()}})

In [None]:
from argparse import Namespace
from dataloaders import make_data_loader

args_dict = get_args()
args_dict['cuda'] = True
args_dict['checkname'] = 'ignore'
args_dict['epochs'] = 1
args_dict['shuffle'] = False # True for real training
args_dict['batch_size'] = 10
args_dict['lr'] = 1e-3
args_dict['full_gt'] = True # True for gt
args_dict['limit_dataset'] = False
# args_dict['rloss_scale'] = 1
args = Namespace(**args_dict)

kwargs = {'num_workers': 6, 'pin_memory': True}
train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs)
gt_loader = val_loader

seeds_args = Namespace(**dict(args_dict, full_gt=False) )
seeds_loader, _, _, _ = make_data_loader(seeds_args, **kwargs)

In [None]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader

batch = iter(seeds_loader).next()
batch_sample = {k:v for k,v in batch.items()}
bs = 10
single_dataset = RepeatDataset(batch_sample,  100*bs)
single_train_loader = DataLoader(single_dataset, batch_size=bs, shuffle=True, num_workers=4)
single_val_loader = DataLoader(single_dataset, batch_size=bs, shuffle=False, num_workers=4)

plt.imshow(batch['image'][0].numpy().transpose(1,2,0));

In [None]:
import wandb
from pytorch_lightning.loggers import WandbLogger
wandb.login()

In [None]:
import pdb
from utils.metrics import Evaluator


def to_img(img_tensor, **kwargs):
  if img_tensor.size(0) == 3:
    img_tensor = img_tensor.permute(1, 2, 0) # 3 h w -> h w 3
  return wandb.Image((img_tensor.cpu().detach().numpy() * 255).astype(np.uint8), **kwargs)
log_num = 0

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class FFCSE_block(nn.Module):

    def __init__(self, channels, ratio_g):
        super(FFCSE_block, self).__init__()
        in_cg = int(channels * ratio_g)
        in_cl = channels - in_cg
        r = 16

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv1 = nn.Conv2d(channels, channels // r,
                               kernel_size=1, bias=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
            channels // r, in_cl, kernel_size=1, bias=True)
        self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
            channels // r, in_cg, kernel_size=1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x if type(x) is tuple else (x, 0)
        id_l, id_g = x

        x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
        x = self.avgpool(x)
        x = self.relu1(self.conv1(x))

        x_l = 0 if self.conv_a2l is None else id_l * \
            self.sigmoid(self.conv_a2l(x))
        x_g = 0 if self.conv_a2g is None else id_g * \
            self.sigmoid(self.conv_a2g(x))
        return x_l, x_g

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        res = x * y.expand_as(x)
        return res

class FourierUnit(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
                 spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
        # bn_layer not used
        super(FourierUnit, self).__init__()
        self.groups = groups

        self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
                                          out_channels=out_channels * 2,
                                          kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
        self.bn = torch.nn.BatchNorm2d(out_channels * 2)
        self.relu = torch.nn.ReLU(inplace=True)

        # squeeze and excitation block
        self.use_se = use_se
        if use_se:
            if se_kwargs is None:
                se_kwargs = {}
            self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)

        self.spatial_scale_factor = spatial_scale_factor
        self.spatial_scale_mode = spatial_scale_mode
        self.spectral_pos_encoding = spectral_pos_encoding
        self.ffc3d = ffc3d
        self.fft_norm = fft_norm

    def forward(self, x):
        batch = x.shape[0]

        if self.spatial_scale_factor is not None:
            orig_size = x.shape[-2:]
            x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)

        r_size = x.size()
        # (batch, c, h, w/2+1, 2)
        fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
        ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
        ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
        ffted = ffted.view((batch, -1,) + ffted.size()[3:])

        if self.spectral_pos_encoding:
            height, width = ffted.shape[-2:]
            coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
            coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
            ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)

        if self.use_se:
            ffted = self.se(ffted)

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(self.bn(ffted))

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
            0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])

        ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
        output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

        if self.spatial_scale_factor is not None:
            output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)

        return output


class SpectralTransform(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
        # bn_layer not used
        super(SpectralTransform, self).__init__()
        self.enable_lfu = enable_lfu
        if stride == 2:
            self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        else:
            self.downsample = nn.Identity()

        self.stride = stride
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels //
                      2, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU(inplace=True)
        )
        self.fu = FourierUnit(
            out_channels // 2, out_channels // 2, groups, **fu_kwargs)
        if self.enable_lfu:
            self.lfu = FourierUnit(
                out_channels // 2, out_channels // 2, groups)
        self.conv2 = torch.nn.Conv2d(
            out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)

    def forward(self, x):
        # import pdb;pdb.set_trace()
        x = self.downsample(x)
        x = self.conv1(x)
        output = self.fu(x)

        if self.enable_lfu:
            n, c, h, w = x.shape
            split_no = 2
            split_s = h // split_no
            xs = torch.cat(torch.split(
                x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
            xs = torch.cat(torch.split(xs, split_s, dim=-1),
                           dim=1).contiguous()
            xs = self.lfu(xs)
            xs = xs.repeat(1, 1, split_no, split_no).contiguous()
        else:
            xs = 0

        output = self.conv2(x + output + xs)

        return output


class FFC(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 ratio_gin, ratio_gout, stride=1, padding=0,
                 dilation=1, groups=1, bias=False, enable_lfu=True,
                 padding_type='reflect', gated=False, **spectral_kwargs):
        super(FFC, self).__init__()

        assert stride == 1 or stride == 2, "Stride should be 1 or 2."
        self.stride = stride

        in_cg = int(in_channels * ratio_gin)
        in_cl = in_channels - in_cg
        out_cg = int(out_channels * ratio_gout)
        out_cl = out_channels - out_cg
        #groups_g = 1 if groups == 1 else int(groups * ratio_gout)
        #groups_l = 1 if groups == 1 else groups - groups_g

        self.ratio_gin = ratio_gin
        self.ratio_gout = ratio_gout
        self.global_in_num = in_cg

        module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
        self.convl2l = module(in_cl, out_cl, kernel_size,
                              stride, padding, dilation, groups, bias, padding_mode=padding_type)
        module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
        self.convl2g = module(in_cl, out_cg, kernel_size,
                              stride, padding, dilation, groups, bias, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
        self.convg2l = module(in_cg, out_cl, kernel_size,
                              stride, padding, dilation, groups, bias, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
        self.convg2g = module(
            in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)

        self.gated = gated
        module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
        self.gate = module(in_channels, 2, 1)

    def forward(self, x):
        x_l, x_g = x if type(x) is tuple else (x, 0)
        out_xl, out_xg = 0, 0

        if self.gated:
            total_input_parts = [x_l]
            if torch.is_tensor(x_g):
                total_input_parts.append(x_g)
            total_input = torch.cat(total_input_parts, dim=1)

            gates = torch.sigmoid(self.gate(total_input))
            g2l_gate, l2g_gate = gates.chunk(2, dim=1)
        else:
            g2l_gate, l2g_gate = 1, 1

        if self.ratio_gout != 1:
            out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
        if self.ratio_gout != 0:
            out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)

        return out_xl, out_xg


class FFC_BN_ACT(nn.Module):

    def __init__(self, in_channels, out_channels,
                 kernel_size, ratio_gin, ratio_gout,
                 stride=1, padding=0, dilation=1, groups=1, bias=False,
                 norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
                 padding_type='reflect',
                 enable_lfu=True, **kwargs):
        super(FFC_BN_ACT, self).__init__()
        self.ffc = FFC(in_channels, out_channels, kernel_size,
                       ratio_gin, ratio_gout, stride, padding, dilation,
                       groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
        lnorm = nn.Identity if ratio_gout == 1 else norm_layer
        gnorm = nn.Identity if ratio_gout == 0 else norm_layer
        global_channels = int(out_channels * ratio_gout)
        self.bn_l = lnorm(out_channels - global_channels)
        self.bn_g = gnorm(global_channels)

        lact = nn.Identity if ratio_gout == 1 else activation_layer
        gact = nn.Identity if ratio_gout == 0 else activation_layer
        self.act_l = lact(inplace=True)
        self.act_g = gact(inplace=True)

    def forward(self, x):
        x_l, x_g = self.ffc(x)
        x_l = self.act_l(self.bn_l(x_l))
        x_g = self.act_g(self.bn_g(x_g))
        return x_l, x_g

from kornia.geometry.transform import rotate
class LearnableSpatialTransformWrapper(nn.Module):
    def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
        super().__init__()
        self.impl = impl
        self.angle = torch.rand(1) * angle_init_range
        if train_angle:
            self.angle = nn.Parameter(self.angle, requires_grad=True)
        self.pad_coef = pad_coef

    def forward(self, x):
        if torch.is_tensor(x):
            return self.inverse_transform(self.impl(self.transform(x)), x)
        elif isinstance(x, tuple):
            x_trans = tuple(self.transform(elem) for elem in x)
            y_trans = self.impl(x_trans)
            return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
        else:
            raise ValueError(f'Unexpected input type {type(x)}')

    def transform(self, x):
        height, width = x.shape[2:]
        pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
        x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
        x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
        return x_padded_rotated

    def inverse_transform(self, y_padded_rotated, orig_x):
        height, width = orig_x.shape[2:]
        pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)

        y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
        y_height, y_width = y_padded.shape[2:]
        y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
        return y

class FFCResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
                 spatial_transform_kwargs=None, inline=False, **conv_kwargs):
        super().__init__()
        self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
                                norm_layer=norm_layer,
                                activation_layer=activation_layer,
                                padding_type=padding_type, 
                                **conv_kwargs)
        self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
                                norm_layer=norm_layer,
                                activation_layer=activation_layer,
                                padding_type=padding_type,
                                **conv_kwargs)
        if spatial_transform_kwargs is not None:
            self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
            self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
        self.inline = inline

    def forward(self, x):
        if self.inline:
            x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
        else:
            x_l, x_g = x if type(x) is tuple else (x, 0)

        id_l, id_g = x_l, x_g

        x_l, x_g = self.conv1((x_l, x_g))
        x_l, x_g = self.conv2((x_l, x_g))

        x_l, x_g = id_l + x_l, id_g + x_g
        out = x_l, x_g
        if self.inline:
            out = torch.cat(out, dim=1)
        return out


class ConcatTupleLayer(nn.Module):
    def forward(self, x):
        assert isinstance(x, tuple)
        x_l, x_g = x
        assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
        if not torch.is_tensor(x_g):
            return x_l
        return torch.cat(x, dim=1)


# class FFC_Decoder(nn.Module):
#     def __init__(self, num_classes, backbone, BatchNorm):
#         super(FFC_Decoder, self).__init__()
#         if backbone == 'resnet' or backbone == 'drn':
#             low_level_inplanes = 256
#         elif backbone == 'xception':
#             low_level_inplanes = 128
#         elif backbone == 'mobilenet':
#             low_level_inplanes = 24
#         else:
#             raise NotImplementedError

#         self.ffc_low = FFC_BN_ACT(low_level_inplanes, 48, kernel_size=7, padding=0, activation_layer=nn.ReLU, ratio_gin=0, ratio_gout=0)
#         self.last_ffc_conv = nn.Sequential(
#             FFCResnetBlock(304, padding_type='reflect', norm_layer=BatchNorm, activation_layer=nn.ReLU, ratio_gin=0.75, ratio_gout=0.75),
#             ConcatTupleLayer(),
#             nn.Conv2d(304, num_classes, kernel_size=7, padding=0) #FFC_BN_ACT(low_level_inplanes, 48, kernel_size=7, padding=0, activation_layer=nn.ReLU)
#         )
#         self.concat_tuple = ConcatTupleLayer()

#     def forward(self, x, low_level_feat):
#         low_level_feat = self.concat_tuple(self.ffc_low(low_level_feat))

#         x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
#         x = torch.cat((x, low_level_feat), dim=1)
#         x = self.last_ffc_conv(x)

#         return x

# from ffc import FFCResnetBlock, ConcatTupleLayer, FFC_BN_ACT # FFC_Decoder
class FFC_Decoder(nn.Module):
    def __init__(self, num_classes, backbone, BatchNorm):
        super(FFC_Decoder, self).__init__()
        if backbone == 'resnet' or backbone == 'drn':
            low_level_inplanes = 256
        elif backbone == 'xception':
            low_level_inplanes = 128
        elif backbone == 'mobilenet':
            low_level_inplanes = 24
        else:
            raise NotImplementedError

        self.ffc_low = FFC_BN_ACT(low_level_inplanes, 48, kernel_size=1, padding=0, activation_layer=nn.ReLU, ratio_gin=0, ratio_gout=0)
        num_channels = 256+48
        self.last_ffc_conv = nn.Sequential(
            FFC_BN_ACT(num_channels, num_channels, kernel_size=1, padding=0, activation_layer=nn.ReLU, ratio_gin=0, ratio_gout=0.9),
            FFCResnetBlock(num_channels, padding_type='reflect', norm_layer=BatchNorm, activation_layer=nn.ReLU, ratio_gin=0.9, ratio_gout=0.9, enable_lfu=False),
            ConcatTupleLayer(),
            nn.Conv2d(num_channels, num_classes, kernel_size=1, padding=0) #FFC_BN_ACT(low_level_inplanes, 48, kernel_size=7, padding=0, activation_layer=nn.ReLU)
        )
        self.concat_tuple = ConcatTupleLayer()

    def forward(self, x, low_level_feat):
        # import pdb;pdb.set_trace()
        low_level_feat = self.ffc_low(low_level_feat)

        x = F.interpolate(x, size=low_level_feat[0].size()[2:], mode='bilinear', align_corners=True)
        x_l = torch.cat((x, low_level_feat[0]), dim=1)
        x_g = low_level_feat[1]
        x = self.last_ffc_conv( (x_l, x_g) )
        return x

class ColorModel(BaseModel):
    def __init__(self, hparams, encoder=None):
        super().__init__(hparams)
        if encoder is None:
          model = DeepLab(num_classes=self.hparams.nclass,
                              backbone=self.hparams.backbone,
                              output_stride=self.hparams.out_stride,
                              sync_bn=self.hparams.sync_bn,
                              freeze_bn=self.hparams.freeze_bn)
          encoder = model.backbone
        self.encoder = encoder
        # self.decoder = MultiColorDecoder(num_classes=self.hparams.nclass, feature_dim=256) # resnet feature map dim 320, aspp=256
        self.decoder = FFC_Decoder(num_classes=self.hparams.nclass, backbone=self.hparams.backbone, BatchNorm=nn.BatchNorm2d)
        
    def forward(self, x):
      feature_map, low_level_feats = self.encoder(x)
      y = self.decoder(feature_map, low_level_feats)
      return F.interpolate(y, size=x.size()[2:], mode='bilinear', align_corners=True)

    def get_loss(self, batch, batch_idx):
            i = batch_idx
            epoch = self.current_epoch
            image, target = batch['image'], batch['label']
            target[target==254]=255
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            output = self.forward(image)
            celoss = self.criterion(output, target.long())

            # x, _ = self.encoder(image)
            # coarse_output = self.decoder.coarse_cls(x)
            # coarse_output = F.interpolate(coarse_output, size=image.size()[2:], mode='bilinear', align_corners=True)
            # coarse_celoss = self.criterion(coarse_output, target.long())

            self.log('train/ce', celoss.item())
            # self.log('train/course_ce', coarse_celoss.item())
            return celoss #+ coarse_celoss*0.1
    
    def get_loss_val(self, batch, batch_idx):
            image, target = batch['image'], batch['label']
            target[target==254]=255
            i= batch_idx % len(batch['image'])
            output = self.forward(image)
            celoss = self.criterion(output, target.long())
            mask = torch.max(output[i].unsqueeze(0),1)[1].detach()
            self.val_img_logs += [wb_mask(image[i].cpu().numpy().transpose([1,2,0]), mask[0].cpu().numpy(), target[i].cpu().numpy())]

            # x, _ = self.encoder(image)
            # coarse_output = self.decoder.coarse_cls(x)
            # coarse_output = F.interpolate(coarse_output, size=image.size()[2:], mode='bilinear', align_corners=True)
            # mask = torch.max(coarse_output[i].unsqueeze(0),1)[1].detach()
            # self.val_img_logs += [wb_mask(image[i].cpu().numpy().transpose([1,2,0]), mask[0].cpu().numpy(), target[i].cpu().numpy())]

            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = target.cpu().numpy()
            self.evaluator.add_batch(target, pred)
            result = {
              'ce_loss': celoss
            }
            return result

    def configure_optimizers(self):
        train_params = [{'params': self.get_1x_lr_params(), 'lr': self.hparams.lr},
                        {'params': self.get_10x_lr_params(), 'lr': self.hparams.lr * 10}]
        self.optimizer = torch.optim.SGD(train_params, momentum=self.hparams.momentum, 
                                                weight_decay=self.hparams.weight_decay, 
                                                nesterov=self.hparams.nesterov)
        self.scheduler = LR_Scheduler(self.hparams.lr_scheduler, self.hparams.lr,
                                            self.hparams.epochs, self.hparams.num_img_tr)
        return self.optimizer

    def get_1x_lr_params(self):
        modules = [self.encoder.encoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                        or isinstance(m[1], nn.BatchNorm2d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_10x_lr_params(self):
        modules = [self.decoder, self.encoder.aspp]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                        or isinstance(m[1], nn.BatchNorm2d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p


In [None]:
args.batch_size = 10
args.limit_dataset = False
args.nclass = nclass
args.lr = 0.01
args.epochs = 1
# train_loader, val_loader, test_loader, nclanss = make_data_loader(args, **kwargs)
train_loader, val_loader = single_train_loader, single_val_loader

args.num_img_tr=len(train_loader)
model = ColorModel(args, encoder=DeepLabEncoder(args))
# model = SegModelDebug(args)
# model.configure_optimizers()
# for param in model.encoder.model.backbone.parameters():
#   param.requires_grad = False

wandb_logger = WandbLogger(project='Color-Query', name='ffc_dlv3+') # proto_dlv3_rand-space  deeplabv3+_seeds  prototype_aspp_seeds  deeplabv3+_color-low-feats

trainer = pl.Trainer(gpus=1, max_epochs=1, logger=wandb_logger, log_every_n_steps=10, num_sanity_val_steps=0, progress_bar_refresh_rate=0, accumulate_grad_batches=1)
results = trainer.fit(model, train_loader, val_loader)
wandb.finish()

In [None]:

# import gc
# for obj in gc.get_objects():
#     try:
#         if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
#             print(type(obj), obj.size())
#     except:
#         pass