In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CRU(nn.Module):
    '''
    alpha: 0<alpha<1
    '''
    def __init__(self, 
                 op_channel:int,
                 alpha:float = 1/2,
                 squeeze_radio:int = 2 ,
                 group_size:int = 2,
                 group_kernel_size:int = 3,
                 ):
        super().__init__()
        self.up_channel     = up_channel   =   int(alpha*op_channel)
        self.low_channel    = low_channel  =   op_channel-up_channel
        self.squeeze1       = nn.Conv2d(up_channel,up_channel//squeeze_radio,kernel_size=1,bias=False)
        self.squeeze2       = nn.Conv2d(low_channel,low_channel//squeeze_radio,kernel_size=1,bias=False)
        #up
        self.GWC            = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=group_kernel_size, stride=1,padding=group_kernel_size//2, groups = group_size)
        self.PWC1           = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=1, bias=False)
        #low
        self.PWC2           = nn.Conv2d(low_channel//squeeze_radio, op_channel-low_channel//squeeze_radio,kernel_size=1, bias=False)
        self.advavg         = nn.AdaptiveAvgPool2d(1)

    def forward(self,x):
        # Split
        up,low  = torch.split(x,[self.up_channel,self.low_channel],dim=1)
        up,low  = self.squeeze1(up),self.squeeze2(low)
        # Transform
        Y1      = self.GWC(up) + self.PWC1(up)
        Y2      = torch.cat( [self.PWC2(low), low], dim= 1 )
        # Fuse
        out     = torch.cat( [Y1, Y2], dim= 1 )
        out     = F.softmax( self.advavg(out), dim=1 ) * out
        out1,out2 = torch.split(out,out.size(1)//2,dim=1)
        return out1+out2

In [10]:

img = torch.ones(1,16,64,64)


oc = CRU(op_channel = 16)

print(oc(img).shape)

torch.Size([1, 16, 64, 64])


In [17]:
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)

        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        print(self.dwconv(x).shape)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        print(x1.shape, x2.shape)
        x = F.gelu(x2)*x1 + F.gelu(x1)*x2
        x = self.project_out(x)
        return x




In [18]:
ffn = FeedForward(dim=16, ffn_expansion_factor=2.66, bias=False)
img = torch.ones(1,16,64,64)

print(ffn(img).shape)


torch.Size([1, 84, 64, 64])
torch.Size([1, 42, 64, 64]) torch.Size([1, 42, 64, 64])
torch.Size([1, 16, 64, 64])


In [12]:
import imp
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_, DropPath, to_2tuple
import os
from model.blocks import Mlp


class query_Attention(nn.Module):
    def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Parameter(torch.ones((1, 10, dim)), requires_grad=True)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        print('before attention', x.shape)
        B, N, C = x.shape
        
        k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, 10, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        print('after attention', x.shape)
        return x


class query_SABlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.norm1 = norm_layer(dim)
        self.attn = query_Attention(
            dim,
            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.pos_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class conv_embedding(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_embedding, self).__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(out_channels // 2),
            nn.GELU(),
            # nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            # nn.BatchNorm2d(out_channels // 2),
            # nn.GELU(),
            nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        x = self.proj(x)
        return x


class Global_pred(nn.Module):
    def __init__(self, in_channels=3, out_channels=64, num_heads=4, type='exp'):
        super(Global_pred, self).__init__()
        if type == 'exp':
            self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=False) # False in exposure correction
        else:
            self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=True)  
        self.color_base = nn.Parameter(torch.eye((3)), requires_grad=True)  # basic color matrix
        # main blocks
        self.conv_large = conv_embedding(in_channels, out_channels)
        self.generator = query_SABlock(dim=out_channels, num_heads=num_heads)
        self.gamma_linear = nn.Linear(out_channels, 1)
        self.color_linear = nn.Linear(out_channels, 1)

        self.apply(self._init_weights)

        for name, p in self.named_parameters():
            if name == 'generator.attn.v.weight':
                nn.init.constant_(p, 0)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        print('x', x.shape)
        #print(self.gamma_base)
        
        x = self.conv_large(x)
        print('conv_large', x.shape)
        x = self.generator(x)
        print('generator', x.shape)
        gamma, color = x[:, 0].unsqueeze(1), x[:, 1:]
        print('gamma, color', gamma.shape, color.shape)
        gamma = self.gamma_linear(gamma).squeeze(-1) + self.gamma_base
        #print(self.gamma_base, self.gamma_linear(gamma))
        color = self.color_linear(color).squeeze(-1).view(-1, 3, 3) + self.color_base
        return gamma, color

if __name__ == "__main__":
    os.environ['CUDA_VISIBLE_DEVICES']='3'

    img = torch.Tensor(8, 3, 400, 600)
    global_net = Global_pred()
    gamma, color = global_net(img)
    print(gamma.shape, color.shape)

x torch.Size([8, 3, 400, 600])
conv_large torch.Size([8, 64, 100, 150])
before attention torch.Size([8, 15000, 64])
qqqq torch.Size([8, 15000, 64])
after attention torch.Size([8, 10, 64])
xxx torch.Size([8, 10, 64])
generator torch.Size([8, 10, 64])
gamma, color torch.Size([8, 1, 64]) torch.Size([8, 9, 64])
torch.Size([8, 1]) torch.Size([8, 3, 3])


In [None]:
class PA(nn.Module):
    """ Position attention module"""
    #Ref from SAGAN
    def __init__(self, in_dim):
        super(PA, self).__init__()
        self.chanel_in = in_dim

        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x, fft_map=None):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X (HxW) X (HxW)
        """
        fft_map = fft_map.float()
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(fft_map).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma*out + x
        return out

In [2]:
import torch.nn as nn
import torch
from torchvision import transforms
import cv2
import numpy as np

class conv_embedding(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_embedding, self).__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )

    def forward(self, x):
        x = self.proj(x)
        return x
    
class CA(nn.Module):
    """ Channel attention module"""
    def __init__(self, in_dim):
        super(CA, self).__init__()
        self.chanel_in = in_dim

        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)
    def forward(self,x, fft_map=None):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X C X C
        """
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma*out + x
        return out
    
    

def fft_img(img):
    
    f = np.fft.fft2(img)
    fshift = np.fft.fftshift(f)
    rows, cols = img.shape
    crow, ccol = int(rows/2), int(cols/2)

    fshift[crow-30:crow+30, ccol-30:ccol+30] = 0
    ishift = np.fft.ifftshift(fshift)
    iimg = np.fft.ifft2(ishift)
    iimg = np.abs(iimg) 
    
    return transforms.ToTensor()(iimg)
    
class FFT_Block(nn.Module):
    def __init__(self, in_channels=3, out_channels=32, num_heads=4, type='exp'):
        super(FFT_Block, self).__init__()
        # main blocks
        self.conv_large = conv_embedding(in_channels, out_channels)
        self.generator = CA(in_dim=out_channels)
        
#         self.up = nn.Upsample(scale_factor=2, mode='bilinear')
        self.conv_fft = nn.Conv2d(3, 3, kernel_size=3, stride=(1, 1), padding=(1, 1))
        self.conv_A = nn.Conv2d(out_channels, 3, kernel_size=3, stride=(1, 1), padding=(1, 1))



    def forward(self, x):
        fft_x = self.conv_fft(x)
        b,c,h,w = fft_x.shape
        fft_map = torch.zeros(b, 32, h, w)
        for i in range(b):
            
            img = fft_x[i, :, :, :].cpu().detach().numpy().transpose(1,2,0) # b c h w -> h w c
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            fft = fft_img(img*255)/255
            fft = torch.cat([fft for i in range(32)], dim=0)
            
            fft_map[i, :,:,:] = fft
        print(fft_map.shape)   
        x = self.conv_large(x)
        x = self.generator(x, fft_map)      
        x = self.conv_A(x)
        return x  
 

import time
fb = FFT_Block()


start = time.time()
# img = cv2.imread('1.png')
# img = cv2.resize(img, (600, 400))


# img = transforms.ToTensor()(img).unsqueeze(0)
img = torch.Tensor(8, 3, 400, 600)
print(img.shape)
print(fb(img).shape)
end = time.time()


# print(pa(img).shape)
print(end-start)

torch.Size([8, 3, 400, 600])
torch.Size([8, 32, 400, 600])
torch.Size([8, 3, 400, 600])
0.6689531803131104


In [4]:
 class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=2):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)*x


