In [7]:
import os
import argparse
import torch
import math
import logging
from dataset import MNISTLayout, JSONLayout
from model import GPT, GPTConfig

import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from utils import top_k_logits, sample

path = '/home/weiran/Projects/RvNN-Layout/LayoutTrans/layout_transformer/logs/magazine_0.3K'
args = torch.load(path + '/conf.pth')
args.load_model = path + '/checkpoints/checkpoint.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

log_dir = os.path.join(args.log_dir, args.exp)
eval_dir = os.path.join(log_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)

train_dataset = JSONLayout(args.train_json)

mconf = GPTConfig(train_dataset.vocab_size, train_dataset.max_length,n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd)

model = GPT(mconf).to(device)

print(f"loading model from {args.load_model}")
model.load_state_dict(torch.load(args.load_model, map_location=device))

loading model from /home/weiran/Projects/RvNN-Layout/LayoutTrans/layout_transformer/logs/magazine_0.3K/checkpoints/checkpoint.pth


<All keys matched successfully>

In [6]:
def layout_reconstruction(start, stop):
    for i in range(start, stop):
        x = train_dataset[i][0]

        x_cond = x.to(device).unsqueeze(0)
        logits, _, _ = model(x_cond)

        probs = F.softmax(logits, dim=-1)
        _, y = torch.topk(probs, k=1, dim=-1)

        '''
        y[:, :, 0].size() : [1, 297]
        conat x_cond[:, :1] and y[:, :, 0] to add a bos token
        '''

        layout = torch.cat((x_cond[:, :1], y[:, :, 0]), dim=1).detach().cpu().numpy()
        layout = train_dataset.render(layout)
        layout.save(os.path.join(eval_dir, f'recon_{i:02d}.png'))
        
        layout = x_cond.detach().cpu().numpy()
        layout = train_dataset.render(layout)
        layout.save(os.path.join(eval_dir, f'input_{i:02d}.png'))

In [8]:
train_dataset = JSONLayout(args.train_json)
layout_reconstruction(0, 10)

In [2]:
def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[...,[-1]]] = -float('Inf')
    return out

def sample_dec(model, x, x_tag, steps, temperature=1.0, sample=False, top_k=None):
    block_size = model.module.get_block_size() if hasattr(model, "module") else model.get_block_size()
    x_cond = x if x.size(1) <= block_size else x[:, -block_size:]  # crop context if needed
    model.eval()
    logits = model.decoder(x_cond)
    logits = logits[:,:5, :] / temperature
    
    for i in range(logits.shape[1]):
        l = logits[:, i, :]
        
        if top_k is not None:
            l = top_k_logits(l, top_k)
        
        probs = F.softmax(l, dim=-1)
        
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        
        print(ix)

    return ix

In [3]:
x_cond = train_dataset[0][0].unsqueeze(dim=0).to(device)

x = model.encoder(x_cond[:, :6])
# x[0][0]

random_tensor = torch.randn_like(x)

In [50]:
x_cond

tensor([[261, 260,  50,  17, 150, 106, 256,  17, 125, 217,  11, 256,  17, 140,
         104,  20, 256, 129, 140, 104,   7, 256, 129, 147, 104,  23, 256,  17,
         160, 104,  10, 257,  17, 177,   9,   3, 256,  17, 179, 217,   5, 259,
          17, 186, 217,  50, 262, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263, 263,
         263, 263, 263, 263, 263, 263, 263, 263, 263

In [50]:
x_cond[:, :1]

tensor([[261]], device='cuda:0')

In [4]:
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
    """
    take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
    the sequence, feeding the predictions back into the model each time. Clearly the sampling
    has quadratic complexity unlike an RNN that is only linear, and has a finite context window
    of block_size, unlike an RNN that has an infinite context window.
    """
    block_size = model.module.get_block_size() if hasattr(model, "module") else model.get_block_size()
    model.eval()
    for k in range(steps):
        x_cond = x if x.size(1) <= block_size else x[:, -block_size:]  # crop context if needed
        logits, _ = model(x_cond)
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        x = torch.cat((x, ix), dim=1)

    return x

def sample_generation(num, random=False, top_k=None):
    for i in range(num):
        x = torch.tensor([[261]]).to(device)
        layout = sample(model, x, steps=train_dataset.max_length,temperature=1.0, sample=random, top_k=top_k).detach().cpu().numpy()
        layout = train_dataset.render(layout)
        layout.save(os.path.join(eval_dir, f'rendom_gen_{i}.png'))

In [5]:
sample_generation(50, random=True, top_k=5)

In [70]:
x = torch.tensor([[261]]).to(device)
layout = sample(model, x, steps=train_dataset.max_length,temperature=1.0, sample=False, top_k=None).detach().cpu().numpy()
layout

array([[261, 257,   3,   3, 190, 254, 258,  20,  20,  41,  41, 258,  67,
         20,  36,  15, 258,  20,  67,  41,  24, 258,  20,  67,  41,  54,
        258,  20, 118,  41,  66, 262,  49,  67,  36,  24, 262,  37, 118,
         41,  28, 258,  20, 160,  41,  28, 262, 113, 186,  41,  58, 258,
         20, 211,  41,  28, 262,  79, 224,  41,  28, 262,  37, 224,  41,
         24, 262,  24, 224,  19,  24, 262, 203, 224,  32,  24, 262,  62,
        224,  32,  15, 262,  24, 224,  41, 224, 262,  32, 224,  32,  24,
        262, 224, 224, 100,  37, 262, 224]])