In [6]:
import os
import argparse
import torch
import math
import logging
from dataset import MNISTLayout, JSONLayout
from model import GPT, GPTConfig

import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from utils import top_k_logits, sample

name = "magazine_1K"
path = '/home/weiran/Projects/RvNN-Layout/LayoutTrans/layout_transformer/logs/' + name
args = torch.load(path + '/conf.pth')
args.load_model = path + '/checkpoints/checkpoint.pth'
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

log_dir = os.path.join(args.log_dir, args.exp)
eval_dir = os.path.join(log_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)

train_dataset = JSONLayout(args.train_json)

mconf = GPTConfig(train_dataset.vocab_size, train_dataset.max_length,n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd)

model = GPT(mconf).to(device)

print(f"loading model from {args.load_model}")
model.load_state_dict(torch.load(args.load_model, map_location=device))

loading model from /home/weiran/Projects/RvNN-Layout/LayoutTrans/layout_transformer/logs/magazine_1K/checkpoints/checkpoint.pth


<All keys matched successfully>

In [7]:
from utils import trim_tokens
from PIL import Image, ImageDraw, ImageOps


def render(self, layout):
    img = Image.new('RGB', (256, 256), color=(255, 255, 255))
    draw = ImageDraw.Draw(img, 'RGBA')
    layout = layout.reshape(-1)
    layout = trim_tokens(layout, self.bos_token, self.eos_token, self.pad_token)
    layout = layout[: len(layout) // 5 * 5].reshape(-1, 5)
    box = layout[:, 1:].astype(np.float32)
    box[:, [0, 1]] = box[:, [0, 1]] / (self.size - 1) * 255
    box[:, [2, 3]] = box[:, [2, 3]] / self.size * 256
    box[:, [2, 3]] = box[:, [0, 1]] + box[:, [2, 3]]
    
    layoutText = []

    for i in range(len(layout)):
        x1, y1, x2, y2 = box[i]
        cat = layout[i][0]
        layoutText.append([cat - 255, x1 / 256, y1 / 256, x2 / 256, y2 / 256])
        col = self.colors[cat-self.size] if 0 <= cat-self.size < len(self.colors) else [0, 0, 0]
        draw.rectangle([x1, y1, x2, y2],
                        outline=tuple(col) + (200,),
                        fill=tuple(col) + (64,),
                        width=2)

    # Add border around image
    img = ImageOps.expand(img, border=2)
    return img, layoutText


def saveTxt(list, path):
    with open(path, 'w') as f:
        for item in list:
            f.write(" ".join(str(x) for x in item) + "\n")


def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
    """
    take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
    the sequence, feeding the predictions back into the model each time. Clearly the sampling
    has quadratic complexity unlike an RNN that is only linear, and has a finite context window
    of block_size, unlike an RNN that has an infinite context window.
    """
    block_size = model.module.get_block_size() if hasattr(model, "module") else model.get_block_size()
    model.eval()
    for k in range(steps):
        x_cond = x if x.size(1) <= block_size else x[:, -block_size:]  # crop context if needed
        logits, _ = model(x_cond)
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        x = torch.cat((x, ix), dim=1)

    return x

Layout reconstruction

In [8]:
def layout_reconstruction(start, stop):
    recon_dir = os.path.join(eval_dir, 'recon')
    os.makedirs(recon_dir, exist_ok=True)
    
    for i in range(start, stop):
        x = train_dataset[i][0]

        x_cond = x.to(device).unsqueeze(0)
        logits, _ = model(x_cond)

        probs = F.softmax(logits, dim=-1)
        _, y = torch.topk(probs, k=1, dim=-1)

        '''
        y[:, :, 0].size() : [1, 297]
        conat x_cond[:, :1] and y[:, :, 0] to add a bos token
        '''

        layout = torch.cat((x_cond[:, :1], y[:, :, 0]), dim=1).detach().cpu().numpy()
        layout, txt = render(train_dataset, layout)
        layout.save(os.path.join(recon_dir, f'{i}_PRED.png'))
        saveTxt(txt, os.path.join(recon_dir, f'{i}_PRED.txt'))
        
        layout = x_cond.detach().cpu().numpy()
        layout, txt = render(train_dataset, layout)
        layout.save(os.path.join(recon_dir, f'{i}_GT.png'))
        saveTxt(txt, os.path.join(recon_dir, f'{i}_GT.txt'))

layout_reconstruction(0, 300)

Layout generation

In [9]:
import os
import random as rand
from tqdm import tqdm

valid_dataset = JSONLayout(args.val_json, max_length=train_dataset.max_length)

def condition_generation(num, random=False, top_k=None):
    gen_dir = os.path.join(eval_dir, 'generation-test')
    os.makedirs(gen_dir, exist_ok=True)

    # 使用 tqdm 创建一个进度条
    for i in tqdm(range(num), desc='Generating'):
        number = rand.randint(0, len(valid_dataset) - 1)
        x = valid_dataset[number][0][:6].unsqueeze(0).to(device)
        layout = sample(model, x, steps=train_dataset.max_length,temperature=1.0, sample=random, top_k=top_k).detach().cpu().numpy()
        layout, txt = render(train_dataset, layout)
        layout.save(os.path.join(gen_dir, f'random_gen_{i}.png'))
        saveTxt(txt, os.path.join(gen_dir, f'random_gen_{i}.txt'))

condition_generation(300, random=True, top_k=5)

Generating: 100%|██████████| 300/300 [03:09<00:00,  1.58it/s]


In [10]:
from tqdm import tqdm

def sample_generation(num, random=False, top_k=None):
    gen_dir = os.path.join(eval_dir, 'generation')
    os.makedirs(gen_dir, exist_ok=True)
    
    for i in tqdm(range(num), desc='Generating'):
        x = torch.tensor([[261]]).to(device)
        layout = sample(model, x, steps=train_dataset.max_length,temperature=1.0, sample=random, top_k=top_k).detach().cpu().numpy()
        layout, txt = render(train_dataset, layout)
        layout.save(os.path.join(gen_dir, f'random_gen_{i}.png'))
        saveTxt(txt, os.path.join(gen_dir, f'random_gen_{i}.txt'))

sample_generation(300, random=True, top_k=5)

Generating: 100%|██████████| 300/300 [03:56<00:00,  1.27it/s]


Just for test

In [2]:
def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[...,[-1]]] = -float('Inf')
    return out

