In [77]:
import os
import re
import cv2
import json
import torch
import wandb
import hydra
import numpy as np
import torch.nn as nn

from SMT import SMT
from config_typings import Config, SMTConfig, DataConfig, CLConfig
from data import PretrainingLinesDataset
from torchinfo import summary
from utils import check_and_retrieveVocabulary
from data_augmentation.data_augmentation import augment, convert_img_to_tensor
from torch.utils.data import Dataset

from model.ConvEncoder import Encoder
from model.ConvNextEncoder import ConvNextEncoder
from model.Decoder import Decoder
from model.PositionEncoding import PositionalEncoding2D, PositionalEncoding1D

from eval_functions import compute_poliphony_metrics
from Generator.MusicSynthGen import VerovioGenerator

from rich import progress

import lightning.pytorch as L
from lightning.pytorch import Trainer, LightningDataModule
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import WandbLogger

torch.set_float32_matmul_precision('high')
wandb.login()




True

In [87]:
class ConfigSection:
    def __init__(self, section_dict):
        for key, value in section_dict.items():
            setattr(self, key, value)

class Config:
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            setattr(self, key, ConfigSection(value) if isinstance(value, dict) else value)

# Now you can create an instance of the Config class with your dictionary
config_dict = {
    'data': {
        'data_path': 'Data/GrandStaff/partitions_grandstaff/types/', 
        'synth_path': '', 
        'vocab_name': 'GrandStaff_BeKern', 
        'out_dir': 'out/GrandStaff/SMT_lines', 
        'krn_type': 'bekrn', 
        'reduce_ratio': 0.5, 
        'base_folder': 'GrandStaff', 
        'file_format': 'jpg', 
        'tokenization_mode': 'bekern', 
        'fold': '???'
    },
    'model_setup': {
        'in_channels': 1, 
        'd_model': 256, 
        'dim_ff': 256, 
        'num_dec_layers': 8, 
        'encoder_type': 'NexT', 
        'max_height': 2512, 
        'max_width': 2512, 
        'max_len': 5512, 
        'lr': 0.001
    }, 
    'experiment': {
        'metric_to_watch': 'val_SER', 
        'metric_mode': 'min', 
        'max_epochs': 100, 
        'val_after': 5, 
        'pretrain_weights': '???'
    }, 
    'cl': {
        'num_cl_steps': 3, 
        'max_synth_prob': 0.9, 
        'min_synth_prob': 0.2, 
        'increase_steps': 40000, 
        'finetune_steps': 200000, 
        'curriculum_stage_beginning': 2, 
        'teacher_forcing_perc': '???', 
        'skip_progressive': '???', 
        'skip_cl': '???'
    }
}

config = Config(config_dict)




In [96]:

