# Installation

Use Conda and Pip to help us install packages for this homework. If you do not have Miniconda or Anaconda, you can install Miniconda from here https://docs.conda.io/en/latest/miniconda.html.

```
conda create --name competition python=3.7
conda activate competition

pip install jupyter pandas
```

Go to https://pytorch.org/ to install PyTorch if you don't have it already

To install the Hugging Face `transformers` library, run
```
pip install transformers
```

Follow the instructions from https://docs.dgl.ai/en/0.4.x/install/ to install Deep Graph Library (DGL).

Spin up jupyter notebook with
```
jupyter notebook
```

Implementation of [Graph-to-Tree Learning for Solving Math Word Problems] https://www.aclweb.org/anthology/2020.acl-main.362.pdf

In [None]:
# ! pip install transformers
# ! pip install dgl-cu101

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import sys
sys.path.append('./src/')

In [None]:
from copy import copy
import itertools
import os
from tqdm import tqdm, trange

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
from dgl.nn import GraphConv

#from utils import setup, check_match, sub_nP, evaluate_prefix_expression

In [None]:
device = 'cuda:0'
torch.cuda.is_available()

# Converting Inputs to Torch Tensors

In [None]:
def tensorize_data(data):
    for d in data:
        ## We already have the indices of the tokens
        ##d['in_idxs'] = torch.tensor([in_vocab.token2idx.get(x, in_vocab.unk) for x in d['in_tokens']])
        d['in_idxs'] = d['question']
        d['n_in'] = n_in = len(d['in_idxs'][0])
        d['n_nP'] = n_nP = len(d['nP'][0])
        ##d['nP_in_mask'] = mask = torch.zeros(n_in, dtype=torch.bool)
        ##mask[d['nP_positions']] = True
        ##d['nP_in_mask'] = d['nP_positions']
        if 'out_tokens' in d:
            d['out_idxs'] = torch.tensor([out_vocab.token2idx.get(x, out_vocab.unk) for x in d['out_tokens']])
            d['n_out'] = len(d['out_idxs'])
            d['nP_out_mask'] = mask = torch.zeros(n_max_nP, dtype=torch.bool)
            mask[:n_nP] = True
        d['qcomp_edges'] = get_quantity_comparison_edges(d)
        d['qcell_edges'] = get_quantity_cell_edges(d)

def get_quantity_comparison_edges(d):
    #print(d['nP'][0])
    quants = [float(x) for x in d['nP'][0]]
    quant_positions = d['nP_positions'][0]
    #print(quant_positions, d['n_in'])
    assert max(quant_positions) < d['n_in']
    adj_matrix = torch.eye(d['n_in'], dtype=np.bool)
    for x, x_pos in zip(quants, quant_positions):
        for y, y_pos in zip(quants, quant_positions):
            adj_matrix[x_pos, y_pos] |= x > y
    src_ids, dst_ids  = np.transpose(np.nonzero(adj_matrix))
    return (src_ids, dst_ids)

def get_quantity_cell_edges(d):
    in_idxs = d['in_idxs'][0]
    quant_positions = d['nP_positions'][0]
    quant_cell_positions = d['quant_cell_positions']
    assert max(quant_cell_positions) < d['n_in']
    word_cells = set(quant_cell_positions) - set(quant_positions)
    adj_matrix = torch.eye(d['n_in'], dtype=torch.bool)
    for w_pos in word_cells:
        for q_pos in quant_positions:
            if abs(w_pos - q_pos) < 4:
                adj_matrix[w_pos, q_pos] = adj_matrix[q_pos, w_pos] = True
    pos_idxs = in_idxs[quant_cell_positions]
    for idx1, pos1 in zip(pos_idxs, quant_cell_positions):
        for idx2, pos2 in zip(pos_idxs, quant_cell_positions):
            if idx1 == idx2:
                adj_matrix[pos1, pos2] = adj_matrix[pos2, pos1] = True
    src_ids, dst_ids  = np.transpose(np.nonzero(adj_matrix))
    return (src_ids, dst_ids)

In [None]:
class TransformerAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.qkv = nn.Linear(n_hid, n_head * (n_k * 2 + n_v))
        self.out = nn.Linear(n_head * n_v, n_hid)

    def forward(self, x, mask=None):
        n_batch, n_batch_max_in, n_hid = x.shape
        q_k_v = self.qkv(x).view(n_batch, n_batch_max_in, n_head, 2 * n_k + n_v).transpose(1, 2)
        q, k, v = q_k_v.split([n_k, n_k, n_v], dim=-1)

        q = q.reshape(n_batch * n_head, n_batch_max_in, n_k)
        k = k.reshape_as(q).transpose(1, 2)
        qk = q.bmm(k) / np.sqrt(n_k)

        if mask is not None:
            qk = qk.view(n_batch, n_head, n_batch_max_in, n_batch_max_in).transpose(1, 2)
            qk[~mask] = -np.inf
            qk = qk.transpose(1, 2).view(n_batch * n_head, n_batch_max_in, n_batch_max_in)
        qk = qk.softmax(dim=-1)
        v = v.reshape(n_batch * n_head, n_batch_max_in, n_v)
        qkv = qk.bmm(v).view(n_batch, n_head, n_batch_max_in, n_v).transpose(1, 2).reshape(n_batch, n_batch_max_in, n_head * n_v)
        out = self.out(qkv)
        return x + out

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = TransformerAttention()
        n_inner = n_hid * 4
        self.inner = nn.Sequential(
            nn.Linear(n_hid, n_inner),
            nn.ReLU(inplace=True),
            nn.Linear(n_inner, n_hid)
        )
    def forward(self, x, mask=None):
        x = x + self.attn(x, mask=mask)
        return x + self.inner(x)
    
class GCNBranch(nn.Module):
    def __init__(self, n_hid_in, n_hid_out, dropout=0.5):
        super().__init__()
        self.gc1 = dgl.nn.GraphConv(n_hid_in, n_hid_in, allow_zero_in_degree=True)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        self.gc2 = dgl.nn.GraphConv(n_hid_in, n_hid_out, allow_zero_in_degree=True) 
    def forward(self, x, graph):
        out = self.gc1(graph, x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.gc2(graph, out)
        return out

class GCN(nn.Module):
    def __init__(self, n_head=4, dropout=0.5):
        super().__init__()
        self.branches = nn.ModuleList(GCNBranch(n_hid, n_hid // n_head, dropout) for _ in range(n_head))

        self.feed_forward = nn.Sequential(
            nn.Linear(n_hid, n_hid),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(n_hid, n_hid),
        )
        self.layer_norm = nn.LayerNorm(n_hid)

    def forward(self, h, gt_graph, attr_graph):
        x = h.reshape(-1, n_hid)
        graphs = [gt_graph, gt_graph, attr_graph, attr_graph]
        x = torch.cat([branch(x, g) for branch, g in zip(self.branches, graphs)], dim=-1).view_as(h)
        x = h + self.layer_norm(x)
        return x + self.feed_forward(x)


class GraphAttentionLayer(nn.Module):

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, h, adj):
        Wh = torch.mm(h, self.W)
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0]
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        Wh_repeated_alternating = Wh.repeat(N, 1)
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        return all_combinations_matrix.view(N, N, 2 * self.out_features)


class GAT(nn.Module):

    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)


class Gate(nn.Module):
    def __init__(self, n_in, n_out):
        super(Gate, self).__init__()
        self.t = nn.Linear(n_in, n_out)
        self.s = nn.Linear(n_in, n_out)

    def forward(self, x):
        return self.t(x).tanh() * self.s(x).sigmoid()

