In [2]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
import multiprocessing
import matplotlib.pyplot as plt

#def convtrans resblock
class Tresblock1(torch.nn.Module):
    def __init__(self, in_channels:int, out_channels:int, kernel_size, device='cpu', stride=1) -> None:
        super(Tresblock1, self).__init__()
        self.convT2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, 
        padding=(kernel_size-stride)//2, device=device)
        self.batchnorm = nn.BatchNorm2d(out_channels, device=device)
        self.identity = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=stride),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding='same', bias=False, device=device)
        )
        self.ReLU = nn.ReLU(True)
    def forward(self, x):
        hid = self.convT2d(x)
        hid = self.batchnorm(hid)
        out = self.ReLU(hid)  + self.identity(x)
        return out

class Tresblock2(torch.nn.Module):
    def __init__(self, in_channels:int, out_channels:int, kernel_size, device='cpu', stride=1) -> None:
        super(Tresblock2, self).__init__()
        self.convT2d1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, 
        padding=1, output_padding=1, device=device, bias=False)
        self.batchnorm1 = nn.BatchNorm2d(out_channels, device=device)
        self.convT2d2 = nn.ConvTranspose2d(out_channels, out_channels, kernel_size, stride=1, 
        padding=1, device=device, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(out_channels, device=device)
        self.identity = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=stride),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding='same', bias=False, device=device)
        )
        self.ReLU = nn.ReLU(True)
    def forward(self, x):
        hid = self.convT2d1(x)
        hid = self.batchnorm1(hid)
        hid = self.ReLU(hid)
        hid = self.convT2d2(hid)
        hid = self.batchnorm2(hid) + self.identity(x)
        out = self.ReLU(hid)
        return out
#resblock without bn
class resblock(torch.nn.Module):
    def __init__(self, in_channel:int, out_channel:int, stride:int = 1, dilation:int=1, device = 'cpu') -> None:
        super(resblock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False, device=device)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding='same', bias=False, device=device)
        self.downsample = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=stride, bias=False, device=device)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out+= self.downsample(x)
        out = self.relu(out)
        return out
class resblock_bn(torch.nn.Module):
    def __init__(self, in_channel:int, out_channel:int, stride:int = 1, dilation:int=1, device = 'cpu') -> None:
        super(resblock_bn, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False, device=device)
        self.bn1 = nn.BatchNorm2d(out_channel, device=device)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, padding='same', bias=False, device=device)
        self.bn2 = nn.BatchNorm2d(out_channel, device=device)
        self.downsample = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=stride, bias=False, device=device)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out) + self.downsample(x)
        out = self.relu(out)
        return out
