In [1]:
#Import everything
import numpy as np
import csv
import os
import torch
from tqdm import tqdm
import cv2
import torch.nn as nn
import torch.nn.functional as F
from math import log2
import torchvision.transforms as transforms
import random
import torchvision
from scipy.stats import truncnorm
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.utils import save_image
import torch.optim as optim
from torch.nn import Module, Sequential, Conv2d, ConvTranspose2d, LeakyReLU, BatchNorm2d, ReLU, Tanh, Sigmoid, BCELoss 
import matplotlib.pyplot as plt

torch.backends.cudnn.benchmarks = True

In [2]:
# Always good to check if gpu support available or not
dev = 'cuda:0' if torch.cuda.is_available() == True else 'cpu'
device = torch.device(dev)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

In [None]:
def crop_center(pil_img, crop_width, crop_height):
    img_width, img_height = pil_img.size
    return pil_img.crop(((img_width - crop_width) // 2,
                         (img_height - crop_height) // 2,
                         (img_width + crop_width) // 2,
                         (img_height + crop_height) // 2))
def crop_max_square(pil_img):
    return crop_center(pil_img, min(pil_img.size), min(pil_img.size))

import os
from PIL import Image
f = r'C:/Users/Unreal User/Downloads/Nevus Database/train'
j=1
for file in os.listdir(f):
    f_img = f+"/"+file
    images = Image.open(f_img)
    img_r = crop_max_square(images)
    img_c = img_r.resize((1024,1024))
    img_c.save(f_img)
    j+=1

print('DONE!')

In [37]:
START_TRAIN_AT_IMG_SIZE = 256
DATASET = 'C:/Users/Unreal User/Downloads/Nevus Database'
CHECKPOINT_G = "netG-150-M.pth"
CHECKPOINT_D = "netD-150-M.pth"
SAVE_MODEL = True
LOAD_MODEL = True
LEARNING_RATE = 1e-3
BATCH_SIZES = [32, 32, 32, 16, 16, 16, 8, 4]
CHANNELS_IMG = 3
Z_DIM = 256  # should be 512 in original paper
IN_CHANNELS = 256 # should be 512 in original paper
D_ITERATIONS = 1
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [100] * len(BATCH_SIZES)
FIXED_NOISE = torch.randn(7, Z_DIM, 1, 1).to(device)
NUM_WORKERS = 2


In [4]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def seed_everything(seed=69):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

#Prepare the data
# path to the image directory
dir_data  = r'C:\Users\Unreal User\Downloads\Melanoma_Resize\Melanoma Resize'
 
# setting image shape to 1024x1024
img_shape = (1024, 1024, 3)
 
# listing out all file names
nm_imgs   = np.sort(os.listdir(dir_data)) 

X_train = []
for file in nm_imgs:
    try:
        img = Image.open(dir_data+'/'+file)
        img = img.convert('RGB')
        img = img.resize((1024,1024))
        img = np.asarray(img)/255
        X_train.append(img)
    except:
        print("something went wrong")
 
X_train = np.array(X_train)
X_train.shape
savez_compressed('kaggle_images_1024x1024.npz', X_train) #Save to a numpy array for future use

#Numpy helper function for displaying
# plot images in a nxn grid
 
def plot_images(imgs, grid_size = 5):
    """
    imgs: vector containing all the numpy images
    grid_size: 2x2 or 5x5 grid containing images
    """
     
    fig = plt.figure(figsize = (8, 8))
    columns = rows = grid_size
    plt.title("Training Images")
 
    for i in range(1, columns*rows +1):
        plt.axis("off")
        fig.add_subplot(rows, columns, i)
        plt.imshow(imgs[i])
    plt.show()

#Dataset class cause Pytorch likes them:
class HumanFacesDataset(Dataset):
    """Human Faces dataset."""
 
    def __init__(self, npz_imgs):
        """
        Args:
            npz_imgs (string): npz file with all the images (created in gan.ipynb)
        """
        self.imgs = npz_imgs
 
    def __len__(self):
        return len(self.imgs)
 
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
 
        image = self.imgs[idx]
 
        return image

In [5]:
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=False,
    )
    return loader, dataset

kernel size given by 1st tut = 4
                  by 2nd tut = 3

In [6]:
factors = [1,1,1,1,1/2, 1/4, 1/8, 1/16] #For progressive learning
#Define Layer clases
class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None
        
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)
        
    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

class PixelNorm(nn.Module):#Can be excluted and simply use BatchNorm
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
    
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

Potentially replace in initial layer th ePixelNorm and ConvTranspose layers with a WSConv layer

In [7]:
#Defining the Generator
class Generator(nn.Module):
    def __init__(self,z_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0), #1x1 -> 4x4
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)
        
        self.prog_blocks, self.rgb_layers = (nn.ModuleList([]), nn.ModuleList([self.initial_rgb]),)

        for i in range(len(factors) - 1):  # -1 to prevent index error because of factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0))
            
    def fade_in(self, alpha, upscaled, generated):
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
    
    def forward(self, x, alpha, steps): #steps=0 (4x4), steps=1 (8x8)
        out = self.initial(x) 

        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)

        # The number of channels in upscale will stay the same, while
        # out which has moved through prog_blocks might change. To ensure
        # we can convert both to rgb we use different rgb_layers
        # (steps-1) and steps for upscaled, out respectively
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

In [8]:
# Defining the Discriminator class
 
class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        # here we work back ways from factors because the discriminator
        # should be mirrored from the generator. So the first prog_block and
        # rgb layer we append will work for input size 1024x1024, then 512->256-> etc
        for i in range(len(factors) - 1, 0, -1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c, use_pixelnorm=False))
            self.rgb_layers.append(WSConv2d(img_channels, conv_in_c, kernel_size=1, stride=1, padding=0))
        
        #For 4x4 resolution
        self.initial_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.initial_rgb)
        
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)  # down sampling using avg pool

        # this is the block for 4x4 input size
        self.final_block = nn.Sequential(
            # +1 to in_channels because we concatenate from MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1),  # this pretty much is a linear layer
        )
    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3]))
        # we take the std for each example (across all channels, and pixels) then we repeat it
        # for a single channel and concatenate it with the image. In this way the discriminator
        # will get information about the variation in the batch/image
        return torch.cat([x, batch_statistics], dim=1)    
     
    def forward(self, x, alpha, steps):
        # where we should start in the list of prog_blocks, maybe a bit confusing but
        # the last is for the 4x4. So example let's say steps=1, then we should start
        # at the second to last because input_size will be 8x8. If steps==0 we just
        # use the final block
        cur_step = len(self.prog_blocks) - steps

        # convert from rgb as initial step, this will depend on
        # the image size (each will have it's on rgb layer)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # i.e, image is 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        # because prog_blocks might change the channels, for down scale we use rgb_layer
        # from previous/smaller size which in our case correlates to +1 in the indexing
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        # the fade_in is done first between the downscaled and the input
        # this is opposite from the generator
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

if __name__ == "__main__":
    Z_DIM = 512
    IN_CHANNELS = 512
    gen = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
    critic = Discriminator(IN_CHANNELS, img_channels=3)

    for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
        num_steps = int(log2(img_size / 4))
        x = torch.randn((1, Z_DIM, 1, 1))
        z = gen(x, 0.5, steps=num_steps)
        assert z.shape == (1, 3, img_size, img_size)
        out = critic(z, alpha=0.5, steps=num_steps)
        assert out.shape == (1, 1)
        print(f"Success! At img size: {img_size}")

def init_weights(m):
    if type(m) == ConvTranspose2d:
        nn.init.normal_(m.weight, 0.0, 0.02)
    elif type(m) == BatchNorm2d:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.constant_(m.bias, 0)

In [40]:
# creating gen and disc
netG = Generator(Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG).to(device)
netD = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(device)

# initialize optimizers and scalers for FP16 training
opt_D = optim.Adam(netD.parameters(), lr = LEARNING_RATE, betas= (0.0, 0.99))
opt_G = optim.Adam(netG.parameters(), lr = LEARNING_RATE, betas= (0.0, 0.99))

scaler_D = torch.cuda.amp.GradScaler()
scaler_G = torch.cuda.amp.GradScaler()
# Setting up the loss function - BCELoss (to check how far the predicted value is from real value)

if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_G, netG, opt_G, LEARNING_RATE,)
        load_checkpoint(CHECKPOINT_D, netD, opt_D, LEARNING_RATE,)
 


=> Loading checkpoint
=> Loading checkpoint