img = torch.Tensor(8, 4, 400, 600)
print(img.shape)
print(ChannelAttention(4)(img).shape)

torch.Size([8, 4, 400, 600])
torch.Size([8, 4, 400, 600])


In [3]:
class query_Attention(nn.Module):
    def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Parameter(torch.ones((1, 10, dim)), requires_grad=True)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, 10, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
img = torch.Tensor(8, 4, 400, 600)
print(img.shape)
print(query_Attention(4)(img).shape)

torch.Size([8, 4, 400, 600])


ValueError: too many values to unpack (expected 3)

In [14]:
from model.blocks import CBlock_ln, SwinTransformerBlock

gloabel = nn.Sequential(
            nn.Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(4),
            nn.GELU(),
        )
linear = nn.Linear(32,1)
img = torch.Tensor(8, 32, 100,150)
img = img.reshape(img.shape[0], -1, img.shape[1])
print(img.shape)
print(linear(img).shape)

torch.Size([8, 15000, 32])
torch.Size([8, 15000, 1])


In [8]:
import torch.nn as nn
import torch
from torchvision import transforms
import torch.nn.functional as F
class DPM(nn.Module):
    def __init__(self, inplanes, planes, act=nn.LeakyReLU(negative_slope=0.2,inplace=True), bias=False):
        super(DPM, self).__init__()

        self.conv_mask = nn.Conv2d(inplanes, inplanes, kernel_size=1, bias=bias)
        self.softmax = nn.Softmax(dim=2)
        self.sigmoid = nn.Sigmoid()

        self.channel_add_conv = nn.Sequential(
            nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias),
            act,
            nn.Conv2d(planes, inplanes, kernel_size=1, bias=bias)
        )
        self.linear = nn.Linear(inplanes, 1)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        input_x = x
        input_x = input_x.view(batch, channel, height * width)
        input_x = input_x.unsqueeze(1)
        
        context_mask = self.conv_mask(x)
        context_mask = context_mask.view(batch, -1, channel)
        print(context_mask.shape)
        context_mask = self.linear(context_mask)
        context_mask = context_mask.unsqueeze(1)
        context = torch.matmul(input_x, context_mask)
        context = context.view(batch, channel, 1, 1)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        channel_add_term = self.channel_add_conv(context)
        x = x + channel_add_term
        return x
