In [5]:
import matplotlib.pyplot as plt 
import torch
import numpy as np 
import os

CKP_PATH_TPN = "models/TPN"

# get last checkpoint
epoch_list = os.listdir(CKP_PATH_TPN)
epoch_list_nums = [int(epoch.split("_")[0]) for epoch in epoch_list]
last_epoch = max(epoch_list_nums)
D_CKP_PATH = os.path.join(CKP_PATH_TPN, f"{last_epoch}_D.ckpt")
GD_CKP_PATH = os.path.join(CKP_PATH_TPN, f"{last_epoch}_G_decoder.ckpt")
GE_CKP_PATH = os.path.join(CKP_PATH_TPN, f"{last_epoch}_G_encoder.ckpt")

In [6]:
from skimage.color import lab2rgb, rgb2lab
import warnings

def lab2rgb_1d(in_lab, clip=True):
    warnings.filterwarnings("ignore")
    tmp_rgb = lab2rgb(in_lab[np.newaxis, np.newaxis, :], illuminant='D50').flatten()
    if clip:
        tmp_rgb = np.clip(tmp_rgb, 0, 1)
    return tmp_rgb

In [7]:
D_state_dic = torch.load(D_CKP_PATH)
GD_state_dict = torch.load(GD_CKP_PATH)
GE_state_dict = torch.load(GE_CKP_PATH)

  D_state_dic = torch.load(D_CKP_PATH)
  GD_state_dict = torch.load(GD_CKP_PATH)
  GE_state_dict = torch.load(GE_CKP_PATH)


In [8]:
from model import PCN, TPN

In [9]:
from solver import Solver

In [10]:
# args taken from a sample training
# args taken from a sample training
from munch import Munch

args = {
    'hidden_size': 150,
    'n_layers': 1,
    'always_give_global_hint': 1,
    'add_L': 1,
    'mode': 'train_TPN',
    'dataset': 'bird256',
    'lr': 0.0005,
    'num_epochs': 1000,
    'resume_epoch': 100,
    'batch_size': 32,
    'dropout_p': 0.2,
    'weight_decay': 5e-05,
    'beta1': 0.5,
    'beta2': 0.99,
    'lambda_sL1': 100.0,
    'lambda_KL': 0.5,
    'lambda_GAN': 0.1,
    'text2pal_dir': './models/TPN',
    'pal2color_dir': './models/PCN',
    'train_sample_dir': './samples/train',
    'test_sample_dir': './samples/test',
    'log_interval': 1,
    'sample_interval': 20,
    'save_interval': 50
}

args = Munch(args)

solver_obj = Solver(args=args)

Loading 10183 palette names...
Making text dictionary...
Using pre-trained word embeddings...


In [11]:
solver_obj.G.load_state_dict(GD_state_dict)
solver_obj.D.load_state_dict(D_state_dic)
solver_obj.encoder.load_state_dict(GE_state_dict)

# turn into eval mode
solver_obj.G.eval()
solver_obj.D.eval()
solver_obj.encoder.eval()

EncoderRNN(
  (embed): Embed(
    (embed): Embedding(4646, 300)
  )
  (gru): GRU(300, 150, dropout=0.2)
  (ca_net): CA_NET(
    (fc): Linear(in_features=150, out_features=300, bias=True)
    (relu): ReLU()
  )
)

In [12]:
NUM_GEN = 10
IMSIZE = 256
BATCH_SIZE = 1

In [13]:
solver_obj.input_dict.max_len

11

In [14]:
# create an input text embedding
TEXT_INP = "bakery in the morning"

temp = [0] * 300
for i, word in enumerate(TEXT_INP.split()):
    temp[i] = solver_obj.input_dict.word2index[word]

temp = torch.LongTensor([temp]).to("mps")

print(temp.shape)

# get the text embedding
hidden = solver_obj.encoder.init_hidden(BATCH_SIZE).to("mps")
encoder_outputs, decoder_hidden, mu, logvar = solver_obj.encoder(temp, hidden)
print(encoder_outputs.shape)
print(decoder_hidden.shape)
print(mu.shape)
print(logvar.shape)
print(encoder_outputs.shape)
print(decoder_hidden.squeeze(0).shape)
decoder_hidden.squeeze(0).size()


colors = torch.FloatTensor(1, 15).zero_().to("mps")

decoder_hidden = decoder_hidden.squeeze(0)

for i in range(5):

    palette = torch.FloatTensor(BATCH_SIZE, 3).zero_().to("mps")
    palette, decoder_context, decoder_hidden, _ = solver_obj.G(palette,
                        decoder_hidden,
                        encoder_outputs,
                        1,
                        i)
    colors[:, 3 * i:3 * (i + 1)] = palette

torch.Size([1, 300])
torch.Size([300, 1, 150])
torch.Size([1, 1, 150])
torch.Size([300, 1, 150])
torch.Size([300, 1, 150])
torch.Size([300, 1, 150])
torch.Size([1, 150])


In [15]:
colors

tensor([[ 65.2210, -24.8606,  29.9913,  64.9723, -13.5165,  19.5137,  54.9016,
          17.1296, -10.1052,  50.6721,  26.6768, -11.6823,  51.1452,  13.6597,
         -18.9174]], device='mps:0', grad_fn=<CopySlices>)

In [16]:
fig1, axs1 = plt.subplots(nrows=1, ncols=5)
axs1[0].set_title(TEXT_INP + '  fake')
x = 0

colors = colors.cpu()

for k in range(5):
    lab = np.array([colors.data[x][3 * k],
                    colors.data[x][3 * k + 1],
                    colors.data[x][3 * k + 2]], dtype='float64')
    rgb = lab2rgb_1d(lab)
    axs1[k].imshow([[rgb]])
    axs1[k].axis('off')

fig1.savefig('test_palette_eval.jpg')

In [25]:
max(solver_obj.input_dict.word2index.values())
# solver_obj.input_dict.word2index.keys()
solver_obj.input_dict.word2index['random']
solver_obj.input_dict.index2word[1]

'EOS'