Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error #2

Open
qingyuany opened this issue Oct 7, 2023 · 0 comments
Open

Error #2

qingyuany opened this issue Oct 7, 2023 · 0 comments

Comments

@qingyuany
Copy link

Hello. I tried using GAN evaluator on the CIFAR10 dataset. Encountered some errors.
1%| | 3/391 [00:16<35:33, 5.50s/it]
Traceback (most recent call last):
File "D:\python_parctice\pt\FID\FID_MODLE.py", line 204, in
evaluator.load_all_real_imgs(real_loader=dataloader, idx_in_loader=0)
File "D:\Anaconda\install\envs\pt\lib\site-packages\gan_evaluator.py", line 176, in load_all_real_imgs
self.fill_real_img_batch(real_batch)
File "D:\Anaconda\install\envs\pt\lib\site-packages\gan_evaluator.py", line 136, in fill_real_img_batch
self.activation_vec_real[self.vec_real_pointer:self.vec_real_pointer +
ValueError: could not broadcast input array from shape (128,2048) into shape (7,2048)
下面是我的代码:
from future import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
import matplotlib.animation as animation
from IPython.display import HTML

from attribute_hashmap import AttributeHashmap
from gan_evaluator import GAN_Evaluator
from log_utils import log
from seed import seed_everything

from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d

from PIL import Image

import matplotlib.pyplot as plt
import sys
import numpy as np
import os

print(os.listdir("../input"))

import time

SEED=42
random.seed(SEED)
torch.manual_seed(SEED)

Batch size during training

batch_size = 128

Spatial size of training images. All images will be resized to this size using a transformer.

image_size = 64

Number of channels in the training images. For color images this is 3

nc = 3

Size of z latent vector (i.e. size of generator input)

nz = 100

Size of feature maps in generator

ngf = 64

Size of feature maps in discriminator

ndf = 64

Number of training epochs

num_epochs = 70

different Learning rate for optimizers

g_lr = 0.0001
d_lr = 0.0004

Beta1 hyperparam for Adam optimizers

beta1 = 0.5
ngpu=1

#normalizing input between -1 and 1
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0,0,0), (1,1,1)),])

dataset = dset.CIFAR10(root="./input/", train=True,
download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=2)

Decide which device we want to run on

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images

real_batch = next(iter(dataloader))

plt.figure(figsize=(8,8))

plt.axis("off")

plt.title("Training Images")

plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

class Generator(nn.Module):
def init(self, ngpu):
super(Generator, self).init()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf
4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)

def forward(self, input):
    return self.main(input)

custom weights initialization called on netG and netD

def weights_init(m):
classname = m.class.name
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)

Create the generator

netG = Generator(ngpu).to(device)
netG.apply(weights_init)
print(netG)

class Discriminator(nn.Module):
def init(self, ngpu):
super(Discriminator, self).init()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf
4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)

def forward(self, input):
    return self.main(input)

Create the Discriminator

netD = Discriminator(ngpu).to(device)

Apply the weights_init function to randomly initialize all weights

to mean=0, stdev=0.2.

netD.apply(weights_init)

Print the model

print(netD)

Initialize BCELoss function

criterion = nn.BCELoss()

Establish convention for real and fake labels during training

real_label = 1

fake_label = 0

"""adding label smoothing"""
real_label=0.9
fake_label=0.1

Setup Adam optimizers for both G and D

optimizerD = optim.Adam(netD.parameters(), lr=d_lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=g_lr, betas=(beta1, 0.999))

img_list = []
G_losses = []
D_losses = []
iters = 0
epoch_list, IS_list, FID_list = [], [], []
print("Starting Training Loop...")
if name=='main':
# Our GAN Evaluator.
evaluator = GAN_Evaluator(device=device,
num_images_real=len(dataloader),
num_images_fake=len(dataloader))

# We can pre-load the real images in the format of a dataloader.
# Of course you can do that in individual batches, but this way is neater.
# Because in CIFAR10, each batch contains a (image, label) pair, we set `idx_in_loader` = 0.
# If we only have images in the datalaoder, we can set `idx_in_loader` = None.
evaluator.load_all_real_imgs(real_loader=dataloader, idx_in_loader=0)
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        #         # add some noise to the input to discriminator

        real_cpu = 0.9 * real_cpu + 0.1 * torch.randn((real_cpu.size()), device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        with torch.no_grad():
           fake = netG(noise)
           IS_mean, IS_std, FID = evaluator.fill_fake_img_batch(
               fake_batch=fake)
           epoch_list.append(epoch + i / len(dataloader))
           IS_list.append(IS_mean)
           FID_list.append(FID)
        label.fill_(fake_label)

        fake = 0.9 * fake + 0.1 * torch.randn((fake.size()), device=device)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        D_G_z2 = output.mean().item()

        # Calculate gradients for G
        errG.backward()
        # Update G
        optimizerG.step()
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
            with torch.no_grad():
                fixed_noise = torch.randn(ngf, nz, 1, 1, device=device)
                fake_display = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake_display, padding=2, normalize=True))

        iters += 1
    # Update the IS and FID curves every epoch.
    fig = plt.figure(figsize=(10, 4))
    ax = fig.add_subplot(1, 2, 1)
    ax.scatter(epoch_list, IS_list, color='firebrick')
    ax.plot(epoch_list, IS_list, color='firebrick')
    ax.set_ylabel('Inception Score (IS)')
    ax.set_xlabel('Epoch')
    ax.spines[['right', 'top']].set_visible(False)
    ax = fig.add_subplot(1, 2, 2)
    ax.scatter(epoch_list, FID_list, color='firebrick')
    ax.plot(epoch_list, FID_list, color='firebrick')
    ax.set_ylabel('Frechet Inception Distance (FID)')
    ax.set_xlabel('Epoch')
    ax.spines[['right', 'top']].set_visible(False)
    plt.tight_layout()
    plt.close(fig=fig)

    # Need to clear up the fake images every epoch.
    evaluator.clear_fake_imgs()

    G_losses.append(errG.item())
    D_losses.append(errD.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant