<a href="https://colab.research.google.com/github/asdsadadad/BasicSR/blob/master/Sprite_Generator_v0_6_Breadsticks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

## Install and load requirements
Green is required, red is only if training a new model. Don't run the entire thing without checking the cells inside.

### <font color='#5f2'>Install an optimizer and a requirement for torchvision to save videos</font>

In [None]:
!pip -qq install torch_optimizer
!pip -qq install av

### <font color='#5f2'>Imports</font>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch_optimizer as optim
import PIL
from PIL import Image
import random
import math
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

def clear_mem():
    torch.cuda.empty_cache()
    gc.collect()

ToTensor = T.ToTensor()
ToImage  = T.ToPILImage()

def OpenImage(x, resize=None, convert="RGB"):
    if resize:
        return ToTensor(Image.open(x).convert(convert).resize(resize)).unsqueeze(0).to(device)
    else:
        return ToTensor(Image.open(x).convert(convert)).unsqueeze(0).to(device)

In [None]:
def diff_abs(x, y=0.0001):
    return torch.sqrt(x*x+y)

def diff_relu(x, y=0.0001):
    return (torch.sqrt(x*x+y)+x)*0.5

def diff_clamp(x, y=0.0001):
    return diff_relu(1-diff_relu(1-x, y), y)

### <font color='#f25'>[Ignore if not training new model]</font> Load data for training

In [None]:
#!gdown --id 1in3e4os3xMsGvvmBzhYXoNn6fStIceFk
!7z e /content/gdrive/MyDrive/pokemon_sprites_with_sw99_gray_final.7z -opokemon_sprites/all/
!7z e /content/gdrive/MyDrive/pokemon_fansprites_unified.7z -opokemon_sprites/fan/
# !mkdir pokemon_sprites_gray_alpha/all
# !mv pokemon_sprites_gray_alpha/*.png pokemon_sprites_gray_alpha/all/

In [None]:
pokemon_sprites = torch.load("/content/gdrive/MyDrive/pokemon_sprites_and_swdemo_fan_gray.pt").mul(3).round().div(3).to(device)

In [None]:
import glob

pokemon_sprites = []
for img_path in sorted(glob.iglob(r'/content/pokemon_sprites/*/*.png')):
  img_in = OpenImage(img_path, convert="L")
  pokemon_sprites.append(img_in)
pokemon_sprites = torch.cat(pokemon_sprites, 0)
print(pokemon_sprites.shape)
torch.save(pokemon_sprites, 'pokemon_sprites_and_swdemo_fan_gray.pt')
torch.save(pokemon_sprites, '/content/gdrive/MyDrive/pokemon_sprites_and_swdemo_fan_gray.pt')

### <font color='#5f2'>Define model</font>

In [None]:
# Save the model's state dict after training if you want
# torch.save(model.state_dict(), "/content/gdrive/MyDrive/sprite_generator_v06_7.pt")

