In [1]:
import os
from torchvision import datasets,transforms
import torch.nn as nn
import torch
import torchvision.utils as vutils
import ignite
from ignite.engine import Engine, Events
import ignite.distributed as idist
import logging
from ignite.metrics import FID, InceptionScore

  from .autonotebook import tqdm as notebook_tqdm


Loading up loggers, setting seed for reproducibility, using ignite for more readable code

In [2]:
ignite.utils.manual_seed(999) ##42 nefunguje
ignite.utils.setup_logger(name="ignite.distributed.auto.auto_dataloader", level=logging.WARNING)
ignite.utils.setup_logger(name="ignite.distributed.launcher.Parallel", level=logging.WARNING)



In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Setting parameters, source: https://arxiv.org/pdf/1511.06434.pdf%C3

In [4]:
batch_size=32
learning_rate=0.0002 #0.0002
epochs=100
z=4096
in_channels=3
img_shape=64
cat=2
path_to_progress_img="./progress/"
path_to_weights="./weights/"
path_to_images="./img/"

In [5]:
import numpy
dataset=datasets.ImageFolder(
    path_to_images,
    transform=transforms.Compose([
        transforms.Resize((64,64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),
    ]))
test_dataset = torch.utils.data.Subset(dataset, numpy.random.choice(len(dataset), 200, replace=False))

In [6]:
dataloader_train = idist.auto_dataloader(dataset, batch_size=batch_size,
                                         shuffle=True,num_workers=2)
dataloader_test = idist.auto_dataloader(test_dataset, batch_size=batch_size,
                                         shuffle=False,num_workers=2)

Source: https://arxiv.org/pdf/1511.06434.pdf%C3

In [7]:
class Generator(nn.Module):
    def __init__(self,labels,embedding_dim):
        super(Generator, self).__init__()
        self.convT = nn.Sequential(
        nn.ConvTranspose2d(in_channels=z+embedding_dim**2,out_channels=512,kernel_size=4,stride=1,padding=0),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=4,stride=2,padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=2,padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(in_channels=64,out_channels=in_channels,kernel_size=4,stride=2,padding=1),
        nn.Tanh(),
        )
        self.embedding=nn.Embedding(labels,embedding_dim**2)
        self.embedding_dim=embedding_dim
    def forward(self, latent,label):
        lat_emb_vector=torch.cat([latent,self.embedding(label).reshape(label.shape[0],self.embedding_dim**2,1,1)],1)
        img=self.convT(lat_emb_vector)
        return img

In [8]:
class Discriminator(nn.Module):
    def __init__(self,labels,embedding_dim):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels+1,out_channels=64,kernel_size=4,stride=2,padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=512,out_channels=1,kernel_size=4,stride=2,padding=0),
        )
        self.embedding=nn.Embedding(labels,embedding_dim**2)
        self.embedding_dim=embedding_dim
    def forward(self, img,label):
        label = self.embedding(label).view(-1, 1, img_shape, img_shape)
        img_emb_vector = torch.cat([img,label], 1)
        valid=self.conv(img_emb_vector)
        return valid

In [9]:
gen=idist.auto_model(Generator(cat,img_shape))
dis=idist.auto_model(Discriminator(cat,img_shape))

In [10]:
lossFun=nn.BCEWithLogitsLoss()
fixed_noise = torch.randn(8, z, 1, 1)
fixed_labels=torch.randint(low=0, high=cat,size=(8,1,1,1))
real_label = 1
fake_label = 0