class TreeDecoder(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        drop = nn.Dropout(dropout)
        self.constant_embedding = nn.Parameter(torch.randn(1, out_vocab.n_constants, n_hid))

        self.qp_gate = nn.Sequential(drop, Gate(n_hid, n_hid))
        self.gts_right = nn.Sequential(drop, Gate(2 * n_hid, n_hid))

        self.attn_fc = nn.Sequential(drop,
            nn.Linear(2 * n_hid, n_hid),
            nn.Tanh(),
            nn.Linear(n_hid, 1)
        )
        self.quant_fc = nn.Sequential(drop,
            nn.Linear(n_hid * 3, n_hid),
            nn.Tanh(),
            nn.Linear(n_hid, 1, bias=False)
        )
        self.op_fc = nn.Sequential(drop, nn.Linear(n_hid * 2, out_vocab.n_ops))

        self.op_embedding = nn.Embedding(out_vocab.n_ops + 1, n_hid, padding_idx=out_vocab.n_ops)
        self.gts_left = nn.Sequential(drop, Gate(n_hid * 2 + n_hid, n_hid))
        self.gts_left_qp = nn.Sequential(drop, Gate(n_hid * 2 + n_hid, n_hid), self.qp_gate)

        self.subtree_gate = nn.Sequential(drop, Gate(n_hid * 2 + n_hid, n_hid))

    def gts_attention(self, q, zbar, in_mask=None):
        attn_score = self.attn_fc(
            torch.cat([q.unsqueeze(1).expand_as(zbar), zbar], dim=2)
        ).squeeze(2)
        if in_mask is not None:
            attn_score[~in_mask] = -np.inf
        attn = attn_score.softmax(dim=1)
        return (attn.unsqueeze(1) @ zbar).squeeze(1) # (n_batch, n_hid)

    def gts_predict(self, qp_Gc, quant_embed, nP_out_mask=None):
        quant_score = self.quant_fc(
            torch.cat([qp_Gc.unsqueeze(1).expand(-1, quant_embed.size(1), -1), quant_embed], dim=2)
        ).squeeze(2)
        op_score = self.op_fc(qp_Gc)
        pred_score = torch.cat((op_score, quant_score), dim=1)
        if nP_out_mask is not None:
            pred_score[:, out_vocab.base_nP:][~nP_out_mask] = -np.inf
        return pred_score

    def merge_subtree(self, op, tl, yr):
        return self.subtree_gate(torch.cat((op, tl, yr), dim=-1))

class Model(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        drop = nn.Dropout(dropout)

        if use_t5:
            self.t5_encoder = t5_model.encoder        
            for i_layer, block in enumerate(self.t5_encoder.block):
                if i_layer in freeze_layers:
                    for param in block.parameters():
                        param.requires_grad = False
        else:
            self.in_embed = nn.Sequential(nn.Embedding(in_vocab.n, n_hid, padding_idx=in_vocab.pad), drop)
            self.pos_embed = nn.Embedding(1 + n_max_in, n_hid)
            self.transformer_layers = nn.ModuleList(TransformerBlock() for _ in range(n_layers))

        self.gcn = GCN()

        self.decoder = TreeDecoder()

        if not use_t5:
            self.apply(self.init_weight)

    def init_weight(self, m):
        if type(m) in [nn.Embedding]:
            nn.init.normal_(m.weight, 0, 0.1)

    def encode(self, in_idxs, n_in, gt_graph, attr_graph, in_mask=None):
        in_idxs_pad = F.pad(in_idxs, (1, 0), value=in_vocab.pad)
        if use_t5:
            """
            Use your T5 encoder to encoder the input indices. Note that you do NOT need to use an input embedding or
            positional embedding (e.g. self.in_embed or self.pos_embed) for T5, since it already defines
            the embeddings internally
            """
            h, = self.t5_encoder(in_idxs_pad)
        else:
            x = self.in_embed(in_idxs_pad)
            h = x + self.pos_embed(torch.arange(x.size(1), device=x.device))
            for layer in self.transformer_layers:
                h = layer(h, mask=in_mask)
        zg, h = h[:, 0], h[:, 1:]
        zbar = self.gcn(h, gt_graph, attr_graph)
        return zbar, zg

# Training a Batch

In [None]:
class Node:
    """
    Node for tree traversal during training
    """
    def __init__(self, up):
        self.up = up
        self.is_root = up is None
        self.left = self.right = None
        self.ql = self.tl = self.op = None

def train(batch, model, opt):
    n_batch = len(batch)

    n_in = [d['n_in'] for d in batch]
    pad = lambda x, value: nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=value)
    in_idxs = pad([d['in_idxs'] for d in batch], in_vocab.pad).to(device)
    in_mask = pad([torch.ones(n, dtype=torch.bool) for n in n_in], False).to(device)
    nP_in_mask = pad([d['nP_in_mask'] for d in batch], False).to(device)
    nP_out_mask = torch.stack([d['nP_out_mask'] for d in batch]).to(device)
    
    n_nodes = max([len(d['in_tokens']) for d in batch])
    qcomp_graph, qcell_graph = [], []
    for d in batch:
        qcomp_graph_i = dgl.graph(get_quantity_comparison_edges(d), num_nodes=n_nodes).to(device)
        qcell_graph_i = dgl.graph(get_quantity_cell_edges(d), num_nodes=n_nodes).to(device)
        
        qcomp_graph.append(qcomp_graph_i)
        qcell_graph.append(qcell_graph_i)
    qcomp_graph = dgl.batch(qcomp_graph)
    qcell_graph = dgl.batch(qcell_graph)
    
    label = pad([d['out_idxs'] for d in batch], out_vocab.pad)
    nP_candidates = [d['nP_candidates'] for d in batch]

    zbar, qroot = model.encode(in_idxs, n_in, qcomp_graph, qcell_graph, in_mask=None)
    z_nP = zbar.new_zeros((n_batch, n_max_nP, n_hid))
    z_nP[nP_out_mask] = zbar[nP_in_mask]

    decoder = model.decoder

    n_quant = out_vocab.n_constants + n_max_nP
    quant_embed = torch.cat([decoder.constant_embedding.expand(n_batch, -1, -1), z_nP], dim=1)

    nodes = np.array([Node(None) for _ in range(n_batch)])
    op_min, op_max = out_vocab.base_op, out_vocab.base_op + out_vocab.n_ops
    quant_min, quant_max = out_vocab.base_quant, out_vocab.base_quant + n_quant

    qp = decoder.qp_gate(qroot)
    scores = []
    for i, label_i in enumerate(label.T): 
        Gc = decoder.gts_attention(qp, zbar, in_mask)
        qp_Gc = torch.cat([qp, Gc], dim=1)

        score = decoder.gts_predict(qp_Gc, quant_embed, nP_out_mask)
        scores.append(score)

        is_op = (op_min <= label_i) & (label_i < op_max)
        is_quant = ((quant_min <= label_i) & (label_i < quant_max)) | (label_i == out_vocab.unk)

        op_embed = decoder.op_embedding((label_i[is_op] - out_vocab.base_op).to(device))
        qp_Gc_op = torch.cat([qp_Gc[is_op], op_embed], dim=1)

        is_left = np.zeros(n_batch, dtype=np.bool)
        qleft_qp = decoder.gts_left_qp(qp_Gc_op)
        qleft = decoder.gts_left(qp_Gc_op)
        for j, ql, op in zip(is_op.nonzero(as_tuple=True)[0], qleft, op_embed):
            node = nodes[j]
            nodes[j] = node.left = Node(node)
            node.op = op
            node.ql = ql
            is_left[j] = True

        is_right = np.zeros(n_batch, dtype=np.bool)
        nP_score = score[:, out_vocab.base_nP:].detach().cpu()
        ql_tl = []
        for j in is_quant.nonzero(as_tuple=True)[0]:
            if label_i[j] == out_vocab.unk:
                candidates = nP_candidates[j][i]
                label_i[j] = out_vocab.base_nP + candidates[nP_score[j, candidates].argmax()]

            node = nodes[j]
            pnode = node.up
            t = quant_embed[j, label_i[j] - out_vocab.base_quant]
            while pnode and pnode.right is node:
                t = decoder.merge_subtree(pnode.op, pnode.tl, t)
                node, pnode = pnode, pnode.up 
            if pnode is None: 
                continue
            pnode.tl = t
            ql_tl.append(torch.cat([pnode.ql, pnode.tl]))
            nodes[j] = pnode.right = Node(pnode)
            is_right[j] = True

        qp = torch.zeros((n_batch, n_hid), device=device)
        qp[is_left] = qleft_qp
        if ql_tl:
            qp[is_right] = decoder.gts_right(torch.stack(ql_tl))

    label = label.to(device).view(-1)
    scores = torch.stack(scores, dim=1).view(-1, out_vocab.n_ops + n_quant)
    loss = F.cross_entropy(scores, label, ignore_index=out_vocab.pad)

    opt.zero_grad()
    loss.backward()
    opt.step()
    return loss.item()

def flag(batch, model, opt, M, alpha):
    model.train()
    optimizer.zero_grad()

    pert = torch.FloatTensor(*batch.shape).uniform_(-alpha, alpha)
    pert.requires_grad_()
    loss = train(batch, model, opt) / M
    
    for _ in range(M-1):
        loss.backward()
        pert_data = pert.detach() + alpha*torch.sign(pert.grad.detach())
        pert.data = pert_data.data
        pert.grad[:] = 0
        loss = train(batch, model, opt) / M
      
    loss.backward()
    optimizer.step()

# Prediction

In [None]:
class BeamNode(Node):
    def __init__(self, up, prev, qp, token=None):
        super().__init__(up)
        self.prev = prev
        self.qp = qp
        self.token = token

    def trace_tokens(self, *last_token):
        if self.prev is None:
            return list(last_token)
        tokens = self.prev.trace_tokens()
        tokens.append(self.token)
        tokens.extend(last_token)
        return tokens

def predict(d, model, beam_size=5, n_max_out=45):
    in_idxs = d['in_idxs'].unsqueeze(0).to(device=device)
    n_nodes = len(d['in_tokens'])
    qcomp_graph = dgl.graph(get_quantity_comparison_edges(d), num_nodes=n_nodes).to(device)
    qcell_graph = dgl.graph(get_quantity_cell_edges(d), num_nodes=n_nodes).to(device)

    zbar, qroot = model.encode(in_idxs, [d['n_in']], qcomp_graph, qcell_graph)
    z_nP = zbar[:, d['nP_positions']]

    decoder = model.decoder

    quant_embed = torch.cat([decoder.constant_embedding, z_nP], dim=1)
    op_min, op_max = out_vocab.base_op, out_vocab.base_op + out_vocab.n_ops

    best_done_beam = (-np.inf, None, None)
    beams = [(0, BeamNode(up=None, prev=None, qp=decoder.qp_gate(qroot)))]
    for _ in range(n_max_out):
        new_beams = []
        for logp_prev, node in beams:
            Gc = decoder.gts_attention(node.qp, zbar)
            qp_Gc = torch.cat([node.qp, Gc], dim=1)

            log_prob = decoder.gts_predict(qp_Gc, quant_embed).log_softmax(dim=1)
            top_logps, top_tokens = log_prob.topk(beam_size, dim=1)
            for logp_token_, out_token_ in zip(top_logps.unbind(dim=1), top_tokens.unbind(dim=1)):
                out_token = out_token_.item()
                logp = logp_prev + logp_token_.item()
                if op_min <= out_token < op_max:
                    op_embed = decoder.op_embedding(out_token_)
                    qp_Gc_op = torch.cat([qp_Gc, op_embed], dim=1)
                    prev_node = copy(node)
                    next_node = prev_node.left = BeamNode(
                        up=prev_node, prev=prev_node,
                        qp=decoder.gts_left_qp(qp_Gc_op),
                        token=out_token
                    )
                    prev_node.op = op_embed
                    prev_node.ql = decoder.gts_left(qp_Gc_op)
                else:
                    pnode, prev_node = node.up, node
                    t = quant_embed[:, out_token - out_vocab.base_quant]
                    while pnode and pnode.tl is not None:
                        t = decoder.merge_subtree(pnode.op, pnode.tl, t)
                        node, pnode = pnode, pnode.up
                    if pnode is None:
                        best_done_beam = max(best_done_beam, (logp, prev_node, out_token))
                        continue
                    pnode = copy(pnode)
                    pnode.tl = t
                    next_node = pnode.right = BeamNode(
                        up=pnode, prev=prev_node,
                        qp=decoder.gts_right(torch.cat([pnode.ql, pnode.tl], dim=1)),
                        token=out_token
                    )
                new_beams.append((logp, next_node))
        beams = sorted(new_beams, key=lambda x: x[0], reverse=True)[:beam_size]
        done_logp, done_node, done_last_token = best_done_beam
        if not len(beams) or done_logp >= beams[0][0]:
            break
    return done_node.trace_tokens(done_last_token)

# Training

In [None]:
## Functions to generate the t5 model and datasets
import pickle
import transformers

with open('t5-math-data.pickle', 'rb') as pick:
        dataset = pickle.load(pick)
        
print(len(dataset))

In [None]:
MAXLEN = 512
BATCH_SIZE = 1

class CreateDataset(torch.utils.data.Dataset):
    def __init__(self, data, max_tokens):
        self.data = data
        self.max_tokens = max_tokens

    def __len__(self):
        return MAXLEN

    def __getitem__(self, item):
        question = self.data[item]['input_ids_list']
        answer = self.data[item]['label_ids_list']
        nP = self.data[item]['nP']
        nP_locs = self.data[item]['nP_positions']

        question = question[:self.max_tokens]
        answer = answer[:self.max_tokens]

        ## Make sure the answer has the same length for each
        num_to_pad = self.max_tokens - len(answer)
        answer = F.pad(torch.Tensor(answer), [0, num_to_pad], mode='constant', value=-100)
        
        # Make sure the answer has the same length for each
        num_to_pad = self.max_tokens - len(nP)
        nP = F.pad(torch.Tensor(nP), [0, num_to_pad], mode='constant', value=-100)
        nP_locs = F.pad(torch.LongTensor(nP_locs), [0, num_to_pad], mode='constant', value=-100)
        nP_locs_mask = torch.zeros(self.max_tokens)
        nP_locs_mask[self.data[item]['nP_positions']] = 1

        
        

        return{
            'question': question,
            'answer': answer,
            'nP': nP,
            'nP_positions': nP_locs,
            'nP_in_mask': nP_locs_mask
        }

def create_data_loader(dataset):
    ds = CreateDataset(dataset, MAXLEN)

    return torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, num_workers=4)

