In [1]:

   
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

class LearnableFourierPositionalEncoding(nn.Module):
    def __init__(self, G: int, M: int, F_dim: int, H_dim: int, D: int, gamma: float):
        """
        Learnable Fourier Features from https://arxiv.org/pdf/2106.02795.pdf (Algorithm 1)
        Implementation of Algorithm 1: Compute the Fourier feature positional encoding of a multi-dimensional position
        Computes the positional encoding of a tensor of shape [N, G, M]
        :param G: positional groups (positions in different groups are independent)
        :param M: each point has a M-dimensional positional values
        :param F_dim: depth of the Fourier feature dimension
        :param H_dim: hidden layer dimension
        :param D: positional encoding dimension
        :param gamma: parameter to initialize Wr
        """
        super().__init__()
        self.G = G
        self.M = M
        self.F_dim = F_dim
        self.H_dim = H_dim
        self.D = D
        self.gamma = gamma

        # Projection matrix on learned lines (used in eq. 2)
        self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
        # MLP (GeLU(F @ W1 + B1) @ W2 + B2 (eq. 6)
        self.mlp = nn.Sequential(
            nn.Linear(self.F_dim, self.H_dim, bias=True),
            nn.GELU(),
            nn.Linear(self.H_dim, self.D // self.G)
        )

        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)

    def forward(self, x):
        """
        Produce positional encodings from x
        :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M],
                  where G is the positional group and each group has M-dimensional positional values.
                  Positions in different positional groups are independent
        :return: positional encoding for X
        """
        N, G, M = x.shape
        # Step 1. Compute Fourier features (eq. 2)
        projected = self.Wr(x)
        cosines = torch.cos(projected)
        sines = torch.sin(projected)
        F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1)
        # Step 2. Compute projected Fourier features (eq. 6)
        Y = self.mlp(F)
        # Step 3. Reshape to x's shape
        PEx = Y.reshape((N, self.D))
        return PEx

In [2]:
from typing import List
from utils import ce_loss, to_tgt_output, ExpRateRecorder
import datamodule
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import FloatTensor, LongTensor, Tensor
from torch.nn.modules.transformer import TransformerDecoder
from datamodule import vocab

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class StrokeImgEmbed(nn.Module):
    def __init__(self, output_dim):
        super(StrokeImgEmbed, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(10816, 128)
        self.fc2 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

def generate_square_subsequent_mask(sz, device):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask
    
import math
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 200):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])



