In [1]:
import sys 
sys.path.append("..") 
from model.DQ.dq import *
import math
import datetime
from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def crop_tensor(image_pack, scale = 4):
    _, _, w, h = image_pack.size()
    a = int(w/scale)
    b = int(h/scale)
    t = torch.split(image_pack, a, dim = 2)
    ans = []
    for i in t:
        for j in torch.split(i,b, dim=3):
            ans.append(j)
    d = torch.stack(ans, 1)
    return d

In [3]:
def cat_tensor(x, scale = 4):
    data = []
    for i in range(scale):
        m = []
        for j in range(scale):
            #print(i,j,i*scale + j, x[:, i*scale + j ,:,:,:].shape)
            m.append(x[:, i*scale + j ,:,:,:])
        data.append(torch.cat(m, dim = -1))
    data = torch.cat(data, dim = -2)
    #print(data.shape)
    return data

In [4]:

def autopad(k, p=None):  # kernel, padding自动填充的设计，更加灵活多变
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  
# auto-pad自动填充，通过自动设置填充数p
        #如果k是整数，p为k与2整除后向下取整；如果k是列表等，p对应的是列表中每个元素整除2。
    return p
class Conv(nn.Module):
    # 这里对应结构图部分的CBL，CBL = conv+BN+Leaky ReLU，后来改成了SiLU（CBS）
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2) 
#将其变为均值为0，方差为1的正态分布，通道数为c2
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
#其中nn.Identity()是网络中的占位符，并没有实际操作，在增减网络过程中，可以使得整个网络层数据不变，便于迁移权重数据；nn.SiLU()一种激活函数(S形加权线性单元)。
 
    def forward(self, x):#正态分布型的前向传播
        return self.act(self.bn(self.conv(x)))
 
    def forward_fuse(self, x):#普通前向传播
        return self.act(self.conv(x))

In [5]:
a = torch.ones(1, 1, 16,16)
a.shape
a = a.view(-1)
for i in range(16*16):
    a[i]*= i + 1
a = a.view([1,1,16,16])

In [6]:
a = torch.zeros(3, 32, 224,224)
a.shape
crop_tensor(a).shape

torch.Size([3, 16, 32, 56, 56])

In [7]:
a = torch.zeros(3, 16, 32, 56, 56)
a.shape

torch.Size([3, 16, 32, 56, 56])

In [8]:
from torch import Tensor
class Attention(nn.Module):
    """
    An attention layer that allows for downscaling the size of the embedding
    after projection to queries, keys, and values.
    """

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        downsample_rate: int = 1,
    ) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        self.internal_dim = embedding_dim // downsample_rate
        self.num_heads = num_heads
        assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."

        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
        self.dropout = nn.Dropout(0.2)

    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
        b, n, c = x.shape
        x = x.reshape(b, n, num_heads, c // num_heads)
        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head

    def _recombine_heads(self, x: Tensor) -> Tensor:
        b, n_heads, n_tokens, c_per_head = x.shape
        x = x.transpose(1, 2)
        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        # Input projections
        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)

        # Separate into heads
        q = self._separate_heads(q, self.num_heads)
        k = self._separate_heads(k, self.num_heads)
        v = self._separate_heads(v, self.num_heads)

        # Attention
        _, _, _, c_per_head = q.shape
        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
        attn = attn / math.sqrt(c_per_head)
        attn = torch.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        # Get output
        out = attn @ v
        out = self._recombine_heads(out)
        out = self.out_proj(out)
        return out

In [9]:
class Focus(nn.Module):
    # Focus wh information into c-space
    def __init__(self,  in_channel, out_channel, scale = 2,):  
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d( 4 * in_channel, out_channel, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.SiLU()
        )
        self.scale = scale
    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
        if self.scale == 1:
            return x
        #print(x.shape)
        y = torch.cat([x[..., ::self.scale, ::self.scale], x[..., 1::self.scale, ::self.scale], x[..., ::self.scale, 1::self.scale], x[..., 1::self.scale, 1::self.scale]], 1)
        #print(y.shape)
        y = self.conv(y)
        #print(y.shape)
        return y  

