In [114]:
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import copy
import time
import random
import argparse
# Define a resnet block
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)

    def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim),
                       activation]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out
def get_grid(batchsize, rows, cols, gpu_id=0, dtype=torch.float32):
    hor = torch.linspace(-1.0, 1.0, cols)
    hor.requires_grad = False
    hor = hor.view(1, 1, 1, cols)
    hor = hor.expand(batchsize, 1, rows, cols)
    ver = torch.linspace(-1.0, 1.0, rows)
    ver.requires_grad = False
    ver = ver.view(1, 1, rows, 1)
    ver = ver.expand(batchsize, 1, rows, cols)

    t_grid = torch.cat([hor, ver], 1)
    t_grid.requires_grad = False

    if dtype == torch.float16: t_grid = t_grid.half()
    return t_grid.cuda(gpu_id)

In [124]:
class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()

    def grid_sample(self, input1, input2):
        if self.opt.fp16: # not sure if it's necessary
            return torch.nn.functional.grid_sample(input1.float(), input2.float(), mode='bilinear', padding_mode='border').half()
        else:
            return torch.nn.functional.grid_sample(input1, input2, mode='bilinear', padding_mode='border')

    #计算光流warp
    def resample(self, image, flow, normalize=True):        
        b, c, h, w = image.size()        
        if not hasattr(self, 'grid') or self.grid.size() != flow.size():
            self.grid = get_grid(b, h, w, gpu_id=flow.get_device(), dtype=flow.dtype)            
        if normalize:
            flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1)        
        final_grid = (self.grid + flow).permute(0, 2, 3, 1).cuda(image.get_device())
        output = self.grid_sample(image, final_grid)
        return output

class Part_Cloth(BaseNetwork):
    def __init__(self, opt, input_nc_1, input_nc_2, output_nc, ngf, n_downsampling, n_blocks, use_fg_model=False, no_flow=False,
                norm_layer=nn.BatchNorm2d, padding_type='reflect'):
        assert(n_blocks >= 0)
        super(Part_Cloth, self).__init__()                
        self.opt = opt
        self.n_downsampling = n_downsampling
        self.use_fg_model = use_fg_model
        self.no_flow = no_flow
        activation = nn.ReLU(True)
        
        ### flow and image generation
        ### downsample,对这三个输入分别进行下采样        
        model_down_1 = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc_1, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
        for i in range(n_downsampling):
            mult = 2**i
            model_down_1 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                               norm_layer(ngf * mult * 2), activation]  

        mult = 2**n_downsampling
        for i in range(n_blocks - n_blocks//2):
            model_down_1 += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
        
        model_down_2 = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc_2, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
        model_down_2 += copy.deepcopy(model_down_1[4:])
        
        #只输入两部分好了
#         model_down_lo = [nn.ReflectionPad2d(3), nn.Conv2d(prev_output_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
#         model_down_lo += copy.deepcopy(model_down_T[4:])
    
    
    
        ### resnet blocks 
        model_res_part = []
        for i in range(n_blocks//2):
            model_res_part += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]

        ### upsample
        model_up_part = []
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model_up_part += [nn.ConvTranspose2d(ngf*mult, ngf*mult//2, kernel_size=3, stride=2, padding=1, output_padding=1),
                             norm_layer(ngf*mult//2), activation]  
            
        ### 最后再用卷积处理一下生成12个通道，最后再使用softmax处理生成粗糙的结果
        model_final_part = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model_final_softmax = [nn.Softmax(dim=1)]
        #model_final_logsoftmax = [nn.LogSoftmax(dim=1)]

        
        #计算这三帧的flow
        model_res_flow = copy.deepcopy(model_res_part)
        model_up_flow = copy.deepcopy(model_up_part)
        model_final_flow = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 2, kernel_size=7, padding=0)]                
        model_final_w = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 1, kernel_size=7, padding=0), nn.Sigmoid()] 

        #将网络连接在一起
        self.model_down_1 = nn.Sequential(*model_down_1)        
        self.model_down_2 = nn.Sequential(*model_down_2)        
        #self.model_down_lo = nn.Sequential(*model_down_lo)        
        self.model_res_part = nn.Sequential(*model_res_part)
        self.model_up_part = nn.Sequential(*model_up_part)
        self.model_final_part = nn.Sequential(*model_final_part)
        self.model_final_softmax = nn.Sequential(*model_final_softmax)
        #self.model_final_logsoftmax = nn.Sequential(*model_final_logsoftmax)
        self.model_res_flow = nn.Sequential(*model_res_flow)
        self.model_up_flow = nn.Sequential(*model_up_flow)
        self.model_final_flow = nn.Sequential(*model_final_flow)
        self.model_final_w = nn.Sequential(*model_final_w)
    

        
    #生成局部衣服
    def forward(self, t_part, s_parsing, t_prev,use_raw_only):
        #print("input_T, input_S, lo_prev, use_raw_only",input_T.shape, input_S.shape, lo_prev.shape, use_raw_only)
        #torch.Size([4, 12, 256, 192]) torch.Size([4, 9, 256, 192]) torch.Size([4, 24, 256, 192])
        gpu_id = t_part.get_device()
        print(gpu_id)
        t_part = torch.cat((t_prev,t_part),axis=1)
        print(t_part.shape)
        print(s_parsing.shape)
        
        #对这两个数据分别进行下采样
        downsample_1 = self.model_down_1(s_parsing)   #3,
        downsample_2 = self.model_down_2(t_part)      #9
       
        print(downsample_1.shape)
        print(downsample_2.shape)
        
        #得到特征图并生成图像
        part_feat = self.model_up_part(self.model_res_part(downsample_1+downsample_2))
        part_raw = self.model_final_part(part_feat)
        
        #是否需要计算flow
        flow = weight = flow_feat = None
        if not self.no_flow:
            print("flow")
            flow_feat = self.model_up_flow(self.model_res_flow(downsample_1))
            flow = self.model_final_flow(flow_feat) * 20
            weight = self.model_final_w(flow_feat)
        
        #是否需要warp
        if use_raw_only or self.no_flow:
            part_final = part_raw
        else:
            print("warp")
            part_warp = self.resample(t_prev[:,-3:,...].cuda(gpu_id), flow).cuda(gpu_id)        
            weight_ = weight.expand_as(part_raw)
            part_final = part_raw * weight_ + part_warp * (1-weight_)

                
        print(part_final.shape)      
        return part_final, part_raw, flow, weight
        #fake_slo, fake_slo_raw, fake_slo_ls, fake_slo_raw_ls, flow, weight
        