In [None]:
train_data_loader = create_data_loader(dataset)

check_data = next(iter(train_data_loader))
print(check_data.keys())


In [None]:
## Save the vocab from the t5 tokenizer 
tokenizer = transformers.T5Tokenizer.from_pretrained('t5-large')
in_vocab = tokenizer.save_vocabulary('./')
out_vocab = tokenizer.save_vocabulary('./')

print(in_vocab)

In [None]:
## Setup the t5 model 
t5_model = transformers.T5Model.from_pretrained('t5-large')


In [None]:
hyp_use_t5 = None
hyp_n_max_in = 100
hyp_n_epochs = 100
hyp_n_batch = 64
hyp_lr = 1e-3
hyp_n_layers = 3
hyp_n_hid = 512
hyp_kv = 64
hyp_n_head = 8
hyp_wd = 0
hyp_t5_freeze_layers = []
hyp_t5_decay = 1e-5
hyp_eval_epoch = 30

use_t5 = hyp_use_t5
model_save_dir = f'./models/{use_t5 or "custom"}'
os.makedirs(model_save_dir, exist_ok=True)

n_max_in = hyp_n_max_in #100
n_epochs = hyp_n_epochs #100
n_batch = hyp_n_batch #64
learning_rate = hyp_lr #1e-3
if use_t5:
    freeze_layers = hyp_t5_freeze_layers #[]
    weight_decay = hyp_t5_decay #1e-5
    n_hid = dict(small=512, base=768)[use_t5]