In [73]:
pa = Pixel_Attention(1, 2, scale = 28)



In [74]:
d = torch.zeros(2, 1, 28, 28)
pa(d).shape

torch.Size([2, 784, 1]) torch.Size([2, 784, 1])


torch.Size([2, 1, 28, 28])

In [84]:
class Pixel_Attention(nn.Module):
    def __init__(self,  in_channel, middle_dim, n_linear_len = 2, scale = 2, n_coder_blocks = 4 ,bias = False):  
        super().__init__()
        middle_list = []
        middle_list.append(nn.Linear(in_channel, middle_dim, bias = bias))
        for i in range(n_linear_len):
            middle_list.append(nn.BatchNorm1d(scale * scale))
            middle_list.append(nn.ReLU())
            middle_list.append(nn.Linear(middle_dim, middle_dim, bias = bias))
           
        #middle_list.append( )
        self.linear_box = []
        linear_point = {}
        for i in range(n_coder_blocks):
            linear_item = nn.Sequential(*middle_list)
            if id(linear_item) in linear_point:
                assert False, 'memory with same local'

            self.linear_box.append(linear_item)
            linear_point[id(linear_item)] = 0
        
        del linear_point
        self.out = nn.Linear(middle_dim * n_coder_blocks, in_channel, bias = bias)
        self.attention = Attention(in_channel, 4, 4)
        self.scale = scale
    def crop_tensor(self, image_pack):
        _, _, w, h = image_pack.size()
        a = int(w/self.scale)
        b = int(h/self.scale)
        t = torch.split(image_pack, a, dim = 2)
        ans = []
        for i in t:
            for j in torch.split(i,b, dim=3):
                ans.append(j)
        d = torch.stack(ans, 1)
        return d
    
    def cat_tensor(self, x,):
        data = []
        for i in range(self.scale):
            m = []
            for j in range(self.scale):
                m.append(x[:, i * self.scale + j ,:,:,:])
            data.append(torch.cat(m, dim = -1))
        data = torch.cat(data, dim = -2)
        return data
    def forward(self, x): 
        x = self.crop_tensor(x)
        #print('croped:', x.shape)
        x = x.squeeze(-1).squeeze(-1)
        #print(x.shape)

        attn = self.attention(x,x,x)
        x_list = []
        for linear_layer in self.linear_box:
            #x_temp = linear_layer(x)
            #print(x_temp.shape)
            x_list.append(linear_layer(x))
      
        x = torch.cat([*x_list], dim = -1)
        
        x = self.out(x) 
        #print(x.shape, attn.shape)
        x = x * attn
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = self.cat_tensor(x)
        return x

In [17]:
128/16

8.0

In [9]:

class HDAE(VQVAE):
    def __init__(self, *args, **kwargs):
        super(HDAE, self).__init__(*args, **kwargs)
        channel = self.channel
        embed_dim = self.embed_dim
        n_codebooks = self.n_codebooks
        dec_in_channels = embed_dim * self.n_codebooks
        n_res_block = self.n_res_block
        n_res_channel = self.n_res_channel
        n_coder_blocks = self.n_coder_blocks
        stride = self.stride
        
        self.focus = Focus(embed_dim, embed_dim)
        self.conv_before_feature = nn.Sequential(
            nn.Conv2d( self.in_channel, embed_dim, 1, 1),
        )
        
        del self.quantize_convs
        del self.enc_blocks
        del self.decoder
        
        self.enc_blocks = nn.ModuleList()
        self.quantize_convs = torch.nn.ModuleList()

        enc_blocks = [
            Encoder(embed_dim, channel, n_res_block, n_res_channel, stride = stride)
        ]
        enc_blocks += [
            Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
            for i in range(n_coder_blocks - 1)
        ]
        self.enc_blocks.append(nn.Sequential(*enc_blocks))
        
        self.high_block_list = nn.ModuleList()
        self.high_block_list.append( Pixel_Attention( channel, channel * 2, scale = 28) )
        self.high_block_list.append( Pixel_Attention( channel, 2 * channel, scale = 14) )
        
        self.conv_down_1x1 = nn.ModuleList()
        self.conv_down_1x1.append( nn.Identity() )
        self.conv_down_1x1.append( nn.Conv2d(2*channel, channel, 1, 1) )
        #self.conv_down_1x1.append( high_color( channel, 2 * channel, scale = 14) )
        
        # bot to top, excluding top because top doesn't accept a concat input
        for i, _ in enumerate(self.n_hier):
            if i != 0:
                enc_blocks = [
                    Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
                ]
                self.enc_blocks.append(nn.Sequential(*enc_blocks))

            cur_hier = len(self.n_hier) - 1 - i
            # only for top we have channel as input to quantizer because we don't condition on prior codes
            conv2D_channels = channel if cur_hier == 0 else channel * 2
            quantize_conv = torch.nn.ModuleList(
                [nn.Conv2d(conv2D_channels, embed_dim, 1) for _ in range(n_codebooks)]
            )

            self.quantize_convs.append(quantize_conv)
        
        dec_blocks = [
            Decoder(
                channel * len(self.n_hier), channel, channel, n_res_block, n_res_channel, stride = stride
            ) for i in range(n_coder_blocks - 1)
        ]

        self.decoder = nn.Sequential(*dec_blocks)

        # ------------------------- 3, high quality image reconstruction ------------------------- #
        upscale = 2
        num_feat = 64
        self.conv_after_body = nn.Conv2d(channel, embed_dim, 3, 1, 1)
        
        self.conv_before_upsample = nn.Sequential(
                nn.Conv2d(embed_dim, 12, 3, 1, 1), nn.LeakyReLU()
                )
        self.upsample_with_high = Upsample(upscale, 12)
        self.conv_last = nn.Conv2d(12, 3, 3, 1, 1)
        
        # ------------------------- 4, high color image ------------------------- #
        self.scale_conv = nn.Conv3d(16, 16, 2,2)
        self.attention = Attention(1024, 8, 16)
        self.descale_conv = nn.Sequential(
            nn.ConvTranspose2d(num_feat, channel, 4, 2, 1),
            nn.ConvTranspose2d(channel, channel, 4, 2, 1),
        )
        
        self.conv_great = nn.Conv2d(128,3,1,1)
        
    def crop_tensor(self, image_pack, scale = 4):
        _, _, w, h = image_pack.size()
        a = int(w/scale)
        b = int(h/scale)
        t = torch.split(image_pack, a, dim = 2)
        ans = []
        for i in t:
            for j in torch.split(i,b, dim=3):
                ans.append(j)
        d = torch.stack(ans, 1)
        return d
    
    def cat_tensor(self, x, scale = 4):
        data = []
        for i in range(scale):
            m = []
            for j in range(scale):
                m.append(x[:, i*scale + j ,:,:,:])
            data.append(torch.cat(m, dim = -1))
        data = torch.cat(data, dim = -2)
        return data
        
    def high_color(self, x):
        print("scale in:", x.shape)
        x_divide = self.crop_tensor(x)
        print("scale out:", x_divide.shape)
        x = self.scale_conv(x_divide)
       
        #print(x_divide.shape, x.shape)
        B,C1,C,W,H = x.shape
        x1 = x.view([-1, C1 * C, W * H]).permute(0,2,1)
        #print(x1.shape)
        a_x = self.attention(x1,x1,x1)
        x = a_x * x1
        #print(x.shape)
        x = x.view([-1, C1, C, W, H])
        #print(x.shape)
        x = self.cat_tensor(x)
        x = self.descale_conv(x)
        #print(x.shape)
        return x
        
    def high_reconstruct(self, x, origin_image):
        # x 256 224 224 origin 256,224,224
        print(origin_image.shape, x.shape)
        x = origin_image + self.conv_after_body(x)
        print(x.shape)
        x = self.conv_before_upsample(x)
        print(x.shape)
        x = self.upsample_with_high(x)
        print(x.shape)
        x = self.conv_last(x)
        print(x.shape)
        # 3 448 448
        return x

    def forward(self, input):

        encodings = []
        origin_image = self.conv_before_feature(input)
        f_input = self.focus(origin_image)
        enc = f_input
        #print()
        # bot to top encoder blocks and encodings
        for block in self.enc_blocks:
            enc = block(enc)
            print(enc.shape)
            encodings.append(enc)

        # reverse them top to bot for the decoding process
        quantize_convs = list(reversed(self.quantize_convs))
        quantizers = list(reversed(self.quantizers))
        encodings = list(reversed(encodings))
        dec_blocks = list(reversed(self.dec_blocks))
        upsample_blocks = list(reversed(self.upsample))
        high_blocks = list(reversed(self.high_block_list))
        quants = []
        pre_quants = []  # used for analysis of mutual information
        ids = []
        upsamples = []
        # Quantizer Loss
        diffs = 0.0

        for i, enc in enumerate(encodings):
            if i == 0:
                # top doesn't have previous decodings to condition on
                pass
            else:
                enc = torch.cat([dec, enc], 1)
                #print('before:', enc.shape)
                #enc = self.conv_down_1x1[i](enc)
                #print('after:', enc.shape)
            
            quant = self.quantize(
                quantize_convs[i], quantizers[i], enc
            )
            #print(f'dec {i}:', enc.shape)
            #quant = high_blocks[i](enc)
            #print('quant:' , quant.shape)
            dec = dec_blocks[i](quant)
            upsampled = upsample_blocks[i](quant)
            upsamples.append(upsampled)

        dec = self.decoder(upsampled) if self.type else self.decoder(torch.cat(upsamples, 1))
        dec = self.dropout(dec)
     
        great_color = self.high_color(dec)
        
        dec = self.high_reconstruct(great_color, origin_image)
        great_color = self.conv_great(great_color)
        return dec, great_color, diffs #, ids, (loss, recon_loss, latent_loss)

    def quantize(self, conv_block, quant_block, input):
        quants = []
        #diff = 0.0
        #ids = []
        pre_quants = []
        #print('input:{}'.format(input.shape))
        for i in range(self.n_codebooks):
            pre_quant = conv_block[i](input)
            #.permute(0, 2, 3, 1)
            #print('pre_quant:{}'.format(pre_quant.shape))
            #quant_i, diff_i, idx = quant_block[i](pre_quant)
            quant_i = pre_quant# quant_i.permute(0, 3, 1, 2)
            #diff_i = diff_i.unsqueeze(0)
            #diff += diff_i
            quants.append(quant_i)
            #ids.append(idx)
            pre_quants.append(pre_quant.permute(0, 3, 1, 2))

        #ids = torch.stack(ids, 1)
        quants = torch.cat(quants, 1)
        #print('quants:{}'.format(quants.shape))
        return quants #diff, ids, torch.cat(pre_quants, 1)