In [10]:
def train_fn(netD, netG, loader, dataset, step, alpha, opt_D, opt_G, tensorboard_step, scaler_G, scaler_D):
    d_loss_lst = ['D']
    g_loss_lst = ['G']
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        
        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
        # which is equivalent to minimizing the negative of the expression
        noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)

        with torch.cuda.amp.autocast():
            fake = netG(noise, alpha, step)
            netD_real = netD(real, alpha, step)
            netD_fake = netD(fake.detach(), alpha, step)
            gp = gradient_penalty(netD, real, fake, alpha, step, device=device)
            loss_D = (-(torch.mean(netD_real) - torch.mean(netD_fake)) + LAMBDA_GP * gp + (0.001 * torch.mean(netD_real ** 2)))

        opt_D.zero_grad()
        scaler_D.scale(loss_D).backward()
        scaler_D.step(opt_D)
        scaler_D.update()
        d_loss_lst.append(loss_D.item())

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        with torch.cuda.amp.autocast():
            gen_fake = netD(fake, alpha, step)
            loss_G = -torch.mean(gen_fake)

        opt_G.zero_grad()
        scaler_G.scale(loss_G).backward()
        scaler_G.step(opt_G)
        scaler_G.update()
        g_loss_lst.append(loss_G.item())

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / ((PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset))
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = netG(FIXED_NOISE, alpha, step) * 0.5 + 0.5
            tensorboard_step += 1
        
        loop.set_postfix(
            gp=gp.item(),
            loss_D=loss_D.item(),
        )
        
    with open('lossM.csv', 'a') as f:
        write = csv.writer(f)  
        write.writerow(d_loss_lst)
        write.writerow(g_loss_lst)
        f.close()
    return tensorboard_step, alpha

In [11]:
netG.train()
netD.train()
tensorboard_step = 0
# start at step that corresponds to img size that we set in config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5  # start with very low alpha
    loader, dataset = get_loader(4 * 2 ** step)  # 4->0, 8->1, 16->2, 32->3, 64 -> 4
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        tensorboard_step, alpha = train_fn(netD, netG, loader, dataset, step, alpha, opt_D, opt_G, tensorboard_step,  scaler_G, scaler_D)
        if SAVE_MODEL:
            save_checkpoint(netG, opt_G, filename=CHECKPOINT_G)
            save_checkpoint(netD, opt_D, filename=CHECKPOINT_D)

    step += 1  # progress to the next img size

Current image size: 256
Epoch [1/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:40<00:00,  2.43it/s, gp=0.00377, loss_D=-34]


=> Saving checkpoint
=> Saving checkpoint
Epoch [2/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:34<00:00,  2.49it/s, gp=0.0505, loss_D=-4.86]


=> Saving checkpoint
=> Saving checkpoint
Epoch [3/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0792, loss_D=-17.2]


=> Saving checkpoint
=> Saving checkpoint
Epoch [4/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.096, loss_D=-14.7]


=> Saving checkpoint
=> Saving checkpoint
Epoch [5/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0231, loss_D=12.3]


=> Saving checkpoint
=> Saving checkpoint
Epoch [6/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0662, loss_D=-6.08]


=> Saving checkpoint
=> Saving checkpoint
Epoch [7/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.104, loss_D=-.607]


=> Saving checkpoint
=> Saving checkpoint
Epoch [8/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0778, loss_D=7.3]


=> Saving checkpoint
=> Saving checkpoint
Epoch [9/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.122, loss_D=-3.47]


=> Saving checkpoint
=> Saving checkpoint
Epoch [10/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0451, loss_D=-6.6]


=> Saving checkpoint
=> Saving checkpoint
Epoch [11/100]


100%|███████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.11, loss_D=-17]


=> Saving checkpoint
=> Saving checkpoint
Epoch [12/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.102, loss_D=-14.7]


=> Saving checkpoint
=> Saving checkpoint
Epoch [13/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0217, loss_D=-.974]


=> Saving checkpoint
=> Saving checkpoint
Epoch [14/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.064, loss_D=2.57]


=> Saving checkpoint
=> Saving checkpoint
Epoch [15/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0458, loss_D=3.25]


=> Saving checkpoint
=> Saving checkpoint
Epoch [16/100]


100%|██████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.00336, loss_D=-.618]


=> Saving checkpoint
=> Saving checkpoint
Epoch [17/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0925, loss_D=11.7]


