In [7]:
! pip install torch_optimizer
! pip install codecarbon
! pip install git+https://github.com/sovrasov/flops-counter.pytorch.git
! pip install skorch
! pip install gdown
from IPython.display import clear_output
clear_output()


In [None]:
from IPython.display import clear_output
import os.path 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.ops
from torch.autograd import Variable, Function
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch_optimizer as optim
from torch.optim import Adam
from codecarbon.emissions_tracker import EmissionsTracker
import os
import shutil
from sklearn.metrics import accuracy_score, cohen_kappa_score, jaccard_score, precision_score, recall_score, f1_score, classification_report
from ptflops import get_model_complexity_info
import pandas as pd
from PIL import Image
import importlib
import torchvision
import torch.utils.data
import math
from torch.nn.parameter import Parameter
from torch.nn.functional import pad
from torch.nn.modules import Module
from torch.nn import ConvTranspose2d
from torch.nn.modules.utils import _single, _pair, _triple
import math
import gdown
import time

In [9]:
url = "https://drive.google.com/drive/folders/17rQ2ALkfZPNtgkGmzr0tYVKYCPjPc5X7?usp=sharing"
if not os.path.exists ("./datasets"):
    gdown.download_folder (url, quiet=False, use_cookies=False)
if not os.path.exists ("./model"):
     os.mkdir("model")
if not os.path.exists ("./output"):
     os.mkdir("output")
clear_output()

Convolutions

In [None]:
#conv2d
def conv2d(in_channels, out_channels, kernel_size, padding, bias=False):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)

In [None]:
#Spacially separable conv
class SpatialSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False):
        super(SpatialSeparableConv2d, self).__init__()
        assert isinstance(kernel_size, int)
        self.conv1 = nn.Conv2d(in_channels, in_channels, (kernel_size, 1), padding=(padding, 0), groups=in_channels, bias=bias)
        self.conv2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size), padding=(0, padding), bias=bias)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


In [None]:
#gaussian dynamic convolution
class HalfNormal(object):
    def __init__(self, scale, seed, device):
        self.scale = scale
        self.device = device
        torch.manual_seed(seed)
    def sample(self, sample_shape=torch.Size()):
        result = torch.zeros(sample_shape).to(self.device)
        result = result.normal_(mean=0, std=self.scale).abs()
        return result
class GFDConv(nn.Module):
    def __init__(self, in_features,out_features, bias=False, scale=0.1, device='cpu', seed=307, fix_w=0, fix_h=0):
        super(GFDConv, self).__init__()
        self.conv = nn.Conv2d(9 * in_features, out_features, 1, bias=bias)
        self.dis = HalfNormal(scale, seed, device)
        self.size = None
        self.device = device
        self.direction_basis = torch.tensor([[-1, 1, -1, 1, -1, 1, 0, 0],
                                             [-1, -1, 1, 1, 0, 0, -1, 1]]).float().view(-1).to(self.device)
        if fix_w != 0 and fix_h != 0:
            yy = torch.linspace(0, fix_h - 1, steps=fix_h).unsqueeze(1).repeat(1, fix_w).unsqueeze(-1).to(self.device)
            xx = torch.linspace(0, fix_w - 1, steps=fix_w).unsqueeze(0).repeat(fix_h, 1).unsqueeze(-1).to(self.device)
            self.base_coor = torch.cat([xx.repeat(1, 1, 8), yy.repeat(1, 1, 8)], dim=-1).to(self.device)
            self.size = torch.tensor([fix_w] * 8 + [fix_h] * 8).float().to(self.device)
    def forward(self, feat):
        sample_coor = self.sample_process(feat.size(2), feat.size(3))
        sample_coor_x = sample_coor[:, :, :8]
        sample_coor_y = sample_coor[:, :, 8:]
        feat = F.pad(feat, [1, 1, 1, 1]).to(sample_coor.device)
        offset_feat = feat[:, :, sample_coor_y, sample_coor_x]
        offset_feat = F.pad(offset_feat.permute(0, 4, 1, 2, 3).contiguous()
                            .view(offset_feat.size(0), -1, offset_feat.size(2), offset_feat.size(3)), [1, 1, 1, 1])
        feat = torch.cat([feat, offset_feat], dim=1)
        feat = self.conv.to(feat.device)(feat)[:, :, 1:-1, 1:-1]
        return feat
    def sample_process(self, h, w):
        if self.size is None:
            yy = torch.linspace(0, h - 1, steps=h).unsqueeze(1).repeat(1, w).unsqueeze(-1).to(self.device)
            xx = torch.linspace(0, w - 1, steps=w).unsqueeze(0).repeat(h, 1).unsqueeze(-1).to(self.device)
            base_coor = torch.cat([xx.repeat(1, 1, 8), yy.repeat(1, 1, 8)], dim=-1).to(self.device)
            size = torch.tensor([w] * 8 + [h] * 8).float().to(self.device)
        else:
            size = self.size
            base_coor = self.base_coor
        sample_ = self.dis.sample(torch.Size([h, w, 16]))
        offset = sample_ * self.direction_basis * size
        sample_coor = base_coor + offset
        sample_coor[:, :, :8] = torch.clamp(sample_coor[:, :, :8], min=0, max=w - 1)
        sample_coor[:, :, 8:] = torch.clamp(sample_coor[:, :, 8:], min=0, max=h - 1)
        return (sample_coor + 1).long()
def gaussian_dynamic_conv(in_channels, out_channels, kernel_size, padding, bias=False):
    fix_w = 0
    fix_h = 0
    seed = 307
    scale = 2
    return GFDConv(in_channels, out_channels, bias, scale, 'cuda', seed, fix_w, fix_h)