def sample_dec(model, x, x_tag, steps, temperature=1.0, sample=False, top_k=None):
    block_size = model.module.get_block_size() if hasattr(model, "module") else model.get_block_size()
    x_cond = x if x.size(1) <= block_size else x[:, -block_size:]  # crop context if needed
    model.eval()
    logits = model.decoder(x_cond)
    logits = logits[:,:5, :] / temperature
    
    for i in range(logits.shape[1]):
        l = logits[:, i, :]
        
        if top_k is not None:
            l = top_k_logits(l, top_k)
        
        probs = F.softmax(l, dim=-1)
        
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        
        print(ix)

    return ix

In [3]:
x_cond = train_dataset[0][0].unsqueeze(dim=0).to(device)

x = model.encoder(x_cond[:, :6])
# x[0][0]

random_tensor = torch.randn_like(x)

In [50]:
x_cond

tensor([[261, 260,  50,  17, 150, 106, 256,  17, 125, 217,  11, 256,  17, 140,
         104,  20, 256, 129, 140, 104,   7, 256, 129, 147, 104,  23, 256,  17,
         160, 104,  10, 257,  17, 177,   9,   3, 256,  17, 179, 217,   5, 259,
          17, 186, 217,  50, 262, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263

In [50]:
x_cond[:, :1]

tensor([[261]], device='cuda:0')

In [42]:
x = torch.tensor([[261]]).to(device)
layout = sample(model, x, steps=train_dataset.max_length,temperature=1.0, sample=True, top_k=5).detach().cpu().numpy()
layout

array([[261, 258,   3,   3,  36,  20, 258,  92,   3,  36,  15, 258, 131,
          3,  36,  15, 258,  45,  45,  36,  75, 258, 131,   3,  36,  75,
        258,  45,  75,  36,  88, 258,  45,  88,  32,  11, 258, 131,  75,
         36,  88, 258, 131,  75,  36,  32, 257, 131, 173,  36,  32, 258,
          3, 173,  75,  32, 258, 131, 173,  36,  32, 258, 131, 173,  36,
         20, 262, 143, 173,  41,  15, 262,  32, 203,  36,  11, 262,  75,
        203,  41,  32, 262,  11, 203,  36,  32, 257,   3, 211,  11, 262,
        211,  11]])

In [44]:
img, text = render(train_dataset, layout)
text

[[3, 0.01171875, 0.01171875, 0.15234375, 0.08984375],
 [3, 0.359375, 0.01171875, 0.5, 0.0703125],
 [3, 0.51171875, 0.01171875, 0.65234375, 0.0703125],
 [3, 0.17578125, 0.17578125, 0.31640625, 0.46875],
 [3, 0.51171875, 0.01171875, 0.65234375, 0.3046875],
 [3, 0.17578125, 0.29296875, 0.31640625, 0.63671875],
 [3, 0.17578125, 0.34375, 0.30078125, 0.38671875],
 [3, 0.51171875, 0.29296875, 0.65234375, 0.63671875],
 [3, 0.51171875, 0.29296875, 0.65234375, 0.41796875],
 [2, 0.51171875, 0.67578125, 0.65234375, 0.80078125],
 [3, 0.01171875, 0.67578125, 0.3046875, 0.80078125],
 [3, 0.51171875, 0.67578125, 0.65234375, 0.80078125],
 [3, 0.51171875, 0.67578125, 0.65234375, 0.75390625]]

In [46]:


saveTxt(text)

In [None]:
layouts = self.fixed_x.detach().cpu().numpy()
input_layouts = [self.train_dataset.render(layout) for layout in layouts]
for i, layout in enumerate(input_layouts):
    layout = self.train_dataset.render(layout)
    layout.save(os.path.join(self.config.samples_dir, f'input_{epoch:02d}_{i:02d}.png'))
