In [None]:
from pathlib import Path
import torch

ROOT = Path(r'path/to/checkpoint')
H5_PATH = r'path/to/if_fonts.h5'
FONT_DIR = Path(r'path/to/fonts')
BASE_FONT = (FONT_DIR / 'choose/any/one.ttf').as_posix()
CKPT_PATH = ROOT / 'ckpt/last.ckpt'
LOG_PATH = ROOT / 'tb_logs/version_0'
BATCH_SIZE = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import sys
import random

from PIL import ImageFont
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
from tqdm import trange as tqdm_range

sys.path.append('.')
from util import utils
utils.setup_seed(23)
from util import importtool
from data import cn, valid_characters
from data.datasets_h5 import IFFontDataset
from data.adapter import pil_to_tensor

In [None]:
config = [OmegaConf.load(c) for c in LOG_PATH.rglob('*.yaml')]
config = OmegaConf.merge(*config)
# parser.link_arguments
n_embd = config.model.init_args.gpt.init_args.config.init_args.n_embd
config.model.init_args.moco_wrapper.init_args.c_out = n_embd
config.model.init_args.ids_enc.init_args.n_embd = n_embd

model = importtool.instantiate_from_config_recursively(config.model)
model.init_from_ckpt(CKPT_PATH)
model = model.eval().to(DEVICE)
model.on_predict_start()
ids_enc = model.ids_encoder
counter = utils.counter(0)

In [None]:
def normalize(x:torch.Tensor, d_range=(-1, 1)):
  low, high = d_range
  x = torch.clamp(x, low, high)
  x = (x - low) / (high - low)
  return x


def sample(*input_tuple, use_tqdm=True, sample=False, top_k=None):
  g_idx = model.sample(
    *input_tuple,
    steps=input_tuple[1].shape[2],
    temperature=1.0,
    sample=sample,
    top_k=top_k,
    step_range=tqdm_range if use_tqdm else range,
  )
  g: torch.Tensor = model.adapter.decode_raw(g_idx)
  g = normalize(g, (-1, 1))
  return g


def infer_dataset(dataloader):
  x_idx, c_idx, *ch_info = model.get_data(next(iter(dataloader)))[:-2]
  x_idx, c_idx = x_idx.to(model.device), c_idx.to(model.device)
  x, c = model.adapter.decode_raw(x_idx), model.adapter.decode_raw(c_idx[:, 0])
  # g = sample(x_idx[:, :0], c_idx, ch_info, sample=True, top_k=100)  # 按概率采样
  g = sample(x_idx[:, :0], c_idx, *ch_info)  # 最大概率采样（确定性采样）
  imgs = torch.stack((x, c, g), dim=0)
  utils.draw_batch_images(*imgs)


def infer_create(font_list, ids_list=None, chs=None, ref_chs=None, size=128, save='temp', layout='row'):
  assert layout in ('row', 'column')
  assert (ids_enc.input_mode=='ch' and chs is not None) or (ids_enc.input_mode=='ids' and ids_list is not None)
  data_list = []
  ref_chs = list(ref_chs or valid_characters.train_ch)

  for i, font in enumerate(font_list):
    f = ImageFont.truetype(font, size=size)
    if chs is not None:
      ch = chs[i]
      x_ids = ch
      x_img = pil_to_tensor(utils.draw_single_char(ch, f, size))
      print('ch@', ch, [''.join(x) for x in ids_enc.query_ids(ch)])
  
    if ids_list is not None:
      x_ids = ids_list[i]
      x_img = utils.draw_text_img(x_ids, size=50, canvas_size=size, font=BASE_FONT)[0] if chs is None else x_img
      ch = (chs is not None and locals().get('ch')) or x_ids
      x_ids = (x if x in ids_enc.vocabulary_map else ids_enc.query_ids(x)[0] for x in x_ids)
      x_ids = utils.chain_sequence(*x_ids)
      print('ids@', ch, ''.join(x_ids))

    c_idx, c_ids = [], []
    random.shuffle(ref_chs)
    for ch in ref_chs:
      if len(c_idx) >= config.data.init_args.num_refs:
        break
      c_img = pil_to_tensor(utils.draw_single_char(ch, f, size))
      if c_img is None:
        print(f'character {ch} not in font {font} !')
        continue
      c_ids.append(ch)
      c_idx.append(model.adapter.encode(c_img))

    c_idx = torch.stack(c_idx, dim=0)
    data_list.append({
      'x': x_img,
      'c': c_img,
      'c_idx': c_idx,
      'x_ids': x_ids,
      'c_ids': c_ids,
    })
  d = IFFontDataset.collate(data_list)
  c_idx = d['c_idx']
  c = d.pop('c').to(model.device)
  x = d.pop('x').to(model.device)
  print(f'c_ids = {d["c_ids"]}')
  g = sample(c_idx[:, 0, :0], c_idx, d['x_ids'], d['c_ids'])
  imgs = torch.stack((x, c, g), dim=0)
  if save is not None:
    save = (ROOT / f'create_samples/{save}.png')
    save.parent.mkdir(exist_ok=True)
  if layout == 'column':
    imgs = imgs.permute(1, 0, 2, 3, 4)
  utils.draw_batch_images(*imgs, n=imgs.shape[1], save=save)

In [None]:
_IDX = 0  # 0-3: ('train', 'train') - ('train', 'val') - ('val', 'train') - ('val', 'val') (font - ch)
_SPLIT = ('train', 'val')
_INFER_MODE = 1

match _INFER_MODE:
  case 1:
    dataset = IFFontDataset(H5_PATH, _SPLIT[_IDX>>1&0b01], _SPLIT[_IDX&0b01], num_refs=config.data.init_args.num_refs)
    dataloader = DataLoader(
      dataset,
      batch_size=BATCH_SIZE,
      shuffle=True,
      collate_fn=IFFontDataset.collate,
      num_workers=0,
      pin_memory=True,
    )
    ids_enc.input_mode = 'ch'
    infer_dataset(dataloader)

  case 2:
    # ids_enc.input_mode = 'ch'
    # ch = '夏色祭茄茫晤'
    ids_enc.input_mode = 'ids'
    ref_chs=None
    ids_map = {
      '俣': '⿰亻吴', '辻': '⿺辶十', '萩': '⿱艹秋', '笹': '⿱𥫗世', '凧': '⿵𠘨巾', '粁': '⿰米千', 
      '込': '⿺辶入', '畑': '⿰火田', '凪': '⿵𠘨止', '雫': '⿱雨下', '丼': '⿴井丶', '畠': '⿱白田', 
    }
    ch = tuple(ids_map.keys())
    ids = tuple(ids_map.values())

    font = (BASE_FONT, )
    infer_create(font * len(ch or ids), ids_list=ids, chs=ch, ref_chs=ref_chs, save=str(counter()))