img = torch.Tensor(8, 32, 100,150)
print(img.shape)
print(DPM(32,32)(img).shape)

torch.Size([8, 32, 100, 150])
torch.Size([8, 15000, 32])
torch.Size([8, 32, 100, 150])


In [4]:
from model.drconv import DRConv2d
import functools

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter


class _routing(nn.Module):

    def __init__(self, in_channels, num_experts, dropout_rate):
        super(_routing, self).__init__()
        
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(in_channels, num_experts)

    def forward(self, x):
        x = torch.flatten(x)
        x = self.dropout(x)
        x = self.fc(x)
        return F.sigmoid(x)
    

class CondConv2D(_ConvNd):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(CondConv2D, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

        self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1))
        self._routing_fn = _routing(in_channels, num_experts, dropout_rate)
        
        self.weight = Parameter(torch.Tensor(
            num_experts, out_channels, in_channels // groups, *kernel_size))
        
        self.reset_parameters()

    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
    
    def forward(self, inputs):
        b, _, _, _ = inputs.size()
        res = []
        for input in inputs:
            input = input.unsqueeze(0)
            pooled_inputs = self._avg_pooling(input)
            routing_weights = self._routing_fn(pooled_inputs)
            kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0)
            out = self._conv_forward(input, kernels)
            res.append(out)
        return torch.cat(res, dim=0)
    
img = torch.Tensor(8, 32, 100,150)
print(img.shape)
import time
start = time.time()
print(CondConv2D(in_channels=32, out_channels=32, kernel_size=1)(img).shape)
end = time.time()
print(end-start)

torch.Size([8, 32, 100, 150])
torch.Size([8, 32, 100, 150])
0.01157999038696289


In [2]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import os
import math

from torchvision import transforms
import cv2
import numpy as np
from timm.models.layers import trunc_normal_
from model.blocks import CBlock_ln, SwinTransformerBlock
from model.CondConv2D import CondConv2D


# Short Cut Connection on Final Layer
class Local_pred_S(nn.Module):
    def __init__(self, in_dim=3, dim=16, number=1):
        super(Local_pred_S, self).__init__()
        # initial convolution
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_dim, dim, 3, padding=1, groups=1),
            nn.Conv2d(dim, dim, 3, padding=1)
        )
        
        self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        block_t = SwinTransformerBlock(dim)  # head number
        
        blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
        blocks1.append(CondConv2D(in_channels=dim, out_channels=dim, kernel_size=1))
        blocks2.append(CondConv2D(in_channels=dim, out_channels=dim, kernel_size=1))
        self.mul_blocks = nn.Sequential(*blocks1)
        self.add_blocks = nn.Sequential(*blocks2)

        self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, padding = 1), nn.ReLU())
        self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, padding = 1), nn.Tanh())
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, img):
        img1 = self.relu(self.conv1(img))
        # short cut connection
        mul = self.mul_blocks(img1) + img1
        add = self.add_blocks(img1) + img1
        mul = self.mul_end(mul)
        add = self.add_end(add)
        return mul, add