In [None]:
#deformable convolution
class DeformableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,dilation=1,bias=False):
        super(DeformableConv2d, self).__init__()
        assert type(kernel_size) == tuple or type(kernel_size) == int
        kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
        self.stride = stride if type(stride) == tuple else (stride, stride)
        self.padding = padding
        self.dilation = dilation
        self.offset_conv = nn.Conv2d(in_channels,2 * kernel_size[0] * kernel_size[1],kernel_size=kernel_size,stride=stride,padding=self.padding,dilation=self.dilation,bias=True)
        nn.init.constant_(self.offset_conv.weight, 0.)
        nn.init.constant_(self.offset_conv.bias, 0.)
        self.modulator_conv = nn.Conv2d(in_channels,1 * kernel_size[0] * kernel_size[1],kernel_size=kernel_size,stride=stride,padding=self.padding,dilation=self.dilation,bias=True)
        nn.init.constant_(self.modulator_conv.weight, 0.)
        nn.init.constant_(self.modulator_conv.bias, 0.)
        self.regular_conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=self.padding,dilation=self.dilation,bias=bias)
    def forward(self, x):
        offset = self.offset_conv(x)
        modulator = 2. * torch.sigmoid(self.modulator_conv(x))
        x = torchvision.ops.deform_conv2d(input=x,offset=offset,weight=self.regular_conv.weight,bias=self.regular_conv.bias,padding=self.padding,mask=modulator,stride=self.stride,dilation=self.dilation)
        return x
def deformable_conv(in_channels, out_channels, kernel_size, padding, bias=False):
    stride = 1
    dilation = 1
    return DeformableConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias)

In [None]:
#Adaptive deformable convolution
class DeformConv2d(nn.Module):
    def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
        super(DeformConv2d, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.zero_padding = nn.ZeroPad2d(padding)
        self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
        self.conv_weight_m = torch.ones([outc,inc,kernel_size,kernel_size], requires_grad=False).cuda()
        self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
        nn.init.constant_(self.p_conv.weight, 0)
        self.p_conv.register_backward_hook(self._set_lr)
        self.modulation = modulation
        if modulation:
            self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
            nn.init.constant_(self.m_conv.weight, 0.5)
            self.m_conv.register_backward_hook(self._set_lr)
    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
    def forward(self, x):
        offset = self.p_conv(x)
        if self.modulation:
            m = torch.sigmoid(self.m_conv(x))
        dtype = offset.data.type()
        ks = self.kernel_size
        N = offset.size(1) // 2
        if self.padding:
            x = self.zero_padding(x)
        p = self._get_p(offset, dtype)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = Variable(p.data, requires_grad=False).floor()
        q_rb = q_lt + 1
        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1)
        mask = torch.cat([p[..., :N].lt(self.padding)+p[..., :N].gt(x.size(2)-1-self.padding),
                          p[..., N:].lt(self.padding)+p[..., N:].gt(x.size(3)-1-self.padding)], dim=-1).type_as(p)
        mask = mask.detach()
        floor_p = p - (p - torch.floor(p))
        p = p*(1-mask) + floor_p*mask
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt
        if self.modulation:
            m = m.contiguous().permute(0, 2, 3, 1)
            m = m.unsqueeze(dim=1)
            m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
            x_offset *= m
        x_offset = self._reshape_x_offset(x_offset, ks)
        out = self.conv(x_offset)
        return out
    def _get_p_n(self, N, dtype):
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
        p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
        p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
        return p_n
    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(1, h*self.stride+1, self.stride),
            torch.arange(1, w*self.stride+1, self.stride))
        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
        return p_0
    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
        p_n = self._get_p_n(N, dtype)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        return p
    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        x = x.contiguous().view(b, c, -1)
        index = q[..., :N]*padded_w + q[..., N:]
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
        return x_offset
    @staticmethod
    def _reshape_x_offset(x_offset, ks):
        b, c, h, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
        x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
        return x_offset

def Adaptive_deformable_conv(in_channels, out_channels, kernel_size, padding, bias=False):
	stride=1
	modulation=False
	bias = None
	return DeformConv2d(in_channels,out_channels,kernel_size,padding,stride,bias,modulation)

