In [None]:
import torch
print(torch.__version__)

1.13.0+cu116


In [None]:
!git clone https://github.com/autonomousvision/stylegan_xl.git
!git clone https://github.com/openai/CLIP
!pip install -e ./CLIP
!pip install einops ninja
!pip install timm

Cloning into 'stylegan_xl'...
remote: Enumerating objects: 298, done.[K
remote: Counting objects: 100% (106/106), done.[K
remote: Compressing objects: 100% (53/53), done.[K
remote: Total 298 (delta 75), reused 80 (delta 53), pack-reused 192[K
Receiving objects: 100% (298/298), 13.89 MiB | 26.23 MiB/s, done.
Resolving deltas: 100% (134/134), done.
Cloning into 'CLIP'...
remote: Enumerating objects: 236, done.[K
remote: Total 236 (delta 0), reused 0 (delta 0), pack-reused 236[K
Receiving objects: 100% (236/236), 8.92 MiB | 20.29 MiB/s, done.
Resolving deltas: 100% (122/122), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/CLIP
Collecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[K     |████████████████████████████████| 53 kB 1.1 MB/s 
Installing collected packages: ftfy, clip
  Running setup.py develop for clip
Successfully installed clip-1.0 ftfy-6.1.1
Looking in indexes: https://pypi

In [None]:
import sys
sys.path.append('./CLIP')
sys.path.append('./stylegan_xl')

import numpy as np
import torch
import torch.nn as nn
import tensorflow as tf
import pickle
import os

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import PIL
from PIL import Image
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.optim as optim

import dnnlib
import legacy
from torch_utils import gen_utils

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using device:', device, file=sys.stderr)

Using device: cpu


In [None]:
# Load in the trained StyleganXL model
network_pkl = "https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl"
with dnnlib.util.open_url(network_pkl) as f:
        G = legacy.load_network_pkl(f)['G_ema']
        G = G.eval().requires_grad_(False).to(device)
c = None

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def get_latents(amount, seed, shape = 512):
    latents = torch.from_numpy(np.random.RandomState(seed).randn(amount, shape)).cuda()
    w = torch.empty(amount, 16, shape)
    for i in range(latents.shape[0]):
        w[i] = G.mapping(latents[i][None], c)
    return latents, w

def img_from_latent(net, latents, img_size, show_img = False, outdir = None):
    outputs = torch.empty((len(latents), 3, img_size, img_size))
    transform=transforms.Compose([transforms.Resize(img_size),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
    for i, latent in enumerate(latents):
        img = gen_utils.w_to_img(net, latent, to_np=True, noise_mode='none')
        img = Image.fromarray(img[0], 'RGB')
        img = img.resize((img_size,img_size), resample = PIL.Image.LANCZOS)
        if show_img:
            plt.axis('off')
            plt.title("Image "+str(i))
            plt.imshow(img)
            plt.show()
        if not outdir is None:
          img.save(outdir + "Image %d.png" % i)
        img = transform(img)
        outputs[i] = img
    return outputs

def train(netD, netG, num_iters, batch_size, img_size, seed, criterion, optimizer):
    losses = []
    for i in range(num_iters):
        w = gen_utils.get_w_from_seed(G, batch_size, device, seed = np.random.seed(seed))
        w = w.to(device) # naar float32?
        x = img_from_latent(netG, w, img_size).to(device)

        netD.zero_grad()

        output = netD(x).squeeze()

        errD = criterion(output, w[:,0])
        errD.backward()
        optimizer.step()
        
        if i % 5 == 0:
          print('[%d/%d]\tLoss_D: %.4f' % (i+1, num_iters, errD.item()))

        losses.append(errD.item())

    return losses

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels = 3, feature_maps = 64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(feature_maps * 8, 512, 4, 1, 0, bias=False),
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Set number of gpu's and device
img_size = 64

In [None]:
seed = 42
outdir = '/output/'
os.makedirs(outdir, exist_ok=True)
batch_sz = 32

w = gen_utils.get_w_from_seed(G, batch_sz, device, seed=np.random.seed(seed))
w = w.to(device)
print(w.shape)

outputs = img_from_latent(G, w, 1024, outdir=outdir)

In [None]:
netD_Adam_W = Discriminator().to(device)
# netD_Adam_W.apply(weights_init)
optimizerD_Adam_W = optim.Adam(netD_Adam_W.parameters(), lr=0.0002)
criterion = nn.MSELoss()

path = '/content/drive/MyDrive/Internship/training/netD_imagenet1024_w_1500.pt'
checkpoint = torch.load(path)
netD_Adam_W.load_state_dict(checkpoint['model_state_dict'])
optimizerD_Adam_W.load_state_dict(checkpoint['optimizer_state_dict'])
iter = checkpoint['iter']
loss = checkpoint['loss']

print(netD_Adam_W)
print(iter)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)
1500


In [None]:
# Number of training epochs
num_iters = 500

# Set batch size, for number of generated samples per epoch
batch_size = 32

# Set random seed
seed = 42

In [None]:
losses_adam_w = train(netD_Adam_W, G, num_iters, batch_size, img_size, seed, criterion, optimizerD_Adam_W)

[1/500]	Loss_D: 0.5566
[6/500]	Loss_D: 0.5636
[11/500]	Loss_D: 0.5556
[16/500]	Loss_D: 0.5481
[21/500]	Loss_D: 0.5791
[26/500]	Loss_D: 0.6156
[31/500]	Loss_D: 0.5576
[36/500]	Loss_D: 0.5675
[41/500]	Loss_D: 0.5482
[46/500]	Loss_D: 0.5513
[51/500]	Loss_D: 0.5685
[56/500]	Loss_D: 0.5385
[61/500]	Loss_D: 0.5868
[66/500]	Loss_D: 0.5408
[71/500]	Loss_D: 0.5770
[76/500]	Loss_D: 0.5937
[81/500]	Loss_D: 0.5649
[86/500]	Loss_D: 0.5549
[91/500]	Loss_D: 0.5919
[96/500]	Loss_D: 0.5868
[101/500]	Loss_D: 0.5919
[106/500]	Loss_D: 0.5526
[111/500]	Loss_D: 0.5535
[116/500]	Loss_D: 0.5763
[121/500]	Loss_D: 0.5994
[126/500]	Loss_D: 0.5638
[131/500]	Loss_D: 0.5669
[136/500]	Loss_D: 0.5441
[141/500]	Loss_D: 0.5793
[146/500]	Loss_D: 0.5786
[151/500]	Loss_D: 0.5626
[156/500]	Loss_D: 0.5519
[161/500]	Loss_D: 0.5812
[166/500]	Loss_D: 0.5533
[171/500]	Loss_D: 0.5917
[176/500]	Loss_D: 0.5594
[181/500]	Loss_D: 0.5690
[186/500]	Loss_D: 0.5603
[191/500]	Loss_D: 0.5909
[196/500]	Loss_D: 0.5905
[201/500]	Loss_D: 0.56

In [None]:
plt.figure(figsize=(10,5))
plt.title("Discriminator Loss During Training")
plt.plot(losses_adam_w,label="Loss")
plt.xlabel("Number of epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
iter = 2000
path = '/content/drive/MyDrive/Internship/training/netD_imagenet1024_w_2000.pt'
loss = losses_adam_w[-1]

torch.save({
    'iter': iter,
    'model_state_dict': netD_Adam_W.state_dict(),
    'optimizer_state_dict': optimizerD_Adam_W.state_dict(),
    'loss': loss,
}, path)