In [4]:
import pandas as pd
import numpy as np

In [5]:
df = pd.read_hdf("texts.hdf5", key="df", index_col=0)
df["K"] = 0

In [6]:
df.tail()

Unnamed: 0,transcription,manuscript,page_id,line_id,CER,lang,K
307823,bant ascamis lop.Si,bnf_fr_22549_sept_sages,0,97,25.0,fro,0
253434,de tret qͥ se tenoiet anre seignor com,BnF_fr_411_saintLambert_microfilm,0,138,9.3023,fro,0
354634,C,bnf_fr_24428_bestiaire,13,19,0.0,fro,0
155174,il ne vit goute. L adamo,bnf_fr_22549_sept_sages,11,139,8.3333,fro,0
110968,,bnf_fr_22549_sept_sages,3,83,93.5484,fro,0


# Get all manuscripts

In [47]:
import copy

def get_manuscripts_and_lang_kfolds(dataframe, k=0, per_k=2):
    all_data = {}
    for lang, mss in dataframe.set_index(['lang', 'manuscript']).sort_index().index.unique():
        if lang not in all_data:
            all_data[lang] = []
        all_data[lang].append(mss)

    train, dev, test = [], [], []
    local_data = copy.deepcopy(all_data)
    for lang in all_data:
        nb_mss = len(local_data[lang])

        for i in range(per_k):
            dev.append(local_data[lang].pop(k*per_k+i))

        for i in range(per_k):
            test.append(local_data[lang].pop(k*per_k+i))

        train.extend(local_data[lang])
    return (
        df.loc[df.manuscript.isin(train)],
        df.loc[df.manuscript.isin(dev)],
        df.loc[df.manuscript.isin(test)]
    )

# Generate K-Folds class

In [11]:
KS = 10

#kf = KFold(n_splits=KS, shuffle = True, random_state = 2)

for unique_bin in df.bin.unique():
    ids = list(np.array_split(np.array(df[df.bin == unique_bin].index), KS))
    for k, k_ids in enumerate(ids):
        df.loc[k_ids, "K"] = k

df.tail()

Unnamed: 0,idx,bin,transcription,manuscript,page_id,line_id,CER,K
265020,70139,5,duigitee nt oi ntenꝰ solu põ sfista om̃ ideñ...,BIS-193,4,131,55.0625,9
265021,16376,0,ñ. qr forͣ eiꝰ nͨ re nͦ intłłcu ẽ młtiplił.,WettF0015,0,72,6.523438,9
265022,187671,1,qimout aenuis lidona.Et en la fin quͣ̃t,bnf_fr_412_wauchier,29,83,11.765625,9
265023,69102,3,pippuore lecit lie fa : tiste,Latin6395,1,112,39.28125,9
265024,299512,0,cist le cocatris ⁊ cue,bnf_fr_24428_bestiaire,13,2,9.523438,9


In [12]:
df.groupby("K").bin.value_counts().sort_index()

K  bin
0  0      7248
   1      3254
   2      2387
   3      2403
   4      2980
          ... 
9  5      2880
   6      2348
   7      1796
   8       964
   9       242
Name: bin, Length: 100, dtype: int64

In [18]:
def get_kfold_train_test(
    dataframe: pd.DataFrame,
    k=0
):
    # Right now only deal with train and test
    ks = list(range(10))
    test = ks[k]
    train = ks[:k]+ks[k+1:]
    dev = train.pop(0)
    return (
        dataframe[dataframe["K"].isin(train)],
        dataframe[dataframe['K'] == dev],
        dataframe[dataframe['K'] == test]
    )

def get_features(
    dataframe: pd.DataFrame
):
    return (
        dataframe["bin"].to_numpy().astype(int),
        dataframe["transcription"].tolist()
    )
    #p = np.random.permutation(len(i))
    #return i[p], c[p], e[p], f[p]

In [21]:
import torch.nn as nn
import torch
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from sklearn.metrics import ConfusionMatrixDisplay, classification_report, accuracy_score
import tqdm.autonotebook as tqdm
import json

        
def var(X):
    return torch.from_numpy(X).cuda()

