In [1]:
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [15]:
def prepare_line(line):
    return [float(i) for i in line[:-1].split(" ")]

def read_file(path):
    with open(path, "r") as f:
        lines = f.readlines()
    X = [[]]
    for i in range(len(lines)):
        if ";" in lines[i]:
            X.append([])
        else:
            X[-1].append(prepare_line(lines[i]))
    if X[-1] == []:
        return X[:-1]
    return X

In [3]:
folder = "./data/"
files = [folder + f for f in os.listdir(folder)]
files

['./data/Baba.txt',
 './data/Khosh.txt',
 './data/MohandesBadMohandes.txt',
 './data/Ast.txt',
 './data/MohandesBakht.txt',
 './data/AstMohandesAst.txt',
 './data/AstMohandes.txt',
 './data/BakhtAstMobark.txt',
 './data/MohandesAst.txt',
 './data/BakhtKhoshMobark.txt',
 './data/BadAstMobark.txt',
 './data/BabaMobark.txt',
 './data/BadMohandes.txt',
 './data/KhoshKhosh.txt',
 './data/KhoshKhoshBaba.txt',
 './data/MobarkKhosh.txt',
 './data/MohandesBabaAst.txt',
 './data/AstBad.txt',
 './data/BakhtKhosh.txt',
 './data/AstKhosh.txt',
 './data/MohandesBabaMohandes.txt',
 './data/MobarkBadBakht.txt',
 './data/MobarkKhoshBad.txt',
 './data/BabaKhoshKhosh.txt',
 './data/MobarkAst.txt',
 './data/BabaBaba.txt',
 './data/BakhtBad.txt',
 './data/KhoshAst.txt',
 './data/MohandesKhosh.txt',
 './data/KhoshMobark.txt',
 './data/MohandesBaba.txt',
 './data/MohandesBabaBakht.txt',
 './data/BabaBabaKhosh.txt',
 './data/KhoshBad.txt',
 './data/KhoshAstAst.txt',
 './data/KhoshMohandesBad.txt',
 './data/Mo

In [19]:
CLASSES = [
    "Ast",
    "Baba",
    "Bad",
    "Bakht",
    "Khosh",
    "Mobark",
    "Mohandes",
]

CLASSES2INDEX = {k:v for v,k in enumerate(CLASSES)}

INDEX2CLASSES = {v:k for v,k in enumerate(CLASSES)}

def output_ctc(output):
    ans = ""
    tmp = "" if output[0] == "_" else output[0]
    ans = ans + tmp
    for i in range(len(output)):
        if output[i] != tmp and output[i] != "_":
            tmp = output[i]
            ans = ans + tmp
        if output[i] == "_":
            tmp = output[i]
    return ans

RANDOM_SEED = 42
import re

def line2float(line: str):
        return [float(i) for i in line[:-1].split(" ")]
    
def read_file(file:str):
    with open(file, "r") as f:
        lines = f.readlines()
    x = [[]]
    for i in range(len(lines)):
        if ";" in lines[i]:
            x.append([])
        else:
            x[-1].append(line2float(lines[i]))
    if x[-1] == []:
        return x[:-1]
    return x

def file2data(file: str, data):
        X = read_file(file=file)
        y = file.split("/")[-1].split(".")[0] if "/" in file else file.split(".")[0]
        y = re.findall('[A-Z][^A-Z]*', y)
        for x in X:
            data.append([np.array(x), [CLASSES2INDEX[move] for move in y]])
        return data

In [20]:
data = []
for f in files:
    data = file2data(f, data)
len(data)

352

In [21]:
BLANK = len(CLASSES)
def pad_x(x, length):
    return np.pad(x, ((0, length - x.shape[0]), (0, 0)), mode='constant', constant_values = 0.)

def pad_y(y, length):
    out = [i for i in y]
    for i in range(length - len(y)):
        out.append(BLANK)
    return out

In [22]:
def extract_batch(data, index, bs):
    last_index = min(len(data), (index + 1) * bs)
    batch = data[index * bs: last_index]
    max_x = max([d[0].shape[0] for d in batch])
    max_y = max([len(d[1]) for d in batch])
    X = []
    y = []
    target_lengths = []
    for d in batch:
        X.append(pad_x(d[0], max_x))
        y.append(pad_y(d[1], max_y))
        target_lengths.append(len(d[1]))
    y = np.array(y)
    X = np.array(X)
    X = torch.from_numpy(X).float().to(device)
    y = torch.from_numpy(y).int().to(device)
    target_lengths = torch.from_numpy(np.array(target_lengths)).int().to(device)
    return X, y, target_lengths, len(batch)

In [23]:
np.random.shuffle(data)

In [25]:
SPLIT = 0.75
split_index = int(SPLIT * len(data))
train = data[:split_index]
test = data[split_index:]
len(train), len(test)

(264, 88)

In [26]:
extract_batch(train, 0, 11)[0].shape

torch.Size([11, 62, 10])

In [27]:
def output_ctc(out):
    s = ""
    tmp = "" if out[0] == "_" else out[0]
    s = s + tmp
    for i in range(len(out)):
        if out[i] != tmp and out[i] != "_":
            tmp = out[i]
            s = s + tmp
        if out[i] == "_":
            tmp = out[i]
    return s
output_ctc("__11112__233_333_4__5")

'1223345'

In [28]:
def score(yhat, answer):
    s = 0
    for i in range(min(len(yhat), len(answer))):
        s += yhat[i] == answer[i]
    return s
score("356", "45")

1

In [29]:
HIDDEN_SIZE1 = 64
HIDDEN_SIZE2 = 128
HIDDEN_SIZE3 = 256
HIDDEN_SIZE4 = 32
D, NL = 1, 1
NUM_CLASS = BLANK + 1

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn1 = nn.LSTM(10, HIDDEN_SIZE1, NL, batch_first=True)
        self.rnn2 = nn.LSTM(D*HIDDEN_SIZE1, HIDDEN_SIZE2, NL, batch_first=True)
        self.rnn3 = nn.LSTM(D*HIDDEN_SIZE2, HIDDEN_SIZE3, NL, batch_first=True)
        self.L1 = nn.Linear(D * HIDDEN_SIZE3, HIDDEN_SIZE4)
        self.L2 = nn.Linear(HIDDEN_SIZE4, NUM_CLASS)
    def forward(self, x):
        out, _ = self.rnn1(x) # (N, L, D*HIDDEN_SIZE)
        #print(out.shape)
        out, _ = self.rnn2(out) # (N, L, D*HIDDEN_SIZE)
        #print(out.shape)
        out, _ = self.rnn3(out) # (N, L, D*HIDDEN_SIZE)
        #print(out.shape)
        l , bs = out.shape[1], out.shape[0]
        #print(out.shape)
        out = self.L1(out)
        #print(out.shape)
        out = nn.Sigmoid()(out)
        out = self.L2(out)
        #print(out.shape)
        out = out.transpose(0, 1)
        return out

In [30]:
m = Model().to(device)
m(extract_batch(train, 0, 5)[0]).shape, sum(p.numel() for p in m.parameters() if p.requires_grad)

(torch.Size([46, 5, 8]), 522536)

In [31]:
train = sorted(train, key=lambda x: len(x[1]), reverse=False)

In [32]:
LR = 1e-3
EPOCH = 100
BATCH_SIZE = 11
model = Model().to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=LR)
criterion = nn.CTCLoss(blank=BLANK)