class SMT(L.LightningModule):
    def __init__(self, config:SMTConfig, w2i, i2w) -> None:
        super().__init__()
        
        if config.encoder_type == "NexT":
            self.encoder = ConvNextEncoder(in_chans=config.in_channels, depths=[3,3,9], dims=[64, 128, 256])
        else:
            self.encoder = Encoder(in_channels=config.in_channels)

        self.decoder = Decoder(config.d_model, config.dim_ff, config.num_dec_layers, config.max_len + 1, len(w2i))
        self.positional_2D = PositionalEncoding2D(config.d_model, (config.max_height//16) + 1, (config.max_width//8) + 1)

        self.padding_token = 0

        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_token)

        self.valpredictions = []
        self.valgts = []

        self.w2i = w2i
        self.i2w = i2w

        self.maxlen = config.max_len

        #self(torch.randn(1,1,config.max_height,config.max_width).to(torch.device("cuda")), torch.randint(0, len(w2i), (1,config.max_len)).to(torch.device("cuda")))
        #import sys
        #sys.exit()
        self.worst_loss_image = None
        self.worst_training_loss = -1
        summary(self, input_size=[(1,1, 2512, 2512), (1, 5512)], dtypes=[torch.float, torch.long])

        self.save_hyperparameters()

    def forward(self, x, y_pred):
        encoder_output = self.encoder(x)
        b, c, h, w = encoder_output.size()
        reduced_size = [s.shape[:2] for s in encoder_output]
        ylens = [len(sample) for sample in y_pred]
        cache = None

        pos_features = self.positional_2D(encoder_output)
        features = torch.flatten(encoder_output, start_dim=2, end_dim=3).permute(2,0,1)
        enhanced_features = features
        enhanced_features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(2,0,1)
        output, predictions, _, _, weights = self.decoder(features, enhanced_features, y_pred[:, :-1], reduced_size, 
                                                           [max(ylens) for _ in range(b)], encoder_output.size(), 
                                                           start=0, cache=cache, keep_all_weights=True)
    
        return output, predictions, cache, weights


    def forward_encoder(self, x):
        return self.encoder(x)
    
    def forward_decoder(self, encoder_output, last_preds, cache=None):
        b, c, h, w = encoder_output.size()
        reduced_size = [s.shape[:2] for s in encoder_output]
        ylens = [len(sample) for sample in last_preds]
        cache = cache

        pos_features = self.positional_2D(encoder_output)
        features = torch.flatten(encoder_output, start_dim=2, end_dim=3).permute(2,0,1)
        enhanced_features = features
        enhanced_features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(2,0,1)
        output, predictions, _, _, weights = self.decoder(features, enhanced_features, last_preds[:, :], reduced_size, 
                                                           [max(ylens) for _ in range(b)], encoder_output.size(), 
                                                           start=0, cache=cache, keep_all_weights=True)
    
        return output, predictions, cache, weights
    
    def configure_optimizers(self):
        return torch.optim.AdamW(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=2e-4, amsgrad=False)

    def training_step(self, train_batch):
        x, di, y = train_batch
        output, predictions, cache, weights = self.forward(x, di)
        loss = self.loss(predictions, y[:, :-1])
        self.log('loss', loss, on_epoch=True, batch_size=1, prog_bar=True)
        if loss > self.worst_training_loss:
            self.worst_loss_image = x
            self.worst_training_loss = loss
        
        return loss

    def on_train_epoch_end(self):
        #plot the worst training loss image in wandb
        self.logger.experiment.log({"worst_training_loss_image": [wandb.Image(self.worst_loss_image.squeeze(0).cpu().numpy())]})
        self.worst_training_loss = -1
        self.worst_loss_image = None

    def validation_step(self, val_batch, batch_idx):
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        x, _, y = val_batch
        encoder_output = self.forward_encoder(x)
        predicted_sequence = torch.from_numpy(np.asarray([self.w2i['<bos>']])).to(device).unsqueeze(0)
        cache = None
        for i in range(128):
             output, predictions, cache, weights = self.forward_decoder(encoder_output, predicted_sequence.long(), cache=cache)
             predicted_token = torch.argmax(predictions[:, :, -1]).item()
             predicted_sequence = torch.cat([predicted_sequence, torch.argmax(predictions[:, :, -1], dim=1, keepdim=True)], dim=1)
             if predicted_token == self.w2i['<eos>']:
                 break
        
        dec = "".join([self.i2w[token.item()] for token in predicted_sequence.squeeze(0)[1:]])
        dec = dec.replace("<t>", "\t")
        dec = dec.replace("<b>", "\n")
        dec = dec.replace("<s>", " ")

        gt = "".join([self.i2w[token.item()] for token in y.squeeze(0)[:-1]])
        gt = gt.replace("<t>", "\t")
        gt = gt.replace("<b>", "\n")
        gt = gt.replace("<s>", " ")

         # Write to file
        with open("validation_results.txt", "a") as f:
            f.write(f"[Prediction] - {dec}\n")
            f.write(f"[GT] - {gt}\n")

        self.valpredictions.append(dec)
        self.valgts.append(gt)
    
    def on_validation_epoch_end(self, name="val"):
        cer, ser, ler = compute_poliphony_metrics(self.valpredictions, self.valgts)
        
        random_index = np.random.randint(0, len(self.valpredictions))
        predtoshow = self.valpredictions[random_index]
        gttoshow = self.valgts[random_index]
        print(f"[Prediction] - {predtoshow}")
        print(f"[GT] - {gttoshow}")

        self.log(f'{name}_CER', cer, prog_bar=True)
        self.log(f'{name}_SER', ser, prog_bar=True)
        self.log(f'{name}_LER', ler, prog_bar=True)

        self.valpredictions = []
        self.valgts = []

        return ser
    
    def test_step(self, test_batch, batch_idx):
        self.validation_step(test_batch, batch_idx)
    
    def on_test_epoch_end(self) -> None:
        return self.on_validation_epoch_end(name="test")
        

In [90]:
def load_set(path, base_folder="GrandStaff", fileformat="jpg", krn_type="bekrn", reduce_ratio=0.5):
    x = []
    y = []
    limit = 500000
    counter = 0
    with open(path) as datafile:
        lines = datafile.readlines()
        for line in progress.track(lines):
            if counter > limit:
                break
            counter += 1
            excerpt = line.replace("\n", "")
            try:
                with open(f"Data/{base_folder}/{'.'.join(excerpt.split('.')[:-1])}.{krn_type}") as krnfile:
                    krn_content = krnfile.read()
                    fname = ".".join(excerpt.split('.')[:-1])
                    img = cv2.imread(f"Data/{base_folder}/{fname}.{fileformat}")
                    width = int(np.ceil(img.shape[1] * reduce_ratio))
                    height = int(np.ceil(img.shape[0] * reduce_ratio))
                    img = cv2.resize(img, (width, height))
                    y.append([content + '\n' for content in krn_content.strip().split("\n")])
                    x.append(img)
            except Exception as e:
               print(f'Error reading Data/GrandStaff/{excerpt}')

    return x, y




x, y = load_set("/home/ubuntu/maikyon/music-sheet-music-modeling/SMT-plusplus/Data/GrandStaff/partitions_grandstaff/types/train.txt", base_folder=config.data.base_folder, 
                                  fileformat=config.data.file_format, krn_type=config.data.krn_type, reduce_ratio=config.data.reduce_ratio)
def preprocess_gt(Y, tokenization_method="standard"):
    for idx, krn in enumerate(Y):
        krnlines = []
        krn = "".join(krn)
        krn = krn.replace(" ", " <s> ")
        
        if tokenization_method == "bekern":
            krn = krn.replace("·", " ")
            krn = krn.replace("@", " ")
        if tokenization_method == "ekern":
            krn = krn.replace("·", " ")
            krn = krn.replace("@", "")
        if tokenization_method == "standard":
            krn = krn.replace("·", "")
            krn = krn.replace("@", "")
            
        krn = krn.replace("/", "")
        krn = krn.replace("\\", "")
        krn = krn.replace("\t", " <t> ")
        krn = krn.replace("\n", " <b> ")
        krn = krn.split(" ")
                
        Y[idx] = erase_numbers_in_tokens_with_equal(['<bos>'] + krn[4:-1] + ['<eos>'])
    return Y
def erase_numbers_in_tokens_with_equal(tokens):
    return [re.sub(r'(?<=\=)\d+', '', token) for token in tokens]

def erase_whitespace_elements(tokens):
    return [token for token in tokens if token != ""]

y = preprocess_gt(y, tokenization_method=config.data.tokenization_mode)





Output()

In [92]:
def make_vocabulary(YSequences, nameOfVoc):
    vocabulary = set()
    for samples in YSequences:
        for element in samples:
            vocabulary.add(element)

    #Vocabulary created
    w2i = {symbol:idx+1 for idx,symbol in enumerate(vocabulary)}
    i2w = {idx+1:symbol for idx,symbol in enumerate(vocabulary)}
    
    w2i['<pad>'] = 0
    i2w[0] = '<pad>'


    return w2i, i2w

w2i, i2w = make_vocabulary(y, "JVocab")
print(len(w2i))

187


In [93]:
import os
import torch
import numpy as np
import re
import lightning as L
from torch.utils.data import DataLoader, Dataset, random_split

# Utility Functions
def erase_whitespace_elements(tokens):
    return [token for token in tokens if token != ""]

def batch_preparation_img2seq(data):
    images = [sample[0] for sample in data]
    dec_in = [sample[1] for sample in data]
    gt = [sample[2] for sample in data]

    max_image_width = max([img.shape[2] for img in images])
    max_image_height = max([img.shape[1] for img in images])

    X_train = torch.ones(size=[len(images), 1, max_image_height, max_image_width], dtype=torch.float32)

    for i, img in enumerate(images):
        _, h, w = img.size()
        X_train[i, :, :h, :w] = img
    
    max_length_seq = max([len(seq) for seq in gt])

    decoder_input = torch.zeros(size=[len(dec_in), max_length_seq], dtype=torch.long)
    y = torch.zeros(size=[len(gt), max_length_seq], dtype=torch.long)

    for i, seq in enumerate(dec_in):
        seq_tensor = seq.clone().detach()
        decoder_input[i, :len(seq_tensor)] = seq_tensor
    
    for i, seq in enumerate(gt):
        seq_tensor = seq.clone().detach()
        y[i, :len(seq_tensor)] = seq_tensor
    
    return X_train, decoder_input, y

# Define worker_init_fn to set CPU affinity
def worker_init_fn(worker_id):
    os.sched_setaffinity(0, range(os.cpu_count()))

# Dataset Class
class OMRIMG2SEQDataset(Dataset):
    def __init__(self, x, y, w2i, i2w, teacher_forcing_perc=0.2, augment=False) -> None:
        self.x = x
        self.y = y
        self.w2i = w2i
        self.i2w = i2w
        self.teacher_forcing_error_rate = teacher_forcing_perc
        self.augment = augment
        self.padding_token = w2i['<pad>']
        super().__init__()
    
    def apply_teacher_forcing(self, sequence):
        errored_sequence = sequence.clone().detach()
        for token in range(1, len(sequence)):
            if np.random.rand() < self.teacher_forcing_error_rate and sequence[token] != self.padding_token:
                errored_sequence[token] = np.random.randint(0, len(self.w2i))
        return errored_sequence

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]
        
        x = convert_img_to_tensor(x)  # Define convert_img_to_tensor function as needed
        y = [self.w2i[token] for token in erase_whitespace_elements(y)]
        y = torch.tensor(y, dtype=torch.long)
        decoder_input = self.apply_teacher_forcing(y)
        return x, decoder_input, y

    @staticmethod
    def erase_whitespace_elements(tokens):
        return [token for token in tokens if token != ""]

    @staticmethod
    def erase_numbers_in_tokens_with_equal(tokens):
        return [re.sub(r'(?<=\=)\d+', '', token) for token in tokens]