=> Saving checkpoint
=> Saving checkpoint
Epoch [18/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0521, loss_D=-17.6]


=> Saving checkpoint
=> Saving checkpoint
Epoch [19/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.159, loss_D=-25.2]


=> Saving checkpoint
=> Saving checkpoint
Epoch [20/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0388, loss_D=-5.67]


=> Saving checkpoint
=> Saving checkpoint
Epoch [21/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0134, loss_D=1.49]


=> Saving checkpoint
=> Saving checkpoint
Epoch [22/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0144, loss_D=-1.66]


=> Saving checkpoint
=> Saving checkpoint
Epoch [23/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0324, loss_D=6.57]


=> Saving checkpoint
=> Saving checkpoint
Epoch [24/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0413, loss_D=11.3]


=> Saving checkpoint
=> Saving checkpoint
Epoch [25/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0107, loss_D=0.153]


=> Saving checkpoint
=> Saving checkpoint
Epoch [26/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0145, loss_D=-11.1]


=> Saving checkpoint
=> Saving checkpoint
Epoch [27/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.164, loss_D=-13.9]


=> Saving checkpoint
=> Saving checkpoint
Epoch [28/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0602, loss_D=-3.64]


=> Saving checkpoint
=> Saving checkpoint
Epoch [29/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0319, loss_D=-6.18]


=> Saving checkpoint
=> Saving checkpoint
Epoch [30/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.108, loss_D=8.71]


=> Saving checkpoint
=> Saving checkpoint
Epoch [31/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0105, loss_D=1.12]


=> Saving checkpoint
=> Saving checkpoint
Epoch [32/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0347, loss_D=-1.69]


=> Saving checkpoint
=> Saving checkpoint
Epoch [33/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.104, loss_D=4.17]


=> Saving checkpoint
=> Saving checkpoint
Epoch [34/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0224, loss_D=14.3]


=> Saving checkpoint
=> Saving checkpoint
Epoch [35/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0841, loss_D=-2.59]


=> Saving checkpoint
=> Saving checkpoint
Epoch [36/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0214, loss_D=2.26]


=> Saving checkpoint
=> Saving checkpoint
Epoch [37/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0375, loss_D=1.38]


=> Saving checkpoint
=> Saving checkpoint
Epoch [38/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0784, loss_D=-.492]


=> Saving checkpoint
=> Saving checkpoint
Epoch [39/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0157, loss_D=-6.29]


=> Saving checkpoint
=> Saving checkpoint
Epoch [40/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0092, loss_D=-.138]


=> Saving checkpoint
=> Saving checkpoint
Epoch [41/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0478, loss_D=-16]


=> Saving checkpoint
=> Saving checkpoint
Epoch [42/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.111, loss_D=-11.7]


=> Saving checkpoint
=> Saving checkpoint
Epoch [43/100]


100%|███████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.129, loss_D=10]


=> Saving checkpoint
=> Saving checkpoint
Epoch [44/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0259, loss_D=4.86]


=> Saving checkpoint
=> Saving checkpoint
Epoch [45/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.122, loss_D=-23.5]


=> Saving checkpoint
=> Saving checkpoint
Epoch [46/100]


100%|██████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.114, loss_D=-12]


=> Saving checkpoint
=> Saving checkpoint
Epoch [47/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0479, loss_D=1.26]


=> Saving checkpoint
=> Saving checkpoint
Epoch [48/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0601, loss_D=4.6]


=> Saving checkpoint
=> Saving checkpoint
Epoch [49/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.095, loss_D=-10.7]


=> Saving checkpoint
=> Saving checkpoint
Epoch [50/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.121, loss_D=-26.9]


=> Saving checkpoint
=> Saving checkpoint
Epoch [51/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.151, loss_D=-9.92]


=> Saving checkpoint
=> Saving checkpoint
Epoch [52/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0821, loss_D=-11.5]


=> Saving checkpoint
=> Saving checkpoint
Epoch [53/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0339, loss_D=-24.5]


=> Saving checkpoint
=> Saving checkpoint
Epoch [54/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0694, loss_D=-13]


=> Saving checkpoint
=> Saving checkpoint
Epoch [55/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0765, loss_D=-11.5]


=> Saving checkpoint
=> Saving checkpoint
Epoch [56/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0264, loss_D=-14.1]


=> Saving checkpoint
=> Saving checkpoint
Epoch [57/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.091, loss_D=-15.4]