else:
    n_layers = hyp_n_layers #3
    n_hid = hyp_n_hid #512
    n_k = n_v = hyp_kv #64
    n_head = hyp_n_head #8
    weight_decay = hyp_wd #0
device = 'cuda:0'

# train_data, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
train_data = train_data_loader 
val_data = train_data_loader 
n_max_nP = 10



tensorize_data(itertools.chain(train_data, val_data))

model = Model()
opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, n_epochs)
model.to(device)

epoch = 0
while epoch < n_epochs:
    print('Epoch:', epoch + 1)
    model.train()
    losses = []
    for start in trange(0, len(train_data), n_batch):
        batch = sorted(train_data[start: start + n_batch], key=lambda d: -d['n_in'])
        loss = train(batch, model, opt)
        losses.append(loss)
    scheduler.step()

    print(f'Training loss: {np.mean(losses):.3g}')

    epoch += 1
    if epoch % 10 == 0:
        model.eval()
        value_match, equation_match = [], []
        with torch.no_grad():
            for d in tqdm(val_data):
                if d['is_quadratic']:
                    val_match = eq_match = False
                else:
                    pred = predict(d, model)
                    d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                    val_match, eq_match = check_match(pred, d)
                value_match.append(val_match)
                equation_match.append(eq_match)
        print(f'Validation expression accuracy: {np.mean(equation_match):.3g}')
        print(f'Validation value accuracy: {np.mean(value_match):.3g}')
        torch.save(model.state_dict(), os.path.join(model_save_dir, f'model-{epoch}.pth'))
    print()

# CUSTOM TRANSFORMER BENCHMARK

In [None]:
hyp_use_t5 = "small"
hyp_n_max_in = 100
hyp_n_epochs = 100
hyp_n_batch = 64
hyp_lr = 1e-3
hyp_n_layers = 3
hyp_n_hid = 512
hyp_kv = 64
hyp_n_head = 8
hyp_wd = 0
hyp_t5_freeze_layers = []
hyp_t5_decay = 1e-5
hyp_eval_epoch = 30

use_t5 = hyp_use_t5
model_save_dir = f'./models/{use_t5 or "custom"}'
os.makedirs(model_save_dir, exist_ok=True)

n_max_in = hyp_n_max_in #100
n_epochs = hyp_n_epochs #100
n_batch = hyp_n_batch #64
learning_rate = hyp_lr #1e-3
if use_t5:
    freeze_layers = hyp_t5_freeze_layers #[]
    weight_decay = hyp_t5_decay #1e-5
    n_hid = dict(small=512, base=768)[use_t5]
else:
    n_layers = hyp_n_layers #3
    n_hid = hyp_n_hid #512
    n_k = n_v = hyp_kv #64
    n_head = hyp_n_head #8
    weight_decay = hyp_wd #0
device = 'cuda:0'

train_data, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
tensorize_data(itertools.chain(train_data, val_data))

model = Model()
opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, n_epochs)
model.to(device)