# DataModule Class
class OMRDataModule(L.LightningDataModule):
    def __init__(self, x, y, w2i, i2w, batch_size=32, augment=False, teacher_forcing_perc=0.2, num_workers=4, train_val_test_split=(0.9, 0.9, 0.01)):
        super().__init__()
        self.x = x
        self.y = y
        self.w2i = w2i
        self.i2w = i2w
        self.batch_size = batch_size
        self.augment = augment
        self.teacher_forcing_perc = teacher_forcing_perc
        self.num_workers = num_workers
        self.train_val_test_split = train_val_test_split
    
    def setup(self, stage=None):
        full_dataset = OMRIMG2SEQDataset(
            x=self.x,
            y=self.y,
            w2i=self.w2i,
            i2w=self.i2w,
            teacher_forcing_perc=self.teacher_forcing_perc,
            augment=self.augment
        )

        train_size = int(self.train_val_test_split[0] * len(full_dataset))
        val_size = int(self.train_val_test_split[1] * len(full_dataset))
        test_size = len(full_dataset) - train_size - val_size
        
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq, worker_init_fn=worker_init_fn)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq, worker_init_fn=worker_init_fn)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq, worker_init_fn=worker_init_fn)

In [101]:
# data_module = PretrainingLinesDataset(config.data)
data_module = OMRDataModule(x, y, w2i, i2w, batch_size=1, num_workers=30, augment=False)
model = SMT(config=config.model_setup, w2i=w2i, i2w=i2w)