class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)

        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.relu(x2)*x1 + F.relu(x1)*x2
        x = self.project_out(x)
        return x    

class Net5(nn.Module):
    def __init__(self, in_dim=3):
        super(Net5, self).__init__()
        #self.local_net = Local_pred()
        self.local_net = Local_pred_S(in_dim=in_dim, dim=32)
        self.fft_block = FFT_Block()
        
    def apply_color(self, image, ccm):
        shape = image.shape
        image = image.view(-1, 3)
        image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
        image = image.view(shape)
        return torch.clamp(image, 1e-8, 1.0)

    def forward(self, img_low):
        #print(self.with_global)
        mul, add = self.local_net(img_low)
        img_high = (img_low.mul(mul)).add(add)
        
        
        fft_map = self.fft_block(img_low)
        img_high = img_high + fft_map
        
        return mul, add, img_high
        
class global_module(nn.Module):
    def __init__(self, inplanes, act=nn.LeakyReLU(negative_slope=0.2,inplace=True), bias=False):
        super(global_module, self).__init__()

        self.conv_mask = nn.Conv2d(inplanes, inplanes, kernel_size=1, bias=bias)
        self.linear = nn.Linear(inplanes, 1)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        input_x = x
        input_x = input_x.view(batch, channel, height * width)
        input_x = input_x.unsqueeze(1)
        
        context_mask = self.conv_mask(x)
        context_mask = context_mask.view(batch, -1, channel)
        context_mask = self.linear(context_mask)
        context_mask = context_mask.unsqueeze(1)
        context = torch.matmul(input_x, context_mask)
        context = context.view(batch, channel, 1, 1)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        x = x + context
        return x
    
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=2):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return out + x


class conv_embedding(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_embedding, self).__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )

    def forward(self, x):
        x = self.proj(x)
        return x
    
