In [None]:
from google.colab import drive

drive.mount('/content/drive')

In [None]:
!mkdir /content/Data
!unzip /content/drive/MyDrive/grandstaff.zip -d /content/Data/GrandstaffDataset

!mkdir /content/weights

!pip install fire wandb torchinfo loguru lightning

In [3]:
import os
import cv2
import torch
import random
import fire
import wandb

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset

from torchinfo import summary
from torchvision import transforms
from torch.utils.data import DataLoader

from itertools import groupby
from loguru import logger
from rich.progress import track
from os import path

import lightning as L

from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

utils

In [4]:
@logger.catch
def check_and_retrieveVocabulary(YSequences, pathOfSequences, nameOfVoc):
    w2ipath = pathOfSequences + "/" + nameOfVoc + "w2i.npy"
    i2wpath = pathOfSequences + "/" + nameOfVoc + "i2w.npy"

    w2i = []
    i2w = []

    if not path.isdir(pathOfSequences):
        os.mkdir(pathOfSequences)

    if path.isfile(w2ipath):
        w2i = np.load(w2ipath, allow_pickle=True).item()
        i2w = np.load(i2wpath, allow_pickle=True).item()
    else:
        w2i, i2w = make_vocabulary(YSequences, pathOfSequences, nameOfVoc)

    return w2i, i2w

def make_vocabulary(YSequences, pathToSave, nameOfVoc):
    vocabulary = set()
    for samples in YSequences:
        for element in samples:
                vocabulary.update(element)

    #Vocabulary created
    w2i = {symbol:idx+1 for idx,symbol in enumerate(vocabulary)}
    i2w = {idx+1:symbol for idx,symbol in enumerate(vocabulary)}

    w2i['<pad>'] = 0
    i2w[0] = '<pad>'

    #Save the vocabulary
    np.save(pathToSave + "/" + nameOfVoc + "w2i.npy", w2i)
    np.save(pathToSave + "/" + nameOfVoc + "i2w.npy", i2w)

    return w2i, i2w


def levenshtein(a,b):
    "Computes the Levenshtein distance between a and b."
    n, m = len(a), len(b)

    if n > m:
        a,b = b,a
        n,m = m,n

    current = range(n+1)
    for i in range(1,m+1):
        previous, current = current, [i]+[0]*n
        for j in range(1,n+1):
            add, delete = previous[j]+1, current[j-1]+1
            change = previous[j-1]
            if a[j-1] != b[i-1]:
                change = change + 1
            current[j] = min(add, delete, change)

    return current[n]

@logger.catch
def save_bkern_output(output_path, array):
    for idx, content in enumerate(array):
        transcription = "".join(content)
        transcription = transcription.replace("<t>", "\t")
        transcription = transcription.replace("<b>", "\n")
        transcription = transcription.replace("<s>", " ")

        with open(f"{output_path}/{idx}.bekern", "w") as bfilewrite:
            bfilewrite.write(transcription)

e2e_unfolding

In [5]:
class DepthSepConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, activation=None, padding=True, stride=(1,1), dilation=(1,1)):
        super(DepthSepConv2D, self).__init__()

        self.padding = None

        padding = [int((k-1)/2) for k in kernel_size]

        if kernel_size[0] % 2 == 0 or kernel_size[1] % 2 == 0:
            padding_h = kernel_size[1] - 1
            padding_w = kernel_size[0] - 1
            self.padding = [padding_h//2, padding_h-padding_h//2, padding_w//2, padding_w-padding_w//2]
            padding = (0, 0)

        self.depth_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, dilation=dilation, stride=stride, padding=padding, groups=in_channels)
        self.point_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, dilation=dilation, kernel_size=(1,1))
        self.activation = activation

    def forward(self, inputs):
        x = self.depth_conv(inputs)
        if self.padding:
            x = F.pad(x, self.padding)
        if self.activation:
            x = self.activation(x)

        x = self.point_conv(x)

        return x