epoch = 0
while epoch < n_epochs:
    print('Epoch:', epoch + 1)
    model.train()
    losses = []
    for start in trange(0, len(train_data), n_batch):
        batch = sorted(train_data[start: start + n_batch], key=lambda d: -d['n_in'])
        loss = train(batch, model, opt)
        losses.append(loss)
    scheduler.step()

    print(f'Training loss: {np.mean(losses):.3g}')

    epoch += 1
    if epoch % 10 == 0:
        model.eval()
        value_match, equation_match = [], []
        with torch.no_grad():
            for d in tqdm(val_data):
                if d['is_quadratic']:
                    val_match = eq_match = False
                else:
                    pred = predict(d, model)
                    d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                    val_match, eq_match = check_match(pred, d)
                value_match.append(val_match)
                equation_match.append(eq_match)
        print(f'Validation expression accuracy: {np.mean(equation_match):.3g}')
        print(f'Validation value accuracy: {np.mean(value_match):.3g}')
        torch.save(model.state_dict(), os.path.join(model_save_dir, f'model-{epoch}.pth'))
    print()

# T5 SMALL BENCHMARK


In [None]:
hyp_search_custom_results = {}
hyp_search_custom_conf = [
    ("batch", 64),
    ("batch", 32),
    ("batch", 128),
    ("lr", 0.005),
    ("lr", 0.0002),
    ("n_l", 5),
    ("n_l", 7),
    ("n_hid", 1024),
    ("kv", 128),
    ("head", 4),
    ("head", 16),
]

for hyp, val in hyp_search_custom_conf:

    hyp_use_t5 = None
    hyp_n_max_in = 100
    hyp_n_epochs = 100
    hyp_n_batch = 64
    hyp_lr = 1e-3
    hyp_n_layers = 3
    hyp_n_hid = 512
    hyp_kv = 64
    hyp_n_head = 8
    hyp_wd = 0
    hyp_t5_freeze_layers = []
    hyp_t5_decay = 1e-5
    hyp_eval_epoch = 30

    if hyp == "batch":
        hyp_n_batch = val
    elif hyp == "lr":
        hyp_lr = val
    elif hyp == "n_l":
        hyp_n_layers = val
    elif hyp == "n_hid":
        hyp_n_hid = val
    elif hyp == "kv":
        hyp_kv = val
    elif hyp == "head":
        hyp_n_head = val

    if hyp not in hyp_search_custom_results:
        hyp_search_custom_results[hyp] = {}
    hyp_search_custom_results[hyp][val] = []

    use_t5 = hyp_use_t5
    model_save_dir = f'./models/{use_t5 or "custom"}'
    os.makedirs(model_save_dir, exist_ok=True)

    n_max_in = hyp_n_max_in #100
    n_epochs = hyp_n_epochs #100
    n_batch = hyp_n_batch #64
    learning_rate = hyp_lr #1e-3
    if use_t5:
        freeze_layers = hyp_t5_freeze_layers #[]
        weight_decay = hyp_t5_decay #1e-5
        n_hid = dict(small=512, base=768)[use_t5]
    else:
        n_layers = hyp_n_layers #3
        n_hid = hyp_n_hid #512
        n_k = n_v = hyp_kv #64
        n_head = hyp_n_head #8
        weight_decay = hyp_wd #0
    device = 'cuda:0'

    #train_data, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
    tensorize_data(itertools.chain(train_data, val_data))

    model = Model()
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, n_epochs)
    model.to(device)

    epoch = 0
    while epoch < n_epochs:
        print('Epoch:', epoch + 1)
        model.train()
        losses = []
        for start in trange(0, len(train_data), n_batch):
            batch = sorted(train_data[start: start + n_batch], key=lambda d: -d['n_in'])
            loss = train(batch, model, opt)
            losses.append(loss)
        scheduler.step()

        print(f'Training loss: {np.mean(losses):.3g}')

        epoch += 1
        if epoch % 10 == 0:
            model.eval()
            value_match, equation_match = [], []
            with torch.no_grad():
                for d in tqdm(val_data):
                    if d['is_quadratic']:
                        val_match = eq_match = False
                    else:
                        pred = predict(d, model)
                        d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                        val_match, eq_match = check_match(pred, d)
                    value_match.append(val_match)
                    equation_match.append(eq_match)
            print(f'Validation expression accuracy: {np.mean(equation_match):.3g}')
            print(f'Validation value accuracy: {np.mean(value_match):.3g}')

            hyp_search_custom_results[hyp][val].append((np.mean(equation_match), np.mean(value_match)))
            torch.save(model.state_dict(), os.path.join(model_save_dir, f'model-{epoch}.pth'))
        print()

print()
print(hyp_search_custom_results)

Custom Transformer Hyperparameters

| Batch Size | Learning Rate | Num Layers | Num Hidden | kv dim | Num Head | Val Expression Acc | Val Value Acc |
| --- | --- | --- | --- | --- | --- |  --- | --- |
| 64 | 0.001 | 3 | 512 | 64 | 8 | 0.742 | 0.765 |
| 32 | 0.001 | 3 | 512 | 64 | 8 | 0.000 | 0.023 |
| 128 | 0.001 | 3 | 512 | 64 | 8 | 0.742 | 0.770 |
| 64 | 0.005 | 3 | 512 | 64 | 8 | 0.000 | 0.019 |
| 64 | 0.0002 | 3 | 512 | 64 | 8 | 0.732 | 0.746 |
| 64 | 0.001 | 5 | 512 | 64 | 8 | 0.709 | 0.723 |
| 64 | 0.001 | 7 | 512 | 64 | 8 | 0.540 | 0.545 |
| 64 | 0.001 | 3 | 1024 | 64 | 8 | 0.000 | 0.014 |
| 64 | 0.001 | 3 | 512 | 128 | 8 | 0.751 | 0.775 |
| 64 | 0.001 | 3 | 512 | 64 | 4 | 0.761 | 0.779 |
| 64 | 0.001 | 3 | 512 | 64 | 16 | 0.714 | 0.732 |