class CA(nn.Module):
    """ Channel attention module"""
    def __init__(self, in_dim):
        super(CA, self).__init__()
        self.chanel_in = in_dim

        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)
    def forward(self,x, fft_map=None):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X C X C
        """
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma*out + x
        return out
    
    

def fft_img(img):
    
    f = np.fft.fft2(img)
    fshift = np.fft.fftshift(f)
    rows, cols = img.shape
    crow, ccol = int(rows/2), int(cols/2)

    fshift[crow-30:crow+30, ccol-30:ccol+30] = 0
    ishift = np.fft.ifftshift(fshift)
    iimg = np.fft.ifft2(ishift)
    iimg = np.abs(iimg) 
    
    return transforms.ToTensor()(iimg)


class FFT_Block(nn.Module):
    def __init__(self, in_channels=3, out_channels=32, num_heads=4, type='exp'):
        super(FFT_Block, self).__init__()
        # main blocks
        self.conv_large = conv_embedding(in_channels, out_channels)
        self.block_t = SwinTransformerBlock(out_channels)
        self.generator = CA(in_dim=out_channels)
        self.conv_fft = nn.Conv2d(3, 3, kernel_size=3, stride=(1, 1), padding=(1, 1))
        self.conv_A = nn.Conv2d(out_channels, 3, kernel_size=3, stride=(1, 1), padding=(1, 1))
        self.gm = ChannelAttention(out_channels)


    def forward(self, x):
        fft_x = self.conv_fft(x)
        b,c,h,w = fft_x.shape
        fft_map = torch.zeros(b, 32, h, w)
        for i in range(b):
            
            img = fft_x[i, :, :, :].cpu().detach().numpy().transpose(1,2,0) # b c h w -> h w c
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            fft = fft_img(img*255)/255
            fft = torch.cat([fft for i in range(32)], dim=0)
            
            fft_map[i, :,:,:] = fft
        x = self.conv_large(x)
        x = self.block_t(x)
        x = self.generator(x, fft_map)   
        x = self.gm(x)
        x = self.conv_A(x)
        return x  


if __name__ == "__main__":
    os.environ['CUDA_VISIBLE_DEVICES']='3'
    img = torch.Tensor(1, 3, 400, 600)
    net = Net5()
    print('total parameters:', sum(param.numel() for param in net.parameters()))
    _, _, high = net(img)




total parameters: 50468




In [3]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import os
import math

from timm.models.layers import trunc_normal_
from model.blocks import CBlock_ln, SwinTransformerBlock
from model.global_net import Global_pred

class Local_pred(nn.Module):
    def __init__(self, dim=16, number=4, type='ccc'):
        super(Local_pred, self).__init__()
        # initial convolution
        self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1)
        self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        # main blocks
        block = CBlock_ln(dim)
        block_t = SwinTransformerBlock(dim)  # head number
        if type =='ccc':  
            #blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)]
            blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
            blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
        elif type =='ttt':
            blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
        elif type =='cct':
            blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
        #    block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
        self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
        self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())


    def forward(self, img):
        img1 = self.relu(self.conv1(img))
        mul = self.mul_blocks(img1)
        add = self.add_blocks(img1)

        return mul, add

# Short Cut Connection on Final Layer
class Local_pred_S(nn.Module):
    def __init__(self, in_dim=3, dim=16, number=4, type='ccc'):
        super(Local_pred_S, self).__init__()
        # initial convolution
        self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1)
        self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        # main blocks
        block = CBlock_ln(dim)
        block_t = SwinTransformerBlock(dim)  # head number
        if type =='ccc':
            blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
            blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
        elif type =='ttt':
            blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
        elif type =='cct':
            blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
        #    block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
        self.mul_blocks = nn.Sequential(*blocks1)
        self.add_blocks = nn.Sequential(*blocks2)

        self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
        self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
            
            

    def forward(self, img):
        img1 = self.relu(self.conv1(img))
        # short cut connection
        mul = self.mul_blocks(img1) + img1
        add = self.add_blocks(img1) + img1
        mul = self.mul_end(mul)
        add = self.add_end(add)

        return mul, add

class IAT(nn.Module):
    def __init__(self, in_dim=3, with_global=True, type='lol'):
        super(IAT, self).__init__()
        #self.local_net = Local_pred()
        
        self.local_net = Local_pred_S(in_dim=in_dim)

        self.with_global = with_global
        if self.with_global:
            self.global_net = Global_pred(in_channels=in_dim, type=type)

    def apply_color(self, image, ccm):
        shape = image.shape
        image = image.view(-1, 3)
        image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
        image = image.view(shape)
        return torch.clamp(image, 1e-8, 1.0)

    def forward(self, img_low):
        #print(self.with_global)
        mul, add = self.local_net(img_low)
        img_high = (img_low.mul(mul)).add(add)

        if not self.with_global:
            return mul, add, img_high
        
        else:
            gamma, color = self.global_net(img_low)
            b = img_high.shape[0]
            img_high = img_high.permute(0, 2, 3, 1)  # (B,C,H,W) -- (B,H,W,C)
            img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])**gamma[i,:] for i in range(b)], dim=0)
            img_high = img_high.permute(0, 3, 1, 2)  # (B,H,W,C) -- (B,C,H,W)
            return mul, add, img_high


if __name__ == "__main__":
    os.environ['CUDA_VISIBLE_DEVICES']='3'
    img = torch.Tensor(1, 3, 400, 600)
    net = IAT()
    print('total parameters:', sum(param.numel() for param in net.parameters()))
    _, _, high = net(img)

total parameters: 91154


In [13]:
import numpy as np
arr = np.ones((608,608))
arr = np.expand_dims(arr, axis=2)
print(arr.shape)
img = np.concatenate((arr,arr,arr), axis=2)
img.shape

(608, 608, 1)


(608, 608, 3)

In [5]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import os
import math

from torchvision import transforms
import cv2
import numpy as np
from timm.models.layers import trunc_normal_
from model.blocks import CBlock_ln, SwinTransformerBlock
from model.CondConv2D import CondConv2D


# Short Cut Connection on Final Layer
class Local_pred_S(nn.Module):
    def __init__(self, in_dim=3, dim=16, number=1):
        super(Local_pred_S, self).__init__()
        # initial convolution
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_dim, dim, 3, padding=1, groups=1),
            nn.Conv2d(dim, dim, 3, padding=1)
        )
        
        self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        block_t = SwinTransformerBlock(dim)  # head number
        
        blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
        blocks1.append(CondConv2D(in_channels=dim, out_channels=dim, kernel_size=1))
        blocks2.append(CondConv2D(in_channels=dim, out_channels=dim, kernel_size=1))
        self.mul_blocks = nn.Sequential(*blocks1)
        self.add_blocks = nn.Sequential(*blocks2)

        self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, padding = 1), nn.ReLU())
        self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, padding = 1), nn.Tanh())
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, img):
        img1 = self.relu(self.conv1(img))
        # short cut connection
        mul = self.mul_blocks(img1) + img1
        add = self.add_blocks(img1) + img1
        mul = self.mul_end(mul)
        add = self.add_end(add)
        return mul, add
 

class Net6(nn.Module):
    def __init__(self, in_dim=3):
        super(Net6, self).__init__()
        #self.local_net = Local_pred()
        self.local_net = Local_pred_S(in_dim=in_dim, dim=32)
        
    def apply_color(self, image, ccm):
        shape = image.shape
        image = image.view(-1, 3)
        image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
        image = image.view(shape)
        return torch.clamp(image, 1e-8, 1.0)

    def forward(self, img_low):
        #print(self.with_global)
        mul, add = self.local_net(img_low)
        img_high = (img_low.mul(mul)).add(add)
        
        
#         fft_map = self.fft_block(img_low)
#         img_high = img_high + fft_map
        
        return mul, add, img_high
        
class global_module(nn.Module):
    def __init__(self, inplanes, act=nn.LeakyReLU(negative_slope=0.2,inplace=True), bias=False):
        super(global_module, self).__init__()

        self.conv_mask = nn.Conv2d(inplanes, inplanes, kernel_size=1, bias=bias)
        self.linear = nn.Linear(inplanes, 1)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        input_x = x
        input_x = input_x.view(batch, channel, height * width)
        input_x = input_x.unsqueeze(1)
        
        context_mask = self.conv_mask(x)
        context_mask = context_mask.view(batch, -1, channel)
        context_mask = self.linear(context_mask)
        context_mask = context_mask.unsqueeze(1)
        context = torch.matmul(input_x, context_mask)
        context = context.view(batch, channel, 1, 1)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        x = x + context
        return x
    
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=2):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return out + x


class conv_embedding(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(conv_embedding, self).__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )

    def forward(self, x):
        x = self.proj(x)
        return x
    
class CA(nn.Module):
    """ Channel attention module"""
    def __init__(self, in_dim):
        super(CA, self).__init__()
        self.chanel_in = in_dim

        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)
    def forward(self,x, fft_map=None):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X C X C
        """
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma*out + x
        return out
    
    

