# Table of Contents

1. [Imports and Variables](#import-variable)
    1. [Imports](#imports)
    2. [Set Variables](#variables)
2. [Tokenizer](#tokenizer)
    1. [Tokenizer Code *](#tokenizer-code)
    2. [Create/Load Tokenizer *](#tokenizer-load)
3. [Logic](#logic)
    1. [Nodes](#nodes)
    2. [Generate Formula](#generate-formula)
    3. [Generate Premise](#generate-premise)
    4. [Generate Key](#generate-key)
    5. [Create Lookup Table](#lookup)
    6. [Translate from Scentence (needs to be created) *](#to-formula)
4. [Model](#model)
    1. [Embedding and FeedForward](#embedding)
    2. [Attention](#attention)
    3. [Encoder](#encoder)
    4. [Decoder](#decoder)
    5. [Transformer](#transformer)
5. [Dataset Generation](#generate-datasets)
    1. [Dataset for Simple Models](#dataset-simple)
6. [Training and Testing](#training-testen)
    1. [Get Batch](#get-batch)
    2. [Code for Training](#train-code)
    3. [Code for Testing](#test-code)
    4. [Model 1](#model1)

Need to create model 2, training for model 2 and modify everything with * for model 2

<a id="import-variable"></a>
## Imports and Variables

<a id="imports"></a>
### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

import pandas as pd
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


<a id="variables"></a>
### Set Variables

In [2]:
epochs = 100000
batch_size = 512
input_window_simple = 100
input_window_complex = 200
dataset_size = 100000
learning_rate = 0.0001

<a id="tokenizer"></a>
## Tokenizer

<a id="tokenizer-code"></a>
### Tokenizer Code

In [3]:
class Tokenizer():

    @staticmethod
    def create_vocab(tokens):
        vocab = dict()
        vocab['<PAD>'] = len(vocab)
        vocab['<UNK>'] = len(vocab)
        vocab['<PRED>'] = len(vocab)
        vocab['<END>'] = len(vocab)
        for token in tokens: vocab[str(token)] = len(vocab)
        return vocab

    def __init__(self, vocab, scentences= False):
        self.encoder = vocab
        self.scentences = scentences
        self.decoder = {index: str(token) for token, index in vocab.items()}

    def __len__(self): return len(self.encoder)

    def tokenize(self, text):
        if self.scentences: return self.tokenize_sentence(text)
        else: return self.tokenize_formula(text)

    def tokenize_formula(self, text):
        split = list(text.replace(' ', ''))
        output = []
        index = 0
        while index < len(split):
            if split[index] == '<':
                output += [''.join(split[index:index+3])]
                index += 3
            elif split[index] == 'p' or split[index] == '-':
                output += [''.join(split[index:index+2])]
                index += 2
            else:
                output += [split[index]]
                index += 1
        return output
    
    def tokenize_sentence(self, text):
        words = text.split(' ')
        output = []

    def encode(self, tokens): return [self.encoder[token] if token in self.encoder else self.encoder['<UNK>'] for token in tokens] 

    def decode(self, indices): return [self.decoder[index] if index in self.decoder else '<UNK>' for index in indices]

<a id="tokenizer-load"></a>
### Create/Load Tokenizer

In [4]:
def create_scentence_tokenizer():
    return Tokenizer(Tokenizer.create_vocab(['temp']), scentences=True)

def create_formula_tokenizer():
    return Tokenizer(Tokenizer.create_vocab(['(', ')', ',', '|', '&', '->', '<->', '!', 'p0', 'p1', 'p2', 'p3', '.', '{', '}']), scentences=False)

formula_tokenizer = create_formula_tokenizer()
scentence_tokenizer = create_scentence_tokenizer()
PAD_ID = torch.tensor(formula_tokenizer.encoder['<PAD>']).to(device)

<a id="logic"></a>
## Logic

<a id="nodecode"></a>
### Nodes

In [5]:
class Node():
    def evaluate(self, values):
        pass

    def to_scentence(self):
        pass

    def get_depth(self):
        pass

    def get_atoms(self):
        pass

    def parse(text):
        text = text.replace(' ', '')
        if text[0] == '(' and text[-1] == ')':
            text = text[1:-1]

        if '(' not in text:
            if '|' in text:
                return Disjunction(Node.parse(text[:text.index('|')]), Node.parse(text[text.index('|') + 1:]))
            elif '&' in text:
                return Conjunction(Node.parse(text[:text.index('&')]), Node.parse(text[text.index('&') + 1:]))
            elif '<->' in text:
                return Biimplication(Node.parse(text[:text.index('<->')]), Node.parse(text[text.index('<->') + 3:]))
            elif '->' in text:
                return Implication(Node.parse(text[:text.index('->')]), Node.parse(text[text.index('->') + 2:]))
            elif '!' in text:
                return Negation(Node.parse(text[1:]))
            return Variable(text, int(text[1:]))

        depth = 0
        for i in range(len(text)):
            if text[i] == '(':
                depth += 1
            elif text[i] == ')':
                depth -= 1
            elif depth == 0 and text[i] == '|':
                return Disjunction(Node.parse(text[:i]), Node.parse(text[i + 1:]))
            elif depth == 0 and text[i] == '&':
                return Conjunction(Node.parse(text[:i]), Node.parse(text[i + 1:]))
            elif depth == 0 and text[i] == '-' and text[i + 1] == '>':
                return Implication(Node.parse(text[:i]), Node.parse(text[i + 2:]))
            elif depth == 0 and text[i] == '<' and text[i + 1] == '-' and text[i + 2] == '>':
                return Biimplication(Node.parse(text[:i]), Node.parse(text[i + 3:]))
            elif depth == 0 and text[i] == '!':
                return Negation(Node.parse(text[i + 1:]))
        
        raise Exception('Something went wrong')
    
    def __str__(self):
        pass

    def __len__(self):
        pass

class Variable(Node):
    def __init__(self, name, index):
        self.name = name
        self.index = index

    def evaluate(self, values):
        return values[self.index]
    
    def to_scentence(self, root = True):
        return f'{self.name}{'.' if root else ''}', 0
    
    def get_atoms(self):
        return [self.index]
    
    def get_depth(self):
        return 0
    
    def __str__(self):
        return self.name
    
    def __len__(self):
        return 1
    
    def __eq__(self, other):
        if isinstance(other, Variable):
            return self.name == other.name
        if isinstance(other, Negation):
            return self.name == other.expr.name
        return False

class Negation(Node):
    def __init__(self, expr):
        self.expr = expr

    def evaluate(self, values):
        return not self.expr.evaluate(values)
    
    def to_scentence(self, root = True):
        text, depth = self.expr.to_scentence(root = False)
        return f'!{text}{'.' if root else ''}', depth
    
    def get_atoms(self):
        return list({-x for x in self.expr.get_atoms()})
    
    def get_depth(self):
        return self.expr.get_depth()

    def __str__(self):
        return f"!{self.expr}"
    
    def __len__(self):
        return len(self.expr)
    
    def __eq__(self, other):
        if isinstance(other, Negation):
            return self.expr == other.expr
        if isinstance(other, Variable):
            return self.expr.name == other.name
        return False

class Implication(Node):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def evaluate(self, values):
        return not self.left.evaluate(values) or self.right.evaluate(values)
    
    def to_scentence(self, nested = False, root = True):
        left_text, left_depth = self.left.to_scentence(nested=True, root=False) if isinstance(self.left, Implication) else self.left.to_scentence(root=False)
        right_text, right_depth = self.right.to_scentence(root=False)

        depth = max(left_depth, right_depth)
        return f'{'if ' if not nested else ''}{left_text}{',' * depth} then {right_text}{'.' if root else ''}', depth + 1
    
    def get_atoms(self):
        return list({x for x in self.left.get_atoms() + self.right.get_atoms()})

    def get_depth(self):
        return max(self.left.get_depth(), self.right.get_depth()) + 1

    def __str__(self):
        return f"({self.left} -> {self.right})"
    
    def __len__(self):
        return len(self.left) + len(self.right)

    def __eq__(self, other):
        if isinstance(other, Implication):
            return self.left == other.left and self.right == other.right
        return False

class Disjunction(Node):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def evaluate(self, values):
        return self.left.evaluate(values) or self.right.evaluate(values)
    
    def to_scentence(self, root = True):
        left_text, left_depth = self.left.to_scentence(root=False)
        right_text, right_depth = self.right.to_scentence(root=False)
        depth = max(left_depth, right_depth)
        return f'{left_text}{',' * depth} or {right_text}{'.' if root else ''}', depth + 1
    
    def get_atoms(self):
        return list({x for x in self.left.get_atoms() + self.right.get_atoms()})
    
    def get_depth(self):
        return max(self.left.get_depth(), self.right.get_depth()) + 1

    def __str__(self):
        return f"({self.left} | {self.right})"
    
    def __len__(self):
        return len(self.left) + len(self.right)

    def __eq__(self, other):
        if isinstance(other, Disjunction):
            return self.left == other.left and self.right == other.right
        return False   

class Conjunction(Node):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def evaluate(self, values):
        return self.left.evaluate(values) and self.right.evaluate(values)
    
    def to_scentence(self, root = True):
        left_text, left_depth = self.left.to_scentence(root=False)
        right_text, right_depth = self.right.to_scentence(root=False)
        depth = max(left_depth, right_depth)
        return f'{left_text}{',' * depth} and {right_text}{'.' if root else ''}', depth + 1
    
    def get_atoms(self):
        return list({x for x in self.left.get_atoms() + self.right.get_atoms()})
    
    def get_depth(self):
        return max(self.left.get_depth(), self.right.get_depth()) + 1

    def __str__(self):
        return f"({self.left} & {self.right})"
    
    def __len__(self):
        return len(self.left) + len(self.right)
    
    def __eq__(self, other):
        if isinstance(other, Conjunction):
            return self.left == other.left and self.right == other.right
        return False
    
class Biimplication(Node):
    def __init__(self, left, right):
        self.left = left
        self.right = right

    def evaluate(self, values):
        return self.left.evaluate(values) == self.right.evaluate(values)

    def to_scentence(self, root = True):
        left_text, left_depth = self.left.to_scentence(root=False)
        right_text, right_depth = self.right.to_scentence(nested = True, root=False) if isinstance(self.right, Implication) else self.right.to_scentence(root=False)
        depth = max(left_depth, right_depth)
        return f'{left_text} if{',' * depth} and only if {right_text}{'.' if root else ''}', depth + 1
    
    def get_atoms(self):
        return list({x for x in self.left.get_atoms() + self.right.get_atoms()})
    
    def get_depth(self):
        return max(self.left.get_depth(), self.right.get_depth()) + 1
    
    def __str__(self):
        return f"({self.left} <-> {self.right})"
    
    def __len__(self):
        return len(self.left) + len(self.right)
    
    def __eq__(self, other):
        if isinstance(other, Biimplication):
            return self.left == other.left and self.right == other.right
        return False

<a id="generate-formula"></a>
### Generate Formula

In [6]:
def generate_formula(max, n):
    output = []
    for _ in range(max + 1):
        output.append([])

    for option in generate(max, n):
        index = option.get_depth()
        output[index].append(option)
    return output

def generate(depth, n):
    for i in range(n):
        yield Variable(f"p{i}", i)

    if depth == 0:
        for neg in generate(depth - 1, n):
            yield Negation(neg)
    elif depth > 0:
        for left in generate(depth - 1, n):
            for right in generate(depth - 1, n):
                if left != right:
                    yield Implication(left, right)
                    yield Disjunction(left, right)
                    yield Conjunction(left, right)
                    yield Biimplication(left, right)

<a id="generate-premise"></a>
### Generate Premise

In [7]:
def generate_premises(key, odds, unqiue_keys):
    options = list(split_formula(key, unqiue_keys))
    if len(options) == 0: return None
    np.random.shuffle(options)

    if len(odds) == 0 or np.random.rand() < odds[0]:
        return options[0]

    for option in options:
        splits_1 = generate_premises(option[0], odds[1:], unqiue_keys)
        splits_2 = generate_premises(option[1], odds[1:], unqiue_keys)

        if splits_1 is not None and splits_2 is not None:
            if np.random.rand() < 0.5:
                return (*splits_1, option[1])
            else:
                return (option[0], *splits_2)
        elif splits_1 is not None:
            return (*splits_1, option[1])
        elif splits_2 is not None:
            return (option[0], *splits_2)
    
    return None



def split_formula(value, indexes):
    for i in indexes:
        if i < value: continue
        for j in indexes:
            if j < value: continue
            if i != value and j != value and i > j and i & j == value:
                yield i, j

<a id="generate-key"></a>
### Generate Key

In [8]:
def generate_number(expression, n):
    output = 0
    for i in range(2 ** n):
        values = [bool(i & (1 << j)) for j in range(n)]
        if expression.evaluate(values):
            output += 2 ** i
    return output

<a id="to-formula"></a>
### Translate to Formula

In [9]:
def to_formula(expression):
    pass

<a id="lookup"></a>
#### Create Lookup Table for Input

In [10]:
def create_lookup():
    table = dict()
    table['by key'] = dict()
    table['by length'] = dict()
    table['by atoms'] = dict()
    table['expressions'] = dict()

    for formula in generate_formula(2, 4):
        for expression in formula:
            key = generate_number(expression, 4)
            if key not in table['by key']:
                table['by key'][key] = []
            table['by key'][key] += [expression]

            length = len(expression)
            if length not in table['by length']:
                table['by length'][length] = []
            table['by length'][length] += [expression]

            atoms = len(expression.get_atoms())
            if atoms not in table['by atoms']:
                table['by atoms'][atoms] = []
            table['by atoms'][atoms] += [expression]
            
            table['expressions'][str(expression)] = dict()
            table['expressions'][str(expression)]['key'] = key
            table['expressions'][str(expression)]['length'] = length
            table['expressions'][str(expression)]['atoms'] = atoms

    return table

lookup_table = create_lookup()
simple_keys = list(lookup_table['by key'].keys())
simple_lengths = list(lookup_table['by length'].keys())
simple_atoms = list(lookup_table['by atoms'].keys())

<a id="model"></a>
## Model

<a id="embedding"></a>
### Embedding and FeedForward

In [11]:
class Embedding(nn.Module):
    
    def __init__(self, vocab_size, context_size, embedding_dim):
        super(Embedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # self.embedding_pos = nn.Embedding(context_size, embedding_dim)

        self.embedding_pos = torch.zeros(context_size, embedding_dim)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float()) * (-math.log(10000)/embedding_dim)
        pos = torch.arange(0, context_size, dtype=torch.float).unsqueeze(1)
        self.embedding_pos[:, 0::2] = torch.sin(pos * div_term)
        self.embedding_pos[:, 1::2] = torch.cos(pos * div_term)
        
        self.embedding_pos = self.embedding_pos.unsqueeze(0).to(device)


    def forward(self, input):
        input = self.embedding(input) * math.sqrt(self.embedding_dim)
        return input + (self.embedding_pos[:, :input.shape[1], :]).requires_grad_(False) 
        #return self.embedding(input) + self.embedding_pos(input).to(device)

class FeedForward(nn.Module):
    
    def __init__(self, embedding_dim, ff_dim):
        super().__init__()
        self.linear_1 = nn.Linear(embedding_dim, ff_dim)
        self.relu = nn.ReLU()
        self.linear_2 = nn.Linear(ff_dim, embedding_dim)

    def forward(self, input):
        input = self.linear_1(input)
        input = self.relu(input)
        input = self.linear_2(input)
        return input

<a id="attention"></a>
### Attention

In [12]:
class AttentionBlock(nn.Module):
    
    def __init__(self, embedding_dim, head_dimension, context_size):
        super().__init__()
        self.query = nn.Linear(embedding_dim, head_dimension, bias=False)
        self.key = nn.Linear(embedding_dim, head_dimension, bias=False)
        self.value = nn.Linear(embedding_dim, head_dimension, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))
    
    def forward(self, input, masked = False):
        B, T, C = input.size()

        query = self.query(input)
        key = self.key(input)
        value = self.value(input)

        qk = query @ key.transpose(-2, -1) * C**-0.5

        attention = qk.masked_fill(self.mask[:T,:T] == 0, float("-inf")) if masked else qk
        attention = F.softmax(input=attention, dim=-1)

        out = attention @ value
        return out
    
class MultiAttentionBlock(nn.Module):

    def __init__(self, embedding_dim, context_size, n_head):
        super().__init__()
        head_dim = embedding_dim // n_head
        assert head_dim * n_head == embedding_dim, "Embedding dimension must be divisible by number of heads"

        self.query = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = nn.Linear(embedding_dim, embedding_dim, bias=False)

        self.attention_blocks = nn.ModuleList([AttentionBlock(embedding_dim, head_dim, context_size) for _ in range(n_head)])
        
    def forward(self, query, key, value, masked = False):
        B, T, C = query.size()

        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        qk = query @ key.transpose(-2, -1) * C**-0.5

        attention = qk
        attention = F.softmax(input=attention, dim=-1)

        out = attention @ value
        return torch.cat([attention_block(out, masked) for attention_block in self.attention_blocks], dim=-1)

<a id="encoder"></a>
### Encoder

In [13]:
class EncoderLayer(nn.Module):

    def __init__(self, context_size, embedding_dim, hidden_dim, n_head):
        super(EncoderLayer, self).__init__()
        self.multi_attention = MultiAttentionBlock(embedding_dim, context_size, n_head)
        self.feed_forward = FeedForward(embedding_dim, hidden_dim)
        self.norm_1 = nn.LayerNorm(embedding_dim)
        self.norm_2 = nn.LayerNorm(embedding_dim)

    def forward(self, input):
        value = self.multi_attention(input, input, input)
        input = self.norm_1(input + value)
        value = self.feed_forward(input)
        return self.norm_2(input + value)

class Encoder(nn.Module):
    
    def __init__(self, vocab_size, context_size, embedding_dim, hidden_dim, n_head, encoder_layers):
        super(Encoder, self).__init__()
        self.embedding = Embedding(vocab_size, context_size, embedding_dim)
        self.encoders = nn.ModuleList([EncoderLayer(context_size, embedding_dim, hidden_dim, n_head) for _ in range(encoder_layers)])

    def forward(self, input):
        input = self.embedding(input)
        for encoder in self.encoders: input = encoder(input)
        return input

<a id="decoder"></a>
### Decoder

In [14]:
class DecoderLayer(nn.Module):
    
    def __init__(self, context_size, embedding_dim, hidden_dim, n_head):
        super(DecoderLayer, self).__init__()
        self.multi_attention_masked = MultiAttentionBlock(embedding_dim, context_size, n_head)
        self.multi_attention = MultiAttentionBlock(embedding_dim, context_size, n_head)
        self.feed_forward = FeedForward(embedding_dim, hidden_dim)
        self.norm_1 = nn.LayerNorm(embedding_dim)
        self.norm_2 = nn.LayerNorm(embedding_dim)
        self.norm_3 = nn.LayerNorm(embedding_dim)

    def forward(self, encoder_input, decoder_input):
        value = self.multi_attention_masked(decoder_input, decoder_input, decoder_input, masked=True)
        decoder_input = self.norm_1(decoder_input + value)
        value = self.multi_attention(encoder_input, encoder_input, decoder_input)
        decoder_input = self.norm_2(decoder_input + value)
        value = self.feed_forward(decoder_input)
        return self.norm_3(decoder_input + value)

class Decoder(nn.Module):
    
    def __init__(self, vocab_size, context_size, embedding_dim, hidden_dim, n_head, decoder_layers):
        super(Decoder, self).__init__()
        self.embedding = Embedding(vocab_size, context_size, embedding_dim)
        self.decoders = nn.ModuleList([DecoderLayer(context_size, embedding_dim, hidden_dim, n_head) for _ in range(decoder_layers)])

    def forward(self, encoder_input, decoder_input):
        decoder_input = self.embedding(decoder_input)
        for decoder in self.decoders: decoder_input = decoder(encoder_input, decoder_input)
        return decoder_input

<a id="transformer"></a>
### Transformer

In [15]:
class Model(nn.Module):

    def __init__(self, vocab_size, context_size, embedding_dim, hidden_dim, n_head= 8, encoder_layers= 3, decoder_layers= 3):
        super(Model, self).__init__()
        self.encoder = Encoder(vocab_size, context_size, embedding_dim, hidden_dim, n_head, encoder_layers)
        self.decoder = Decoder(vocab_size, context_size, embedding_dim, hidden_dim, n_head, decoder_layers)
        self.projection = nn.Linear(context_size * embedding_dim, vocab_size)

    def forward(self, encoder_input, decoder_input):
        encoder_output = self.encoder(encoder_input)
        decoder_output = self.decoder(encoder_output, decoder_input)
        out = self.projection(decoder_output.view(decoder_output.size(0), -1))
        return F.softmax(out, dim=-1)
    
    def predict(self, encoder_input, decoder_input, greedy = False):
        with torch.no_grad():
            if greedy: return torch.argmax(self(encoder_input, decoder_input), dim=-1)
            return torch.multinomial(self(encoder_input, decoder_input)[-1], num_samples=1)[-1]

<a id="generate-dataset"></a>
## Generate datasets

<a id="dataset-code"></a>
### Code to Generate Datasets

In [16]:
def create_simple_dataset(tokenizer, size: int, atoms: int, depth, premise_odds: list):
    output = []
    
    options = list(generate(depth, atoms))
    counter = 0
    np.random.shuffle(options)

    for conclusion in options:
        key = generate_number(conclusion, atoms)
        splits = generate_premises(key, premise_odds, simple_keys)
        if splits is None:
            continue
        counter += 1
        new_premises = [np.random.choice(lookup_table['by key'][split]) for split in splits]
        token_premises = [tokenizer.encode(tokenizer.tokenize(str(premise)) + ['.']) for premise in new_premises]
        encoder_input = [tokenizer.encoder['{']]
        for token_premise in token_premises:
            encoder_input += token_premise
        encoder_input += [tokenizer.encoder['}']]
        encoder_input += [tokenizer.encoder['<PAD>']] * (input_window_simple - len(encoder_input))
        
        decoder_input = tokenizer.encode(tokenizer.tokenize(str(conclusion)) + ['<END>'])

        for i in range(len(decoder_input)):
            new_datapoint = []
            new_datapoint.append(encoder_input)
            new_datapoint.append(decoder_input[:i] + [tokenizer.encoder['<PRED>']] + [tokenizer.encoder['<PAD>']] * (input_window_simple - len(decoder_input[:i]) - 1))
            new_datapoint.append(decoder_input[i])
            output.append(new_datapoint)
        if counter == size:
            return output
    return None

<a id="dataset-simple"></a>
### Generate Dataset for Simple Models

In [18]:
simple_dataset = create_simple_dataset(formula_tokenizer, dataset_size, 4, 2, [0.5, 0.5, 0.5])

AttributeError: 'tuple' object has no attribute 'split'

<a id="training-testen"></a>
## Training and Testing

<a id="get-batch"></a>
#### Get Batch

In [18]:
def get_batch(dataset, batch_size = 64):
    while True:
        np.random.shuffle(dataset)
        for i in range(0, len(dataset), batch_size):
            encoder_input = []
            decoder_input = []
            target_labels = []
            for j in range(i, min(i + batch_size, len(dataset))):
                encoder_input.append(dataset[j][0])
                decoder_input.append(dataset[j][1])
                target_labels.append(dataset[j][2])
            yield torch.tensor(encoder_input, dtype= torch.int64).to(device), torch.tensor(decoder_input, dtype= torch.int64).to(device), torch.tensor(target_labels, dtype= torch.int64).to(device)

<a id="train-code"></a>
#### Code for Training

In [19]:
def train(model, optimizer, dataset, epochs, batch_size, save_path, save_interval = 1000, save=False):
    batch = get_batch(dataset, batch_size)
    for epoch in range(epochs):
        encoder_input, decoder_input, decoder_target = next(batch)

        decoder_output = model(encoder_input, decoder_input)
        loss = F.cross_entropy(decoder_output, decoder_target, ignore_index=PAD_ID)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 100 == 0 or epoch == epochs - 1:
                print(f'Epoch {epoch + 1} Loss {loss.item()}')
        if save and (epoch + 1) % save_interval == 0: torch.save(model, save_path)
    
    if save: torch.save(model, save_path)
    
    return model
        

<a id="test-code"></a>
#### Code for Testing

In [30]:
def simple_test(model, tokenizer, n_prob = 10, n_greedy = 10):
    columns = ['Premises', 'Generated Conclusion', 'Greedy', 'Target', 'Generated Key', 'Valid', 'Correct']
    results = []
    for n in range(n_prob + n_greedy):
        new_results = [0, 0, 1, 0, -1, 0, 0]
        if n < n_prob: new_results[2] = 0


        datapoint = np.random.choice(lookup_table['by atoms'][np.random.choice(simple_atoms)])
        new_results[3] = lookup_table['expressions'][str(datapoint)]['key']

        decoder_sequence = tokenizer.encode(['<PRED>'] + ['<PAD>'] * (input_window_simple - 1))

        premises = generate_premises(lookup_table['expressions'][str(datapoint)]['key'], [0.5, 0.5, 0.5], simple_keys)
        while premises is None:
            datapoint = np.random.choice(lookup_table['by atoms'][np.random.choice(simple_atoms)])
            premises = generate_premises(lookup_table['expressions'][str(datapoint)]['key'], [0.5, 0.5, 0.5], simple_keys)

        encoder_sequence = tokenizer.encode(['{'])
        new_results[0] = '{'

        for premise in premises:
            premise_formula = np.random.choice(lookup_table['by key'][premise])
            encoder_sequence += tokenizer.encode(tokenizer.tokenize(str(premise_formula)) + ['.'])
            new_results[0] += str(premise_formula) + '. '
        new_results[0] += '}'

        encoder_sequence += tokenizer.encode(['}'])
        encoder_sequence += [tokenizer.encoder['<PAD>']] * (input_window_simple - len(encoder_sequence))
        encoder_input = torch.tensor(encoder_sequence, dtype= torch.int64).reshape(1,-1).to(device)

        for i in range(input_window_simple):
            decoder_input = torch.tensor(decoder_sequence, dtype= torch.int64).reshape(1,-1).to(device)
            
            new_token = model.predict(encoder_input, decoder_input, greedy = (n >= n_prob))
            if new_token.item() == tokenizer.encoder['<END>']:
                break
            
            decoder_sequence[i] = new_token.item()
            if i < input_window_simple - 1:
                decoder_sequence[i + 1] = tokenizer.encoder['<PRED>']

        output = ''.join(tokenizer.decode(decoder_sequence[:i]))
        new_results[1] = output
        try:
            output = Node.parse(output)
            new_results[4] = lookup_table['expressions'][str(output)]['key']
            new_results[5] = 1
            new_results[6] = 1 if lookup_table['expressions'][str(output)]['key'] == new_results[3] else 0
        except:
            pass
        results.append(new_results)
    return pd.DataFrame(results, columns=columns)

<a id="model1"></a>
### Model 1

In [24]:
model_1 = Model(vocab_size=len(formula_tokenizer), context_size=input_window_simple, embedding_dim=32, hidden_dim=1024, n_head=8, encoder_layers= 12, decoder_layers= 12).to(device)
optimizer_1 = torch.optim.AdamW(model_1.parameters(), lr=learning_rate)
model_1 = train(model_1, optimizer_1, simple_dataset, epochs, batch_size, save_path='new_model_1.pth', save=True, save_interval=1000)
#model_1 = torch.load('new_model_1.pth')
#print(sum(p.numel() for p in model_1.parameters() if p.requires_grad))

973651


  model_1 = torch.load('new_model_1.pth')


In [31]:
simple_test(model_1, formula_tokenizer, n_prob=1000, n_greedy=1000).to_csv('simple_test.csv', index=False)