In [None]:
hyp_search_t5_results = {}
hyp_search_t5_conf = [
    ("freeze", [0, 1, 2, 3, 4, 5]),
    ("freeze", [0, 1, 2, 3]),
    ("freeze", [0, 1]),
    ("decay", 1e-6),
    ("decay", 1e-4),
]

for hyp, val in hyp_search_t5_conf:

    hyp_use_t5 = "small"
    hyp_n_max_in = 100
    hyp_n_epochs = 100
    hyp_n_batch = 64
    hyp_lr = 1e-3
    hyp_n_layers = 3
    hyp_n_hid = 512
    hyp_kv = 64
    hyp_n_head = 8
    hyp_wd = 0
    hyp_t5_freeze_layers = []
    hyp_t5_decay = 1e-5
    hyp_eval_epoch = 30

    if hyp == "batch":
        hyp_n_batch = val
    elif hyp == "lr":
        hyp_lr = val
    elif hyp == "freeze":
        hyp_t5_freeze_layers = val
    elif hyp == "decay":
        hyp_t5_decay = val

    if hyp not in hyp_search_t5_results:
        hyp_search_t5_results[hyp] = {}
    hyp_search_t5_results[hyp][str(val)] = []

    use_t5 = hyp_use_t5
    model_save_dir = f'./models/{use_t5 or "custom"}'
    os.makedirs(model_save_dir, exist_ok=True)

    n_max_in = hyp_n_max_in #100
    n_epochs = hyp_n_epochs #100
    n_batch = hyp_n_batch #64
    learning_rate = hyp_lr #1e-3
    if use_t5:
        freeze_layers = hyp_t5_freeze_layers #[]
        weight_decay = hyp_t5_decay #1e-5
        n_hid = dict(small=512, base=768)[use_t5]
    else:
        n_layers = hyp_n_layers #3
        n_hid = hyp_n_hid #512
        n_k = n_v = hyp_kv #64
        n_head = hyp_n_head #8
        weight_decay = hyp_wd #0
    device = 'cuda:0'

    train_data, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
    tensorize_data(itertools.chain(train_data, val_data))

    model = Model()
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, n_epochs)
    model.to(device)

    epoch = 0
    while epoch < n_epochs:
        print('Epoch:', epoch + 1)
        model.train()
        losses = []
        for start in trange(0, len(train_data), n_batch):
            batch = sorted(train_data[start: start + n_batch], key=lambda d: -d['n_in'])
            loss = train(batch, model, opt)
            losses.append(loss)
        scheduler.step()

        print(f'Training loss: {np.mean(losses):.3g}')

        epoch += 1
        if epoch % 10 == 0:
            model.eval()
            value_match, equation_match = [], []
            with torch.no_grad():
                for d in tqdm(val_data):
                    if d['is_quadratic']:
                        val_match = eq_match = False
                    else:
                        pred = predict(d, model)
                        d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                        val_match, eq_match = check_match(pred, d)
                    value_match.append(val_match)
                    equation_match.append(eq_match)
            print(f'Validation expression accuracy: {np.mean(equation_match):.3g}')
            print(f'Validation value accuracy: {np.mean(value_match):.3g}')

            hyp_search_t5_results[hyp][str(val)].append((np.mean(equation_match), np.mean(value_match)))
            torch.save(model.state_dict(), os.path.join(model_save_dir, f'model-{epoch}.pth'))
        print()

print()
print(hyp_search_t5_results)

T5 Hyperparameters

| Batch Size | Freeze Layers | Weight Decay | Val Expression Acc | Val Value Acc |
| --- | --- | --- | --- | --- |
| 64 | [] | 1e-5 | 0.709 | 0.728 |
| 128 | [] | 1e-5 | 0.709 | 0.728 |
| 64 | [0, 1] | 1e-5 | 0.723 | 0.742 |
| 64 | [0, 1, 2, 3] | 1e-5 | 0.685 | 0.704 |
| 64 | [0, 1, 2, 3, 4, 5] | 1e-5 | 0.648 | 0.671 |
| 64 | [] | 1e-6 | 0.718 | 0.737 |
| 64 | [] | 1e-4 | 0.737 | 0.761 |

# Prediction

In [None]:
hyp_use_t5 = "base"
hyp_n_max_in = 100
hyp_n_epochs = 120
hyp_n_batch = 64
hyp_lr = 0.0005
hyp_t5_freeze_layers = []
hyp_t5_decay = 0.000001 #0.0001
hyp_eval_epoch = 100

use_t5 = hyp_use_t5
model_save_dir = f'./models/{use_t5 or "custom"}'
os.makedirs(model_save_dir, exist_ok=True)

