In [1]:
# Run below code to set up the environment
import os
repo_path = 'Handwriting-Transformers'
if not os.path.exists(repo_path) and os.getcwd().split('/')[-1] != repo_path:
    !git clone https://github.com/ankanbhunia/Handwriting-Transformers
if os.getcwd().split('/')[-1] != repo_path:
    %cd Handwriting-Transformers
if not os.path.exists('files'): # Get the model and data files
    %pip install --upgrade --no-cache-dir gdown
    !gdown --id 16g9zgysQnWk7-353_tMig92KsZsrcM6k
    !unzip files.zip && rm files.zip

/home/taha_adeel/Desktop/Sem-7/Deep Learning/Project/Handwriting-Transformers


In [2]:
# Import the libraries
import os
import time
from data.dataset import TextDataset, TextDatasetval
import torch
import cv2
import os
import numpy as np
from models.model import TRGAN
from params import *
from torch import nn
from data.dataset import get_transform
import pickle
from PIL import Image
import tqdm
import shutil

In [3]:
# Load the IAM dataset
num_writers = 32 # Number of writers to use
dataset_path = '../DataSet/IAM/'
iam_data_path = 'files/IAM-32.pickle'
model_path = 'files/iam_model.pth'

os.makedirs(dataset_path, exist_ok=True) # Create the output folder

# Unpickle the IAM dataset
with open(iam_data_path, 'rb') as f:
    data = pickle.load(f)
    train_data = data['test']
    for writer_id, i in zip(train_data, range(num_writers)):
        train_data[writer_id] = train_data[writer_id][:20]
        os.makedirs(os.path.join(dataset_path, str(writer_id)), exist_ok=True)
        for word_id, word in enumerate(train_data[writer_id]):
            word['img'].save(os.path.join(dataset_path, str(writer_id), str(word_id)+'.jpg'))

In [4]:
# Demo variables
writer_img_paths = os.listdir(dataset_path)
# writer_img_paths = ['../dhruv', '../internet']
text = 'Does htis work properly? Maybe something longer so that we can see how it works. Ask it more.'

num_examples = 15 # Number of words from style writer to be used
batch_size = 8 # Number of results per page (Change in params.py also)
output_path = '../demo_output/'

os.makedirs(output_path, exist_ok=True)

In [5]:
def preprocess_writer_images(image_path):
    '''Opens folder of images and returns a list of word images resized to (Wx32)'''
    image_list = []
    img_ht = 32
    for image_name in os.listdir(image_path):
        image = cv2.imread(os.path.join(image_path, image_name))
        image = cv2.resize(image, (image.shape[1]*img_ht//image.shape[0], img_ht))
        image_list.append((Image.fromarray(image), image_name))
    return image_list

def get_word_images(paragraph_img_path):
    '''Opens a paragraph image and returns a list of word images'''
    # TODO
    pass

def post_process(result_img, threshold=0.8):
    '''Remove grey background from words'''
    result_img[result_img>threshold] = 1
    return result_img

In [6]:
# Create the pickle file for input handwriting style

# Create test dictionary
test_dataset = {'test': {}}
for writer_id, writer_img_path in enumerate(writer_img_paths):
    test_dataset['test'][writer_id] = []
    for word_img, label in preprocess_writer_images(dataset_path + writer_img_path):
        test_dataset['test'][writer_id].append({'img': word_img, 'label': label})

# Save the pickle file
with open('../DataSet/demo.pickle', 'wb') as f:
    pickle.dump(test_dataset, f)
data_path = '../DataSet/demo.pickle'

In [7]:
# Load the model and the dataset
print ('(1) Loading dataset files...')
TextDatasetObjval = TextDatasetval(base_path = data_path, num_examples = num_examples)
datasetval = torch.utils.data.DataLoader(
            TextDatasetObjval,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=True, drop_last=True,
            collate_fn=TextDatasetObjval.collate_fn)

print ('(2) Loading model...')

model = TRGAN()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cuda:0')
model.netG.load_state_dict(torch.load(model_path, map_location=device))
print (model_path+' : Model loaded Successfully')

print ('(3) Loading text content...')
text_encode =  [j.encode() for j in text.split(' ')]
eval_text_encode, eval_len_text = model.netconverter.encode(text_encode)
eval_text_encode = eval_text_encode.to(device).repeat(batch_size, 1, 1)

(1) Loading dataset files...
(2) Loading model...




initialize network with N02
initialize network with N02
initialize network with N02




files/iam_model.pth : Model loaded Successfully
(3) Loading text content...


In [8]:
input_handwriting_style_imgs = []
output_imgs = []

# Generate the images
for i,data_val in enumerate(tqdm.tqdm(datasetval)): 
    page_val = model._generate_page(data_val['simg'].to(DEVICE), data_val['swids'], eval_text_encode,eval_len_text)
    cv2.imwrite(output_path + 'image' + str(i) + '.png', post_process(page_val)*255)
    
print ('\nOutput images saved in : ' + output_path)

100%|██████████| 12/12 [00:08<00:00,  1.41it/s]


Output images saved in : ../demo_output/





In [9]:
# Save the images for fid calculations
real_path, fake_path = model.save_images_for_fid_calculation(datasetval, epoch=None, mode='test')

In [10]:
import pytorch_fid.fid_score as fid

fid_score = fid.calculate_fid_given_paths([real_path, fake_path], device='cuda', dims=2048, batch_size=1, num_workers=8)
print ('FID Score : ' + str(fid_score))

100%|██████████| 10240/10240 [02:12<00:00, 77.39it/s]
100%|██████████| 6144/6144 [01:21<00:00, 75.55it/s]


FID Score : 18.643953081175653
