In [0]:
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import (Dataset,DataLoader,TensorDataset)
from torchvision import models
from torch import nn,optim
import tqdm
from torchvision.datasets import FashionMNIST
from torchvision.datasets import MNIST


In [0]:
!wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz

In [0]:
!tar xf 102flowers.tgz
!mkdir oxford=102
!mkdir oxford=102/jpg
!mv jpg/*.jpg oxford=102/jpg

In [0]:
img_data=ImageFolder("oxford=102/",transform=transforms.Compose([transforms.Resize(80),transforms.CenterCrop(64),transforms.ToTensor()]))
batch_size=64
img_loader=DataLoader(img_data,batch_size=batch_size,shuffle=True)

以下２セルは人の顔を用いる時

In [0]:
!wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz
!tar xf lfw-deepfunneled.tgz
!mkdir lfw-deepfunneled/train
!mkdir lfw-deepfunneled/test
!mv lfw-deepfunneled/[A-W]* lfw-deepfunneled/train
!mv lfw-deepfunneled/[X-Z]* lfw-deepfunneled/test

In [0]:
img_data=ImageFolder("lfw-deepfunneled/train",transform=transforms.Compose([transforms.Resize(80),transforms.CenterCrop(64),transforms.ToTensor()]))
batch_size=64
img_loader=DataLoader(img_data,batch_size=batch_size,shuffle=True)

以下の1セルはMNIST

In [3]:
img_data=MNIST("/MNIST",train=True,download=True,transform=transforms.Compose([transforms.CenterCrop(64),transforms.ToTensor()]))
batch_size=128
img_loader=DataLoader(img_data,batch_size=batch_size,shuffle=True)

  0%|          | 16384/9912422 [00:00<01:33, 105743.58it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /MNIST/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:00, 21175261.61it/s]                           


Extracting /MNIST/MNIST/raw/train-images-idx3-ubyte.gz


32768it [00:00, 302427.21it/s]                           
0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /MNIST/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /MNIST/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 4724544.13it/s]                           
8192it [00:00, 118316.62it/s]


Extracting /MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


ここからはモデルの定義

In [0]:
nz=100
ngf=32
in_dim=1

class GNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main=nn.Sequential(
        nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
        nn.BatchNorm2d(ngf*8),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf*4),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf*2),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(inplace=True),
        nn.ConvTranspose2d(ngf, in_dim, 4, 2, 1, bias=False),
        nn.Tanh()
    )
  def forward(self,x):
    out=self.main(x)
    return out
    
ndf=32

class DNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main=nn.Sequential(
        nn.Conv2d(in_dim, ndf, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf*2),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf*4),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf*8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
    )
    
  def forward(self,x):
    out=self.main(x)
    return out.squeeze()
        
d=DNet().to("cuda:0")
g=GNet().to("cuda:0")

opt_d=optim.Adam(d.parameters(),lr=0.0002,betas=(0.5,0.999))
opt_g=optim.Adam(g.parameters(),lr=0.0002,betas=(0.5,0.999))

ones=torch.ones(batch_size).to("cuda:0")
zeros=torch.zeros(batch_size).to("cuda:0")
loss_f=nn.BCEWithLogitsLoss()

fixed_z=torch.randn(batch_size,nz,1,1).to("cuda:0")

In [0]:
from statistics import mean

def train_dcgan(g, d, opt_g, opt_d, loader):
  log_loss_g=[]
  log_loss_d=[]
  for real_img,_ in tqdm.tqdm(loader):
    batch_len=len(real_img)
    real_img=real_img.to("cuda:0")
    z=torch.randn(batch_len,nz,1,1).to("cuda:0")
    fake_img=g(z)
    fake_img_tensor=fake_img.detach()
    out=d(fake_img)
    loss_g=loss_f(out, ones[:batch_len])
    log_loss_g.append(loss_g.item())
    d.zero_grad(),g.zero_grad()
    loss_g.backward()
    opt_g.step()
    real_out=d(real_img)
    loss_d_real=loss_f(real_out, ones[: batch_len])
    fake_img=fake_img_tensor
    fake_out=d(fake_img_tensor)
    loss_d_fake=loss_f(fake_out,zeros[: batch_len])
    loss_d=loss_d_real+loss_d_fake
    log_loss_d.append(loss_d.item())
    d.zero_grad(), g.zero_grad()
    loss_d.backward()
    opt_d.step()
  return mean(log_loss_g), mean(log_loss_d)

以下のセルで学習を行う

In [0]:
from IPython.display import Image,display_jpeg
from torchvision.utils import save_image

for epoch in range(150):　#range(エポック数)　適宜変更可能
  train_dcgan(g,d,opt_g,opt_d,img_loader)
  if epoch%1==0:
    torch.save(
        g.state_dict(),
        "g_{:03d}.prm".format(epoch),
        pickle_protocol=4)
    torch.save(
        d.state_dict(),
        "d_{:03d}.prm".format(epoch),
        pickle_protocol=4)
    generated_img=g(fixed_z)
    save_image(generated_img,"{:03d}.jpg".format(epoch))
    display_jpeg(Image("{:03d}.jpg".format(epoch)))