In [None]:
#Asymmetric convolution
class orgACBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(orgACBlock, self).__init__()
        c1 = int(out_channels*0.33)
        c2 = int(out_channels*0.33)
        c3 = out_channels - c1 - c2
        self.square_conv = nn.Conv2d(in_channels,c1 , (kernel_size, kernel_size), stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.ver_conv = nn.Conv2d(in_channels,c2 , (kernel_size, 1), stride=stride, padding=(padding, 0), dilation=dilation, groups=groups, bias=bias)
        self.hor_conv = nn.Conv2d(in_channels, c3, (1, kernel_size), stride=stride, padding=(0, padding), dilation=dilation, groups=groups, bias=bias)
    def forward(self, x):
        x1 = self.square_conv(x)
        x2 = self.ver_conv(x)
        x3 = self.hor_conv(x)
        x = torch.cat([x1,x2,x3], dim=1)
        return x

class ACBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(ACBlock, self).__init__()
        c1 = int(out_channels*0.5)
        c2 = int(out_channels*0.33)
        c3 = out_channels - c1 - c2
        self.square_conv = nn.Conv2d(in_channels,c1 , (kernel_size, kernel_size), stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.ver_conv = nn.Conv2d(in_channels,c2 , (kernel_size, 1), stride=stride, padding=(padding, 0), dilation=dilation, groups=groups, bias=bias)
        self.hor_conv = nn.Conv2d(in_channels, c3, (1, kernel_size), stride=stride, padding=(0, padding), dilation=dilation, groups=groups, bias=bias)
    def forward(self, x):
        x1 = self.square_conv(x)
        x2 = self.ver_conv(x)
        x3 = self.hor_conv(x)
        x = torch.cat([x1,x2,x3], dim=1)
        return x



class GD_AC(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(GD_AC, self).__init__()
        assert isinstance(kernel_size, int)
        c1 = int(out_channels*0.5)
        c2 = int(out_channels*0.33)
        c3 = out_channels - c1 - c2
        self.square_conv = gaussian_dynamic_conv(in_channels, c1, kernel_size, padding, bias=False)
        self.ver_conv = nn.Conv2d(in_channels, c2, (kernel_size, 1), stride=stride, padding=(padding, 0), dilation=dilation, groups=groups, bias=bias)
        self.hor_conv = nn.Conv2d(in_channels, c3, (1, kernel_size), stride=stride, padding=(0, padding), dilation=dilation, groups=groups, bias=bias)

    def forward(self, x):
        x1 = self.square_conv(x)
        x2 = self.ver_conv(x)
        x3 = self.hor_conv(x)
        x = torch.cat([x1,x2,x3], dim=1)
        return x

class D_AC(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(D_AC, self).__init__()
        assert isinstance(kernel_size, int)
        c1 = int(out_channels*0.5)
        c2 = int(out_channels*0.33)
        c3 = out_channels - c1 - c2
        self.square_conv = deformable_conv(in_channels, c1, kernel_size, padding, bias=False)
        self.ver_conv = nn.Conv2d(in_channels, c2, (kernel_size, 1), stride=stride, padding=(padding, 0), dilation=dilation, groups=groups, bias=bias)
        self.hor_conv = nn.Conv2d(in_channels, c3, (1, kernel_size), stride=stride, padding=(0, padding), dilation=dilation, groups=groups, bias=bias)

    def forward(self, x):
        x1 = self.square_conv(x)
        x2 = self.ver_conv(x)
        x3 = self.hor_conv(x)
        x = torch.cat([x1,x2,x3], dim=1)
        return x

class AD_AC(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(AD_AC, self).__init__()
        assert isinstance(kernel_size, int)
        c1 = int(out_channels*0.5)
        c2 = int(out_channels*0.33)
        c3 = out_channels - c1 - c2
        self.square_conv = Adaptive_deformable_conv(in_channels, c1, kernel_size, padding, bias=False)
        self.ver_conv = nn.Conv2d(in_channels, c2, (kernel_size, 1), stride=stride, padding=(padding, 0), dilation=dilation, groups=groups, bias=bias)
        self.hor_conv = nn.Conv2d(in_channels, c3, (1, kernel_size), stride=stride, padding=(0, padding), dilation=dilation, groups=groups, bias=bias)

    def forward(self, x):
        x1 = self.square_conv(x)
        x2 = self.ver_conv(x)
        x3 = self.hor_conv(x)
        x = torch.cat([x1,x2,x3], dim=1)
        return x

In [None]:
conv_function = [conv2d, SpatialSeparableConv2d, gaussian_dynamic_conv, deformable_conv,Adaptive_deformable_conv,orgACBlock,ACBlock,GD_AC,D_AC,AD_AC]

Unet

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None,index=0):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(

            conv_function[index](in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            conv_function[index](mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x)
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False,index = 0):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.index = index
        self.inc = (DoubleConv(n_channels, 64,index))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

Attention-Unet

In [None]:
class conv_block(nn.Module):
    def __init__(self, in_ch, out_ch,index=0):
        super(conv_block, self).__init__()

        self.conv = nn.Sequential(
            conv_function[index](in_ch, out_ch, kernel_size=3,  padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            conv_function[index](out_ch, out_ch, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()

        self.W_g = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        out = x * psi
        return out
class AttenUNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False,index = 0):
        super(AttenUNet, self).__init__()
        img_ch=n_channels
        output_ch=n_classes
        self.index=index
        n1=64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Conv1 = conv_block(img_ch, filters[0],index)
        self.Conv2 = conv_block(filters[0], filters[1],index)
        self.Conv3 = conv_block(filters[1], filters[2],index)
        self.Conv4 = conv_block(filters[2], filters[3],index)
        self.Conv5 = conv_block(filters[3], filters[4],index)
        self.Up5 = up_conv(filters[4], filters[3])
        self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2])
        self.Up_conv5 = conv_block(filters[4], filters[3],index)
        self.Up4 = up_conv(filters[3], filters[2])
        self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1])
        self.Up_conv4 = conv_block(filters[3], filters[2],index)
        self.Up3 = up_conv(filters[2], filters[1])
        self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0])
        self.Up_conv3 = conv_block(filters[2], filters[1],index)
        self.Up2 = up_conv(filters[1], filters[0])
        self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32)
        self.Up_conv2 = conv_block(filters[1], filters[0],index)
        self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        e1 = self.Conv1(x)
        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)
        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)
        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)
        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        d5 = self.Up5(e5)
        x4 = self.Att5(g=d5, x=e4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=e3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=e2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)
        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=e1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        out = self.Conv(d2)
        return out

UNet++

In [None]:
class conv_block_nested(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None,index=0):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.conv = nn.Sequential(

            conv_function[index](in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            conv_function[index](mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, inputs):
        return self.conv(inputs)

class UNetplusplus(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False,index = 0):
        super(UNetplusplus, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.index = index
        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv0_0 = conv_block_nested(n_channels, filters[0], filters[0],index)
        self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1],index)
        self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2],index)
        self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3],index)
        self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4],index)
        self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0],index)
        self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1],index)
        self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2],index)
        self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3],index)
        self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1], filters[0], filters[0],index)
        self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2], filters[1], filters[1],index)
        self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3], filters[2], filters[2],index)
        self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1], filters[0], filters[0],index)
        self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2], filters[1], filters[1],index)
        self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1], filters[0], filters[0],index)
        self.final = nn.Conv2d(filters[0], n_classes, kernel_size=1)
    def forward(self, x):
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))
        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))
        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))
        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))
        output = self.final(x0_4)
        return output
    def use_checkpointing(self):
        self.conv0_0 = torch.utils.checkpoint(self.conv0_0)
        self.conv1_0 = torch.utils.checkpoint(self.conv1_0)
        self.conv0_1 = torch.utils.checkpoint(self.conv0_1)
        self.conv2_0 = torch.utils.checkpoint(self.conv2_0)
        self.conv1_1 = torch.utils.checkpoint(self.conv1_1)
        self.conv0_2 = torch.utils.checkpoint(self.conv0_2)
        self.conv3_0 = torch.utils.checkpoint(self.conv3_0)
        self.conv2_1 = torch.utils.checkpoint(self.conv2_1)
        self.conv1_2 = torch.utils.checkpoint(self.conv1_2)
        self.conv0_3 = torch.utils.checkpoint(self.conv0_3)
        self.conv4_0 = torch.utils.checkpoint(self.conv4_0)
        self.conv3_1 = torch.utils.checkpoint(self.conv3_1)
        self.conv2_2 = torch.utils.checkpoint(self.conv2_2)
        self.conv1_3 = torch.utils.checkpoint(self.conv1_3)
        self.conv0_4 = torch.utils.checkpoint(self.conv0_4)
        self.final = torch.utils.checkpoint(self.final)

