In [None]:
import torch
from torch import nn
import torch.optim as optim
import pygame
import json 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

In [None]:
class CNN(nn.Module):
    def __init__(self, out_channels=6, kernel_size=1, stride=1, padding=0, vocab_size=5000, linear_size=5000, normalization=False):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=1, out_channels=out_channels, 
            kernel_size=kernel_size, stride=stride, padding=padding
        )
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(linear_size, vocab_size) 

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x


In [None]:


# SCREEN_WIDTH    = 100
# SCREEN_HEIGHT   = 20
# FONT_SIZE       = 10 

SCREEN_WIDTH    = 130
SCREEN_HEIGHT   = 25
FONT_SIZE       = 15 
# 4680

pygame.init()
font_noto_sans_regular = pygame.font.Font("../converttext/noto-sans.regular.ttf", FONT_SIZE)


def to_image(text:str, font, id:int=None, noise=False):
  # pygame.init()
  # screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
  screen = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
  screen.fill((255, 255, 255))
  # draw image
  img = font.render(str(text), True, (0, 0, 0))
  screen.blit(img, (2, 0))
  for event in pygame.event.get():
    if event.type == pygame.QUIT:
      run = False
  # pygame.display.flip() 
  # Save the screen as an image when the program finishes
  if noise == False:
    filename = f"./temp_image/word_{str(id)}_{str(text)}_notoSans.png"
  else:
    filename = f"./temp_image/word_{str(id)}_{str(text)}_notoSans_noised.png"
  pygame.image.save(screen, filename)
  # print("Screen saved as ", filename)
  # pygame.quit()
  return filename

image_path = to_image(text="1nd1st1nguishÎ±ble", font=font_noto_sans_regular, id=5, noise=False)


In [None]:
# load dictionary

with open('word_num_dict.json', 'r') as fp:
    word_num_dict_test = json.load(fp)
with open('word_to_id_dict.json', 'r') as fp:
    word_to_id_dict_test = json.load(fp)
with open('id_to_word_dict.json', 'r') as fp:
    id_to_word_dict_test = json.load(fp)
    
print("test load: length of word_num_dict:", len(word_num_dict_test.keys()))
print("test load: length of word_to_id_dict_test:", len(word_to_id_dict_test.keys()))
print("test load: length of id_to_word_dict_test:", len(id_to_word_dict_test.keys()))
print('note index needs to use str(index)')
print(f"test load: index: [5], word in id_to_word_dict: [{id_to_word_dict_test[str(5)]}], id in word_id_dict: [{word_to_id_dict_test[id_to_word_dict_test['5']]}]")

VOCAB_SIZE = len(word_num_dict_test.keys())


In [None]:
# define dataset

transform_norm = transforms.Compose([
    transforms.ToTensor(),
])
class WordImageIDDataset(Dataset):
    def __init__(self, word_to_id_list, font, noise=False):
        self.word_to_id_list = word_to_id_list
        self.font = font
        self.noise = noise
    
    def __len__(self):
        return len(self.word_to_id_list)
    
    def __getitem__(self, index):
        ''' index is not token ID '''
        output_word = self.word_to_id_list[index][0]
        image_path = to_image(
            text=output_word, 
            font=self.font,
            id=index, 
            noise=self.noise)
        
        # for path in image_paths:
        output_img = Image.open(image_path).convert('L')
        output_img = transform_norm(output_img)
        id = self.word_to_id_list[index][1]
        # output_id_onehot = torch.zeros(1, VOCAB_SIZE)
        # output_id_onehot[0][id] = 1
        output_id_onehot = torch.zeros(VOCAB_SIZE)
        output_id_onehot[id] = 1
                
        output = {'word'    : output_word,
                  'image'   : output_img,
                  'id'      : output_id_onehot}
        return output

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN(out_channels=16, kernel_size=3, stride=1, padding=1, vocab_size=VOCAB_SIZE, normalization=False, linear_size=12480).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
dataset = WordImageIDDataset(word_to_id_list = list(word_to_id_dict_test.items()),
                            font=font_noto_sans_regular,
                            noise=False)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    print(f'epoch: {epoch}')
    epoch_loss = []
    for data in dataloader:
        words    = data['word']
        imgs     = data['image']
        ids      = data['id']
        optimizer.zero_grad()
        # print(word)
        # print(imgs.shape)
        # print(id)
        # print(ids.shape)
        # read image    
        outputs = model(imgs.to(device))
        # print(output)
        softmax = torch.nn.Softmax()
        # outputs = softmax(outputs)
        # print(pred) 
        # print(preds.shape)
        # print(ids.shape)
        
        loss = criterion(outputs, ids.to(device))
        # print("Loss:", loss.item())
        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        # torch.cuda.empty_cache()
    print(sum(epoch_loss)/len(epoch_loss))

In [None]:
torch.save(model, "./CNN_130_250_15_16_3_1_1.pth")