In [0]:
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import (Dataset,DataLoader,TensorDataset)
from torchvision import models
from torch import nn,optim
import tqdm
from IPython.display import Image,display_jpeg
from torchvision.utils import save_image
from statistics import mean


学習データのダウンロード

In [0]:
!wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz
!tar xf lfw-deepfunneled.tgz
!mkdir lfw-deepfunneled/train
!mv lfw-deepfunneled/[A-Z]* lfw-deepfunneled/train

In [0]:
img_data=ImageFolder("lfw-deepfunneled/train",transform=transforms.Compose([transforms.Resize(80),transforms.CenterCrop(64),transforms.ToTensor()]))
batch_size=32
img_loader=DataLoader(img_data,batch_size=batch_size,shuffle=True)

以下1セルはMNISTを使うとき用

In [0]:
from torchvision.datasets import MNIST
img_data=MNIST("/MNIST",train=True,download=True,transform=transforms.Compose([transforms.CenterCrop(64),transforms.ToTensor()]))
batch_size=32
img_loader=DataLoader(img_data,batch_size=batch_size,shuffle=True)

以下2セルは花を用いる場合

In [0]:
!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!tar xf 102flowers.tgz
!mkdir oxford=102
!mkdir oxford=102/jpg
!mv jpg/*.jpg oxford=102/jpg

In [0]:
img_data=ImageFolder("oxford=102/",transform=transforms.Compose([transforms.Resize(80),transforms.CenterCrop(64),transforms.ToTensor()]))
batch_size=32
img_loader=DataLoader(img_data,batch_size=batch_size,shuffle=True)

ここからはモデルの定義

In [0]:
#@title デフォルトのタイトル テキスト
nz=100
ngf=32
in_dim=3 #入力画像の次元（カラーなら３、グレースケールなら１）

class GNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main=nn.Sequential(
        nn.ConvTranspose2d(nz, ngf, 8, 1, 0, bias=False),
        
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Upsample(scale_factor=2, mode='nearest'),
        
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Upsample(scale_factor=2, mode='nearest'),
        
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Upsample(scale_factor=2, mode='nearest'),
        
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        
        nn.Conv2d(ngf, in_dim, 3, 1, 1, bias=False),
    )
  def forward(self,x):
    out=self.main(x)
    return out

class CAENet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main=nn.Sequential(
        nn.Conv2d(in_dim, ngf, 3, 1, 1, bias=False),
        
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(ngf, ngf*2, 3, 2, 1, bias=False),
        nn.BatchNorm2d(2*ngf),
        nn.ELU(inplace=True),
        
        nn.Conv2d(2*ngf, 2*ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(2*ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(2*ngf, ngf*3, 3, 2, 1, bias=False),
        nn.BatchNorm2d(3*ngf),
        nn.ELU(inplace=True),
        
        nn.Conv2d(3*ngf, 3*ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(3*ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(3*ngf, ngf*4, 3, 2, 1, bias=False),
        nn.BatchNorm2d(4*ngf),
        nn.ELU(inplace=True),
        
        nn.Conv2d(4*ngf, 4*ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(4*ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(4*ngf, ngf*4, 3, 1, 1, bias=False),
        nn.BatchNorm2d(4*ngf),
        nn.ELU(inplace=True),
        
        nn.Conv2d(4*ngf, nz, 8, 1, 0, bias=False),
        
        nn.ConvTranspose2d(nz, ngf, 8, 1, 0, bias=False),
        
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Upsample(scale_factor=2, mode='nearest'),
        
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Upsample(scale_factor=2, mode='nearest'),
        
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Upsample(scale_factor=2, mode='nearest'),
        
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        nn.Conv2d(ngf, ngf, 3, 1, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ELU(inplace=True),
        
        nn.Conv2d(ngf, in_dim, 3, 1, 1, bias=False),
    
    )
    
  def forward(self,x):
    out=self.main(x)
    return out
  
cae=CAENet().to("cuda:0")
g=GNet().to("cuda:0")



In [0]:
opt_cae=optim.Adam(cae.parameters(),lr=0.0002,betas=(0.5,0.999))
opt_g=optim.Adam(g.parameters(),lr=0.0002,betas=(0.5,0.999))

ones=torch.ones(batch_size).to("cuda:0")
zeros=torch.zeros(batch_size).to("cuda:0")
loss_f=nn.MSELoss()

In [0]:
def train_began(g, cae, opt_g, opt_cae, loader,k=0.0,lamda=0.001,gamma=0.7):
  
  log_loss_g=[]
  log_loss_cae=[]
  roop=0
  for real_img,_ in tqdm.tqdm(loader):
    batch_len=len(real_img)
    real_img=real_img.to("cuda:0")
    z=torch.randn(batch_len,nz,1,1).to("cuda:0")
    fake_img=g(z)
    fake_img_tensor=fake_img.detach()
    out=cae(fake_img)
    loss_g=loss_f(out, fake_img)
    log_loss_g.append(loss_g.item())
    cae.zero_grad(),g.zero_grad()
    loss_g.backward()
    opt_g.step()
    
    
    real_out=cae(real_img)
    loss_cae_real=loss_f(real_out, real_img)
    fake_img2=fake_img_tensor
    fake_out=cae(fake_img2)
    loss_cae_fake=loss_f(fake_out,fake_img2)
    loss_cae=loss_cae_real-k*loss_cae_fake
    log_loss_cae.append(loss_cae.item())
    cae.zero_grad(), g.zero_grad()
    loss_cae.backward()
    opt_cae.step()
    
    loss_cae_fake_tensor= loss_cae_fake.detach()
    loss_cae_real_tensor=loss_cae_real.detach()
    k+=lamda*(gamma*loss_cae_real_tensor-loss_cae_fake_tensor)
    roop+=1
  print (abs(gamma*loss_cae_real_tensor-loss_cae_fake_tensor)+loss_cae_real_tensor)
  return k



以下のセルで学習を行う

In [0]:
fixed_z=torch.randn(batch_size,nz,1,1).to("cuda:0")
know=0.0
for epoch in range(100): #range(エポック数)で、適宜変更可能
  know=train_began(g,cae,opt_g,opt_cae,img_loader,know)
  if epoch%1==0:
    torch.save(
        g.state_dict(),
        "g_{:03d}.prm".format(epoch),
        pickle_protocol=4)
    torch.save(
        cae.state_dict(),
        "cae_{:03d}.prm".format(epoch),
        pickle_protocol=4)
    generated_img=g(fixed_z)
    save_image(generated_img,"{:03d}.jpg".format(epoch))
    display_jpeg(Image("{:03d}.jpg".format(epoch)))

以下のセルで画像を生成

In [0]:
from IPython.display import Image,display_jpeg
from torchvision.utils import save_image

#生成する画像の数を指定
num_of_pic=10

z=torch.randn(num_of_pic,nz,1,1).to("cuda:0")

params=torch.load("BEGAN_face.prm") #学習済みパラメータを読み込む(パスを入力)
g.load_state_dict(params)

generated_img=g(z[:num_of_pic,:,:,:])
save_image(generated_img,"ganerated.jpg")
display_jpeg(Image("ganerated.jpg"))