# WGAN tomato - 64x64

Import libraries

In [None]:
from __future__ import print_function
import os
import time
import datetime
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from PIL import Image

%matplotlib inline

Set up paths

In [None]:
dataset_dir = '../input/tomato-leaf-diseases/Bacterial spot'
figures_dir = './figures/bacterial spot'
checkpoints_dir = './checkpoints/bacterial spot/'
graphs_dir = './graphs/bacterial spot'
old_checkpoints_dir = '../input/checkpointsfin'

if not(os.path.exists(figures_dir)): os.makedirs(figures_dir)
if not(os.path.exists(checkpoints_dir)): os.makedirs(checkpoints_dir)
if not(os.path.exists(graphs_dir)): os.makedirs(graphs_dir)
    
fg = open("g_losses.txt", "a")
fd = open("d_losses.txt", "a")
fe = open("epoch.txt", "a")

Set up hyperparameters

In [None]:
workers = 2

batch_size = 64
image_size = 64

nc = 3
noise_dim = 100

nfg = 64
nfd = 64
epochs = 4001

g_learning_rate = 1e-4
d_learning_rate = 1e-4
beta1 = 0.5
beta2 = 0.999
critic_iterations = 5
lambda_gp=10


True if you want to load the model from disk, False if you want the model to be initialized from scratch

In [None]:
load_model = True

Set up GPU device for training

In [None]:
device = torch.device('cuda:0')

In [None]:
!pwd

Load the dataset

In [None]:
transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data = datasets.ImageFolder(dataset_dir, transform=transform)
# dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=True)
data_loader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=batch_size, num_workers=workers)

View samples from the dataset

In [None]:
ds_sample = next(iter(data_loader))

plt.figure(figsize=(8, 8))
plt.axis('off')
plt.title('Train data')
grid = np.transpose(utils.make_grid(ds_sample[0].to(device)[:64], padding=4, normalize=True).cpu(), (1, 2, 0))
plt.imshow(grid)

Define a method for weights initialization

In [None]:
def init_weights(model):
    if model.__class__.__name__.find('Conv') != -1:
        nn.init.normal_(model.weight, 0.0, 0.02)
    elif model.__class__.__name__.find('BatchNorm') != -1:
        nn.init.normal_(model.weight, 1.0, 0.02)
        nn.init.zeros_(model.bias)