#这3部分如何处理呢
t_1 = Variable(torch.rand(4,3,256,192).cuda(1))#  1
s_1 = Variable(torch.rand(4,3,256,192).cuda(1))#0,1,2,背景，上衣，下衣 3
t_prev = Variable(torch.rand(4,6,256,192).cuda(1))# 2
#s_1 = Variable(torch.rand(2,3,256,196))，这是gt

parser = argparse.ArgumentParser()     
parser.add_argument('--batch', type=int, default=1, help='input batch size')
parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP')

opt = parser.parse_args(args=[])


#模型参数 opt, input_nc_1, input_nc_2, output_nc, prev_output_nc, ngf, n_downsampling, n_blocks,
input_nc_1 = 3
input_nc_2 = 9
output_nc = 3
ngf = 64
n_downsampling = 3
n_blocks = 9
use_raw_only = False


device = torch.device('cuda:1')
net = Part_Cloth(opt,input_nc_1,input_nc_2,output_nc,ngf,n_downsampling,n_blocks)
net.to(device)

t_part = t_1
s_parsing = s_1

#保证输入是这三个
output = net(t_part, s_parsing, t_prev, use_raw_only)

#输入：目标局部衣服，源语义图，前两帧生成的目标局部衣服





1
torch.Size([4, 9, 256, 192])
torch.Size([4, 3, 256, 192])
torch.Size([4, 512, 32, 24])
torch.Size([4, 512, 32, 24])
flow
warp
torch.Size([4, 3, 256, 192])


In [2]:

for i in range(0, 1, 12):
    print(i)

0


In [5]:
print(list(range(0, 12, 1)))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]


In [7]:
a = [1,2,3,4,5,6,7,8]
print(a[:-1])

[1, 2, 3, 4, 5, 6, 7]


In [10]:
import torch
a = torch.rand(2,3)
print(a)
b = torch.rand(2,3)
print(b)
c = torch.cat([a,b],dim=1)
print(c)

tensor([[0.6612, 0.9598, 0.9742],
        [0.3327, 0.2073, 0.9006]])
tensor([[0.4617, 0.8045, 0.2320],
        [0.7412, 0.9151, 0.4362]])
tensor([[0.6612, 0.9598, 0.9742, 0.4617, 0.8045, 0.2320],
        [0.3327, 0.2073, 0.9006, 0.7412, 0.9151, 0.4362]])


In [9]:
import torch
a = torch.randn(1,12,16,16)
b = a[0].max(0, keepdim=True)[1]#0是具体的值，1是index
print(a[0].shape)
print(b.shape)

torch.Size([12, 16, 16])
torch.Size([1, 16, 16])


In [13]:
a = torch.rand(3,3)
b = a.max(0,keepdim=True)[1]
print(a.shape)
print(b.shape)

torch.Size([3, 3])
torch.Size([1, 3])


In [1]:
a = [1,2,3,4]
print(a[-1:])


[4]


In [18]:
import torch


def get_affinity(mk, qk):
        B, CK, h, w = mk.shape  #2,5,16,16
        mk = mk.flatten(start_dim=2)
        qk = qk.flatten(start_dim=2)

        print("mk",mk.shape)#2,5,256
        print("qk",qk.shape)#2,5,256

        #先所有元素平方，然后c通道相加
        m1 = torch.sqrt(mk.pow(2).sum(1).unsqueeze(2))#2,5,256--->2,256---->2,256,1
        q1 = torch.sqrt(qk.pow(2).sum(1).unsqueeze(1))#2,5,256--->2,256---->2,1,256
        print("m1",m1.shape)
        print("q1",q1.shape)
        mq = m1@q1   #2,256,1*2,1,256 = 2,256,256
      

        #计算余弦距离
        b = (mk.transpose(1, 2) @ qk)#2,256,5*2,5,256--->2,256,256
        
        affinity = b/ mq   # B, THW, HW
        #print("affinity",affinity.shape)

        # softmax operation; aligned the evaluation style
        x_exp = torch.exp(affinity)#2,256,256
        x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)#2,256,256-->2,1,256
   
        affinity = x_exp / x_exp_sum 

        #print(affinity.sum(1))

        return affinity

    #mv是补充信息
def fusion(mv,qv):
        B, CV, H, W = qv.shape
        affinity = get_affinity(mv,qv)
        mo = mv.view(B, CV, H*W)#2,12,16
        mem = torch.bmm(mo, affinity)#2,12,16*2,16,16 = 2,12,16
        mem = mem.view(B, CV, H, W)
        mem_out = qv+mem

        return mem_out
    
mv = torch.randn((2,5,16,16))
qv = torch.randn((2,5,16,16))
out = fusion(mv,qv)
print(out.shape)

mk torch.Size([2, 5, 256])
qk torch.Size([2, 5, 256])
torch.Size([2, 256])
m1 torch.Size([2, 256, 1])
q1 torch.Size([2, 1, 256])
torch.Size([2, 5, 16, 16])