In [None]:
LOSS = []
model.train()
for ep in tqdm(range(1, EPOCH + 1)):
    LOSS.append(0)
    np.random.shuffle(train)
    for b in range(len(train) // BATCH_SIZE):
        X, y, target_lengths, bs = extract_batch(train, b, BATCH_SIZE)
        optimizer.zero_grad()
        yp = model(X)
        yp = nn.functional.log_softmax(yp, dim=2)
        input_lengths = torch.LongTensor([yp.shape[0]] * bs).to(device)
        loss = criterion(yp, y, input_lengths, target_lengths)
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            LOSS[-1] += (float(loss) / (len(train) // BATCH_SIZE))

 84%|██████████████████████████████████████████████████████████████████████████▊              | 84/100 [01:14<00:12,  1.31it/s]

In [None]:
plt.plot(list(range(1, len(LOSS) + 1)), LOSS)

In [None]:
def evaluate(data, model_for_eval, bs, erorr=False):
    acc, tot, word_acc, word_tot = 0, 0, 0, len(data)
    L = len(data) // bs
    ERORRS = []
    for l in tqdm(range(L)):
        X, y, _, _ = extract_batch(data, l, bs)
        model_for_eval.eval()
        with torch.no_grad():
            yp = model_for_eval(X)
            yp = nn.functional.log_softmax(yp, dim=2)
            for j in range(bs):
                ypp = torch.argmax(yp[:, j, :], dim=1)
                ypp = list(ypp)
                s = ""
                for i in ypp:
                    if i == 2:
                        s = s + "_"
                    else:
                        s = s + str(int(i))
                ans = [str(int(e)) for e in y[j]]
                ans = [a for a in ans if a != "2"]
                ans = "".join(ans)
                s = output_ctc(s)
                if erorr and s != ans:
                    ERORRS.append([s, ans, data[l * bs + j][0], l * bs + j])
                word_acc += ans == s
                tot += len(ans)
                for cahr in range(len(ans)):
                    if cahr >= len(s):
                        break
                    acc += s[cahr] == ans[cahr]
    if erorr:
        return word_acc / word_tot, acc / tot, ERORRS  
    return word_acc / word_tot , acc / tot
evaluate(train, model, BATCH_SIZE)

In [None]:
X, y, _, _ = extract_batch(train, b, BATCH_SIZE)
model.eval()
with torch.no_grad():
    yp = model(X)
    yp = nn.functional.log_softmax(yp, dim=2)
    for j in range(BATCH_SIZE):
        ypp = torch.argmax(yp[:, j, :], dim=1)
        ypp = list(ypp)
        s = ""
        for i in ypp:
            if i == 2:
                s = s + "_"
            else:
                s = s + str(int(i))
        ans = "".join([str(int(e)) for e in y[j] if int(e) != 2])
        print(f"{output_ctc(s)} : {ans} ---> {ans == output_ctc(s)}")

In [None]:
evaluate(test, model, 11)

In [None]:
X, y, _, _ = extract_batch(test, 0, BATCH_SIZE)
model.eval()
with torch.no_grad():
    yp = model(X)
    yp = nn.functional.log_softmax(yp, dim=2)
    for j in range(BATCH_SIZE):
        ypp = torch.argmax(yp[:, j, :], dim=1)
        ypp = list(ypp)
        s = ""
        for i in ypp:
            if i == 2:
                s = s + "_"
            else:
                s = s + str(int(i))
        ans = "".join([str(int(e)) for e in y[j] if int(e) != 2])
        print(f"{output_ctc(s)} : {ans} ---> {ans == output_ctc(s)}")