In [24]:
class LstmModel(nn.Module):
    def __init__(self, input_dim, output_dim, device="cuda:0", features = None):
        super(LstmModel, self).__init__()
        self.inp = input_dim + 4
        self.out = output_dim
        self.features = ("[UNK]", "[PAD]", "[BOS]", "[EOS]", *features) or ()
        
        _EMB_SIZE = 100
        _HID_SIZE = 128
        
        self._emb = nn.Sequential(nn.Embedding(self.inp, _EMB_SIZE), nn.Dropout(.1))
        self._lstm = nn.LSTM(_EMB_SIZE, hidden_size=_HID_SIZE, bidirectional=True, batch_first=True)
        # HID_SIZE*4 because CONCAT first and last
        self._lin = nn.Sequential(nn.Dropout(.1), nn.Linear(_HID_SIZE*4, _HID_SIZE*2), nn.Linear(_HID_SIZE*2, self.out))
        self.to(device)
        
        
    @classmethod
    def from_input(cls, transcriptions, classes, device="cuda:0"):
        features = tuple(set([char for sentence in transcriptions for char in sentence]))
        input_dim = len(features)+2
        output_dim = len(np.unique(classes))
        return cls(features=features, input_dim=input_dim, output_dim=output_dim, device=device)
        
    def forward(self, matrix, lengths):
        matrix = self._emb(matrix)
        matrix = self.pack(matrix, lengths)
        matrix, z = self._lstm(matrix)
        matrix, _ = pad_packed_sequence(matrix, batch_first=True)
        first = matrix[:, 0, :]
        last = matrix[range(matrix.shape[0]), lengths-1, :]
        
        return self._lin(torch.cat([first, last], dim=-1))
        
    def encode(self, string):
        if isinstance(string, list):
            return np.array([self.encode(s) for s in string])
        # Add BOS and EOS ([2] and [3])
        return np.array([2] + list([self.features.index(c) if c in self.features else 0 for c in string]) + [3])
        
    def get_batches(self, X, Y=None, batch_size=256):
        samples = X.shape[0]
        for index in range(0, samples, batch_size):
            matrix = X[index:min(index+batch_size,samples)]
            lengths = [len(x) for x in matrix] # Add 2 for EOS/BOS
            matrix = self.pad([torch.tensor(vector) for vector in matrix])
            yield (
                matrix.cuda(),
                torch.tensor(lengths),
                var(Y[index:min(index+batch_size,samples)]) if Y is not None else None
            )
    
    def check_and_encode(self, inputs):
        if isinstance(inputs, list):
            # It's a raw input
            return self.encode(inputs)
        return inputs
        
    def fit(self, inputs, truthes, 
              dev_set,
              epochs=1000, max_bad_epochs=10, batch_size=128, lr=5e-3,
             delta=.005, use_loss=True):
        
        criterion = torch.nn.CrossEntropyLoss() 
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        
        best = float("-inf")
        if use_loss:
            best = float("inf")
        bad_epochs = 0
        best_params = self.state_dict()
        
        dev_x, dev_y = dev_set
        
        inputs = self.check_and_encode(inputs)
        dev_x = self.check_and_encode(dev_x)
        
        for epoch in (pbar := tqdm.tqdm(range(epochs), position=0, leave=True)):
            nb_batches = truthes.shape[0] // batch_size + int(truthes.shape[0] % batch_size)
            
            # Shuffle input
            shuffle_indices = np.random.permutation(truthes.shape[0])
            
            epoch_loss = torch.tensor(.0)
            # ToDo : Start the fucking encoding of Transcription.
            for batch_idx, (xs, length, ys) in enumerate(
                tqdm.tqdm(
                    self.get_batches(inputs[shuffle_indices], truthes[shuffle_indices], batch_size=batch_size),
                    position=1, 
                    leave=True,
                    desc="Batch (Train)"
                )
            ):
                outputs = self(xs, length)

                loss = criterion(outputs, ys)
                epoch_loss += loss.item()
                loss.backward()

            # update parameters
            optimizer.step()
            optimizer.zero_grad()
            epoch_loss = epoch_loss / (batch_idx+1)

            self.eval()
            pred_dev = self.pred(dev_x, batch_size=batch_size, _verbose=False)
            acc = accuracy_score(pred_dev, dev_y)
            self.train()
            
            factor = 100
            if use_loss:
                factor = 1
            
            pbar.set_description(f'BAD:{bad_epochs:0>2} LOSS:{epoch_loss.item():.2f} ACC:{acc*100:.1f} BEST:{best*factor:.1f} ')
            
            if use_loss:
                if abs(epoch_loss - best) > delta and epoch_loss < best:
                    best = epoch_loss
                    bad_epochs = 0
                    best_params = self.state_dict()
                else:
                    bad_epochs += 1
                    if bad_epochs == max_bad_epochs + 1:
                        break
            else:
                if abs(acc - best) > delta and acc > best:
                    best = acc
                    bad_epochs = 0
                    best_params = self.state_dict()
                else:
                    bad_epochs += 1
                    if bad_epochs == max_bad_epochs + 1:
                        break

            #if accum_loss < 2e-5:
            #    break
        print("Loading best params...")
        self.load_state_dict(best_params)
        self.eval()
        
    def pad(self, matrix):
        return pad_sequence(matrix, batch_first=True, padding_value=self.features.index("[PAD]"))
    
    def pack(self, matrix, lengths):
        return pack_padded_sequence(matrix, lengths, batch_first=True, enforce_sorted=False)
        
    def pred(self, inputs, batch_size=256, _verbose: bool = False):
        out = []
        
        if _verbose:
            deco = tqdm.tqdm
        else:
            deco = lambda x: x
        
        inputs = self.check_and_encode(inputs)
            
        for (xs, length, _) in deco(self.get_batches(inputs, batch_size=batch_size)):
            out.extend(self(xs, length).argmax(dim=-1).cpu().flatten().tolist())
        return np.array(out)
    
    def save(self, name):
        torch.save(self.state_dict(), f"{name}.pt")
        with open(f"{name}.json", "w") as f:
            json.dump(self.features, f)
            
train, dev, test = get_kfold_train_test(df, 4)

YCs, XTranscriptions = get_features(train)
YC3s, XTranscription3s = get_features(dev)
YC2s, XTranscription2s = get_features(test)

lstm = LstmModel.from_input(XTranscriptions, YCs)
lstm.fit(XTranscriptions, YCs, (XTranscription3s, YC3s), delta=.01)


  return np.array([self.encode(s) for s in string])


  0%|          | 0/1000 [00:00<?, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Batch: 0it [00:00, ?it/s]

Loading best params...


In [27]:
preds = lstm.pred(lstm.check_and_encode(XTranscription2s))
((YC2s < 2) == (preds < 2)).sum() / len(preds)

  return np.array([self.encode(s) for s in string])


0.7975399939631753

In [28]:
lstm.save("good-model-lstm")