In [102]:
early_stopping = EarlyStopping(monitor=config.experiment.metric_to_watch, min_delta=0.01, patience=2, mode="min", verbose=True)
wandb_logger = WandbLogger(project='FP_SMT', group=f"SMTppNEXT", name=f"GrandStaff", log_model=True)
trainer = Trainer(max_epochs=config.experiment.max_epochs, 
                    check_val_every_n_epoch=config.experiment.val_after, callbacks=[early_stopping], logger=wandb_logger,
                    precision='16-mixed')

trainer.fit(model, data_module)

# model = SMT.load_from_checkpoint(checkpointer.best_model_path)

# trainer.test(model, datamodule=data_module)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                 | Params | Mode 
---------------------------------------------------------------
0 | encoder       | ConvNextEncoder      | 5.5 M  | train
1 | decoder       | Decoder              | 5.4 M  | train
2 | positional_2D | PositionalEncoding2D | 0      | train
3 | loss          | CrossEntropyLoss     | 0      | train
---------------------------------------------------------------
10.9 M    Trainable params
0         Non-trainable params
10.9 M    Total params
43.660    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

[Prediction] - XXXXXXXXXXXXXXXXXXXXXXXXXXXX=:|!XXXXXXXXXXXXXXXXXXXXXXXXXXX=:|!=:|!=:|!=:|!LL>=:|!LL>*k[b-]XXX*k[b-e-a-d-g-c-]X*k[b-e-a-d-g-c-]X*k[b-e-a-d-g-c-]X*k[b-e-a-d-g-c-]XXXXXX*k[b-e-a-d-g-c-]FFFFFFFFFFFFFFFFFFFFFFFFFFFFFF*k[b-]aaXXX*k[b-e-a-d-g-c-]X*k[b-e-a-d-g-c-]*k[f#c#g#d#a#]aaX=:|!=:|!=:|!*^X*k[b-e-a-d-g-c-]X*k[b-e-a-d-g-c-]X=:|!LL>*k[b-]*k[f#c#]X=:|!LL>=:|!LL>=:|!=:|!=:|!
[GT] - <bos>*clefF4	*clefG2
*k[b-e-a-d-g-c-]	*k[b-e-a-d-g-c-]
*M44	*M44
*met(c)	*met(c)
=-	=-
8e-L	4eee-
8a- 8cc-	.
8a- 8cc-	8r
8a- 8cc-J	8aaa-
8e-L	8ggg-L
8a- 8cc- 8dd	8fff
8a- 8cc- 8dd	8eee-
8a- 8cc- 8ddJ	8dddJ
=	=
8e-L	8fffL
8g- 8b-	8eee-J
8g- 8b-	8r
8g- 8b-J	8ee-
8e-L	20ddLL
.	20ee-
.	20ee
8g 8b- 8dd-	.
.	20ff
.	20gg-JJ
8g 8b- 8dd-	16gg-LL
.	16aa-
8g 8b- 8dd-J	16aa
.	16bb-JJ
=	=
8e-L	16ccc-LL
.	16ccc
8a- 8cc-	16ddd-
.	16dddJJ
8a- 8cc-	8eee-L
8a- 8cc-J	8aaa-J
8e-L	8ggg-L
8a- 8cc- 8dd	8fff
8a- 8cc- 8dd	8eee-
8a- 8cc- 8ddJ	8dddJ
=	=
4e- 4g- 4b- 4ee-	16eee-LL
.	16fff
.	16eee-
.	16dddJJ
8r	16eee-LL
.	16ggg-

Training: |          | 0/? [00:00<?, ?it/s]

In [16]:
checkpoint = torch.load("model_checkpoint.ckpt")

# Print the keys in the checkpoint to see what information is stored
print(checkpoint.keys())
model_state_dict = checkpoint['state_dict']
optimizer_state_dict = checkpoint['optimizer_states']
epoch = checkpoint['epoch']
global_step = checkpoint['global_step']

print("Model State Dict Keys:", model_state_dict.keys())

print("Epoch:", epoch)
print("Global Step:", global_step)

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision', 'hparams_name', 'hyper_parameters'])
Model State Dict Keys: odict_keys(['encoder.downsample_layers.0.0.weight', 'encoder.downsample_layers.0.0.bias', 'encoder.downsample_layers.0.1.weight', 'encoder.downsample_layers.0.1.bias', 'encoder.downsample_layers.1.0.weight', 'encoder.downsample_layers.1.0.bias', 'encoder.downsample_layers.1.1.weight', 'encoder.downsample_layers.1.1.bias', 'encoder.downsample_layers.2.0.weight', 'encoder.downsample_layers.2.0.bias', 'encoder.downsample_layers.2.1.weight', 'encoder.downsample_layers.2.1.bias', 'encoder.stages.0.0.gamma', 'encoder.stages.0.0.dwconv.weight', 'encoder.stages.0.0.dwconv.bias', 'encoder.stages.0.0.norm.weight', 'encoder.stages.0.0.norm.bias', 'encoder.stages.0.0.pwconv1.weight', 'encoder.stages.0.0.pwconv1.bias', 'encoder.stages.0.0.pwconv2.weight', 'encoder.stages.0.0.pwconv2.bias

In [85]:
trainer.validate(model=model, datamodule=data_module, ckpt_path="model_checkpoint.ckpt", verbose=True)

Restoring states from the checkpoint path at model_checkpoint.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at model_checkpoint.ckpt


Validation: |          | 0/? [00:00<?, ?it/s]