class StrokesTransformer(nn.Module):
    def __init__(self, 
                stroke_feature_dim: int,
                stroke_pos_feature_dim: int,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 tgt_vocab_size: int,
                 PAD_IDX,
                 device=None,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1,):
        self.device = device
        super(StrokesTransformer, self).__init__()
        self.PAD_IDX = PAD_IDX
        self.transformer = nn.Transformer(
            batch_first=True,
            d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.stroke_feature_dim = stroke_feature_dim
        self.stroke_pos_feature_dim = stroke_pos_feature_dim

        self.stroke_img_cnn = StrokeImgEmbed(stroke_feature_dim)
        self.positional_encoding_add = PositionalEncoding(
            emb_size, dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.tgt_vocab_size = tgt_vocab_size

        self.box_pos_enc = LearnableFourierPositionalEncoding(4, 1, 32, 32, self.stroke_pos_feature_dim, 100)

        self.src_embed = nn.Linear(self.stroke_feature_dim + self.stroke_pos_feature_dim, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)



    def just_encode(self, strokeImgs: FloatTensor, strokePaddingMask: FloatTensor, positions: FloatTensor):
        batch_size, stroke_num, h, w = strokeImgs.size()

        assert batch_size == positions.size(0)
        assert stroke_num == positions.size(1)
        assert positions.size(2) == 4
        # assert tgt.size(2) == self.tgt_vocab_size

        stroke_feat = self.stroke_img_cnn(strokeImgs.view((batch_size * stroke_num, 1, h, w))).view(batch_size, stroke_num, self.stroke_feature_dim)
        pos_feat = self.box_pos_enc(positions.view((batch_size * stroke_num, 4, 1))).view((batch_size, stroke_num, self.stroke_pos_feature_dim))
        words = torch.concat([stroke_feat, pos_feat], dim=2)
        """
        words: [batch_size, stroke_num, stroke_feature_dim + 4]
        """
        src = self.src_embed(words)
        """
        src: [batch_size, stroke_num, emb_size]
        """
        
        memory = self.transformer.encoder(src, mask=None, src_key_padding_mask=strokePaddingMask)
        return memory

    def just_decode(self, memory: Tensor, tgt: Tensor, strokePaddingMask: Tensor):
        tgt_emb = self.positional_encoding_add(self.tgt_tok_emb(tgt))
        tgt_pad_mask = tgt == self.PAD_IDX

        tgt_mask = generate_square_subsequent_mask(tgt.size(1), self.device)

        output = self.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask, memory_mask=None,
                              tgt_key_padding_mask=tgt_pad_mask,
                              memory_key_padding_mask=strokePaddingMask)
        return self.generator(output)

    def forward(self, strokeImgs: FloatTensor, strokePaddingMask: FloatTensor, positions: FloatTensor, tgt: FloatTensor):
        """
        strokeImgs: [batch_size, stroke_num, standard_h, standard_w]
        positions: [batch_size, stroke_num, 4(left, top, right, bottom)]
        tgt: [batch_size, answer_length]
        stroke_padding_mask: [batch_size, stroke_num]
        """

        memory = self.just_encode(strokeImgs, strokePaddingMask, positions)
        return self.just_decode(memory, tgt, strokePaddingMask)

class StrokesTransformerAdapter(pl.LightningModule):
    def __init__(self, 
                stroke_feature_dim: int,
                stroke_pos_feature_dim: int,
                num_encoder_layers: int,
                num_decoder_layers: int,
                emb_size: int,
                nhead: int,
                patience: int,
                learning_rate: float,
                dim_feedforward: int = 512,
                dropout: float = 0.1,):
        super().__init__()
        self.save_hyperparameters()

        self.model = StrokesTransformer(
            stroke_feature_dim = stroke_feature_dim,
            stroke_pos_feature_dim = stroke_pos_feature_dim,
            num_encoder_layers = num_encoder_layers,
            num_decoder_layers = num_decoder_layers,
            emb_size = emb_size,
            nhead = nhead,
            tgt_vocab_size = len(vocab),
            PAD_IDX = vocab.PAD_IDX,
            device=self.device,
            dim_feedforward = dim_feedforward,
            dropout = dropout,
        )


        self.exprate_recorder = ExpRateRecorder()


    def forward(self, strokeImgs: FloatTensor, strokePaddingMask: FloatTensor, positions: FloatTensor, tgt: FloatTensor):
        return self.model(strokeImgs, strokePaddingMask, positions, tgt)

    def configure_optimizers(self):
        optimizer = optim.Adadelta(
            self.parameters(),
            lr=self.hparams.learning_rate,
            eps=1e-6,
            weight_decay=1e-4,
        )

        reduce_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode="max",
            factor=0.1,
            patience=self.hparams.patience // self.trainer.check_val_every_n_epoch,
        )
        scheduler = {
            "scheduler": reduce_scheduler,
            "monitor": "val_ExpRate",
            "interval": "epoch",
            "frequency": self.trainer.check_val_every_n_epoch,
            "strict": True,
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler}


    def training_step(self, batch: datamodule.Batch, _):
        tgt, out = to_tgt_output(batch.wordLabels, self.device)
        out_hat = self(batch.strokeImgs, batch.strokeMasks, batch.positions, tgt)

        loss = ce_loss(out_hat, out)
        self.log("train_loss", loss, on_step=False, on_epoch=True, sync_dist=True)

        return loss
        
    def validation_step(self, batch: datamodule.Batch, _):
        tgt, out = to_tgt_output(batch.wordLabels, self.device)
        out_hat = self(batch.strokeImgs, batch.strokeMasks, batch.positions, tgt)

        loss = ce_loss(out_hat, out)
        self.log(
            "val_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )

        indeed_max_len = max([len(t) for t in batch.wordLabels])
        pred = self.greedy(
            batch, max_len=int(indeed_max_len * 1.5)
        )

        self.exprate_recorder(pred, batch.wordLabels[0])
        self.log(
            "val_ExpRate",
            self.exprate_recorder,
            prog_bar=True,
            on_step=False,
            on_epoch=True,
        )


    def greedy(self, d: datamodule.Batch, max_len:int):
        assert len(d) == 1
        memory = self.model.just_encode(d.strokeImgs, d.strokeMasks, d.positions)
        
        ret = []
        for i in range(max_len):
            seq = LongTensor([[vocab.SOS_IDX] + ret])
            outs = self.model.just_decode(memory, seq, d.strokeMasks)
            id = outs[0][i].argmax()
            ret.append(id.item())
            if id == vocab.EOS_IDX:
                break
        return vocab.indices2words(ret)




ImportError: cannot import name 'CROHMEDatamodule' from 'datamodule.datamodule' (/Users/aomori/repos/strokeAtt/datamodule/datamodule.py)

In [None]:
import datamodule
from importlib import reload
reload(datamodule)
import datamodule


dm = datamodule.StrokeDatamodule(["strokes/CROHME2011_train/**/*.npy"], ["strokes/CROHME2019_testGT/**/*.npy"])
import pytorch_lightning as pl
from datamodule import vocab

trainer = pl.Trainer(gpus=0,max_epochs=1)
model = StrokesTransformerAdapter(**{
    "stroke_feature_dim": 30, 
    "stroke_pos_feature_dim": 128, 
    "num_encoder_layers": 4,
    "num_decoder_layers": 5, 
    "emb_size": 512, 
    "nhead": 8, 
    "learning_rate":0.3,
    "patience": 20,
})

trainer.fit(model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


Extract data from: ['strokes/CROHME2011_train/**/*.npy'], with data size: 921



  | Name             | Type               | Params
--------------------------------------------------------
0 | model            | StrokesTransformer | 21.1 M
1 | exprate_recorder | ExpRateRecorder    | 0     
--------------------------------------------------------
21.1 M    Trainable params
0         Non-trainable params
21.1 M    Total params
84.296    Total estimated model params size (MB)


Extract data from: ['strokes/CROHME2019_testGT/**/*.npy'], with data size: 1199
Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]



Validation sanity check:  50%|█████     | 1/2 [00:00<00:00,  7.49it/s]



Epoch 0:   0%|          | 0/1315 [00:00<?, ?it/s]                     



Epoch 0:   3%|▎         | 44/1315 [00:32<15:44,  1.35it/s, loss=3.42, v_num=42, val_loss=5.010, val_ExpRate=0.000]