def fft_img(img):
    
    f = np.fft.fft2(img)
    fshift = np.fft.fftshift(f)
    rows, cols = img.shape
    crow, ccol = int(rows/2), int(cols/2)

    fshift[crow-30:crow+30, ccol-30:ccol+30] = 0
    ishift = np.fft.ifftshift(fshift)
    iimg = np.fft.ifft2(ishift)
    iimg = np.abs(iimg) 
    
    return transforms.ToTensor()(iimg)


class FFT_Block(nn.Module):
    def __init__(self, in_channels=3, out_channels=32, num_heads=4, type='exp'):
        super(FFT_Block, self).__init__()
        # main blocks
        self.conv_large = conv_embedding(in_channels, out_channels)
        self.block_t = SwinTransformerBlock(out_channels)
        self.generator = CA(in_dim=out_channels)
        self.conv_fft = nn.Conv2d(3, 3, kernel_size=3, stride=(1, 1), padding=(1, 1))
        self.conv_A = nn.Conv2d(out_channels, 3, kernel_size=3, stride=(1, 1), padding=(1, 1))
        self.gm = ChannelAttention(out_channels)


    def forward(self, x):
        fft_x = self.conv_fft(x)
        b,c,h,w = fft_x.shape
        fft_map = torch.zeros(b, 32, h, w)
        for i in range(b):
            
            img = fft_x[i, :, :, :].cpu().detach().numpy().transpose(1,2,0) # b c h w -> h w c
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            fft = fft_img(img*255)/255
            fft = torch.cat([fft for i in range(32)], dim=0)
            
            fft_map[i, :,:,:] = fft
        x = self.conv_large(x)
        x = self.block_t(x)
        x = self.generator(x, fft_map)   
        x = self.gm(x)
        x = self.conv_A(x)
        return x  







