In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import config
from model import Generator, Discriminator
from dataset import get_dataloader
from utils import weights_init, show_results

# セットアップ
device = config.device
dataloader = get_dataloader()

# モデル初期化
netG = Generator(config.ngpu).to(device)
netD = Discriminator(config.ngpu).to(device)
netG.apply(weights_init)
netD.apply(weights_init)

# 損失関数とオプティマイザ
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=config.lr, betas=(config.beta1, 0.999))

# 学習用の変数
fixed_noise = torch.randn(64, config.nz, 1, 1, device=device)
img_list = []
iters = 0

print("Starting Training Loop...")
for epoch in range(config.num_epochs):
    for i, data in enumerate(dataloader, 0):
        
        # --- (1) Update D network ---
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), config.real_label, dtype=torch.float, device=device)
        
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        
        noise = torch.randn(b_size, config.nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(config.fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        optimizerD.step()

        # --- (2) Update G network ---
        netG.zero_grad()
        label.fill_(config.real_label)
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        # ログ出力
        if i % 50 == 0:
            print(f'[{epoch}/{config.num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD_real+errD_fake:.4f} Loss_G: {errG:.4f}')

        # 画像の保存
        if (iters % 500 == 0) or ((epoch == config.num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake_display = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake_display, padding=2, normalize=True))

        iters += 1

# 結果の表示
real_batch = next(iter(dataloader))
show_results(real_batch, img_list, device)

FileNotFoundError: [Errno 2] No such file or directory: '/data/dc_gan/images'