In [10]:
model = HDAE(
in_channel = 3,
channel = 128,
n_res_block = 3,
n_res_channel = 128,
n_coder_blocks = 2,
embed_dim = 32,
n_codebooks = 4,
stride = 2,
decay = 0.99,
loss_name = "mse",
vq_type = "dq",
beta = 0.25,
n_hier = [64, 128],
n_logistic_mix = 10,
)

In [11]:
batch_image = torch.randn((1,3,224,224))

In [12]:
ans = model(batch_image)
ans[0].shape, ans[1].shape

torch.Size([1, 128, 28, 28])
torch.Size([1, 128, 14, 14])
scale in: torch.Size([1, 128, 112, 112])
scale out: torch.Size([1, 16, 128, 28, 28])
torch.Size([1, 32, 224, 224]) torch.Size([1, 128, 224, 224])
torch.Size([1, 32, 224, 224])
torch.Size([1, 12, 224, 224])
torch.Size([1, 12, 448, 448])
torch.Size([1, 3, 448, 448])


(torch.Size([1, 3, 448, 448]), torch.Size([1, 3, 224, 224]))

In [15]:
model = model.to('cuda:0')

In [16]:
summary(model, (3,224,224))

torch.Size([2, 128, 28, 28])
torch.Size([2, 128, 14, 14])
dec 0: torch.Size([2, 128, 14, 14])
quant: torch.Size([2, 128, 14, 14])
dec 1: torch.Size([2, 128, 28, 28])
quant: torch.Size([2, 128, 28, 28])
scale in: torch.Size([2, 128, 112, 112])
scale out: torch.Size([2, 16, 128, 28, 28])
torch.Size([2, 32, 224, 224]) torch.Size([2, 128, 224, 224])
torch.Size([2, 32, 224, 224])
torch.Size([2, 12, 224, 224])
torch.Size([2, 12, 448, 448])
torch.Size([2, 3, 448, 448])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 224, 224]             128
            Conv2d-2         [-1, 32, 112, 112]           4,128
       BatchNorm2d-3         [-1, 32, 112, 112]              64
              SiLU-4         [-1, 32, 112, 112]               0
             Focus-5         [-1, 32, 112, 112]               0
            Conv2d-6           [-1, 64, 56, 56]          32,832
              ReLU-7