class MixDropout(nn.Module):
    def __init__(self, dropout_prob=0.4, dropout_2d_prob=0.2):
        super(MixDropout, self).__init__()

        self.dropout = nn.Dropout(dropout_prob)
        self.dropout2D = nn.Dropout2d(dropout_2d_prob)

    def forward(self, inputs):
        if random.random() < 0.5:
            return self.dropout(inputs)
        return self.dropout2D(inputs)

class ConvolutionalBlock(nn.Module):
    def __init__(self, in_c, out_c, stride=(1,1), kernel=3, activation=nn.ReLU, dropout=0.4):
        super(ConvolutionalBlock, self).__init__()

        self.activation = activation()
        self.conv1 = nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=kernel, padding=kernel//2)
        self.conv2 = nn.Conv2d(in_channels=out_c, out_channels=out_c, kernel_size=kernel, padding=kernel//2)
        self.conv3 = nn.Conv2d(in_channels=out_c, out_channels=out_c, kernel_size=(3,3), padding=(1,1), stride=stride)
        self.normLayer = nn.InstanceNorm2d(num_features=out_c, eps=0.001, momentum=0.99, track_running_stats=False)
        self.dropout = MixDropout(dropout_prob=dropout, dropout_2d_prob=dropout/2)

    def forward(self, inputs):
        pos = random.randint(1,3)

        x = self.conv1(inputs)
        x = self.activation(x)

        if pos == 1:
            x = self.dropout(x)

        x = self.conv2(x)
        x = self.activation(x)

        if pos == 2:
            x = self.dropout(x)

        x = self.normLayer(x)
        x = self.conv3(x)
        x = self.activation(x)

        if pos == 3:
            x = self.dropout(x)

        return x

class DSCBlock(nn.Module):

    def __init__(self, in_c, out_c, stride=(2, 1), activation=nn.ReLU, dropout=0.4):
        super(DSCBlock, self).__init__()

        self.activation = activation()
        self.conv1 = DepthSepConv2D(in_c, out_c, kernel_size=(3, 3))
        self.conv2 = DepthSepConv2D(out_c, out_c, kernel_size=(3, 3))
        self.conv3 = DepthSepConv2D(out_c, out_c, kernel_size=(3, 3), padding=(1, 1), stride=stride)
        self.norm_layer = nn.InstanceNorm2d(out_c, eps=0.001, momentum=0.99, track_running_stats=False)
        self.dropout = MixDropout(dropout_prob=dropout, dropout_2d_prob=dropout/2)

    def forward(self, x):
        pos = random.randint(1, 3)
        x = self.conv1(x)
        x = self.activation(x)

        if pos == 1:
            x = self.dropout(x)

        x = self.conv2(x)
        x = self.activation(x)

        if pos == 2:
            x = self.dropout(x)

        x = self.norm_layer(x)
        x = self.conv3(x)

        if pos == 3:
            x = self.dropout(x)

        return x

class Encoder(nn.Module):

    def __init__(self, in_channels, dropout=0.4):
        super(Encoder, self).__init__()

        self.conv_blocks = nn.ModuleList([
            ConvolutionalBlock(in_c=in_channels, out_c=32, stride=(1,1), dropout=dropout),
            ConvolutionalBlock(in_c=32, out_c=64, stride=(2,2), dropout=dropout),
            ConvolutionalBlock(in_c=64, out_c=128, stride=(2,2), dropout=dropout),
            ConvolutionalBlock(in_c=128, out_c=256, stride=(2,2), dropout=dropout),
            ConvolutionalBlock(in_c=256, out_c=512, stride=(2,1), dropout=dropout)
        ])

        self.dscblocks = nn.ModuleList([
            DSCBlock(in_c=512, out_c=512, stride=(1,1), dropout = dropout),
            DSCBlock(in_c=512, out_c=512, stride=(1,1), dropout = dropout),
            DSCBlock(in_c=512, out_c=512, stride=(1,1), dropout = dropout),
            DSCBlock(in_c=512, out_c=512, stride=(1,1), dropout = dropout)
        ])

    def forward(self, x):
        for layer in self.conv_blocks:
            x = layer(x)

        for layer in self.dscblocks:
            xt = layer(x)
            x = x + xt if x.size() == xt.size() else xt

        return x


class RecurrentScoreUnfolding(nn.Module):

    def __init__(self, out_cats):
        super(RecurrentScoreUnfolding, self).__init__()
        self.dec_lstm = nn.LSTM(input_size=512, hidden_size=256, bidirectional=True, batch_first=True)
        self.out_dense = nn.Linear(in_features=512, out_features=out_cats)

    def forward(self, inputs):
        x = inputs
        b, c, h, w = x.size()
        x = x.reshape(b, c, h*w)
        x = x.permute(0,2,1)
        x, _ = self.dec_lstm(x)
        x = self.out_dense(x)
        x = x.permute(1,0,2)
        return F.log_softmax(x, dim=2)


class E2EScore_CRNN(nn.Module):

    def __init__(self, in_channels, out_cats, pretrain_path=None):
        super(E2EScore_CRNN, self).__init__()
        self.encoder = Encoder(in_channels=in_channels)

        if pretrain_path != None:
            print(f"Loading weights from {pretrain_path}")
            self.encoder.load_state_dict(torch.load(pretrain_path), strict=True)

        self.decoder = RecurrentScoreUnfolding(out_cats=out_cats)

    def forward(self, inputs):
        x = self.encoder(inputs)
        x = self.decoder(x)
        return x


def get_rcnn_model(maxwidth, maxheight, in_channels, out_size):
    model = E2EScore_CRNN(in_channels=in_channels, out_cats=out_size)
    summary(model, input_size=[(1,in_channels,maxheight,maxwidth)], dtypes=[torch.float])

    return model

eval_functions

In [6]:
def parse_krn_content(krn, ler_parsing=False, cer_parsing=False):
    if cer_parsing:
        krn = krn.replace("\n", " <b> ")
        krn = krn.replace("\t", " <t> ")
        tokens = krn.split(" ")
        characters = []
        for token in tokens:
            if token in ['<b>', '<t>']:
                characters.append(token)
            else:
                for char in token:
                    characters.append(char)
        return characters
    elif ler_parsing:
        krn_lines = krn.split("\n")
        for i, line in enumerate(krn_lines):
            line = line.replace("\n", " <b> ")
            line = line.replace("\t", " <t> ")
            krn_lines[i] = line
        return krn_lines
    else:
        krn = krn.replace("\n", " <b> ")
        krn = krn.replace("\t", " <t> ")
        return krn.split(" ")

def compute_metric(a1, a2):
    acc_ed_dist = 0
    acc_len = 0

    for (h, g) in zip(a1, a2):
        acc_ed_dist += levenshtein(h, g)
        acc_len += len(g)

    return 100.*acc_ed_dist / acc_len

def get_metrics(hyp_array, gt_array):
    hyp_cer = []
    gt_cer = []

    hyp_ser = []
    gt_ser = []

    hyp_ler = []
    gt_ler = []

    for h_string, gt_string in zip(hyp_array, gt_array):
        hyp_ler.append(parse_krn_content(h_string, ler_parsing=True, cer_parsing=False))
        gt_ler.append(parse_krn_content(gt_string, ler_parsing=True, cer_parsing=False))

        hyp_ser.append(parse_krn_content(h_string, ler_parsing=False, cer_parsing=False))
        gt_ser.append(parse_krn_content(gt_string, ler_parsing=False, cer_parsing=False))

        hyp_cer.append(parse_krn_content(h_string, ler_parsing=False, cer_parsing=True))
        gt_cer.append(parse_krn_content(gt_string, ler_parsing=False, cer_parsing=True))

    acc_ed_dist = 0
    acc_len = 0

    cer = 0
    ser = 0
    ler = 0

    for (h, g) in zip(hyp_cer, gt_cer):
        acc_ed_dist += levenshtein(h, g)
        acc_len += len(g)

    cer = compute_metric(hyp_cer, gt_cer)
    ser = compute_metric(hyp_ser, gt_ser)
    ler = compute_metric(hyp_ler, gt_ler)

    return cer, ser, ler

data

In [7]:
@logger.catch
def batch_preparation_ctc(data):
    images = [sample[0] for sample in data]
    gt = [sample[1] for sample in data]
    L = [sample[2] for sample in data]
    T = [sample[3] for sample in data]

    max_image_width = max([img.shape[2] for img in images])
    max_image_height = max([img.shape[1] for img in images])

    X_train = torch.ones(size=[len(images), 1, max_image_height, max_image_width], dtype=torch.float32)

    for i, img in enumerate(images):
        c, h, w = img.size()
        X_train[i, :, :h, :w] = img

    max_length_seq = max([len(w) for w in gt])
    Y_train = torch.zeros(size=[len(gt),max_length_seq])
    for i, seq in enumerate(gt):
        Y_train[i, 0:len(seq)] = torch.from_numpy(np.asarray([char for char in seq]))

    return X_train, Y_train, L, T

@logger.catch
def load_data(partition_file, resize_ratio = 1, load_distorted=True, extension=".krn"):
    X = []
    Y = []
    with open(partition_file) as partfile:
        part_lines = partfile.read()
        part_lines = part_lines.split("\n")
        for f_path in track(part_lines, description="Loading..."):
            if extension != ".bekrn":
                f_path = f_path.replace(".bekrn", extension)
            krn = None
            krnlines = []
            file_path = f"{f_path}"
            if os.path.isfile(file_path):
                with open(file_path) as krnfile:
                    krn = krnfile.read()
                    krn = krn.replace(" ", " <s> ")
                    krn = krn.replace("·", " ")
                    lines = krn.split("\n")
                    for line in lines:
                        line = line.replace("\t", " <t> ")
                        line = line.split(" ")
                        if len(line) > 1:
                            line.append("<b>")
                            krnlines.append(line)
                    if os.path.exists(f"{file_path.split('.')[0]}.jpg"):
                        if load_distorted:
                            height = 256
                            img = cv2.imread(f"{file_path.split('.')[0]}_distorted.jpg", 0)
                            width = int(float(height * img.shape[1]) / img.shape[0])
                            img =  cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
                            if (height//8) * (width//16) > len(sum(krnlines, [])):
                                width = int(np.ceil(img.shape[1] * resize_ratio))
                                height = int(np.ceil(img.shape[0] * resize_ratio))
                                img = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
                                img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
                                X.append(img)
                                Y.append(sum(krnlines, []))
                        else:
                            img = cv2.imread(f"{file_path.split('.')[0]}.jpg", 0)
                            width = int(np.ceil(img.shape[1] * resize_ratio))
                            height = int(np.ceil(img.shape[0] * resize_ratio))
                            img = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)
                            img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
                            X.append(img)
                            Y.append(sum(krnlines, []))

    return X, Y


class PoliphonicDataset(Dataset):
    def __init__(self, partition_file) -> None:
        self.x, self.y = load_data(partition_file)

        self.tensorTransform = transforms.ToTensor()

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        image = self.tensorTransform(self.x[index])
        gt = torch.from_numpy(np.asarray([self.w2i[token] for token in self.y[index]]))

        return image, gt, (image.shape[2] // 8) * (image.shape[1] // 16), len(gt)

    def get_max_hw(self):
        m_width = np.max([img.shape[1] for img in self.x])
        m_height = np.max([img.shape[0] for img in self.x])

        return m_height, m_width

    def get_max_seqlen(self):
        return np.max([len(seq) for seq in self.y])

    def vocab_size(self):
        return len(self.w2i)

    def get_gt(self):
        return self.y

    def set_dictionaries(self, w2i, i2w):
        self.w2i = w2i
        self.i2w = i2w
        self.padding_token = w2i['<pad>']

    def get_dictionaries(self):
        return self.w2i, self.i2w

    def get_i2w(self):
        return self.i2w

def load_dataset(train_path=None, val_path=None, test_path=None, corpus_name=None):
    train_dataset = PoliphonicDataset(partition_file=train_path)
    val_dataset = PoliphonicDataset(partition_file=val_path)
    test_dataset = PoliphonicDataset(partition_file=test_path)

    w2i, i2w = check_and_retrieveVocabulary([train_dataset.get_gt(), val_dataset.get_gt(), test_dataset.get_gt()], "/content/vocab/", f"{corpus_name}")

    train_dataset.set_dictionaries(w2i, i2w)
    val_dataset.set_dictionaries(w2i, i2w)
    test_dataset.set_dictionaries(w2i, i2w)

    return train_dataset, val_dataset, test_dataset

model_manager

In [8]:
class LighntingE2EModelUnfolding(L.LightningModule):
    def __init__(self, model, blank_idx, i2w, output_path) -> None:
        super(LighntingE2EModelUnfolding, self).__init__()
        self.model = model
        self.loss = nn.CTCLoss(blank=blank_idx)
        self.blank_idx = blank_idx
        self.i2w = i2w
        self.accum_ed = 0
        self.accum_len = 0

        self.dec_val_ex = []
        self.gt_val_ex = []
        self.img_val_ex = []
        self.ind_val_ker = []

        self.out_path = output_path

        self.save_hyperparameters(ignore=['model'])

    def forward(self, input):
        return self.model(input)

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters(), lr=1e-4)

    def training_step(self, train_batch, batch_idx):
         X_tr, Y_tr, L_tr, T_tr = train_batch
         predictions = self.forward(X_tr)
         loss = self.loss(predictions, Y_tr, L_tr, T_tr)
         self.log('loss', loss, on_epoch=True, batch_size=1, prog_bar=True)
         return loss

    def compute_prediction(self, batch):
        X, Y, _, _ = batch
        pred = self.forward(X)
        pred = pred.permute(1,0,2).contiguous()
        pred = pred[0]
        out_best = torch.argmax(pred,dim=1)
        out_best = [k for k, g in groupby(list(out_best))]
        decoded = []
        for c in out_best:
            if c.item() != self.blank_idx:
                decoded.append(c.item())

        decoded = [self.i2w[tok] for tok in decoded]
        gt = [self.i2w[int(tok.item())] for tok in Y[0]]

        return decoded, gt

    def validation_step(self, val_batch, batch_idx):
        dec, gt = self.compute_prediction(val_batch)

        dec = "".join(dec)
        dec = dec.replace("<t>", "\t")
        dec = dec.replace("<b>", "\n")
        dec = dec.replace("<s>", " ")

        gt = "".join(gt)
        gt = gt.replace("<t>", "\t")
        gt = gt.replace("<b>", "\n")
        gt = gt.replace("<s>", " ")

        self.dec_val_ex.append(dec)
        self.gt_val_ex.append(gt)

    def on_validation_epoch_end(self):

        cer, ser, ler = get_metrics(self.dec_val_ex, self.gt_val_ex)

        self.log('val_CER', cer)
        self.log('val_SER', ser)
        self.log('val_LER', ler)

        return ser

    def test_step(self, test_batch, batch_idx):
        dec, gt = self.compute_prediction(test_batch)

        dec = "".join(dec)
        dec = dec.replace("<t>", "\t")
        dec = dec.replace("<b>", "\n")
        dec = dec.replace("<s>", " ")

        gt = "".join(gt)
        gt = gt.replace("<t>", "\t")
        gt = gt.replace("<b>", "\n")
        gt = gt.replace("<s>", " ")


        with open(f"{self.out_path}/hyp/{batch_idx}.krn", "w+") as krnfile:
            krnfile.write(dec)

        with open(f"{self.out_path}/gt/{batch_idx}.krn", "w+") as krnfile:
            krnfile.write(gt)

        self.dec_val_ex.append(dec)
        self.gt_val_ex.append(gt)
        self.img_val_ex.append((255.*test_batch[0].squeeze(0)))

    def on_test_epoch_end(self) -> None:
        cer, ser, ler = get_metrics(self.dec_val_ex, self.gt_val_ex)

        self.log('val_CER', cer)
        self.log('val_SER', ser)
        self.log('val_LER', ler)

        columns = ['Image', 'PRED', 'GT']
        data = []

        nsamples = len(self.dec_val_ex) if len(self.dec_val_ex) < 5 else 5
        random_indices = random.sample(range(len(self.dec_val_ex)), nsamples)

        for index in random_indices:
            data.append([wandb.Image(self.img_val_ex[index]), "".join(self.dec_val_ex[index]), "".join(self.gt_val_ex[index])])

        table = wandb.Table(columns= columns, data=data)

        self.logger.experiment.log(
            {'Test samples': table}
        )

        self.gt_val_ex = []
        self.dec_val_ex = []

        return ser

def get_model(maxwidth, maxheight, in_channels, out_size, blank_idx, i2w, model_name, output_path):
    model = get_rcnn_model(maxwidth, maxheight, in_channels, out_size)
    lighningModel = LighntingE2EModelUnfolding(model=model, blank_idx=blank_idx, i2w=i2w, output_path=output_path)
    summary(lighningModel, input_size=([1, in_channels, maxheight, maxwidth]))
    return lighningModel, model

main

In [None]:
train_pth = "/content/Data/GrandstaffDataset/partitions/train.txt"
val_pth = "/content/Data/GrandstaffDataset/partitions/val.txt"
test_pth = "/content/Data/GrandstaffDataset/partitions/test.txt"

def main(train_path=train_pth, val_path=train_pth, test_path=test_pth, encoding="krn", model_name="CRNN"):
    outpath = f"/content/out/GrandStaff_{encoding}/{model_name}"
    os.makedirs(outpath, exist_ok=True)
    os.makedirs(f"{outpath}/hyp", exist_ok=True)
    os.makedirs(f"{outpath}/gt", exist_ok=True)


    train_dataset, val_dataset, test_dataset = load_dataset(train_path, val_path, test_path, corpus_name=f"GrandStaff_{encoding}")

    _, i2w = train_dataset.get_dictionaries()

    train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=20, collate_fn=batch_preparation_ctc)
    val_dataloader = DataLoader(val_dataset, batch_size=1, num_workers=20, collate_fn=batch_preparation_ctc)
    test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=20, collate_fn=batch_preparation_ctc)

    maxheight, maxwidth = train_dataset.get_max_hw()

    model, torchmodel = get_model(maxwidth=maxwidth, maxheight=maxheight, in_channels=1, blank_idx=len(i2w), out_size=train_dataset.vocab_size()+1, i2w=i2w, model_name=model_name, output_path=outpath)

    wandb_logger = WandbLogger(project='E2E_Pianoform', name=model_name)

    early_stopping = EarlyStopping(monitor='val_SER', min_delta=0.01, patience=5, mode="min", verbose=False)

    checkpointer = ModelCheckpoint(dirpath=f"/content/weights/{encoding}/{model_name}", filename=f"{model_name}",
                                   monitor="val_SER", mode='min',
                                   save_top_k=1, verbose=False)

    trainer = Trainer(max_epochs=2, logger=wandb_logger, callbacks=[checkpointer, early_stopping])

    trainer.fit(model, train_dataloader, val_dataloader)

    model = LighntingE2EModelUnfolding.load_from_checkpoint(checkpointer.best_model_path, model=torchmodel)
    trainer.test(model, test_dataloader)
    wandb.finish()

if __name__ == "__main__":
    main()

In [None]:
!mkdir /content/drive/MyDrive/modelBackup
!cp -r /content/weights/ /content/drive/MyDrive/modelBackup