In [1]:
# 解决内核挂掉的问题
import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"

In [2]:
from tensorboardX import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import os
from torch.utils.data import Dataset, DataLoader
from skimage.transform import resize
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.autograd import Variable
from torch import LongTensor, FloatTensor
from torchvision.utils import save_image
import sys
gpu_id = 0

In [3]:

###  常规定义
class NWSDataset(Dataset):
    """
    NWS Dataset
    """

    def __init__(
            self, fake='../data/fake.pt', c=0.75, i=1, n=2557  ### 初始化参数c、i，影响结果和速率；n是数据集数量
    ):
        val = int(n * (c ** i))
        self.real = torch.load('../data/train.pt').cuda(gpu_id)
        self.real.requires_grad = False
        self.fake = torch.load(fake).cuda(gpu_id)
#         self.fake.requires_grad = False

        self.realdata = torch.cat([self.real[:val], self.fake[:-1 * val]], 0) #选择性拼接，训练集和生成结果都算分布模型训练集

    def __len__(self):
        return self.realdata.shape[0]

    def __getitem__(self, item):
        return self.realdata[item]


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)


def convTBNReLU(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, True),
    )


def convBNReLU(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, True),
    )


In [4]:
## 模型定义
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Generator, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.block1 = convTBNReLU(in_channels, 512, 4, 1, 0)
        self.block2 = convTBNReLU(512, 256)
        self.block3 = convTBNReLU(256, 128)
        self.block4 = convTBNReLU(128, 64)
        self.block5 = nn.ConvTranspose2d(64, out_channels, 4, 2, 1)

    def forward(self, noise):
        out = self.block1(noise)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        return torch.tanh(self.block5(out))


class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.in_channels = in_channels
        self.block1 = convBNReLU(self.in_channels, 64)
        self.block2 = convBNReLU(64, 128)
        self.block3 = convBNReLU(128, 256)
        self.block4 = convBNReLU(256, 512)
        self.block5 = nn.Conv2d(512, 64, 4, 1, 0)
        self.source = nn.Linear(64, 1)

    def forward(self, input):
        out = self.block1(input)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.block5(out)
        size = out.shape[0]
        out = out.view(size, -1)
        source = torch.sigmoid(self.source(out))
        return source


In [5]:
### 一般参数定义
latentdim = 20
criterionSource = nn.BCELoss()
G = Generator(in_channels=latentdim, out_channels=1).cuda(gpu_id)
D = Discriminator(in_channels=1).cuda(gpu_id)
G.apply(weights_init_normal)
D.apply(weights_init_normal)

optimizerG = optim.Adam(G.parameters(), lr=0.00002, betas=(0.5, 0.999))
optimizerD = optim.Adam(D.parameters(), lr=0.00001, betas=(0.5, 0.999))
static_z = Variable(FloatTensor(torch.randn((81, latentdim, 1, 1)))).cuda(gpu_id)


def sample_image(stage, epoch):
    static_sample = G(static_z).detach().cpu()
    static_sample = (static_sample + 1) / 2.0
    save_image(static_sample, DIRNAME + "stage%depoch%d.png" % (stage, epoch), nrow=9)


c = 0.75
k = 10
DIRNAME = '../DistShift/'
os.makedirs(DIRNAME, exist_ok=True)
board = SummaryWriter(log_dir=DIRNAME)

G.load_state_dict(torch.load('../DCGAN/G999.pt'))
D.load_state_dict(torch.load('../DCGAN/D999.pt'))
step = 0
fake_name = '../data/fake.pt'

In [6]:
### 迭代模型
n = 2557 # 2556应根据数据集实际大小决定
for i in range(1, k):
    dataloader = DataLoader(NWSDataset(fake=fake_name, c=c, i=i, n=n), batch_size=256, shuffle=True)
    for epoch in range(0, 100):
        print(i,"      ",epoch)
        for realdata in dataloader:
            noise = 1e-5 * max(1 - (epoch / 100.0), 0)
            step += 1
            batch_size = realdata[0].shape[0]
            trueTensor = 0.7 + 0.5 * torch.rand(batch_size)
            falseTensor = 0.3 * torch.rand(batch_size)
            probFlip = torch.rand(batch_size) < 0.05
            probFlip = probFlip.float()
            trueTensor, falseTensor = (
                probFlip * falseTensor + (1 - probFlip) * trueTensor,
                probFlip * trueTensor + (1 - probFlip) * falseTensor,
            )
            trueTensor = trueTensor.view(-1, 1).cuda(gpu_id)
            falseTensor = falseTensor.view(-1, 1).cuda(gpu_id)
            realdata = realdata.cuda(gpu_id)
            realSource = D(realdata)
            realLoss = criterionSource(realSource, trueTensor.expand_as(realSource))
            latent = Variable(torch.randn(batch_size, latentdim, 1, 1)).cuda(gpu_id)
            fakeGen = G(latent)
            fakeGenSource = D(fakeGen.detach())
            fakeGenLoss = criterionSource(fakeGenSource, falseTensor.expand_as(fakeGenSource))
            lossD = realLoss + fakeGenLoss
            optimizerD.zero_grad()
            lossD.backward()
            torch.nn.utils.clip_grad_norm_(D.parameters(), 20)
            optimizerD.step()
            fakeGenSource = D(fakeGen)
            lossG = criterionSource(fakeGenSource, trueTensor.expand_as(fakeGenSource))
            optimizerG.zero_grad()
            lossG.backward()
            torch.nn.utils.clip_grad_norm_(G.parameters(), 20)
            optimizerG.step()
            board.add_scalar('realLoss', realLoss.item(), step)
            board.add_scalar('fakeGenLoss', fakeGenLoss.item(), step)
            board.add_scalar('lossD', lossD.item(), step)
            board.add_scalar('lossG', lossG.item(), step)
        if (epoch + 1) % 50 == 0:
            torch.save(G.state_dict(), DIRNAME + "Gstage" + str(i) + 'epoch' + str(epoch) + ".pt")
            torch.save(D.state_dict(), DIRNAME + "Dstage" + str(i) + 'epoch' + str(epoch) + ".pt")
        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                G.eval()
                sample_image(i, epoch)
                G.train()
    with torch.no_grad():
        G.eval()
        fsize = int((1 - (c ** (i + 1))) * n / c)
        fakeSamples = G(Variable(torch.randn(fsize, latentdim, 1, 1)).cuda(gpu_id))
        sums = fakeSamples.sum(dim=(1, 2, 3)).detach().cpu().numpy().argsort()[::-1].copy()
        fake_name = DIRNAME + 'fake' + str(i + 1) + '.pt'
        torch.save(fakeSamples.data[sums], fake_name)
        del fakeSamples
        G.train()

1        0
1        1
1        2
1        3
1        4
1        5
1        6
1        7
1        8
1        9
1        10
1        11
1        12
1        13
1        14
1        15
1        16
1        17
1        18
1        19
1        20
1        21
1        22
1        23
1        24
1        25
1        26
1        27
1        28
1        29
1        30
1        31
1        32
1        33
1        34
1        35
1        36
1        37
1        38
1        39
1        40
1        41
1        42
1        43
1        44
1        45
1        46
1        47
1        48
1        49
1        50
1        51
1        52
1        53
1        54
1        55
1        56
1        57
1        58
1        59
1        60
1        61
1        62
1        63
1        64
1        65
1        66
1        67
1        68
1        69
1        70
1        71
1        72
1        73
1        74
1        75
1        76
1        77
1        78
1        79
1        80
1        81
1        82
1        83
1 