In [None]:
class PixelNet(torch.nn.Module):
    def __init__(self):
        super(PixelNet, self).__init__()

        self.enc_layers_0 = nn.Sequential(
                              nn.Conv2d( 1,  8, 2, 2, 0),
                              nn.Conv2d( 8, 16, 2, 2, 0),
                              nn.Conv2d(16, 32, 2, 2, 0),
                        )

        self.batchnorm = nn.InstanceNorm2d(32)

        self.lin_layers = nn.Sequential(
                              nn.Conv2d(32, 64, 3, 1, 1),
                              nn.BatchNorm2d(64),
                              nn.GELU(),
                              nn.Conv2d(64, 64, 3, 1, 1),
                              nn.BatchNorm2d(64),
                              nn.GELU(),
                              nn.Conv2d(64, 32, 3, 1, 1),
                              nn.BatchNorm2d(32),
                          )

        self.cnn_layers = nn.Sequential(
                              nn.Conv2d(32, 32, 3, 1, 1),      
                              nn.ConvTranspose2d(32, 16, 2, 2),
                              nn.Conv2d(16, 32, 3, 1, 1),      
                              nn.BatchNorm2d(32),              
                              nn.GELU(),                       
                              nn.ConvTranspose2d(32, 16, 2, 2),
                              nn.Conv2d(16, 32, 3, 1, 1),      
                              nn.BatchNorm2d(32),              
                              nn.GELU(),                       
                              nn.ConvTranspose2d(32, 16, 2, 2),
                              nn.Conv2d(16, 4, 3, 1, 1),       
                          )

        self.palette = torch.tensor([[0.0, 0.333, 0.667, 1.0]]).T.to(device)

    def encode(self, x):
        x = self.enc_layers_0(x)
        return x

    def decode(self, x, y=25, batchnorm=True, dropout_p=0):
        if dropout_p == 0:
            dropout = nn.Identity()
        else:
            dropout = nn.AlphaDropout(dropout_p)
        if batchnorm == True:
            x = self.batchnorm(x)
        # x = self.lin_layers(x)
        # x = self.cnn_layers(x)
        x = self.lin_layers[:3](x)
        x = dropout(x)
        x = self.lin_layers[3:6](x)
        x = self.lin_layers[6:](x)
        x = dropout(x)
        x = self.cnn_layers[0](x)
        x = dropout(x)
        x = self.cnn_layers[1:](x)
        x = (x.permute(0,2,3,1).mul(y).softmax(-1) @ self.palette).permute(0,3,1,2)
        return x

    def forward(self, x, y=25):
        x = self.encode(x)
        x = self.decode(x, y)
        return x

model = PixelNet().to(device)
# model(torch.rand(8,1,56,56).to(device)).shape

### <font color='#5f2'>Load the state dict of pretrained modelImports</font>
Change the path to whatever it's saved to, if you have one.

In [None]:
model.load_state_dict(torch.load("/content/gdrive/MyDrive/sprite_generator_v06_6.pt"))

<All keys matched successfully>

### <font color='#f25'>[Ignore if not training new model]</font> Training the model

In [None]:
optimizer = optim.Yogi(model.parameters(), lr=0.01, eps=1e-3)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, verbose=True)

In [None]:
torch.set_grad_enabled(True)

laplacian_kernel = torch.tensor([
    [.05, .2, .05],
    [ .2, -1, .2 ],
    [.05, .2, .05]
]).unsqueeze(0).unsqueeze(0).to(device)

def laplacian_filter(x):
    b, c, h, w = x.shape
    x = x.reshape(b*c, 1, h, w)
    x = F.pad(x, [2,2,2,2], mode='reflect')
    x = F.conv2d(x, laplacian_kernel)
    x = TF.pad(x, -1)
    b2, c2, h2, w2 = x.shape
    x = x.reshape(b, c, h2, w2)
    return x

