In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import automata as dfa
import utils
from tqdm import tqdm

## Global Variables

In [None]:
ATT_PATH = "./where"
DATA_PATH = "./where"
NAMES_PATH = "./.../names.txt"

BATCH_SIZE = 200

TEST_SIZE = 0.2

N_EPOCH = 20
LEARNING_RATES = 0.001

torch.manual_seed(689)
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cpu")

## Load the datas

In [None]:
names, extensions = utils.lang_names(NAMES_PATH)

datas = dict()
automata = dict()
for name in names:
    datas[name], automata[name] = utils.data_automata_loading(ATT_PATH, DATA_PATH, name, extensions, return_automata=True)

In [None]:
def makeloader(torchdata:dfa.TorchData)-> tuple[DataLoader, list[torch.Tensor]]:
    train, test = random_split(torchdata, [1-TEST_SIZE, TEST_SIZE])

    train = DataLoader(train, BATCH_SIZE, shuffle=True)
    test = next(iter(DataLoader(test, len(test))))

    return train, test

## Model

In [None]:
def getstats(data:dfa.TorchData, automaton:dfa.DFA, loss = nn.BCELoss(), optimizer = torch.optim.Adam, returnmodel=False):
    trainloader, test = makeloader(data)

    noparam = dfa.AutomataRNN(automaton, device=DEVICE).to(DEVICE)
    target = dfa.sigmoid_to_tanh(dfa.dfa2srn(automaton.transition.T, automaton.finites))
    optim = optimizer(noparam.parameters(), lr=LEARNING_RATES)
    optim.zero_grad()

    statsnoparam = utils.initstats(["losses, l2, linf, targetdist, acc"])

    for epoch in range(N_EPOCH):
        trainer = iter(trainloader)
        for _ in tqdm(range(len(trainloader)), desc=f"Epoch {epoch+1}...", ncols=75):
            words, lengths, labels = next(trainer)
            labels = labels.to(dtype = torch.float32)

            out = noparam(words, lengths)

            statsnoparam["losses"]= loss(out, labels)
            statsnoparam["losses"][-1].backward()

            optim.step()
            optim.zero_grad()

        with torch.no_grad():
            predictions = noparam.predict(test[0], test[1])
            acc = torch.sum(predictions == test[2].reshape(-1), dim=0) * 100 / len(test[2])
            print(f"Done! Trainloss: {statsnoparam["losses"][-1]:.6f}, Test accuracy: {acc:.4f}")

    print("\nNon parametrized RNN done!\n")

    withparam = dfa.ParametrizeRNN(automaton, device=DEVICE).to(DEVICE)

    optim = optimizer(withparam.parameters(), lr=LEARNING_RATES)
    optim.zero_grad()

    losslistinit = list()

    for epoch in range(N_EPOCH):
        trainer = iter(trainloader)
        for _ in tqdm(range(len(trainloader)), desc=f"Epoch {epoch+1}...", ncols=75):
            words, lengths, labels = next(trainer)
            labels = labels.to(dtype = torch.float32)

            out = withparam(words, lengths)

            losslistinit.append(loss(out, labels))
            losslistinit[-1].backward()

            optim.step()
            optim.zero_grad()

        with torch.no_grad():
            predictions = withparam.predict(test[0], test[1])
            acc = torch.sum(predictions == test[2].reshape(-1), dim=0) * 100 / len(test[2])
            print(f"Done! Trainloss: {losslistinit[-1]:.6f}, Test accuracy: {acc:.4f}")

    print("\nParametrized RNN done!")
    if returnmodel:
        return (losslistnoinit, noparam), (losslistinit, withparam)
    else:
        return statsnoparam, statsparam
            