# generatar model
#-------------------------------structure of generator---------------------------------------
#preprocess step(linear, relu)   input:batchsize*input_shape    output:batchsize*channel*n*n
#111111111111111111111111111111111111111111111111111111111111111
#1Tresblock1*n inchannel  outchannel if kernelsize=2
#1  convT2d
#1   batchnorm
#1   ReLU
#1   out + upsampling(x)
#1 Tresblock2*n inchannel outchannel if kernelsize=3    <the original form in paper but larger
#1   convT2d
#1   batchnorm
#1   ReLU
#1   convT2d
#1   batchnorm + upsampling(x)
#1   ReLU
#1111111111111111111111111111111111111111111111111111111111111111
#output layer inchannel   3*figsize*figsize
#   convT2d
#   Tanh
#--------------------------------------------------------------------------------------------
#Caution: output is in range [-1,1], should be convert linearly into [0,1] or [0,255]
class Generator(torch.nn.Module):
    def __init__(self, input_shape:int, blocklist, figsize:int, device='cpu', Simple = False) -> None:
        super(Generator, self).__init__()
        self.device = device
        in_channel = blocklist[0]['out_channel']*2
        self.input_channel = in_channel
        Layers = []
        self.input_size = figsize
        for layer in blocklist:
            #change channels
            #apply resblock
            if Simple:
                Layers.extend([
                    nn.ConvTranspose2d(in_channel, layer['out_channel'], layer['kernel'], layer['stride'], padding=1, bias=False),
                    nn.BatchNorm2d(layer['out_channel']),
                    nn.ReLU(True),
                ])
            else:
                if layer['kernel']==2:
                    Layers.append(Tresblock1(in_channel, layer['out_channel'], kernel_size=2, stride=2, device=device))
                elif layer['kernel']==3:
                    Layers.append(Tresblock2(in_channel, layer['out_channel'], kernel_size=3, stride=2, device=device))
            
            self.input_size = self.input_size//layer['stride']
            in_channel = layer['out_channel']
        #give rgb data
        Layers.extend([
            nn.ConvTranspose2d(in_channel, 3, kernel_size=2, stride=2, device=device),
            nn.Tanh()
        ])
        self.input_size = self.input_size//2
        self.preprocess = nn.Sequential(
            nn.Linear(input_shape, blocklist[0]['out_channel']*2*self.input_size**2, device=device),
            nn.ReLU()
        )
        self.Seq = nn.Sequential(*Layers)
        self.init_parameters()
    def forward(self, x) -> torch.Tensor:
        res_in = self.preprocess(x).view(-1, self.input_channel, self.input_size, self.input_size)
        out = self.Seq(res_in)
        return out

    def init_parameters(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
        return 0
# generatar model
#-------------------------------structure of discriminator---------------------------------------
#recieve batchsize*3*input_size*input_size tensor as input(generator output, img should normalized to [0,1] or N(0.5,0.5))
#if use_res==True
#   resblock*n inchannel(first is 3)  outchannel
#       conv2d(no bias)
#       ReLU
#       conv2d(no bias)
#       out + downsampling(x)
#       ReLU
#else
#   [conv2d
#   LeakyReLU]*n
#
#output layer
#   linear
#--------------------------------------------------------------------------------------------
class Discriminator(torch.nn.Module):
    def __init__(self, input_size:int, blocklist, device = 'cpu', use_res = False) -> None:
        super(Discriminator, self).__init__()
        self.device = device
        #self.preprocess = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding='same', device=device)
        in_channel = 3
        self.figsize = input_size
        Layers = []
        if use_res==False:
            for layer in blocklist:
                #change channels
                Layers.extend([
                    nn.Conv2d(in_channel, layer['out_channel'], layer['kernel'],stride=layer['stride'], padding=1, device=device),
                    nn.LeakyReLU()
                ])
                #resampling
                in_channel = layer['out_channel']
                self.figsize = self.figsize//layer['stride']
        elif use_res==True:
            for layer in block_list:
                Layers.append(resblock(in_channel, layer['out_channel'], stride=layer['stride'], device=device))
                in_channel = layer['out_channel']
                self.figsize = self.figsize//layer['stride']
        elif use_res=='bn':
            Layers.append(nn.Conv2d(in_channel, block_list[0]['out_channel'], 3, padding='same', bias=False, device=device))
            Layers.append(nn.ReLU())
            in_channel = block_list[0]['out_channel']
            for layer in block_list:
                Layers.append(resblock_bn(in_channel, layer['out_channel'], stride=layer['stride'], device=device))
                in_channel = layer['out_channel']
                self.figsize = self.figsize//layer['stride']
        self.seq = nn.Sequential(*Layers)
        self.outlinear = nn.Linear(in_features=self.figsize**2*in_channel, out_features=1,device=device)
        self.outchannel = in_channel
        self.init_parameters()

    def forward(self, x) -> torch.Tensor:
        #res_in = self.preprocess(x)
        res_out = self.seq(x)
        out = self.outlinear(res_out.view(-1, self.figsize**2*self.outchannel))
        return(out)
    
    def init_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
        return 0


def get_param_num(model:nn.Module):
    param_num = sum(p.numel() for p in model.parameters())
    trainable_param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return{'Total':param_num, 'Trainable':trainable_param_num}
# using wgan-div to train discriminator with k and p
def train_discrminator(model:Discriminator, real_data:torch.Tensor, fake_data:torch.Tensor, optimizer:torch.optim.Optimizer, k = 2, p = 6, device = 'cpu'):
    model.train()
    score_loss = model(fake_data).mean()-model(real_data).mean()
    #compute gradient loss
    mixconst = torch.rand(real_data.size(0), device=device)
    x_mix = torch.tensordot(torch.diag(mixconst), real_data, dims=[[0],[0]]) \
        + torch.tensordot(torch.diag(1-mixconst), fake_data, dims=[[0],[0]])
    x_mix.requires_grad_()
    model.eval()
    grad_mix = torch.autograd.grad(model(x_mix).sum(), x_mix, create_graph=True, retain_graph=True, only_inputs=True)[0]
    model.train()
    #grad_mix_check =  torch.autograd.grad(model(x_mix), x_mix, grad_outputs=torch.ones(x_mix.size()).to(device), create_graph=True, retain_graph=True)[0]
    gradient_loss = (k*torch.sum(grad_mix**2, dim=[1,2,3])**(p/2)).mean()
    #add loss
    loss = score_loss + gradient_loss
    #train
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return gradient_loss
#train generator using discriminator
def train_generator(g_model:Generator, d_model:Discriminator, randinput:torch.Tensor, optimizer:torch.optim.Optimizer):
    g_model.train()
    d_model.eval()
    g_model.zero_grad()
    loss = -d_model(g_model(randinput)).mean()#torch.sum(-d_model(g_model(randinput)))/batchsize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


In [3]:
import multiprocessing
import matplotlib.pyplot as plt
import numpy as np
import time

def plotsamp(samp, figname):
  fig = plt.figure()
  ax = fig.add_subplot()
  samp = np.transpose(samp, (1,2,0))
  ax.imshow(samp)
  if figname:
    fig.savefig(figname)
  plt.close()
  print('done')

In [14]:
gene_blocklist = [
    {'kernel':3, 'out_channel':256, 'stride':2},
    {'kernel':3, 'out_channel':128, 'stride':2},
    {'kernel':3, 'out_channel':64,  'stride':2},
]
block_list = [
    {'kernel':3, 'out_channel':128, 'stride':2},
    {'kernel':3, 'out_channel':256, 'stride':2},
    {'kernel':3, 'out_channel':256, 'stride':2},
]
device = 'cuda'
batchsize = 64
maxiter = 1e5
use_tensorboard = False 
load = False
iteration = 0
checkpoint_folder = 'LSUN_CHURCH/checkpoint'
try:
    os.makedirs(os.path.join('./',checkpoint_folder))
except:
  print('checkpoint_folder already exist')

checkpoint_folder already exist


In [17]:
#-----------------------------------build model and optimizer-------------------------------------------
gene64 = Generator(128,blocklist=gene_blocklist, figsize=64, device=device)
#gene64 = wgan_gp.Generator().to(device)
dis64 = Discriminator(64, blocklist=block_list, device=device, use_res=True)
#dis64 = wgan_gp.Discriminator().to(device)
optim_g = torch.optim.Adam(gene64.parameters(), lr=2e-4)
optim_d = torch.optim.Adam(dis64.parameters(), lr=2e-4)

#--------------------------------------train iteration--------------------------------------------------------
if load:
    gene64.load_state_dict(torch.load('./LSUN_CHURCH/checkpoint/gene_checkpoint_4000.pth'))
    dis64.load_state_dict(torch.load('./LSUN_CHURCH/checkpoint/dis_checkpoint_4000.pth'))
if use_tensorboard:
    writer = SummaryWriter('./runs/exp2-LSUN/train_{}'.format(datetime.datetime.now().strftime("%Y%m%d_%H_%M_%S")))
    writer.add_graph(dis64, gene64(torch.randn(1,128,device=device)))   #show graph of discriminator in tensorboard

In [18]:
sampin = torch.randn(1, 128, device=device)
fake_fig = gene64(sampin)*0.5 + 0.5
plotsamp(fake_fig.cpu().squeeze().detach().numpy(), 'out.png')

done


In [9]:
import torch
import torch.nn as nn
import numpy as np
import datetime, os
import torchvision
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
import multiprocessing
import matplotlib.pyplot as plt
import numpy as np
transform = transforms.Compose([
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 
])
'''NPY数据格式'''
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = np.load(data) #加载npy数据
        self.data = np.transpose(self.data, (0, 3, 1, 2))
        self.data = torch.from_numpy(self.data)
        self.data = self.data.to(device)
        self.transforms = transform
    def __getitem__(self, index):
        hdct = self.data[index, :, :, :]  # 读取每一个npy的数据
        hdct = hdct/128 - 1
        return hdct, 0 #返回数据还有标签
    def __len__(self):
        return self.data.shape[0] #返回数据的总个数

#-------------------------------------load and transform dataset------------------------------------------
set = MyDataset('./data/LSUN/church_outdoor_train_lmdb_color_64.npy')
train_set, test_set = torch.utils.data.random_split(set, [len(set)//5*4, len(set)-len(set)//5*4])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batchsize, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batchsize, shuffle=True, drop_last=True)

In [19]:
while iteration < maxiter:
    for i_imag, image in enumerate(train_loader):#need reconstruction using iterator
        real_data = image[0].to(device)
        randin = torch.randn(batchsize, 128, device=device)
        gradloss = train_discrminator(dis64, real_data, gene64(randin), optim_d, device=device)
        #train discriminator 5 times each iteration
        if i_imag%5 == 0:
            iteration += 1
            train_generator(gene64, dis64, randin, optim_g)
            #sample generator every 100 iter
            if iteration%100 == 99:
                if iteration%100 == 99:
                  gene64.eval()
                  dis64.eval()
                  with torch.no_grad():
                        #cal D score on test D_score = mean(D(testdata))'
                        d_score = 0
                        for image in test_loader:
                            real_imag = image[0].to(device)
                            d_score -= dis64(real_imag).sum().item()
                        if use_tensorboard:
                            print(f'iter:{iteration+1}', end='\r')
                            sampin = torch.randn(5, 128, device=device)
                            fake_fig = gene64(sampin)*0.5 + 0.5 #transform back to standard rgb
                            writer.add_images('gene', torch.concat([fake_fig, real_data[0,...].unsqueeze(0)*0.5+0.5], dim=0), global_step=iteration)
                            writer.add_scalar('d_score', d_score/len(test_set), iteration)
                        else:
                            print(f'iter:{iteration+1}  d_score:{d_score/len(test_set)}')
                            with open('g_score.txt', mode='a') as f:
                                f.write(','.join(map(str, [iteration+1, dis64(gene64(torch.randn(batchsize*10, 128, device=device))).mean().item()])),'\n')
                            with open('d_score.txt', mode='a') as f:
                                f.write(','.join(map(str, [iteration+1, d_score/len(test_set)])),'\n')
                            sampin = torch.randn(1, 128, device=device)
                            fake_fig = gene64(sampin)*0.5 + 0.5 #transform back to standard rgb
                            plotsamp(fake_fig.cpu().squeeze().detach().numpy(), f'{iteration+1}_gene.png')
            if iteration%100 == 99:
                torch.save(gene64.state_dict(), os.path.join(os.getcwd(), checkpoint_folder, 'gene_checkpoint_last.pth'))
                torch.save(dis64.state_dict(), os.path.join(os.getcwd(), checkpoint_folder, 'dis_checkpoint_last.pth'))
            if iteration%1000 == 1000-1:
                torch.save(gene64.state_dict(), os.path.join(os.getcwd(), checkpoint_folder, f'gene_checkpoint_{iteration+1}.pth'))
                torch.save(dis64.state_dict(), os.path.join(os.getcwd(), checkpoint_folder, f'dis_checkpoint_{iteration+1}.pth'))