In [111]:
d = torch.zeros(2, 128 ,28, 28)

In [112]:
c = high_color(128, 128 * 2, scale = 28)

In [114]:
c(d).shape

torch.Size([2, 784, 128, 1, 1])
torch.Size([2, 784, 128])
torch.Size([2, 784, 128])


torch.Size([2, 128, 28, 28])

In [59]:
d1 = d.view([1,  28*28, 128])
d1.shape

torch.Size([1, 784, 128])

In [66]:
model.crop_tensor(d,28).squeeze().shape

torch.Size([2, 784, 128])

In [67]:
c = nn.Linear(128, 256)

In [68]:
c(d1).shape

torch.Size([1, 784, 256])

In [22]:
c = nn.Conv2d(1,1,(32,32),(32,32))

In [23]:
c(d).shape

torch.Size([1, 1, 128, 128])

In [20]:
torch.save(obj = c.state_dict(), f='test.pkl')

In [16]:
torch.save(obj = model.state_dict(), f='test.pkl')

In [21]:
d = torch.zeros((1,256,112,112))
model.decoder[2](d).shape

RuntimeError: running_mean should contain 256 elements not 128

In [10]:
bat = torch.zeros((1,256,224,224))
model.decoder[1](bat).shape

torch.Size([1, 256, 448, 448])

In [8]:
d.shape

torch.Size([1, 3, 224, 224])

In [52]:
c = torch.zeros((1,256,224,224))
model.decoder[1](c).shape
nn.Conv2d(256, 1, 2, 2)(c).shape

torch.Size([1, 1, 112, 112])

In [121]:
im = torch.zeros(1, 3*16, 16*28*28).to('cuda:0')
model = Attention(12544, 8, 16).to('cuda:0')
model(im,im,im).shape

torch.Size([1, 48, 12544])

In [118]:
model = Attention(12544, 8, 16).to('cuda:0')
summary(model, [[3*16, 16*28*28], [3*16, 16*28*28], [3*16, 16*28*28]])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 48, 784]       9,835,280
            Linear-2              [-1, 48, 784]       9,835,280
            Linear-3              [-1, 48, 784]       9,835,280
            Linear-4            [-1, 48, 12544]       9,847,040
Total params: 39,352,880
Trainable params: 39,352,880
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 5.46
Params size (MB): 150.12
Estimated Total Size (MB): 155.57
----------------------------------------------------------------


In [63]:
o = torch.zeros((1,1,224,224))

In [68]:
2016/8

252.0

In [19]:
d = nn.Conv2d(1,1,1,1)
d_list = [d]
c = nn.Sequential(d)
c1 = nn.Sequential(d)

In [23]:
id(c)

2543643296768

In [24]:
id(c1)

2543643371664

In [60]:
len(model.dec_blocks),len(model.decoder)

(4, 3)

In [61]:
model.n_hier

[32, 64, 128, 256]

In [57]:
channel = 1
c = nn.Sequential(
    nn.ConvTranspose2d(channel, channel, 3, stride=3),
    nn.Conv2d(channel, channel, 3, 3),
    nn.ConvTranspose2d(channel, channel, 3, stride=3),
    nn.Conv2d(channel, channel, 2, 2),
    nn.ConvTranspose2d(channel, channel, 2, stride=2),
    nn.Conv2d(channel, channel, 2, 2),
)
a = torch.zeros((1,1,224,224))
c(a).shape

torch.Size([1, 1, 336, 336])