Define the generator network

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            nn.ConvTranspose2d(noise_dim, nfg*8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(nfg*8),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(nfg*8, nfg*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(nfg*4),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(nfg*4, nfg*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(nfg*2),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(nfg*2, nfg, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(nfg),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(nfg, nc, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, input):
        return self.model(input)

Define the discriminator network

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(nc, nfd, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Dropout2d(0.5, inplace=False),
            
            nn.Conv2d(nfd, nfd*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(nfg*2, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(nfd*2, nfd*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(nfg*4, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(nfd*4, nfd*8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(nfg*8, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Dropout2d(0.5, inplace=False),
            
            nn.Conv2d(nfd*8, 1, kernel_size=4, stride=2, padding=0, bias=False),
        )
        
    def forward(self, input):
        return self.model(input)

Instantiate the generator and initialize its weights (option 1)

In [None]:
if load_model == False:
    generator = Generator().to(device)
    generator.apply(init_weights)

Instantiate the discriminator and initialize its weights (option 1)

In [None]:
if load_model == False:
    discriminator = Discriminator().to(device)
    discriminator.apply(init_weights)

Load the model from the disk (option 2)

In [None]:
if load_model == True:
    generator = torch.load(os.path.join(old_checkpoints_dir, 'generator_new4000.pt'))
    discriminator = torch.load(os.path.join(old_checkpoints_dir, 'discriminator_new4000.pt'))

Define the loss function (BinaryCrossEntropy)

In [None]:
# cross_entropy = nn.BCELoss()

Define a noise vector to use to track progress

In [None]:
sample_noise = torch.randn(64, noise_dim, 1, 1, device=device)

Define the optimizers (Adam)

In [None]:
disc_optimizer = optim.Adam(discriminator.parameters(), lr=d_learning_rate, betas = (0.0, 0.9))
gen_optimizer = optim.Adam(generator.parameters(), lr=g_learning_rate, betas=(0.0, 0.9))

Define a function to plot loss

In [None]:
def plot_loss(gen_losses, disc_losses, epoch=None, save=False, show=True):
    plt.figure(figsize=(10, 5))
    plt.title('Generator and Discriminator losses')
    plt.plot(gen_losses, label='G')
    plt.plot(disc_losses, label='D')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.legend()
    
    if save == True:
        plt.savefig(os.path.join(graphs_dir, f'loss_{epoch}.jpg'))
    if show == True:
        plt.show()

Train both networks simultaneously

In [None]:
def gradient_penalty(critic, real, fake, device):
    batch_size, C, H, W = real.shape
    epsilon = torch.rand((batch_size, 1, 1, 1)).repeat(1, C, H, W).to(device)
    try:
        interpolated_images = real * epsilon + fake * (1 - epsilon)
    except:
        print(real.shape)
        print(fake.shape)
        print(epsilon.shape)
    
    mixed_scores = critic(interpolated_images)
    
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
gen_losses = []
disc_losses = []

for epoch in range(epochs):
    start = time.time()
    for i, data in enumerate(data_loader, 0):

        ###TRAIN DISCRIMINATOR
        #Put the real images on the GPU
        real = data[0].to(device)
        #Iterate
        for _ in range(critic_iterations):
        #Generate fake images for later use
            size = real.size(0)
            noise = torch.randn(size, noise_dim, 1, 1, device=device)
            fake = generator(noise)
            label = torch.full((size,), 1, device=device, dtype=torch.float)
            output_real = discriminator(real).view(-1)  
            real_mean = output_real.mean().item()
            label.fill_(0)
            output_fake = discriminator(fake.detach()).view(-1)
            fake_mean = output_fake.mean().item()
            gp = gradient_penalty(discriminator, real, fake, device)
            disc_err = (-(torch.mean(output_real) - torch.mean(output_fake))+lambda_gp*gp)
            # Zero out gradients prior to backward passes
            discriminator.zero_grad()    
            disc_err.backward(retain_graph=True)
            disc_optimizer.step()
            
        # TRAIN THE GENERATOR
        # Discriminate on fake with updated discriminator
        label.fill_(1)
        output = discriminator(fake).view(-1)
        gen_mean = output.mean().item()
        # Calculate loss on fake
        gen_err = -torch.mean(output)
        generator.zero_grad()
        gen_err.backward()
        gen_optimizer.step()
        
        if epoch % 1000 == 0:
            print('[%d/%d][%d/%d] \tD-Loss:%.4f\t G-Loss:%.4f\t D(x):%.4f\t D(G(z)):%.4f\t G(z):%.4f' 
                  % (epoch + 1, epochs, i + 1, len(data_loader), disc_err.item(), gen_err.item(), real_mean, fake_mean, gen_mean))
            
        gen_losses.append(gen_err.item())
        disc_losses.append(disc_err.item())
        
    end = time.time()
    timedelta = datetime.timedelta(seconds=int(end - start))
    if epoch % 500 == 0: print(f'Time elapsed for epoch {epoch + 1}: {timedelta}\n')
    with torch.no_grad():
        sample = generator(sample_noise).detach().cpu()
    grid = np.transpose(utils.make_grid(sample, padding=4, normalize=True).cpu(), (1, 2, 0))
    

    
    if epoch % 1000 == 0:
        # Generate loss graph
        plt.figure(figsize=(8, 8))
        plt.axis('off')
        plt.imshow(grid)
        plt.savefig(os.path.join(figures_dir, f'epoch_{epoch + 1}.png'))
        plt.close()
        
        with open("d_losses.txt",'a',encoding = 'utf-8') as fd:
           fd.write(str(disc_losses[-1]) +"\n")
        with open("g_losses.txt",'a',encoding = 'utf-8') as fg:
           fg.write(str(gen_losses[-1]) +"\n")
        
        
        # Save progress
        torch.save(generator, os.path.join(checkpoints_dir, f'generator_new{epoch}.pt'))
        torch.save(discriminator, os.path.join(checkpoints_dir, f'discriminator_new{epoch}.pt'))

In [None]:
with open("epoch.txt",'w',encoding = 'utf-8') as fe:
   fe.write(str(epoch+1) +"\n")
        

Clear CUDA cache if needed

Plot the loss graph

Generate samples on random noise

<a href="BM1.zip"> Download File </a>

<a href="working.zip"> Download File </a>