SegCaps

In [None]:
#nn
class _ConvNd(Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding, groups, bias):
        super(_ConvNd, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        if transposed:
            self.weight = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        return s.format(**self.__dict__)
class Conv2d(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=False):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias)
    def forward(self, input):
        return conv2d_same(input, self.weight, self.bias, self.stride,
                           self.dilation, self.groups)
class ConvTranspose2d(nn.ConvTranspose2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, output_padding=0, groups=1, bias=False, dilation=1):
        super(ConvTranspose2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation)
    def forward(self, input):
        input_size = input.size(2)
        output_size = input_size*self.stride[0]
        pad_l, pad_r = get_same(input_size,self.kernel_size[0],self.stride[0],dilation=1)
        self.padding=max(pad_l,pad_r)
        input_size=(input_size-1)*self.stride[0]+self.kernel_size[0]-2*self.padding
        output_padding=output_size-input_size
        return F.conv_transpose2d(
            input, self.weight, self.bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)
def conv2d_same(input, weight, bias=None, stride=[1, 1], dilation=(1, 1), groups=1):
    input_rows = input.size(2)
    filter_rows = weight.size(2)
    out_rows = (input_rows + stride[0] - 1) // stride[0]
    padding_rows = max(0, (out_rows - 1) * stride[0] +
                       (filter_rows - 1) * dilation[0] + 1 - input_rows)
    rows_odd = (padding_rows % 2 != 0)
    padding_cols = max(0, (out_rows - 1) * stride[0] +
                       (filter_rows - 1) * dilation[0] + 1 - input_rows)
    cols_odd = (padding_rows % 2 != 0)
    if rows_odd or cols_odd:
        input = pad(input, [0, int(cols_odd), 0, int(rows_odd)])
    return F.conv2d(input, weight, bias, stride,
                    padding=(padding_rows // 2, padding_cols // 2),
                    dilation=dilation, groups=groups)
def max_pool2d_same(input, kernel_size, stride=1, dilation=1, ceil_mode=False, return_indices=False):
    input_rows = input.size(2)
    out_rows = (input_rows + stride - 1) // stride
    padding_rows = max(0, (out_rows - 1) * stride +
                       (kernel_size - 1) * dilation + 1 - input_rows)
    rows_odd = (padding_rows % 2 != 0)
    cols_odd = (padding_rows % 2 != 0)
    if rows_odd or cols_odd:
        input = pad(input, [0, int(cols_odd), 0, int(rows_odd)])
    return F.max_pool2d(input, kernel_size=kernel_size, stride=stride, padding=padding_rows // 2, dilation=dilation,
                        ceil_mode=ceil_mode, return_indices=return_indices)
def get_same(size, kernel, stride, dilation):
    out_size = (size + stride - 1) // stride
    padding = max(0, (out_size - 1) * stride +
                  (kernel - 1) * dilation + 1 - size)
    size_odd = (padding % 2 != 0)
    pad_l = padding // 2
    pad_r = padding // 2
    if size_odd:
        pad_l += 1
    return pad_l, pad_r
#capsule layer
class CapsuleLayer(nn.Module):
    def __init__(self, t_0,z_0, op, k, s, t_1, z_1, routing):
        super().__init__()
        self.t_1 = t_1
        self.z_1 = z_1
        self.op = op
        self.k = k
        self.s = s
        self.routing = routing
        self.convs = nn.ModuleList()
        self.t_0=t_0
        for _ in range(t_0):
            if self.op=='conv':
                self.convs.append(nn.Conv2d(z_0, self.t_1*self.z_1, self.k, self.s,padding=2,bias=False))
            else:
                self.convs.append(nn.ConvTranspose2d(z_0, self.t_1 * self.z_1, self.k, self.s,padding=2,output_padding=1))
    def forward(self, u): 
        if u.shape[1]!=self.t_0:
            raise ValueError("Wrong type of operation for capsule")
        op = self.op
        k = self.k
        s = self.s
        t_1 = self.t_1
        z_1 = self.z_1
        routing = self.routing
        N = u.shape[0]
        H_1=u.shape[3]
        W_1=u.shape[4]
        t_0 = self.t_0
        u_t_list = [u_t.squeeze(1) for u_t in u.split(1, 1)] 
        u_hat_t_list = []
        for i, u_t in zip(range(self.t_0), u_t_list):  
            if op == "conv":
                u_hat_t = self.convs[i](u_t)  
            elif op == "deconv":
                u_hat_t = self.convs[i](u_t) 
            else:
                raise ValueError("Wrong type of operation for capsule")
            H_1 = u_hat_t.shape[2]
            W_1 = u_hat_t.shape[3]
            u_hat_t = u_hat_t.reshape(N, t_1,z_1,H_1, W_1).transpose_(1,3).transpose_(2,4)
            u_hat_t_list.append(u_hat_t)   
        v=self.update_routing(u_hat_t_list,k,N,H_1,W_1,t_0,t_1,routing)
        return v
    def update_routing(self,u_hat_t_list, k, N, H_1, W_1, t_0, t_1, routing):
        one_kernel = torch.ones(1, t_1, k, k).cuda()
        b = torch.zeros(N, H_1, W_1, t_0, t_1).cuda()  
        b_t_list = [b_t.squeeze(3) for b_t in b.split(1, 3)]
        u_hat_t_list_sg = []
        for u_hat_t in u_hat_t_list:
            u_hat_t_sg=u_hat_t.detach()
            u_hat_t_list_sg.append(u_hat_t_sg)
        for d in range(routing):
            if d < routing - 1:
                u_hat_t_list_ = u_hat_t_list_sg
            else:
                u_hat_t_list_ = u_hat_t_list
            r_t_mul_u_hat_t_list = []
            for b_t, u_hat_t in zip(b_t_list, u_hat_t_list_):
                b_t.transpose_(1, 3).transpose_(2, 3)  
                b_t_max = torch.nn.functional.max_pool2d(b_t,k,1,padding=2)
                b_t_max = b_t_max.max(1, True)[0]
                c_t = torch.exp(b_t - b_t_max)
                sum_c_t = conv2d_same(c_t, one_kernel, stride=(1, 1)) 
                r_t = c_t / sum_c_t  
                r_t = r_t.transpose(1, 3).transpose(1, 2)  
                r_t = r_t.unsqueeze(4)  
                r_t_mul_u_hat_t_list.append(r_t * u_hat_t)  
            p = sum(r_t_mul_u_hat_t_list) 
            v = squash(p)
            if d < routing - 1:
                b_t_list_ = []
                for b_t, u_hat_t in zip(b_t_list, u_hat_t_list_):
                    b_t.transpose_(1,3).transpose_(2,1)
                    b_t_list_.append(b_t + (u_hat_t * v).sum(4))
        v.transpose_(1, 3).transpose_(2, 4)
        return v
    def squash(self, p):
        p_norm_sq = (p * p).sum(-1, True)
        p_norm = (p_norm_sq + 1e-9).sqrt()
        v = p_norm_sq / (1. + p_norm_sq) * p / p_norm
        return v
def update_routing(u_hat_t_list,k,N,H_1,W_1,t_0,t_1,routing):
    one_kernel = torch.ones(1, t_1, k, k).cuda()
    b = torch.zeros(N, H_1, W_1, t_0, t_1 ).cuda()
    b_t_list = [b_t.squeeze(3) for b_t in b.split(1, 3)]
    u_hat_t_list_sg = []
    for u_hat_t in u_hat_t_list:
        u_hat_t_sg = u_hat_t.clone()
        u_hat_t_sg.detach_()
        u_hat_t_list_sg.append(u_hat_t_sg)
    for d in range(routing):
        if d < routing - 1:
            u_hat_t_list_ = u_hat_t_list_sg
        else:
            u_hat_t_list_ = u_hat_t_list
        r_t_mul_u_hat_t_list = []
        for b_t, u_hat_t in zip(b_t_list, u_hat_t_list_):
            b_t.transpose_(1, 3).transpose_(2, 3)
            torch.nn.functional.max_pool2d(b_t,k,)
            b_t_max = max_pool2d_same(b_t, k, 1)
            b_t_max = b_t_max.max(1, True)[0]
            c_t = torch.exp(b_t - b_t_max)
            sum_c_t = conv2d_same(c_t, one_kernel, stride=(1, 1)) 
            r_t = c_t / sum_c_t 
            r_t = r_t.transpose(1, 3).transpose(1, 2)  
            r_t = r_t.unsqueeze(4)  
            r_t_mul_u_hat_t_list.append(r_t * u_hat_t) 
        p = sum(r_t_mul_u_hat_t_list)  
        v = squash(p)
        if d < routing - 1:
            b_t_list_ = []
            for b_t, u_hat_t in zip(b_t_list, u_hat_t_list_):
                b_t = b_t.transpose(1, 3).transpose(1, 2)
                b_t_list_.append(b_t + (u_hat_t * v).sum(4))
            b_t_list = b_t_list_
        v.transpose_(1,3).transpose_(2,4)
    return v
def squash( p):
    p_norm_sq = (p * p).sum(-1, True)
    p_norm = (p_norm_sq + 1e-9).sqrt()
    v = p_norm_sq / (1. + p_norm_sq) * p / p_norm
    return v
def test():
    m=CapsuleLayer(1, 16, "conv", k=5, s=1, t_1=2, z_1=16, routing=1)
    m=m.cuda()
    b=input('s')
    a=torch.randn(10, 1, 16, int(b), int(b))
    a=a.cuda()
    optimizer = optim.Adam(m.parameters(), lr=1)
    for k,v in m.named_parameters():
        print(k)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20],
                                               gamma=0.1)
    b=m(a)
    c=b.mean()
    for k in m.parameters():
        print(k)
    print(b.shape)
    print(c)
    c.backward()
    optimizer.step()
    b=m(a)
    c=b.mean()
    print(c)
    print(a.grad)
    print(b.shape)