=> Saving checkpoint
=> Saving checkpoint
Epoch [58/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0202, loss_D=-9.27]


=> Saving checkpoint
=> Saving checkpoint
Epoch [59/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:38<00:00,  2.45it/s, gp=0.0343, loss_D=-10.4]


=> Saving checkpoint
=> Saving checkpoint
Epoch [60/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:33<00:00,  2.50it/s, gp=0.0115, loss_D=-5.92]


=> Saving checkpoint
=> Saving checkpoint
Epoch [61/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0611, loss_D=-9.59]


=> Saving checkpoint
=> Saving checkpoint
Epoch [62/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0323, loss_D=-1.39]


=> Saving checkpoint
=> Saving checkpoint
Epoch [63/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0554, loss_D=-3.05]


=> Saving checkpoint
=> Saving checkpoint
Epoch [64/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.016, loss_D=4.15]


=> Saving checkpoint
=> Saving checkpoint
Epoch [65/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0991, loss_D=4.32]


=> Saving checkpoint
=> Saving checkpoint
Epoch [66/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0455, loss_D=4.98]


=> Saving checkpoint
=> Saving checkpoint
Epoch [67/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0185, loss_D=0.611]


=> Saving checkpoint
=> Saving checkpoint
Epoch [68/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0816, loss_D=-1.91]


=> Saving checkpoint
=> Saving checkpoint
Epoch [69/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0647, loss_D=-18.2]


=> Saving checkpoint
=> Saving checkpoint
Epoch [70/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0576, loss_D=-12.5]


=> Saving checkpoint
=> Saving checkpoint
Epoch [71/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.164, loss_D=1.88]


=> Saving checkpoint
=> Saving checkpoint
Epoch [72/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.176, loss_D=-.807]


=> Saving checkpoint
=> Saving checkpoint
Epoch [73/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0438, loss_D=5.24]


=> Saving checkpoint
=> Saving checkpoint
Epoch [74/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0258, loss_D=2.83]


=> Saving checkpoint
=> Saving checkpoint
Epoch [75/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:35<00:00,  2.49it/s, gp=0.0388, loss_D=1.89]


=> Saving checkpoint
=> Saving checkpoint
Epoch [76/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:34<00:00,  2.49it/s, gp=0.0679, loss_D=8.28]


=> Saving checkpoint
=> Saving checkpoint
Epoch [77/100]


100%|██████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0731, loss_D=-.0798]


=> Saving checkpoint
=> Saving checkpoint
Epoch [78/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0284, loss_D=6.13]


=> Saving checkpoint
=> Saving checkpoint
Epoch [79/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0577, loss_D=-13.5]


=> Saving checkpoint
=> Saving checkpoint
Epoch [80/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0854, loss_D=-6.88]


=> Saving checkpoint
=> Saving checkpoint
Epoch [81/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:35<00:00,  2.48it/s, gp=0.333, loss_D=-3.46]


=> Saving checkpoint
=> Saving checkpoint
Epoch [82/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:37<00:00,  2.46it/s, gp=0.0655, loss_D=7.18]


=> Saving checkpoint
=> Saving checkpoint
Epoch [83/100]


100%|█████████████████████████████████████████████████████| 535/535 [03:37<00:00,  2.45it/s, gp=0.000785, loss_D=-10.9]


=> Saving checkpoint
=> Saving checkpoint
Epoch [84/100]


100%|██████████████████████████████████████████████████████| 535/535 [03:34<00:00,  2.50it/s, gp=0.00661, loss_D=-6.95]


=> Saving checkpoint
=> Saving checkpoint
Epoch [85/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0307, loss_D=-7.72]


=> Saving checkpoint
=> Saving checkpoint
Epoch [86/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.129, loss_D=3.93]


=> Saving checkpoint
=> Saving checkpoint
Epoch [87/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0221, loss_D=18.1]


=> Saving checkpoint
=> Saving checkpoint
Epoch [88/100]


100%|█████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.221, loss_D=3.49]


=> Saving checkpoint
=> Saving checkpoint
Epoch [89/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.132, loss_D=-5.33]


=> Saving checkpoint
=> Saving checkpoint
Epoch [90/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.51it/s, gp=0.0488, loss_D=-3.71]


=> Saving checkpoint
=> Saving checkpoint
Epoch [91/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0289, loss_D=-14.5]


