In [17]:
import jittor as jt
from jittor import init
import argparse
import os
import numpy as np
import math
from jittor import nn

if jt.has_cuda:
    jt.flags.use_cuda = 1


In [18]:
import sys
sys.argv = ['run.py']
    

In [19]:
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--n_classes', type=int, default=10, help='number of classes for dataset')
parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=1000, help='interval between image sampling')
opt = parser.parse_args()
print(opt)

Namespace(n_epochs=100, batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=8, latent_dim=100, n_classes=10, img_size=32, channels=1, sample_interval=1000)


In [13]:
img_shape = (opt.channels, opt.img_size, opt.img_size)

In [20]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
        # nn.Linear(in_dim, out_dim)表示全连接层
        # in_dim：输入向量维度
        # out_dim：输出向量维度
        def block(in_feat, out_feat, normalize=True):  #用于定义一个层
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))# 0.8是momentum参数，控制均值和方差的移动平均值的权重
            layers.append(nn.LeakyReLU(0.2)) #激活函数是ReLu的变种，当输入小于0时，Leaky ReLU会乘以0.2，而不是直接输出0
            return layers
        self.model = nn.Sequential(*block((opt.latent_dim + opt.n_classes), 128, normalize=False), 
                                   *block(128, 256), 
                                   *block(256, 512), 
                                   *block(512, 1024), 
                                   nn.Linear(1024, int(np.prod(img_shape))), 
                                   nn.Tanh())

    def execute(self, noise, labels):
        gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)
        img = self.model(gen_input)
        # 将img从1024维向量变为32*32矩阵
        img = img.view((img.shape[0], *img_shape))
        return img

In [26]:
test = nn.Embedding(opt.n_classes, opt.n_classes)

In [24]:
x = jt.int32([1, 2, 3, 3])

In [27]:
test(x)

jt.Var([[-0.7081189   0.21343614  1.631101   -0.04060285  0.5380743  -0.17995544
          0.5088695  -1.3160007  -0.30307457 -1.088641  ]
        [-0.4683468  -1.0584054  -0.46382385 -0.33939174  0.98749816  0.38035026
         -2.1871002  -0.7143253  -1.0805888   1.6664243 ]
        [-0.6475718   1.8264806  -1.3996431   1.2343264   0.8505684  -1.7552401
         -0.3257211  -1.0909462   0.09111369 -0.9528391 ]
        [-0.6475718   1.8264806  -1.3996431   1.2343264   0.8505684  -1.7552401
         -0.3257211  -1.0909462   0.09111369 -0.9528391 ]], dtype=float32)

In [15]:
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
        self.model = nn.Sequential(nn.Linear((opt.n_classes + int(np.prod(img_shape))), 512), 
                                   nn.LeakyReLU(0.2), 
                                   nn.Linear(512, 512), 
                                   nn.Dropout(0.4), 
                                   nn.LeakyReLU(0.2), 
                                   nn.Linear(512, 512), 
                                   nn.Dropout(0.4), 
                                   nn.LeakyReLU(0.2), 
                                   # TODO: 添加最后一个线性层，最终输出为一个实数
                                   nn.Linear(512, 1)
                                   )

    def execute(self, img, labels):
        d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)
        # TODO: 将d_in输入到模型中并返回计算结果
        validity = self.model(d_in)
        return validity


In [16]:
# 损失函数：平方误差
# 调用方法：adversarial_loss(网络输出A, 分类标签B)
# 计算结果：(A-B)^2
adversarial_loss = nn.MSELoss()

generator = Generator()
discriminator = Discriminator()

# 导入MNIST数据集
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
    transform.Resize(opt.img_size),
    transform.Gray(),
    transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)

optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to C:\Users\zxy08\.cache\jittor\dataset/mnist_data/train-images-idx3-ubyte.gz


9.46MB [00:05, 1.96MB/s]                            


Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to C:\Users\zxy08\.cache\jittor\dataset/mnist_data/train-labels-idx1-ubyte.gz


32.0kB [00:00, 49.1kB/s]                   


Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to C:\Users\zxy08\.cache\jittor\dataset/mnist_data/t10k-images-idx3-ubyte.gz


1.58MB [00:01, 1.60MB/s]                            


Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to C:\Users\zxy08\.cache\jittor\dataset/mnist_data/t10k-labels-idx1-ubyte.gz


8.00kB [00:00, 12.2kB/s]                   