def test1():
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    a = []

    b = torch.ones([1, 10, 10, 2, 3]).cuda()
    print(b)
    a.append(b)
    c = update_routing(a, 2, 1, 10, 10, 1, 2,3)
    print(c.cpu().numpy())

#capsulener
class SegCaps(nn.Module):
    def __init__(self,n_channels, n_classes, bilinear=False,index = 0):
        super().__init__()
        self.conv_1 = nn.Sequential(
            conv_function[index](3, 16, 5, padding=2,bias=False),
        )
        self.step_1 = nn.Sequential(  # 1/2
            CapsuleLayer(1, 16, "conv", k=5, s=2, t_1=2, z_1=16, routing=1),
            CapsuleLayer(2, 16, "conv", k=5, s=1, t_1=4, z_1=16, routing=3),
        )
        self.step_2 = nn.Sequential(  # 1/4
            CapsuleLayer(4, 16, "conv", k=5, s=2, t_1=4, z_1=32, routing=3),
            CapsuleLayer(4, 32, "conv", k=5, s=1, t_1=8, z_1=32, routing=3)
        )
        self.step_3 = nn.Sequential(  # 1/8
            CapsuleLayer(8, 32, "conv", k=5, s=2, t_1=8, z_1=64, routing=3),
            CapsuleLayer(8, 64, "conv", k=5, s=1, t_1=8, z_1=32, routing=3)
        )
        self.step_4 = CapsuleLayer(8, 32, "deconv", k=5, s=2, t_1=8, z_1=32, routing=3)

        self.step_5 = CapsuleLayer(16, 32, "conv", k=5, s=1, t_1=4, z_1=32, routing=3)

        self.step_6 = CapsuleLayer(4, 32, "deconv", k=5, s=2, t_1=4, z_1=16, routing=3)
        self.step_7 = CapsuleLayer(8, 16, "conv", k=5, s=1, t_1=4, z_1=16, routing=3)
        self.step_8 = CapsuleLayer(4, 16, "deconv", k=5, s=2, t_1=2, z_1=16, routing=3)
        self.step_10 = CapsuleLayer(3, 16, "conv", k=5, s=1, t_1=5, z_1=n_classes, routing=3)
        self.conv_2 = nn.Sequential(
            conv_function[index](16, n_classes, 5, padding=2,bias=False),
        )
    def forward(self, x):
        x = self.conv_1(x)
        x.unsqueeze_(1)
        skip_1 = x
        x = self.step_1(x)
        skip_2 = x 
        x = self.step_2(x)
        skip_3 = x 
        x = self.step_3(x)  
        x = self.step_4(x) 
        x = torch.cat((x, skip_3), 1)  

        x = self.step_5(x)  

        x = self.step_6(x)

        x = torch.cat((x, skip_2), 1)  
        x = self.step_7(x)  
        x = self.step_8(x) 
        x=torch.cat((x,skip_1),1)
        x=self.step_10(x)
        x.squeeze_(1)
        v_lens = self.compute_vector_length(x)
        v_lens=v_lens.squeeze(1)
        return v_lens
    def compute_vector_length(self, x):
        out = (x.pow(2)).sum(1, True)+1e-9
        out=out.sqrt()
        return out

In [None]:
models = [UNet,AttenUNet,UNetplusplus,SegCaps]  
class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
    def __len__(self):
        return len(self.x)
    def __getitem__(self, index):
        return self.x[index], self.y[index]
def train(x_train,y_train,n_classes,names,d,batch_size,num_epochs,n_channels,n,learning_rate):
    global time_list
    global emissions
    device = 'cuda'
    metrics = []
    torch.manual_seed(42)
    y_train = torch.from_numpy(y_train).long()
    x_train = torch.from_numpy(x_train).float()
    x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2,shuffle=True)
    train_dataset = CustomDataset(x_train, y_train)
    val_dataset = CustomDataset(x_val, y_val) 

    y = y_train.view(-1).numpy()
    class_counts = np.bincount(y)
    num_classes = len(class_counts)
    total_samples = len(y)
    class_weights = []
    for count in class_counts:
        freq = count / total_samples
        weight = 1/freq
        class_weights.append(weight)

    class_weights = [x/sum(class_weights) for x in class_weights]
    class_weights = torch.FloatTensor(class_weights)
    class_weights = class_weights.to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    print("y_train:",y.shape,"num_classes",num_classes,"class_weights:",class_weights)
    q = 0
    b = batch_size
    temp = -1
    batch = []
    while q < len(names):                      
        try:
            if temp !=q:
                print(names[q]+d) 
            train_losses = []
            train_accs = []
            val_losses = []
            val_accs = []
            tracker = EmissionsTracker(save_to_file=True, output_file='my_emissions.csv', log_level="ERROR")
            tracker.start()
            t1 = time.time()
            model = models[n](n_channels, n_classes, bilinear=False,index=q)
            model.to(device)
            best_acc = 0.0
            train_dataloader = DataLoader(train_dataset, b, shuffle=True)
            val_dataloader = DataLoader(val_dataset, b, shuffle=True)
            l_t = len(train_dataloader)
            l_v = len(val_dataloader)
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
            for epoch in range(num_epochs):
                running_loss = 0.0
                running_acc = 0.0
                t = 0
                v = 0
                model.train()
                for i, (inputs, labels) in enumerate(train_dataloader):
                    inputs = inputs.permute(0, 3, 1, 2)
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    outputs = torch.exp(outputs)
                    loss = criterion(outputs, labels)
                    acc = (outputs.argmax(dim=1) == labels).float().mean()
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    running_acc += acc.item()
                    t+=1
                    print(f'\rEpoch {epoch+1}: Train Progress: {int((t/l_t)*100)}%',end="")
                train_losses.append(running_loss / len(train_dataloader))
                train_accs.append(running_acc / len(train_dataloader))
                val_loss = 0.0
                val_acc = 0.0
                model.eval()
                with torch.no_grad():
                    for i ,(inputs, labels) in enumerate(val_dataloader):
                        inputs = inputs.permute(0, 3, 1, 2)
                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        outputs = model(inputs)
                        outputs = torch.exp(outputs)
                        loss = criterion(outputs, labels)
                        acc = (outputs.argmax(dim=1) == labels).float().mean()
                        val_loss += loss.item()
                        val_acc += acc.item()
                        v+=1
                        print(f'\rEpoch {epoch+1}: Train Progress: {int((t/l_t)*100)}% Validation Progress: {int((v/l_v)*100)}%  ',end="")
                val_losses.append(val_loss / len(val_dataloader))
                val_accs.append(val_acc / len(val_dataloader))
                scheduler.step(val_loss)
                print(f'Epoch {epoch+1}: Training Loss: {running_loss / len(train_dataloader):.4f}, Training Accuracy: {running_acc / len(train_dataloader):.4f}, Validation Loss: {val_loss / len(val_dataloader):.4f}, Validation Accuracy: {val_acc / len(val_dataloader):.4f}, batch_size: {b} ')
                if val_acc > best_acc:
                    best_acc = val_acc
                    torch.save(model.state_dict(),f'model/{names[q]}{d}.pth')                         
            t2 = time.time()
            tracker.stop()
            df = pd.read_csv('my_emissions.csv')
            df = df['emissions']
            total_emissions = df.sum()
            emissions.append(total_emissions)
            os.remove('my_emissions.csv')
            metrics.append({'train_losses': train_losses, 'train_accs': train_accs, 'val_losses': val_losses, 'val_accs': val_accs}) 
            time_list.append((t2 - t1) / 60.0)
            batch.append(b)
            b = batch_size
            q+=1
        except Exception as e:
            if temp != q:
                print(e,"for batch_size:",batch_size)
                temp = q
            if b > 1:
                b -=1
            else:
                print("batch_size can't be reduced further, increase your gpu size")
                return
    f, axarr = plt.subplots(len(names), 1, figsize=(10, 4*len(names))) 
    epoch = range(1, num_epochs +1)
    for q in range(len(names)):
        axarr[q].plot(epoch, metrics[q]['train_losses'], marker='o', linestyle='-', color='blue', label="Training Loss")
        axarr[q].plot(epoch, metrics[q]['val_losses'], marker='o', linestyle='-', color='red', label="Validation Loss")
        axarr[q].plot(epoch, metrics[q]['train_accs'], marker='o', linestyle='-', color='green', label="Training Accuracy")
        axarr[q].plot(epoch, metrics[q]['val_accs'], marker='o', linestyle='-', color='orange', label="Validation Accuracy")
        axarr[q].set_ylim(0, 3)
        num_ticks = 25 
        axarr[q].set_yticks(np.linspace(0, 3, num_ticks))
        axarr[q].set_xlabel("Epochs")
        axarr[q].set_ylabel("Metrics")
        axarr[q].set_title(f"{names[q]+d} (batch_size:{batch[q]},num_epochs:{num_epochs},learning_rate:{learning_rate})")
        axarr[q].legend()
    plt.tight_layout()
    plt.savefig("output/train_models_"+names[0]+d+".png",bbox_inches='tight')
    plt.close()           