In [11]:
optimizerD = torch.optim.Adam(dis.parameters(), lr=learning_rate,betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(gen.parameters(), lr=learning_rate,betas=(0.5, 0.999))

Source: https://machinelearningmastery.com/how-to-code-generative-adversarial-network-hacks/

In [12]:
def smooth_positive_labels(y):
 return y - 0.3 + (torch.rand(y.shape,device=device) * 0.5)
def smooth_negative_labels(y):
 return y + torch.rand(y.shape,device=device) * 0.3

In [13]:
def training_step(engine,data):
                dis.zero_grad()
                real=data[0].to(device)
                batch_size=real.size(0)
                labelTrue = (torch.full((batch_size,1,1,1), real_label,
                       dtype=real.dtype, device=device))
                labelFalse = (torch.full((batch_size,1,1,1), fake_label,
                       dtype=real.dtype, device=device))
                disReal = dis(real,data[1])
                lossReal=lossFun(disReal,labelTrue)
                lossReal.backward()
                D_x = disReal.mean().item()

                noise = torch.randn(batch_size, z, 1, 1, device=device)
                gen_label=torch.randint(low=0, high=cat,size=(batch_size,1,1,1), device=device)
                fake = gen(noise,gen_label)
                disFakeDisBack = dis(fake.detach(),gen_label)
                errD_fake = lossFun(disFakeDisBack, labelFalse)
                errD = lossReal + errD_fake
                errD_fake.backward()
                D_G_z1 = disFakeDisBack.mean().item()
                optimizerD.step()

                gen.zero_grad()
                disFakeGenBack = dis(fake,gen_label)
                errG = lossFun(disFakeGenBack, labelTrue)
                errG.backward()
                D_G_z2 = disFakeGenBack.mean().item()
                optimizerG.step()

                
                return {
                "Loss_G" : errG.item(),
                "Loss_D" : errD.item(),
                "D_x": D_x,
                "D_G_z1": D_G_z1,
                "D_G_z2": D_G_z2,
                }


Source: https://pytorch-ignite.ai/blog/gan-evaluation-with-fid-and-is/

In [14]:
trainer = Engine(training_step)

In [15]:
def initialize_fn(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [16]:
@trainer.on(Events.STARTED)
def init_weights():
    if (len(os.listdir(path_to_weights)) == 0):
        dis.apply(initialize_fn)
        gen.apply(initialize_fn)
    else:
        idx=len(os.listdir(path_to_weights))//2
        dis.load_state_dict(torch.load(path_to_weights+"dis_epoch_%d.pth" % idx))
        gen.load_state_dict(torch.load(path_to_weights+"gen_epoch_%d.pth" % idx))
    

In [17]:
G_losses = []
D_losses = []
D_x = []
D_G_z1 = []
D_G_z2 = []



@trainer.on(Events.ITERATION_COMPLETED)
def store_losses(engine):
    o = engine.state.output
    G_losses.append(o["Loss_G"])
    D_losses.append(o["Loss_D"])
    D_x.append(o["D_x"])
    D_G_z1.append(o["D_G_z1"])
    D_G_z2.append(o["D_G_z2"])

In [18]:
@trainer.on(Events.ITERATION_COMPLETED(every=100))
def store_images(engine):
    with torch.no_grad():
        fake = gen(fixed_noise,fixed_labels).cpu()
    vutils.save_image(fake.detach(),
            '%s/fake_samples_epoch_%03d.png' % (path_to_progress_img, engine.state.epoch),
            normalize=True)

In [19]:
fid_metric = FID(device=idist.device())
is_metric = InceptionScore(device=idist.device(),output_transform=lambda x: x[0]) #, num_features=11

In [20]:
import PIL.Image as Image


def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img)
        resized_img = pil_img.resize((299,299), Image.BILINEAR)
        #resized_img=resized_img.convert('RGB')
        arr.append(transforms.ToTensor()(resized_img))
        
    return torch.stack(arr)


def evaluation_step(engine, batch):
    with torch.no_grad():
        noise = torch.randn(len(batch["image"]), z, 1, 1, device=idist.device())
        gen.eval()
        fake_batch = gen(noise,batch["label"])
        fake = interpolate(fake_batch)
        real = interpolate(batch["image"])
        return fake, real

In [21]:
evaluator = Engine(evaluation_step)
fid_metric.attach(evaluator, "fid")
is_metric.attach(evaluator, "is")

In [22]:
fid_values = []
is_values = []


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    evaluator.run(dataloader_test,max_epochs=1)
    metrics = evaluator.state.metrics
    fid_score = metrics['fid']
    is_score = metrics['is']
    fid_values.append(fid_score)
    is_values.append(is_score)
    print(f"Epoch [{engine.state.epoch}/{epochs}] Metric Scores")
    print(f"*   FID : {fid_score:4f}")
    print(f"*    IS : {is_score:4f}")


In [23]:
@trainer.on(Events.EPOCH_COMPLETED)
def save_weights(engine):
    torch.save(gen.state_dict(), '%s/gen_epoch_%d.pth' % (path_to_weights, engine.state.epoch))
    torch.save(dis.state_dict(), '%s/dis_epoch_%d.pth' % (path_to_weights, engine.state.epoch))

In [24]:
from ignite.metrics import RunningAverage


RunningAverage(output_transform=lambda x: x["Loss_G"]).attach(trainer, 'Loss_G')
RunningAverage(output_transform=lambda x: x["Loss_D"]).attach(trainer, 'Loss_D')
RunningAverage(output_transform=lambda x: x["D_x"]).attach(trainer, 'D_x')
RunningAverage(output_transform=lambda x: x["D_G_z1"]).attach(trainer, 'D_G_z1')
RunningAverage(output_transform=lambda x: x["D_G_z2"]).attach(trainer, 'D_G_z2')

In [25]:
from ignite.contrib.handlers import ProgressBar


ProgressBar().attach(trainer, metric_names=['Loss_G','Loss_D',"D_x","D_G_z1","D_G_z2"])
ProgressBar().attach(evaluator)

In [26]:
def training(*args):
    trainer.run(dataloader_train, max_epochs=epochs)

In [None]:
with idist.Parallel(backend='gloo') as parallel:
    parallel.run(training)