n_max_in = hyp_n_max_in #100
n_epochs = hyp_n_epochs #100
n_batch = hyp_n_batch #64
learning_rate = hyp_lr #1e-3
if use_t5:
    freeze_layers = hyp_t5_freeze_layers #[]
    weight_decay = hyp_t5_decay #1e-5
    n_hid = dict(small=512, base=768)[use_t5]
else:
    n_layers = hyp_n_layers #3
    n_hid = hyp_n_hid #512
    n_k = n_v = hyp_kv #64
    n_head = hyp_n_head #8
    weight_decay = hyp_wd #0
device = 'cuda:0'

train_data, val_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5)
tensorize_data(itertools.chain(train_data, val_data))

model = Model()
opt = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, n_epochs)
model.to(device)

epoch = 0
while epoch < n_epochs:
    print('Epoch:', epoch + 1)
    model.train()
  
    losses = []
    for start in trange(0, len(train_data), n_batch):
        batch = sorted(train_data[start: start + n_batch], key=lambda d: -d['n_in'])
        loss = train(batch, model, opt)
        losses.append(loss)
    scheduler.step()

    print(f'Training loss: {np.mean(losses):.3g}')

    epoch += 1
    if epoch % 10 == 0:
        model.eval()
        value_match, equation_match = [], []
        with torch.no_grad():
            for d in tqdm(val_data):
                if d['is_quadratic']:
                    val_match = eq_match = False
                else:
                    pred = predict(d, model)
                    d['pred_tokens'] = [out_vocab.idx2token[idx] for idx in pred]
                    val_match, eq_match = check_match(pred, d)
                value_match.append(val_match)
                equation_match.append(eq_match)
        print(f'Validation expression accuracy: {np.mean(equation_match):.3g}')
        print(f'Validation value accuracy: {np.mean(value_match):.3g}')
        if (np.mean(value_match) > 0.78):
            torch.save(model.state_dict(), os.path.join(model_save_dir, f'model-{epoch}.pth'))
    print()

In [None]:
use_t5 = hyp_use_t5
eval_epoch = 120
device = 'cuda:0' #'cpu'

n_max_in = hyp_n_max_in #100
if use_t5:
    freeze_layers = hyp_t5_freeze_layers #[]
    n_hid = dict(small=512, base=768)[use_t5]
else:
    n_layers = hyp_n_layers #3
    n_hid = hyp_n_hid #512
    n_k = n_v = hyp_kv #64
    n_head = hyp_n_head #8

test_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5, do_eval=True)
model = Model().to(device)
model.load_state_dict(torch.load(f'./models/{use_t5 or "custom"}/model-{eval_epoch}.pth'))
tensorize_data(test_data)

with torch.no_grad():
    for d in tqdm(test_data):
        pred = predict(d, model)
        d['pred_tokens'] = pred_tokens = [out_vocab.idx2token[idx] for idx in pred]
        d['subbed_tokens'] = subbed_tokens = sub_nP(pred_tokens, d['nP'])
        d['Predicted'] = round(evaluate_prefix_expression(subbed_tokens), 3)

import pandas as pd
predictions = pd.DataFrame(test_data).set_index('Id')

In [None]:
predictions[['pred_tokens', 'subbed_tokens', 'Predicted']]

In [None]:
predictions[['Predicted']].replace([np.inf, -np.inf, np.nan], 0).to_csv('./preds/prediction_{}.csv'.format(eval_epoch))

In [None]:

modelfps = [
]
predictions_list = []
for modelfp in modelfps:
    use_t5 = "base"
    device = 'cuda:0' #'cpu'
    n_max_in = 100
    if use_t5:
        freeze_layers = []
        n_hid = dict(small=512, base=768)[use_t5]
    else:
        n_layers = hyp_n_layers #3
        n_hid = hyp_n_hid #512
        n_k = n_v = hyp_kv #64
        n_head = hyp_n_head #8

    test_data, in_vocab, out_vocab, n_max_nP, t5_model = setup(use_t5, do_eval=True)
    model = Model().to(device)
    model.load_state_dict(torch.load(modelfp))
    tensorize_data(test_data)

    with torch.no_grad():
        for d in tqdm(test_data):
            pred = predict(d, model)
            d['pred_tokens'] = pred_tokens = [out_vocab.idx2token[idx] for idx in pred]
            d['subbed_tokens'] = subbed_tokens = sub_nP(pred_tokens, d['nP'])
            d['Predicted'] = round(evaluate_prefix_expression(subbed_tokens), 3)

    import pandas as pd
    predictions = pd.DataFrame(test_data).set_index('Id')
    predictions_list.append(predictions)

print(predictions_list)

prediction_merged = predictions_list[0]
for index, row in prediction_merged.iterrows():
    i_res = []
    for preddf in predictions_list:
        p_row = preddf.iloc[index]
        match = False
        for r in i_res:
            if ((p_row['pred_tokens'] == r[0]) and (p_row['subbed_tokens'] == r[1]) and (p_row['Predicted'] == r[2])):
                match = True
                r[3] += 1
        if not match:
            i_res.append((p_row['pred_tokens'], p_row['subbed_tokens'], p_row['Predicted'], 1))

    max_v, max_i = 0, -1
    for ir in range(len(i_res)):
        if i_res[ir][3] > max_v:
            max_v = i_res[ir][3]
            max_i = ir
    row['pred_tokens'] = i_res[max_i]['pred_tokens']
    row['subbed_tokens'] = i_res[max_i]['subbed_tokens']
    row['Predicted'] = i_res[max_i]['Predicted']

print(prediction_merged)