In [None]:
def evaluate(x_test,y_test,names,d,batch_size,n_channels,n_classes,n):
    rows_table = []
    device = 'cuda'
    x_test = x_test.transpose(0, 3, 1, 2)
    x_test = torch.from_numpy(x_test).float()
    y_test = torch.from_numpy(y_test).long()
    test_dataset = torch.utils.data.TensorDataset(x_test, y_test)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    l = len(test_loader)
    x_test.to(device)
    y_test.to(device)    
    for q in range(len(names)):
        model = models[n](n_channels, n_classes, bilinear=False,index=q)
        model.load_state_dict(torch.load(f'model/{names[q]}{d}.pth'))
        model.to(device)
        model.eval()
        outputs_list = []
        preds_list = []
        x = 0
        flops = 0
        with torch.no_grad():
            for x_batch, y_batch in test_loader:
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)
                out_batch = model(x_batch)
                pred_batch = out_batch.argmax(dim=1)
                outputs_list.append(out_batch)
                preds_list.append(pred_batch)
                x+=1
                print(f'\r{names[q]+d} Progress: {int((x/l)*100)}%',end="")
        print()
        flops, _ = get_model_complexity_info(model, (3, 256, 256), as_strings=True, print_per_layer_stat=False)
        outputs = torch.cat(outputs_list, dim=0)
        preds = torch.cat(preds_list, dim=0)
        y_test_flat = y_test.flatten()
        preds_flat = preds.flatten().cpu().numpy()
        accuracy = accuracy_score(y_test_flat, preds_flat)
        kappa = cohen_kappa_score(y_test_flat, preds_flat)
        iou_score = jaccard_score(y_test_flat, preds_flat, average='macro')
        prec_score = precision_score(y_test_flat, preds_flat, average='macro')
        rec_score = recall_score(y_test_flat, preds_flat, average='macro')
        f1 = f1_score(y_test_flat, preds_flat, average='macro')
        cr = classification_report(y_test_flat, preds_flat)
        total_params = 0
        for param in model.parameters():
            num_params = param.numel()
            total_params += num_params
        rows_table.append(
        [names[q],
        f'{float(accuracy):.5f}' if isinstance(accuracy, (int, float)) else accuracy,
        f'{float(iou_score):.5f}' if isinstance(iou_score, (int, float)) else iou_score,
        f'{float(prec_score):.5f}' if isinstance(prec_score, (int, float)) else prec_score,
        f'{float(rec_score):.5f}' if isinstance(rec_score, (int, float)) else rec_score,
        f'{float(f1):.5f}' if isinstance(f1, (int, float)) else f1,
        f'{float(kappa):.5f}' if isinstance(kappa, (int, float)) else kappa,
        f'{float(total_params):.5f}' if isinstance(total_params, (int, float)) else total_params,
        f'{float(flops):.5f}' if isinstance(flops, (int, float)) else flops,
        cr,
        f'{float(emissions[q]):.2e}',
        f'{float(time_list[q]):.5f}'])
    print(cr)
    headers = ["Model", "Accuracy", "IoU", "Precision", "Recall", "F1 Score", "Kappa", "Parameters","Flops","Classification Report","Carbon Emmission(kgCO2eq)","Train time(min)"]
    fig, ax = plt.subplots()
    ax.axis('off')
    table = ax.table(cellText=rows_table, colLabels=headers, loc='center', cellLoc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.auto_set_column_width(col=[9])
    table.auto_set_column_width(col=[10])
    table.auto_set_column_width(col=[11])
    table.scale(3*1, 6*2)
    for i, header in enumerate(headers):
        table[0, i].get_text().set_fontweight('bold')
    plt.savefig('output/table_'+names[0]+d+'.png',bbox_inches='tight')
    plt.close()

In [None]:
def hex_to_rgb(hex_code):
  rgb=np.array(tuple(int(hex_code[i:i+2],16) for i in (0,2,4)))
  return rgb
corn = hex_to_rgb("ffd300")
soyabean=  hex_to_rgb("267000")
cotton = hex_to_rgb("ff2626")
spring_wheat = hex_to_rgb("d8b56b")
outside_usa=np.array([0,0,0])
def get_mask1(arr):
    map_dict = {'corn': corn, 'soyabean': soyabean, 'cotton': cotton, 'spring_wheat':spring_wheat,'outside_usa':outside_usa}
    mask = np.zeros((arr.shape[0], arr.shape[1], 3))
    mask[arr == 0, :] = map_dict['corn']
    mask[arr == 1, :] = map_dict['soyabean']
    mask[arr == 2, :] = map_dict['cotton']
    mask[arr == 3, :] = map_dict['spring_wheat']
    mask[arr == 4, :] = map_dict['outside_usa']
    return mask
def get_mask2(arr):
    map_dict = {'b': np.array([[[0, 0, 255]]]), 'g': np.array([[[0, 255, 0]]]), 'r': np.array([[[255, 0, 0]]])}
    mask = np.zeros((arr.shape[0], arr.shape[1], 3))
    mask[arr == 1, :] = map_dict['b']
    mask[arr == 2, :] = map_dict['g']
    mask[arr == 3, :] = map_dict['r']
    return mask
get_mask_func = [get_mask1,get_mask2,get_mask2]
def plot(x_test,y_test,names,d,n_channels, n_classes,n):
    get_mask = get_mask_func[int(d)-1]
    device = 'cuda'
    h= x_test.shape[2]
    x_test_tensor = torch.from_numpy(x_test).permute(0, 3, 1, 2).float()
    x_test_tensor = x_test_tensor.to(device)
    f, axarr = plt.subplots(6, len(names)+2, figsize=(3 * (len(names)+2), 3 * 6))
    cnt=0
    for i in [18,21,25,15,99,51]:
      axarr[0,0].set_title('NDVI', fontweight='bold', fontsize=16)
      axarr[0,1].set_title('GT', fontweight='bold', fontsize=16)
      axarr[cnt,0].imshow(x_test[i])
      axarr[cnt,1].imshow(get_mask(y_test[i].reshape((h, h))))
      cnt=cnt+1
    for q in range(len(names)):
        cnt=0
        for i in [18,21,25,15,99,51]:
            model = models[n](n_channels, n_classes, bilinear=False,index=q)
            model.load_state_dict(torch.load(f'model/{names[q]}{d}.pth'))
            model.to(device)
            model.eval()
            test_img = model.forward(x_test_tensor[i:i + 1])
            test_img = torch.argmax(test_img, dim=1)
            test_img = test_img.cpu().numpy().squeeze()
            test_img = get_mask(test_img.reshape((h, h)))
            test_img = np.clip(test_img, 0, h-1)
            axarr[0,q+2].set_title(names[q], fontweight='bold', fontsize=16)
            axarr[cnt, q+2].imshow(test_img / (h-1))
            cnt+=1
    plt.savefig('output/plot_'+names[0]+d+'.png',bbox_inches='tight')
    plt.close()

In [None]:
#run
n = int(input("Select architecture number: 1-UNet, 2-Attention-UNet, 3-UNet++, 4-SegCaps:  "))
d = int(input("Select dataset number: 1, 2, 3:  "))
names = {}
time_list = []
emissions = [] 
names[1] = ['UNet', 'UNet_SSC', 'UNet_GDC', 'UNet_DC', 'UNet_ADC', 'UNet_orgAC', 'UNet_AC','UNet_GD_AC', 'UNet_D_AC', 'UNet_AD_AC']
names[2] = ['AttenUnet', 'AttenUnet_SSC', 'AttenUnet_GDC', 'AttenUnet_DC','AttenUnet_ADC','AttenUnet_orgAC', 'AttenUnet_AC',  'AttenUnet_GD_AC', 'AttenUnet_D_AC','AttenUnet_AD_AC']
names[3] = ['UNet++', 'UNet++_SSC', 'UNet++_GDC', 'UNet++_DC','UNet++_ADC', 'UNet++_orgAC', 'UNet++_AC','UNet++_GD_AC', 'UNet++_D_AC','UNet++_AD_AC']
names[4] = ['SegCaps', 'SegCaps_SSC', 'SegCaps_GDC', 'SegCaps_DC', 'SegCaps_ADC', 'SegCaps_orgAC', 'SegCaps_AC','SegCaps_GD_AC', 'SegCaps_D_AC', 'SegCaps_AD_AC']
d = str(d)
if d == '1':
    y_train = np.load('datasets/dataset'+d+'/y_train.npz')['arr_0']
    x_train = np.load('datasets/dataset'+d+'/x_train.npz')['arr_0']
    x_test = np.load('datasets/dataset'+d+'/x_test.npz')['arr_0']
    y_test = np.load('datasets/dataset'+d+'/y_test.npz')['arr_0']
    n_classes = len(np.unique(y_train))
else:
    y_train = np.load('datasets/dataset'+d+'/y_train.npy')
    x_train = np.load('datasets/dataset'+d+'/x_train.npy')
    x_test = np.load('datasets/dataset'+d+'/x_test.npy')
    y_test = np.load('datasets/dataset'+d+'/y_test.npy')
    n_classes = len(np.unique(y_train))
    y_train = np.squeeze(y_train, axis=-1)
    y_test = np.squeeze(y_test,axis=-1)
n_channels = x_train.shape[3]
batch_size = int(input("Input batch size number:  "))
num_epochs = int(input("Input number of epochs:  "))
learning_rate = float(input("Input learning rate: "))      
print("x_train",x_train.shape,"y_train:",y_train.shape,"x_test",x_test.shape,y_test.shape,"num_classes:",n_classes)
print("Training...")
train(x_train,y_train,n_classes,names[n],d,batch_size,num_epochs,n_channels,n-1,learning_rate)
print("Evaluating...")
evaluate(x_test,y_test,names[n],d,1,n_channels,n_classes,n-1)
print("Plotting...")
plot(x_test,y_test,names[n],d,n_channels, n_classes,n-1)
