In [1]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms


img_data = ImageFolder('./oxford-102/', transform=transforms.Compose([transforms.Resize(80), transforms.CenterCrop(64), transforms.ToTensor()]))
batch_size = 64
img_loader = DataLoader(img_data, batch_size=batch_size, shuffle=True)

In [2]:
from torch import nn


nz = 100
ngf = 32

class GNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, x):
        out = self.main(x)
        return out

In [3]:
ndf = 32

class DNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)
        )

    def forward(self, x):
        out = self.main(x)
        return out.squeeze()

In [4]:
import torch
from torch import optim
from torch.autograd import Variable as V


d = DNet()
g = GNet()

opt_d = optim.Adam(d.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_g = optim.Adam(g.parameters(), lr=0.0002, betas=(0.5, 0.999))
ones = V(torch.ones(batch_size))
zeros = V(torch.zeros(batch_size))
loss_f = nn.BCEWithLogitsLoss()
fixed_z = V(torch.randn(batch_size, nz, 1, 1))

In [5]:
from statistics import mean


def train_dcgan(g, d, opt_g, opt_d, loader):
    log_loss_g = []
    log_loss_d = []
    for real_img, _ in tqdm(loader):
        batch_len = len(real_img)
        z = torch.randn(batch_len, nz, 1, 1)
        fake_img = g(V(z))
        fake_img_tensor = fake_img.data
        out = d(fake_img)
        loss_g = loss_f(out, ones[:batch_len])
        log_loss_g.append(loss_g.data[0])
        d.zero_grad()
        g.zero_grad()
        loss_g.backward()
        opt_g.step()
        real_out = d(V(real_img))
        loss_d_real = loss_f(real_out, ones[:batch_len])
        fake_img = V(fake_img_tensor)
        fake_out = d(fake_img)
        loss_d_fake = loss_f(fake_out, zeros[:batch_len])
        loss_d = loss_d_real + loss_d_fake
        log_loss_d.append(loss_d.data[0])
        d.zero_grad()
        g.zero_grad()
        loss_d.backward()
        opt_d.step()
    return mean(log_loss_g), mean(log_loss_d)

In [6]:
from torchvision.utils import save_image
from tqdm import tqdm


for epoch in range(40):
    train_dcgan(g, d, opt_g, opt_d, img_loader)
    if epoch % 10 == 0:
        torch.save(
            g.state_dict(),
            './oxford-102-gen/g_{:03d}.prm'.format(epoch),
            pickle_protocol=4
        )
        torch.save(
            d.state_dict(),
            './oxford-102-gen/g_{:03d}.prm'.format(epoch),
            pickle_protocol=4
        )
        generated_img = g(fixed_z).data
        save_image(
            generated_img,
            './oxford-102-gen/{:03d}.jpg'.format(epoch)
        )

100%|██████████| 128/128 [02:52<00:00,  1.35s/it]
100%|██████████| 128/128 [02:54<00:00,  1.36s/it]
100%|██████████| 128/128 [02:52<00:00,  1.35s/it]
100%|██████████| 128/128 [02:51<00:00,  1.34s/it]
100%|██████████| 128/128 [02:51<00:00,  1.34s/it]
100%|██████████| 128/128 [02:52<00:00,  1.35s/it]
100%|██████████| 128/128 [02:51<00:00,  1.34s/it]
100%|██████████| 128/128 [02:53<00:00,  1.35s/it]
100%|██████████| 128/128 [02:52<00:00,  1.35s/it]
100%|██████████| 128/128 [02:57<00:00,  1.39s/it]
100%|██████████| 128/128 [02:53<00:00,  1.35s/it]
100%|██████████| 128/128 [02:55<00:00,  1.37s/it]
100%|██████████| 128/128 [02:54<00:00,  1.36s/it]
100%|██████████| 128/128 [02:55<00:00,  1.37s/it]
100%|██████████| 128/128 [02:55<00:00,  1.37s/it]
100%|██████████| 128/128 [02:55<00:00,  1.37s/it]
100%|██████████| 128/128 [03:02<00:00,  1.42s/it]
100%|██████████| 128/128 [02:58<00:00,  1.39s/it]
100%|██████████| 128/128 [02:57<00:00,  1.39s/it]
100%|██████████| 128/128 [02:56<00:00,  1.38s/it]
