In [1]:
from torch import nn

from torch import autograd
import torch

###  定义生成器  init(),forward()

In [2]:
class NetG(nn.Module):
    """
    生成器定义
    __init__()
    forward()
    """
    def __init__(self,opt):
        super(NetG,self).__init__()
        ngf = opt.ngf
        self.main = nn.Sequential(
            # 输入 1*nz*1*1维的噪声
            nn.ConvTranspose2d(opt.nz,ngf*8,  4,1,0,bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),
            # (ngf*8)*4*4
            
            nn.ConvTranspose2d(ngf*8, ngf*4, 4,2,1,bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # (ngf*4)*8*8
            nn.ConvTranspose2d(ngf*4, ngf*2, 4,2,1,bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            # (ngf*2)*16*16
            nn.ConvTranspose2d(ngf*2, ngf*1, 4,2,1,bias=False),
            nn.BatchNorm2d(ngf*1),
            nn.ReLU(True),
            # (ngf*1)*32*32
            nn.ConvTranspose2d(ngf, 3, 5,3,1,bias=False),
            nn.Tanh()
            # 3*96*96  range(-1,1)
        )
    def forward(self,x):
        return self.main(x)

### 定义判别器

In [2]:
class NetD(nn.Module):
    """
    判别器
    __init__()
    forward()
    """
    def __init__(self,opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 与生成器正好相反
            # 3*96*96
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # ndf*32*32
            nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2,inplace=True),

            nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(ndf*8,1,4,1,0,bias=False),
            # batch*1*1*1
            nn.Sigmoid()
            
        )
    def forward(self,x):
        return self.main(x).view(-1) # batch
        

In [31]:
class Net1(nn.Module):
    def __init__(self):
        super(Net1,self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3,3,3,1,1,bias=False)
        
        )
    def forward(self,x):
        return self.main(x)

In [32]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2,self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3,1,3,1,0,bias=False)
        
        )
    def forward(self,x):
        return self.main(x).view(-1)

In [33]:
net1,net2 = Net1(), Net2()

In [34]:
from torch.autograd import Variable
fix_noises = Variable(torch.randn(1,3, 3, 3))

In [44]:
fix_1 = net1(fix_noises)

In [45]:
fix_2 = net2(fix_noises)

In [46]:
optimizer_2 = torch.optim.Adam(params=net2.parameters(), lr = 0.1)

In [47]:
net1.state_dict()

OrderedDict([('main.0.weight', tensor([[[[ 0.1527, -0.0862, -0.1650],
                        [-0.1514, -0.0155, -0.1031],
                        [ 0.0388,  0.0482, -0.1310]],
              
                       [[-0.1674, -0.0658,  0.0700],
                        [ 0.0371,  0.1182, -0.1781],
                        [ 0.0589, -0.0365,  0.1535]],
              
                       [[ 0.1085,  0.0635,  0.0670],
                        [ 0.1039, -0.0420,  0.1099],
                        [ 0.1594,  0.1450, -0.0445]]],
              
              
                      [[[-0.0472, -0.1280,  0.1149],
                        [ 0.0770, -0.0885,  0.0864],
                        [-0.1013,  0.1382, -0.0625]],
              
                       [[-0.0273, -0.0111, -0.0091],
                        [-0.1510,  0.1394,  0.1157],
                        [ 0.0128, -0.0388,  0.1641]],
              
                       [[ 0.0241,  0.1080,  0.1671],
                        [-0.1758, -0.08

In [48]:
net2.state_dict()

OrderedDict([('main.0.weight', tensor([[[[-0.0616,  0.0197, -0.1520],
                        [-0.0072, -0.2792,  0.0404],
                        [ 0.0581, -0.2292, -0.2194]],
              
                       [[ 0.0127, -0.0744,  0.0668],
                        [ 0.0393,  0.2048, -0.0266],
                        [-0.0504,  0.0173, -0.2773]],
              
                       [[ 0.1115,  0.0981,  0.0080],
                        [-0.0389,  0.0490, -0.0617],
                        [-0.1009,  0.0205, -0.0838]]]]))])

In [49]:
fix_2
optimizer_2.zero_grad()

In [50]:
fix_2.backward()

In [51]:
optimizer_2.step()

In [52]:
net1.state_dict()

OrderedDict([('main.0.weight', tensor([[[[ 0.1527, -0.0862, -0.1650],
                        [-0.1514, -0.0155, -0.1031],
                        [ 0.0388,  0.0482, -0.1310]],
              
                       [[-0.1674, -0.0658,  0.0700],
                        [ 0.0371,  0.1182, -0.1781],
                        [ 0.0589, -0.0365,  0.1535]],
              
                       [[ 0.1085,  0.0635,  0.0670],
                        [ 0.1039, -0.0420,  0.1099],
                        [ 0.1594,  0.1450, -0.0445]]],
              
              
                      [[[-0.0472, -0.1280,  0.1149],
                        [ 0.0770, -0.0885,  0.0864],
                        [-0.1013,  0.1382, -0.0625]],
              
                       [[-0.0273, -0.0111, -0.0091],
                        [-0.1510,  0.1394,  0.1157],
                        [ 0.0128, -0.0388,  0.1641]],
              
                       [[ 0.0241,  0.1080,  0.1671],
                        [-0.1758, -0.08

In [53]:
net2.state_dict()

OrderedDict([('main.0.weight', tensor([[[[-0.1616, -0.0803, -0.2520],
                        [ 0.0928, -0.3792, -0.0596],
                        [ 0.1581, -0.3292, -0.3194]],
              
                       [[-0.0873, -0.1744,  0.1668],
                        [-0.0607,  0.3048,  0.0734],
                        [-0.1504,  0.1173, -0.3773]],
              
                       [[ 0.2115,  0.1981,  0.1080],
                        [-0.1389,  0.1490, -0.1617],
                        [-0.2009,  0.1205,  0.0162]]]]))])