In [3]:
import sys
import numpy as np
import torch
import torchvision.models as models

sys.path.append('./dataset/')
from create_google_fonts_dataset import parse_gf_metadata, save_rendered_glyphs
from classification_dataset import CharClassificationDataset
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

%load_ext autoreload
%autoreload 2

In [4]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

sys.path.append('./dataset/')
from create_google_fonts_dataset import parse_gf_metadata, save_rendered_glyphs
from classification_dataset import CharClassificationDataset

# render fonts
ofl_path = './dataset/fonts/ofl/'
fonts_data = parse_gf_metadata(ofl_path) # google fonts dataframe

# removing blacklisted fonts from the dataframe
blacklist_fonts = ['Kumar One', 'Rubik'] # fonts with broken tabels
indeces_to_remove = False
for font_name in blacklist_fonts:
    indeces_to_remove += (fonts_data.name==font_name).values
fonts_data.drop(np.where(indeces_to_remove)[0], inplace=True)

# construct letter set
capital_alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
char_set = capital_alphabet
for char in 'OQMWIN': # removing problematic symbols
    char_set = char_set.replace(char, '')

# creating dataset
tfs = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[np.mean([0.485, 0.456, 0.406])],
                                 std=[np.mean([0.229, 0.224, 0.225])])                           
])
root_dir = './dataset/rendered_set/'
dataset_c = CharClassificationDataset(fonts_data, root_dir, char_set, transform=tfs)

In [5]:
from torch.utils.data import DataLoader, SubsetRandomSampler
batch_size = 64

data_size = len(dataset_c)
validation_fraction = .2

val_split = int(np.floor((validation_fraction) * data_size))
indices = list(range(data_size))
np.random.seed(4)
np.random.shuffle(indices)

val_indices, train_indices = indices[:val_split], indices[val_split:]

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset_c, batch_size=batch_size, 
                                           sampler=train_sampler, num_workers=1, pin_memory=True)
val_loader = torch.utils.data.DataLoader(dataset_c, batch_size=batch_size,
                                         sampler=val_sampler, num_workers=1, pin_memory=True)

In [6]:
import pytorch_lightning as pl
from torch.nn import functional as F

class Classification_model(pl.LightningModule):

    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone

    def forward(self, x):
        return self.backbone.forward(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y)
        return {'val_loss': val_loss}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [7]:
# prepare backbone
vgg11 = models.vgg11(pretrained=True).requires_grad_(False)
vgg11.features[0] = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
vgg11.classifier[-1] = torch.nn.Linear(in_features=4096, out_features=len(char_set), bias=True)
for name, param in vgg11.named_parameters():
    if not param.requires_grad: continue
    print('parameter', name, 'requires grad')

parameter features.0.weight requires grad
parameter features.0.bias requires grad
parameter classifier.6.weight requires grad
parameter classifier.6.bias requires grad


In [8]:
from pytorch_lightning import Trainer

model = Classification_model(backbone=vgg11)
trainer = Trainer()
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)   

GPU available: False, used: False
No environment variable for node rank defined. Set as 0.

   | Name                  | Type              | Params
--------------------------------------------------------
0  | backbone              | VGG               | 128 M 
1  | backbone.features     | Sequential        | 9 M   
2  | backbone.features.0   | Conv2d            | 640   
3  | backbone.features.1   | ReLU              | 0     
4  | backbone.features.2   | MaxPool2d         | 0     
5  | backbone.features.3   | Conv2d            | 73 K  
6  | backbone.features.4   | ReLU              | 0     
7  | backbone.features.5   | MaxPool2d         | 0     
8  | backbone.features.6   | Conv2d            | 295 K 
9  | backbone.features.7   | ReLU              | 0     
10 | backbone.features.8   | Conv2d            | 590 K 
11 | backbone.features.9   | ReLU              | 0     
12 | backbone.features.10  | MaxPool2d         | 0     
13 | backbone.features.11  | Conv2d            | 1 M   
14 | backbo

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Detected KeyboardInterrupt, attempting graceful shutdown...





Traceback (most recent call last):
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


1