=> Saving checkpoint
=> Saving checkpoint
Epoch [92/100]


100%|████████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.114, loss_D=-2.24]


=> Saving checkpoint
=> Saving checkpoint
Epoch [93/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:32<00:00,  2.52it/s, gp=0.0363, loss_D=-14.3]


=> Saving checkpoint
=> Saving checkpoint
Epoch [94/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:37<00:00,  2.46it/s, gp=0.0835, loss_D=-17.2]


=> Saving checkpoint
=> Saving checkpoint
Epoch [95/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:37<00:00,  2.46it/s, gp=0.0181, loss_D=-4.98]


=> Saving checkpoint
=> Saving checkpoint
Epoch [96/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:37<00:00,  2.46it/s, gp=0.0548, loss_D=-5.49]


=> Saving checkpoint
=> Saving checkpoint
Epoch [97/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:37<00:00,  2.45it/s, gp=0.0127, loss_D=-5.34]


=> Saving checkpoint
=> Saving checkpoint
Epoch [98/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:37<00:00,  2.46it/s, gp=0.0377, loss_D=-11.8]


=> Saving checkpoint
=> Saving checkpoint
Epoch [99/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:38<00:00,  2.44it/s, gp=0.0717, loss_D=-12.3]


=> Saving checkpoint
=> Saving checkpoint
Epoch [100/100]


100%|███████████████████████████████████████████████████████| 535/535 [03:37<00:00,  2.46it/s, gp=0.0543, loss_D=-17.3]


=> Saving checkpoint
=> Saving checkpoint
Current image size: 512
Epoch [1/100]


  1%|▌                                                      | 12/1069 [00:11<16:29,  1.07it/s, gp=0.0849, loss_D=-3.83]


KeyboardInterrupt: 

In [None]:
def plot_learning_curve(d_loss, g_loss):
  fig = plt.figure()
  plt.plot(d_loss, color='skyblue')
  plt.plot(g_loss, color='gold')
  plt.title('Model Learning Curve')
  plt.xlabel('Epochs'); plt.ylabel('Loss')
  plt.show()

torch.save(netG.state_dict(), r'C:\Users\Unreal User\Kaspars\netG.pth')

noise = torch.randn(32, 100, 1, 1, device = device)
generated_samples = netG(noise)
img_plot = np.transpose(generated_samples.detach().cpu(), (0,2,3,1)) # .detach().cpu() is imp for copying fake_img tensor to host memory first
plot_images(img_plot)

model = Generator().to(device)
model.load_state_dict(torch.load(r'C:\Users\Unreal User\Kaspars\netG.pth'))
model.eval()
size = 32
noise = torch.randn(size, 100, 1, 1, device = device)
generated_samples = model(noise)
img_plot = np.transpose(generated_samples.detach().cpu(), (0,2,3,1)) # .detach().cpu() is imp for copying fake_img tensor to host memory first
plot_images(img_plot)
save_images_func(generated_samples, size)

In [12]:
def generate_examples(gen, steps, truncation=0.7, n=100):
    """
    Tried using truncation trick here but not sure it actually helped anything, you can
    remove it if you like and just sample from torch.randn
    """
    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.tensor(truncnorm.rvs(-truncation, truncation, size=(1, Z_DIM, 1, 1)), device=device, dtype=torch.float32)
            img = gen(noise, alpha, steps)
            save_image(img*0.5+0.5, f"img_{i}.png")
    gen.train()

Ká redzams, sákotnéjá mapé neeksistéja neviena bilde.

Programma strádá bez klúdám, tas kas ir redzams augsták ir trené'sanas bloks, kuru nácá spielietot KeyboardInterrupt.

Tagad tiks generéti 10 512x512 attéli ar melanomám

Atteli nav perfekti , jo tíklam nepieciesami papildus trenins, taçu tie ie pietiekami skaidri

Péc tam es parádísu sakotnéjos attelus lai varétu salídzinát. Sakotnéjie attéli ir 1024x1024, Tikls spéj 'generét 1024x1024 bet pagaidam nav uz to trenéts, jo yas aiz'nem 'loti daudz laika.

In [None]:
plot_learning_curve(d_loss_lst, g_loss_lst)

In [47]:
generate_examples(netG,7, n=10)

IndexError: index 7 is out of range

In [None]:
torch.cuda.empty_cache()