def downscale_for_loss(x, downscale, interp=True, blur=True):
    b, c, h, w = x.shape
    h = h // downscale
    w = w // downscale
    if interp == True:
        if blur == True:
            kernel_size = (downscale//2)*2+1
            x = TF.gaussian_blur(x, kernel_size, downscale/math.pi)
        x = TF.resize(x, (h,w), T.InterpolationMode.BILINEAR)
    if interp == False:
        x = TF.resize(x, (h,w), T.InterpolationMode.NEAREST)
    return x

def unfold_tiles(x, size, step):
    b, c, h, w = x.shape
    x = x.unfold(-2,size,step).unfold(-2,size,step).reshape(b,-1,size*size)
    return x

dataset_aug = T.Compose([
    T.RandomHorizontalFlip(0.5),
])

batch_size = 64

for i in range(5001):
    total_loss = 0.0
    total_loss_steps = 0
    random_selection = torch.randperm(pokemon_sprites.shape[0])
    for batch in range(pokemon_sprites.shape[0]//batch_size):
        batch_scaled = batch*batch_size
        x = torch.cat([dataset_aug(pokemon_sprites[random_selection[batch_scaled + b]].unsqueeze(0)) for b in range(batch_size)])

        reconstructions = model.encode(x)

        # reconstructions = reconstructions + torch.randn_like(reconstructions).tanh() * torch.randn((reconstructions.shape[0],1,1,1)).mul(2).pow(2).tanh().mul(0.08).add(0.02).to(device)

        reconstructions = model.decode(reconstructions, 25, dropout_p=0)

        # === Loss Calculations === Begin
        recon_pad = TF.pad(reconstructions, 4, padding_mode='reflect')
        x_pad = TF.pad(x, 4, padding_mode='reflect')
        loss  = unfold_tiles(recon_pad - x_pad, 8, 8).mul(2).pow(2).mean(-1).sum() * 1.0

        loss += unfold_tiles(laplacian_filter(recon_pad - x_pad), 8, 8).mul(2).pow(2).mean(-1).sum() * 1.0

        loss += unfold_tiles(recon_pad - torch.round(recon_pad * 3).div(3), 8, 8).mul(2).pow(2).mean(-1).sum() * 0.25
        loss += unfold_tiles(recon_pad - torch.round(recon_pad * 3).div(3), 8, 8).pow(2).add(1e-8).pow(0.5).mean(-1).sum() * 0.25

        loss += unfold_tiles(downscale_for_loss(recon_pad, 2) - downscale_for_loss(x_pad, 2), 4, 4).mul(2).pow(2).mean(-1).sum() * 0.125
        loss += unfold_tiles(downscale_for_loss(recon_pad, 4) - downscale_for_loss(x_pad, 4), 4, 4).mul(2).pow(2).mean(-1).sum() * 0.125
        loss += unfold_tiles(downscale_for_loss(recon_pad, 8) - downscale_for_loss(x_pad, 8), 2, 2).mul(2).pow(2).mean(-1).sum() * 0.25

        # === Loss Calculations === End

        with torch.no_grad():
            total_loss += loss.item()
            total_loss_steps += 1
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    with torch.no_grad():
        if i % 50 == 0:
            print(i, total_loss / total_loss_steps)
            previews = torch.cat([
                              torch.cat([reconstructions.clamp(0,1)[ 0+r]   for r in range(3)], -1),
                              torch.cat([reconstructions.clamp(0,1)[ 0+r+3] for r in range(3)], -1),
                              torch.cat([reconstructions.clamp(0,1)[32+r+6] for r in range(3)], -1)
                              ], -2)
            display(ToImage(previews[0]).resize((336,336),0))

In [None]:
#@markdown Old training code
torch.set_grad_enabled(True)

def downscale_for_loss(x, downscale, interp=True):
    b, c, h, w = x.shape
    h = h // downscale
    w = w // downscale
    if interp == True:
        kernel_size = (downscale//2)*2+1
        x = TF.gaussian_blur(x, kernel_size, downscale/math.pi)
        x = TF.resize(x, (h,w), T.InterpolationMode.BILINEAR)
    if interp == False:
        x = TF.resize(x, (h,w), T.InterpolationMode.NEAREST)
    return x

def unfold_tiles(x, size, step):
    b, c, h, w = x.shape
    x = x.unfold(-2,size,step).unfold(-2,size,step).reshape(b,-1,size*size)
    return x

dataset_aug = T.Compose([
    T.RandomHorizontalFlip(0.5),
])

batch_size = 64

for i in range(5001):
    total_loss = 0.0
    total_loss_steps = 0
    # for x, y in train_loader:
    random_selection = torch.randperm(pokemon_sprites.shape[0])
    for batch in range(pokemon_sprites.shape[0]//batch_size):
        batch_scaled = batch*batch_size
        x = torch.cat([dataset_aug(pokemon_sprites[random_selection[batch_scaled + b]].unsqueeze(0)) for b in range(batch_size)])
        # x = pokemon_sprites[random_selection[batch_scaled:batch_scaled+batch_size]]
        reconstructions = model.encode(x)

        # === Encoded Latents Shuffling and Scrambling === Begin
        reconstructions = reconstructions.reshape(2,2,16,32,49)

        chunk_dice_length = 8
        chunk_dice = torch.randperm(49)[:chunk_dice_length]

        reconstructions_bak = reconstructions
        reconstructions[0,0,:,:,chunk_dice] = reconstructions[0,1,:,:,chunk_dice]
        reconstructions[0,1,:,:,chunk_dice] = reconstructions_bak[0,0,:,:,chunk_dice]

        slice_dice_length = 8
        slice_dice = torch.randperm(32)[:slice_dice_length]

        reconstructions_bak = reconstructions
        reconstructions[0,0,:,slice_dice] = reconstructions[0,1,:,slice_dice]
        reconstructions[0,1,:,slice_dice] = reconstructions_bak[0,0,:,slice_dice]

        reconstructions = reconstructions.reshape(64,32,7,7)
        # === Encoded Latents Shuffling and Scrambling === End

        reconstructions = reconstructions + torch.randn_like(reconstructions).tanh() * torch.randn((reconstructions.shape[0],1,1,1)).mul(2).pow(2).tanh().mul(0.08).add(0.02).to(device)

        reconstructions = model.decode(reconstructions, 25)

        # === Loss Calculations === Begin
        recon_pad = TF.pad(reconstructions, 4, padding_mode='reflect')
        x_pad = TF.pad(x, 4, padding_mode='reflect')
        loss  = unfold_tiles(recon_pad[32:] - x_pad[32:], 8, 8).mul(2).pow(2).mean(-1).sum() * 1.0

        loss += unfold_tiles(laplacian_filter(recon_pad[32:] - x_pad[32:]), 8, 8).mul(2).pow(2).mean(-1).sum() * 1.0

        loss += unfold_tiles(reconstructions - torch.round(reconstructions * 3).div(3), 8, 4).mul(2).pow(2).mean(-1).sum() * 0.25

        loss += unfold_tiles(downscale_for_loss(recon_pad[:32], 2) - downscale_for_loss(x_pad[:32], 2), 4, 2).pow(2).add(1e-8).pow(0.50).mean(-1).sum() * 0.25
        loss += unfold_tiles(downscale_for_loss(recon_pad[:32], 4) - downscale_for_loss(x_pad[:32], 4), 4, 2).pow(2).add(1e-8).pow(0.25).mean(-1).sum() * 0.25
        loss += unfold_tiles(downscale_for_loss(recon_pad[:32], 4) - downscale_for_loss(x_pad[:32], 4), 4, 2).pow(2).add(1e-8).pow(0.50).mean(-1).sum() * 0.25
        loss += unfold_tiles(downscale_for_loss(recon_pad[:32], 8) - downscale_for_loss(x_pad[:32], 8), 2, 2).pow(2).add(1e-8).pow(1.00).mean(-1).sum() * 0.50

        # === Loss Calculations === End

        with torch.no_grad():
            total_loss += loss.item()
            total_loss_steps += 1
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    with torch.no_grad():
        if i % 50 == 0:
            print(i, total_loss / total_loss_steps)
            previews = torch.cat([
                              torch.cat([reconstructions.clamp(0,1)[ 0+r]   for r in range(3)], -1),
                              torch.cat([reconstructions.clamp(0,1)[ 0+r+3] for r in range(3)], -1),
                              torch.cat([reconstructions.clamp(0,1)[32+r+6] for r in range(3)], -1)
                              ], -2)
            display(ToImage(previews[0]).resize((336,336),0))

# CLIP guided generation

### Throw the model into eval() and install CLIP
And load CLIP

In [None]:
model.eval();

In [None]:
!pip install --no-deps git+https://github.com/openai/CLIP.git
!pip install --no-deps ftfy regex tqdm

import clip
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

perceptor, preprocess = clip.load("ViT-B/32", jit=False)
perceptor.eval().float().requires_grad_(False);

CLIP_Normalization = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

### Set the initial image

In [None]:
# Vignettes for init images and masking noise
vig = torch.cat(torch.meshgrid(2*[torch.linspace(-1,1,64)])).reshape(2,64,64).pow(2).sum(0,True).mul(2).sub(0).clamp(0,1).to(device)
vig = TF.crop(vig, 0, 4, 56, 56).unsqueeze(0)
vig_7x7 = torch.cat(torch.meshgrid(2*[torch.linspace(-1,1,9)])).reshape(2,9,9).pow(2).sum(0,True).mul(3).sub(1).clamp(0,1).to(device)
vig_7x7 = 1 - TF.crop(vig_7x7, 0, 1, 7, 7)

blank_image = model.encode(vig - torch.rand_like(vig) * 0.333)
blank_image = model.batchnorm(blank_image) # The encoder's values are huge due to the batchnorm being in the decoder section, but for easier optimization batchnorm it first.
clip_root = blank_image.clone().detach() + torch.randn((1,32,7,7)).to(device).mul(vig_7x7) * 0.25
clip_root = clip_root.requires_grad_(True)
clip_optimizer = optim.Yogi([clip_root], lr=1/8, weight_decay=0.00)

video_frames = None

### Set the prompt

In [None]:
prompt = perceptor.encode_text(clip.tokenize(["Floral, the florist Pokemon"]).to(device)).mean(0,True)

Or use an image as a prompt instead

In [None]:
prompt_img = OpenImage("flickr_dog_000137.jpg", (224,224)).tile(8,1,1,1)
prompt_img += torch.randn_like(prompt_img) * 0.025
prompt = perceptor.encode_image(CLIP_Normalization(prompt_img)).mean(0,True)

# Generate images with CLIP guiding the model

In [None]:
torch.set_grad_enabled(True)
model.eval()

vig = torch.cat(torch.meshgrid(2*[torch.linspace(-1,1,64)])).reshape(2,64,64).pow(2).sum(0,True).mul(3).sub(1).clamp(0,1).pow(2).to(device)
vig = TF.crop(vig, 0, 4, 56, 56)

laplacian_kernel = torch.tensor([
    [.05, .2, .05],
    [ .2, -1, .2 ],
    [.05, .2, .05]
]).unsqueeze(0).unsqueeze(0).to(device)

def laplacian_filter(x):
    b, c, h, w = x.shape
    x = x.reshape(b*c, 1, h, w)
    x = F.pad(x, [2,2,2,2], mode='reflect')
    x = F.conv2d(x, laplacian_kernel)
    x = TF.pad(x, -1)
    b2, c2, h2, w2 = x.shape
    x = x.reshape(b, c, h2, w2)
    return x

def white_background_check(x):
    x = 1 - x
    x = x * vig
    x_mean = x.mean()
    x_std = x.std()
    return x_mean + x_std

def sneaky_round(x):
    new_x = x
    with torch.no_grad():
      new_x[:] = x.mul(3).round().div(3)
    return new_x

augments = T.Compose([
        T.RandomChoice([
            T.Resize((224,224), T.InterpolationMode.NEAREST),
            T.Resize((224,224), T.InterpolationMode.BILINEAR),
            T.Resize((240,224), T.InterpolationMode.NEAREST),
            T.Resize((240,224), T.InterpolationMode.BILINEAR),
            T.Resize((256,224), T.InterpolationMode.NEAREST),
            T.Resize((256,224), T.InterpolationMode.BILINEAR),
            T.Resize((224,240), T.InterpolationMode.NEAREST),
            T.Resize((224,240), T.InterpolationMode.BILINEAR),
            T.Resize((224,256), T.InterpolationMode.NEAREST),
            T.Resize((224,256), T.InterpolationMode.BILINEAR),
        ]),
        T.Pad(64, fill=1.0),
        T.RandomRotation(15, T.InterpolationMode.BILINEAR),
        T.Pad(-56),
        T.RandomChoice([
            T.Lambda(lambda x: TF.gaussian_blur(x, 7, 2.55)),
            T.Lambda(lambda x: TF.gaussian_blur(x, 5, 1.27)),
            T.Lambda(lambda x: TF.gaussian_blur(x, 3, 0.64)),
            T.Lambda(lambda x: x),
        ]),
        T.Lambda(lambda x: x + torch.rand(1,).item() * 0.2 - 0.1),
        T.Lambda(lambda x: x * (torch.rand(1,).item() * 0.3 + 0.85)),
        T.RandomCrop((224,224)),
        T.Lambda(lambda x: x + torch.randn_like(x).mul(0.02)),
        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

gentle_augments = T.Compose([
        T.RandomChoice([
            T.Resize((224,224), T.InterpolationMode.NEAREST),
            T.Resize((224,224), T.InterpolationMode.BILINEAR),
        ]),
        T.Pad(16, fill=1.0),
        T.RandomRotation(2, T.InterpolationMode.BILINEAR),
        T.Pad(-12),
        T.RandomChoice([
            T.Lambda(lambda x: TF.gaussian_blur(x, 7, 2.55)),
            T.Lambda(lambda x: TF.gaussian_blur(x, 5, 1.27)),
            T.Lambda(lambda x: TF.gaussian_blur(x, 3, 0.64)),
            T.Lambda(lambda x: x),
        ]),
        T.Lambda(lambda x: x + torch.rand(1,).item() * 0.2 - 0.1),
        T.Lambda(lambda x: x * (torch.rand(1,).item() * 0.3 + 0.85)),
        T.RandomCrop((224,224)),
        T.Lambda(lambda x: x + torch.randn_like(x).mul(0.02)),
        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

for i in range(500):
    x = model.decode(clip_root)
    x_color = x.tile(1,3,1,1)

    with torch.no_grad(): #This is just preparing frames for the video. It stores them in ram, not gpu memory, so it shouldn't impact much.
        x_color_big = TF.resize(x_color, (112,112), T.InterpolationMode.NEAREST)
        x_color_big = F.interpolate(x_color_big, (224,224), mode='bicubic', align_corners=False)
        if video_frames == None:
            video_frames = (x_color_big.permute(0,2,3,1).clamp(0,1)*255).byte().cpu()
        else:
            video_frames = torch.cat([video_frames, (x_color_big.permute(0,2,3,1).clamp(0,1)*255).byte().cpu()])

    x_aug = torch.cat([augments(x_color) for _ in range(16)])
    x_enc = perceptor.encode_image(x_aug)
    comparisons = 1.0 - torch.cosine_similarity(prompt, x_enc, -1)
    loss  = torch.mean(comparisons)
    loss += (x - torch.round(x*3).div(3)).pow(2).mean() * 0.25
    loss += white_background_check(x).mul(2).pow(2) * 0.25
    loss += laplacian_filter(x).mean().pow(4)
    with torch.no_grad():
        loss.backward()
        # clip_root.grad = TF.gaussian_blur(clip_root.grad, 3) * 0.25 + clip_root.grad * 0.75
        clip_root.grad /= clip_root.grad.norm().add(1e-8)
        clip_optimizer.step()
        clip_optimizer.zero_grad()
        if i % 100 == 0:
            print(i, loss.item())
            display(ToImage(x.clamp(0,1)[0]).resize((224,224),0))

x = model.decode(clip_root)
display(ToImage(x.clamp(0,1)[0]).resize((224,224),0))

Write video of generation process

In [None]:
torchvision.io.write_video("lr_100.mp4", video_frames, fps=15, options={'crf': '30'})