In [6]:
img = torch.Tensor(1, 3, 400, 600)
net = Net6()
print('total parameters:', sum(param.numel() for param in net.parameters()))
_, _, high = net(img)

total parameters: 33356




In [3]:
# 90 10 27 48
images = ['low00746.png', 'low00737.png', 'low00775.png', 'low00762.png', 'low00767.png', 'low00742.png', 'low00709.png', 'low00756.png', 'low00705.png', 'low00764.png', 'low00783.png', 'low00771.png', 'low00699.png', 'low00738.png', 'low00706.png', 'low00741.png', 'low00739.png', 'low00704.png', 'low00752.png', 'low00744.png', 'low00748.png', 'low00784.png', 'low00779.png', 'low00789.png', 'low00725.png', 'low00768.png', 'low00736.png', 'low00732.png', 'low00751.png', 'low00785.png', 'low00788.png', 'low00723.png', 'low00754.png', 'low00734.png', 'low00729.png', 'low00728.png', 'low00713.png', 'low00693.png', 'low00692.png', 'low00718.png', 'low00770.png', 'low00782.png', 'low00710.png', 'low00750.png', 'low00703.png', 'low00691.png', 'low00720.png', 'low00730.png', 'low00719.png', 'low00773.png', 'low00721.png', 'low00716.png', 'low00701.png', 'low00715.png', 'low00740.png', 'low00776.png', 'low00711.png', 'low00714.png', 'low00758.png', 'low00777.png', 'low00787.png', 'low00763.png', 'low00697.png', 'low00781.png', 'low00724.png', 'low00733.png', 'low00786.png', 'low00695.png', 'low00717.png', 'low00698.png', 'low00753.png', 'low00766.png', 'low00743.png', 'low00735.png', 'low00727.png', 'low00774.png', 'low00759.png', 'low00707.png', 'low00749.png', 'low00700.png', 'low00696.png', 'low00765.png', 'low00760.png', 'low00702.png', 'low00690.png', 'low00778.png', 'low00755.png', 'low00757.png', 'low00708.png', 'low00731.png', 'low00722.png', 'low00694.png', 'low00747.png', 'low00745.png', 'low00769.png', 'low00772.png', 'low00780.png', 'low00726.png', 'low00712.png', 'low00761.png']
print(images.index('low00748.png'))
print(images.index('low00778.png'))
print(images.index('low00780.png'))

# print(images[10])
# print(images[27])
# print(images[48])
# print(images[90])

20
85
96
