In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import wandb
import torch
import numpy as np
import pickle

from pathlib import Path
from torch import nn, Tensor, optim
from typing import Optional
from torch_geometric.loader import DataLoader


from src.datapipe import WikiDataset
from src.utils.common import PAD, save_checkpoint
from src.utils.training import train_fn, eval_fn
from src.models.graph_seq import GraphSeq, GraphSeqAttn
from src.optim import (
    get_optimizer,
    get_group_params,
    get_linear_scheduler_with_warmup,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.login(key="fd8e6949c75375b623a566795f8460842fee1e14")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mandompesta[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ando_cavallari/.netrc


True

# Experiment 1
*Goal* overfit a single batch to verify code correctnes 

In [3]:
project = "astrazeneca"
experiment_name = "single-batch"

trials = [
    # trial setup
    dict(
        job_type="train",
        project=project,
        group=experiment_name,
        notes="test training pipeline with a single batch on simple model",
        config=dict(
            dataset_base_path="data/wiki",
            dataset_name="dev",
            vocab_path="data/wiki/entity_2_id.bin",
            batch_size=5,
            learning_rate=0.003,
            device="cuda",
            accumulation_steps=1,
            max_grad_norm=20.,
            epochs=500,
            steps_per_epoch=5,
            pad_idx=0,
            emb_dim=6,
            graph_conv_layers=1,
            rnn_decoder_layers=1,
            rnn_dropout=0.,
        ),
    )
]

In [4]:
for trial in trials:
    with wandb.init(**trial) as exp:
        dev_dataset = WikiDataset(
            exp.config.dataset_base_path,
            exp.config.dataset_name,
            exp.config.vocab_path,
        )
        exp.config["vocab_size"] = len(dev_dataset.entity_2_id.data)

        dev_dl = DataLoader(
            dev_dataset,
            batch_size=exp.config.batch_size,
            shuffle=False,
        )

        # as it is a single batch experiment
        batch_data = next(iter(dev_dl))
        dev_dl = [batch_data] * exp.config.steps_per_epoch

        # create model
        model = GraphSeq(
            emb_dim=exp.config.emb_dim,
            vocab_size=exp.config.vocab_size,
            pad_idx=exp.config.pad_idx,
            graph_conv_layers=exp.config.graph_conv_layers,
            rnn_decoder_layers=exp.config.rnn_decoder_layers,
            rnn_dropout=exp.config.rnn_dropout,
        )

        # create optimizer
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=exp.config.learning_rate,
        )

        # setup for training
        device = exp.config.device
        optimizer.zero_grad()
        model.train()
        model = model.to(device)

        for epoch in range(exp.config.epochs):

            metrics = train_fn(
                model=model,
                dataloader=dev_dl,
                optimizer=optimizer,
                steps_per_epoch=exp.config.steps_per_epoch,
                device=device,
                gradient_accumulation_steps=exp.config.accumulation_steps,
                pad_idx=exp.config.pad_idx,
                max_grad_norm=exp.config.max_grad_norm,
            )

            print("epoch:{epoch}\tacc:{acc} \t loss:{loss}".format(
                epoch=epoch,
                acc=metrics["train_accuracy"],
                loss=metrics["train_loss"],
            ))
            exp.log(metrics, step=epoch)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mandompesta[0m. Use [1m`wandb login --relogin`[0m to force relogin


epoch:0	acc:0.0 	 loss:4.331194305419922
epoch:1	acc:0.0 	 loss:4.308772277832031
epoch:2	acc:0.0 	 loss:4.278983402252197
epoch:3	acc:0.07647058823529412 	 loss:4.238000202178955
epoch:4	acc:0.15294117647058825 	 loss:4.180426597595215
epoch:5	acc:0.2647058823529412 	 loss:4.098905754089356
epoch:6	acc:0.16470588235294117 	 loss:3.985474729537964
epoch:7	acc:0.14705882352941177 	 loss:3.8361175537109373
epoch:8	acc:0.14705882352941177 	 loss:3.6525842189788817
epoch:9	acc:0.11176470588235295 	 loss:3.4361230373382567
epoch:10	acc:0.029411764705882353 	 loss:3.178119659423828
epoch:11	acc:0.029411764705882353 	 loss:2.8574286460876466
epoch:12	acc:0.07647058823529412 	 loss:2.445392942428589
epoch:13	acc:0.08823529411764706 	 loss:1.9586676359176636
epoch:14	acc:0.08823529411764706 	 loss:1.5992687702178956
epoch:15	acc:0.08823529411764706 	 loss:1.447387194633484
epoch:16	acc:0.08823529411764706 	 loss:1.3276278495788574
epoch:17	acc:0.14705882352941177 	 loss:1.2392212390899657
epoch

0,1
train_accuracy,▁▁▃▃▄▆▇▇▇▇▇█▆████▇▇█▇██████▇████████████
train_loss,█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_accuracy,0.91176
train_loss,0.06694


In [18]:
# visualize predictions
logits = model(
    batch_data.x,
    batch_data.src_seq,
    batch_data.edge_index,
    batch_data.bw_edge_index,
    batch_data.batch,
)

preds = logits.softmax(-1).argmax(-1)
preds, batch_data.trg_seq, batch_data.src_seq

(tensor([[    4,    40,     6,    34,     3,     2,     2,     2,     2,     2,
              2,     2,     2,     2,     2,     2,     2],
         [    4,    17,    34,     6,    35,   227,     2,     2,     2,     2,
              2,     2,     2,     2,     2,     2,     2],
         [    4,    34,     6,    35,   681,     2,     2,     2,     2,     2,
              2,     2,     2,     2,     2,     2,     2],
         [    4,    42,     6,    35,   227,     2,     2,     2,     2,     2,
              2,     2,     2,     2,     2,     2,     2],
         [    4,    42,     6,     6,   681,     5,     5, 14416,     2,     2,
              2,     2,     2,     2,     2,     2,     2]], device='cuda:0'),
 tensor([[    4,    40,     6,    34,     3,     2,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
         [    4,    17,    34,     6,    35,   227,     2,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
 

# Experiment 2
As we are able to overfit a single batch, we can move to the next step

*Goal*: train and eval on full dataset with simple model 

In [19]:
experiment_name = "graph-seq train and eval"

trials = [
    # trial setup
    dict(
        job_type="train",
        project=project,
        group=experiment_name,
        notes=
        "test training and validation pipeline on the entire dataset with a simple model",
        config=dict(
            dataset_base_path="data/wiki",
            train_dataset_name="train",
            dev_dataset_name="dev",
            vocab_path="data/wiki/entity_2_id.bin",
            batch_size=64,
            learning_rate=0.003,
            device="cuda",
            accumulation_steps=1,
            max_grad_norm=20.,
            epochs=10,
            pad_idx=0,
            emb_dim=60,
            graph_conv_layers=3,
            rnn_layers=2,
            rnn_dropout=0.5,
        ),
    )
]

In [20]:
for trial in trials:
    with wandb.init(**trial) as exp:
        train_dataset = WikiDataset(
            exp.config.dataset_base_path,
            exp.config.train_dataset_name,
            exp.config.vocab_path,
        )
        train_dl = DataLoader(
            train_dataset,
            batch_size=exp.config.batch_size,
            shuffle=True,
        )
        exp.config.steps_per_epoch = len(train_dl)
        exp.config.vocab_size = len(dev_dataset.entity_2_id.data)

        dev_dataset = WikiDataset(
            exp.config.dataset_base_path,
            exp.config.dev_dataset_name,
            exp.config.vocab_path,
        )
        dev_dl = DataLoader(
            dev_dataset,
            batch_size=exp.config.batch_size,
            shuffle=False,
        )

        # create model
        model = GraphSeq(
            emb_dim=exp.config.emb_dim,
            vocab_size=exp.config.vocab_size,
            pad_idx=exp.config.pad_idx,
            graph_conv_layers=exp.config.graph_conv_layers,
            rnn_decoder_layers=exp.config.rnn_layers,
            rnn_dropout=exp.config.rnn_dropout,
        )

        # create optimizer
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=exp.config.learning_rate,
        )

        device = exp.config.device
        # setup for training
        optimizer.zero_grad()
        model.train()
        model = model.to(device)

        for epoch in range(exp.config.epochs):

            metrics = train_fn(
                model=model,
                dataloader=dev_dl,
                optimizer=optimizer,
                steps_per_epoch=exp.config.steps_per_epoch,
                device=device,
                gradient_accumulation_steps=exp.config.accumulation_steps,
                pad_idx=exp.config.pad_idx,
                max_grad_norm=exp.config.max_grad_norm,
            )

            print("epoch:{epoch}\tacc:{acc} \t loss:{loss}".format(
                epoch=epoch,
                acc=metrics["train_accuracy"],
                loss=metrics["train_loss"],
            ))
            exp.log(metrics, step=epoch)

            if epoch % 1 == 0:
                # eval every 1 epochs
                is_best = False
                scores = eval_fn(
                    model=model,
                    dataloader=dev_dl,
                    device=device,
                )

                print(epoch, scores)
                print()
                exp.log(scores, step=epoch)


Processing...
Done!


batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:0	acc:0.46205942798218075 	 loss:1.4920014627404705
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


0 {'eval_loss': 0.9521422697739168, 'eval_accuracy': 0.5877499919745754, 'eval_blue_score': 0.10205085511420839}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:1	acc:0.5806792801952503 	 loss:0.9575620880191903
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


1 {'eval_loss': 0.7148003117604689, 'eval_accuracy': 0.6526435748451093, 'eval_blue_score': 0.1839454358009507}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:2	acc:0.6296894161524803 	 loss:0.7839474743327272
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


2 {'eval_loss': 0.5599505885532408, 'eval_accuracy': 0.7111810214760361, 'eval_blue_score': 0.2717121205156288}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:3	acc:0.6715583642728589 	 loss:0.6657895584517792
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


3 {'eval_loss': 0.4461886298024293, 'eval_accuracy': 0.7669416712144073, 'eval_blue_score': 0.38070674969269325}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:4	acc:0.703831040074014 	 loss:0.5802082945418277
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


4 {'eval_loss': 0.3654428055566369, 'eval_accuracy': 0.8064588616737826, 'eval_blue_score': 0.4638796137146781}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:5	acc:0.7286444994301987 	 loss:0.5173888485223291
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


5 {'eval_loss': 0.30989974037264334, 'eval_accuracy': 0.8331835254084942, 'eval_blue_score': 0.5333742155627581}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:6	acc:0.7467239439402299 	 loss:0.4691592994633381
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


6 {'eval_loss': 0.25836929357187316, 'eval_accuracy': 0.858688324612372, 'eval_blue_score': 0.6103597191551461}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:7	acc:0.7627602955738609 	 loss:0.42965020543803695
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


7 {'eval_loss': 0.22046721980653025, 'eval_accuracy': 0.8805335302237488, 'eval_blue_score': 0.6779827867448665}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:8	acc:0.7749248898344083 	 loss:0.4011964034141125
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


8 {'eval_loss': 0.19530869359997186, 'eval_accuracy': 0.8921222432666688, 'eval_blue_score': 0.7122706705837628}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:9	acc:0.7855499365147438 	 loss:0.37670996739496304
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


9 {'eval_loss': 0.17914718208890973, 'eval_accuracy': 0.9009020577188533, 'eval_blue_score': 0.7389551826714676}



0,1
eval_accuracy,▁▂▄▅▆▆▇███
eval_blue_score,▁▂▃▄▅▆▇▇██
eval_loss,█▆▄▃▃▂▂▁▁▁
train_accuracy,▁▄▅▆▆▇▇███
train_loss,█▅▄▃▂▂▂▁▁▁

0,1
eval_accuracy,0.9009
eval_blue_score,0.73896
eval_loss,0.17915
train_accuracy,0.78555
train_loss,0.37671


# Experiment 3
Eval attention component

In [4]:
project = "astrazeneca"
experiment_name = "att-single-batch"

trials = [
    # trial setup
    dict(
        job_type="train",
        project=project,
        group=experiment_name,
        notes="test training pipeline with a single batch on attention model",
        config=dict(
            dataset_base_path="data/wiki",
            dataset_name="dev",
            vocab_path="data/wiki/entity_2_id.bin",
            batch_size=5,
            learning_rate=0.003,
            device="cuda",
            accumulation_steps=1,
            max_grad_norm=20.,
            epochs=500,
            steps_per_epoch=5,
            pad_idx=0,
            emb_dim=6,
            graph_conv_layers=1,
            rnn_decoder_layers=1,
            rnn_dropout=0.,
        ),
    )
]

In [6]:
for trial in trials:
    with wandb.init(**trial) as exp:
        dev_dataset = WikiDataset(
            exp.config.dataset_base_path,
            exp.config.dataset_name,
            exp.config.vocab_path,
        )
        exp.config["vocab_size"] = len(dev_dataset.entity_2_id.data)

        dev_dl = DataLoader(
            dev_dataset,
            batch_size=exp.config.batch_size,
            shuffle=False,
        )

        # as it is a single batch experiment
        batch_data = next(iter(dev_dl))
        dev_dl = [batch_data] * exp.config.steps_per_epoch

        # create model
        model = GraphSeqAttn(
            emb_dim=exp.config.emb_dim,
            vocab_size=exp.config.vocab_size,
            pad_idx=exp.config.pad_idx,
            graph_conv_layers=exp.config.graph_conv_layers,
            rnn_decoder_layers=exp.config.rnn_decoder_layers,
            rnn_dropout=exp.config.rnn_dropout,
        )

        # create optimizer
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=exp.config.learning_rate,
        )

        # setup for training
        device = exp.config.device
        optimizer.zero_grad()
        model.train()
        model = model.to(device)

        for epoch in range(exp.config.epochs):

            metrics = train_fn(
                model=model,
                dataloader=dev_dl,
                optimizer=optimizer,
                steps_per_epoch=exp.config.steps_per_epoch,
                device=device,
                gradient_accumulation_steps=exp.config.accumulation_steps,
                pad_idx=exp.config.pad_idx,
                max_grad_norm=exp.config.max_grad_norm,
            )

            print("epoch:{epoch}\tacc:{acc} \t loss:{loss}".format(
                epoch=epoch,
                acc=metrics["train_accuracy"],
                loss=metrics["train_loss"],
            ))
            exp.log(metrics, step=epoch)

epoch:0	acc:0.0 	 loss:4.329547977447509
epoch:1	acc:0.0 	 loss:4.304247665405273
epoch:2	acc:0.0 	 loss:4.270625114440918
epoch:3	acc:0.0 	 loss:4.223822212219238
epoch:4	acc:0.0 	 loss:4.156934642791748
epoch:5	acc:0.029411764705882353 	 loss:4.062280368804932
epoch:6	acc:0.11176470588235295 	 loss:3.9368031501770018
epoch:7	acc:0.08823529411764706 	 loss:3.788045644760132
epoch:8	acc:0.08823529411764706 	 loss:3.62978138923645
epoch:9	acc:0.08823529411764706 	 loss:3.4729974269866943
epoch:10	acc:0.08823529411764706 	 loss:3.323265314102173
epoch:11	acc:0.08823529411764706 	 loss:3.1820910930633546
epoch:12	acc:0.08823529411764706 	 loss:3.0487760066986085
epoch:13	acc:0.08823529411764706 	 loss:2.9212154388427733
epoch:14	acc:0.08823529411764706 	 loss:2.7968277454376222
epoch:15	acc:0.08823529411764706 	 loss:2.6726473808288573
epoch:16	acc:0.08823529411764706 	 loss:2.545180606842041
epoch:17	acc:0.08823529411764706 	 loss:2.4110250949859617
epoch:18	acc:0.08823529411764706 	 los

0,1
train_accuracy,▁▃▄▅▇▇██████████████████████████████████
train_loss,█▅▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_accuracy,0.41176
train_loss,0.48834


In [5]:
# visualize predictions
logits = model(
    batch_data.x,
    batch_data.src_seq,
    batch_data.edge_index,
    batch_data.bw_edge_index,
    batch_data.batch,
)

preds = logits.softmax(-1).argmax(-1)
preds, batch_data.trg_seq, batch_data.src_seq

(tensor([[   4,   40,    6,   34,    3,    3,    3,    3,    3,    3,    3,    3,
             3,    3,    3,    3,    3],
         [   4,   17,   34,    2,   35,    2,    2,   35,   35,   35,   35,   35,
            35,   35,   35,   35,   35],
         [   4,   34,    2,   35,    2,    2,   35,   35,   35,   35,   35,   35,
            35,   35,   35,   35,   35],
         [   4,   42,    6,   35, 1061, 1061,   35,   35,   35,   35,   35,   35,
            35,   35,   35,   35,   35],
         [   4,   42,    6,   40,    5,    5,    5,    5,    5, 5271, 5271,    2,
             2,    2,    2,    2,    2]], device='cuda:0'),
 tensor([[    4,    40,     6,    34,     3,     2,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
         [    4,    17,    34,     6,    35,   227,     2,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
         [    4,    34,     6,    35,   681,     2,     0,     0,     0,     0,
      

In [8]:
mask = (batch_data.trg_seq != 0).view(-1)
l = batch_data.trg_seq.view(-1)
p = preds.view(-1)

p, l, mask
p[mask] == l[mask]

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True, False], device='cuda:0')

In [10]:
project = "astrazeneca"
experiment_name = "att-graph-seq train and eval"

trials = [
    # trial setup
    dict(
        job_type="train",
        project=project,
        group=experiment_name,
        notes="training and validation pipeline on the entire dataset",
        config=dict(
            dataset_base_path="data/wiki",
            train_dataset_name="train",
            dev_dataset_name="dev",
            vocab_path="data/wiki/entity_2_id.bin",
            batch_size=64,
            learning_rate=0.003,
            device="cuda",
            accumulation_steps=1,
            max_grad_norm=10.,
            epochs=10,
            pad_idx=0,
            emb_dim=60,
            graph_conv_layers=3,
            rnn_layers=2,
            rnn_dropout=0.25,
            ckp_base_path="ckps",
            optim_method="adam",
            weight_decay=0.001,
            warmup_persentage=2.5,
        ),
    )
]

In [11]:
for trial in trials:
    with wandb.init(**trial) as exp:
        train_dataset = WikiDataset(
            exp.config.dataset_base_path,
            exp.config.train_dataset_name,
            exp.config.vocab_path,
        )
        train_dl = DataLoader(
            train_dataset,
            batch_size=exp.config.batch_size,
            shuffle=True,
        )
        # scheduler parameters
        exp.config.batches_per_epoch = len(train_dl)
        exp.config.steps_per_epoch = int(
            exp.config.batches_per_epoch / exp.config.accumulation_steps
        )
        exp.config.num_warmup_steps = exp.config.steps_per_epoch * exp.config.warmup_persentage
        exp.config.num_training_steps = int(exp.config.steps_per_epoch * exp.config.epochs)

        exp.config.vocab_size = len(train_dataset.entity_2_id.data)

        dev_dataset = WikiDataset(
            exp.config.dataset_base_path,
            exp.config.dev_dataset_name,
            exp.config.vocab_path,
        )
        dev_dl = DataLoader(
            dev_dataset,
            batch_size=exp.config.batch_size,
            shuffle=False,
        )

        # create model
        model = GraphSeqAttn(
            emb_dim=exp.config.emb_dim,
            vocab_size=exp.config.vocab_size,
            pad_idx=exp.config.pad_idx,
            graph_conv_layers=exp.config.graph_conv_layers,
            rnn_decoder_layers=exp.config.rnn_layers,
            rnn_dropout=exp.config.rnn_dropout,
        )

        # setup optimizers
        named_params = list(model.named_parameters())
        group_params = get_group_params(
            named_params,
            exp.config.weight_decay,
            no_decay=["bias"],
        )
        optimizer = get_optimizer(
            method=exp.config.optim_method,
            params=group_params,
            lr=exp.config.learning_rate,
        )
        scheduler = get_linear_scheduler_with_warmup(
            optimizer,
            exp.config.num_warmup_steps,
            exp.config.num_training_steps,
        )

        device = exp.config.device
        # setup for training
        optimizer.zero_grad()
        model.train()
        model = model.to(device)

        ckp_path = Path(exp.config.ckp_base_path).joinpath(exp.name)
        ckp_path.mkdir(
            parents=True,
            exist_ok=True,
        )
        best_blue_score = float("-inf")

        for epoch in range(exp.config.epochs):

            metrics = train_fn(
                model=model,
                dataloader=dev_dl,
                optimizer=optimizer,
                steps_per_epoch=exp.config.steps_per_epoch,
                scheduler=scheduler,
                device=device,
                gradient_accumulation_steps=exp.config.accumulation_steps,
                pad_idx=exp.config.pad_idx,
                max_grad_norm=exp.config.max_grad_norm,
            )

            print("epoch:{epoch}\tacc:{acc} \t loss:{loss}".format(
                epoch=epoch,
                acc=metrics["train_accuracy"],
                loss=metrics["train_loss"],
            ))
            exp.log(metrics, step=epoch)

            if epoch % 1 == 0:
                # eval every 1 epochs
                is_best = False
                scores = eval_fn(
                    model=model,
                    dataloader=dev_dl,
                    device=device,
                )

                print(epoch, scores)
                print()
                exp.log(scores, step=epoch)

                if scores["eval_blue_score"] > best_blue_score:
                    best_blue_score = scores["eval_blue_score"]
                    is_best = True

                if isinstance(model, torch.nn.DataParallel):
                    state_dict = dict([
                        (n, p.to("cpu"))
                        for n, p in model.module.state_dict().items()
                    ])
                else:
                    state_dict = dict([
                        (n, p.to("cpu")) for n, p in model.state_dict().items()
                    ])

                save_checkpoint(
                    path_=ckp_path,
                    state=state_dict,
                    is_best=is_best,
                    filename=f"ckp_{epoch}.pth.tar",
                )


batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:0	acc:0.3558619659468459 	 loss:2.177056984381833
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


0 {'eval_loss': 1.285157428094835, 'eval_accuracy': 0.5313312574235177, 'eval_blue_score': 0.057201265253471086}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:1	acc:0.563337083822222 	 loss:1.1533790481077015
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


1 {'eval_loss': 0.9012198242725749, 'eval_accuracy': 0.618342910339957, 'eval_blue_score': 0.12149704467490718}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:2	acc:0.622430172241403 	 loss:0.8674439771627325
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


2 {'eval_loss': 0.6783349313067667, 'eval_accuracy': 0.6792879843343713, 'eval_blue_score': 0.22622212037206707}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:3	acc:0.6718643463427915 	 loss:0.67993181950659
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


3 {'eval_loss': 0.5193604655338057, 'eval_accuracy': 0.7336361593528298, 'eval_blue_score': 0.34496388573903664}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:4	acc:0.7094953223292223 	 loss:0.5654873748507592
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


4 {'eval_loss': 0.4217682887207378, 'eval_accuracy': 0.7739237905685211, 'eval_blue_score': 0.43810381629577494}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:5	acc:0.7456204810423629 	 loss:0.47774256520590636
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


5 {'eval_loss': 0.3320292842884858, 'eval_accuracy': 0.8216590157619338, 'eval_blue_score': 0.5574015385680758}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:6	acc:0.7722698328181431 	 loss:0.4176410717232138
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


6 {'eval_loss': 0.28717275319451635, 'eval_accuracy': 0.8444512214696157, 'eval_blue_score': 0.6248094364721009}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:7	acc:0.7953244975991249 	 loss:0.3697721321520551
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


7 {'eval_loss': 0.24542989017385425, 'eval_accuracy': 0.8692497833135373, 'eval_blue_score': 0.6956498168046662}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:8	acc:0.8128088431227518 	 loss:0.33710634380274546
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


8 {'eval_loss': 0.21265378230336038, 'eval_accuracy': 0.8907739719431158, 'eval_blue_score': 0.7553751494545705}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:9	acc:0.8245566271620525 	 loss:0.31656004853808783
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


9 {'eval_loss': 0.19587659991035858, 'eval_accuracy': 0.9005489390388751, 'eval_blue_score': 0.7858861646757241}



0,1
eval_accuracy,▁▃▄▅▆▇▇▇██
eval_blue_score,▁▂▃▄▅▆▆▇██
eval_loss,█▆▄▃▂▂▂▁▁▁
train_accuracy,▁▄▅▆▆▇▇███
train_loss,█▄▃▂▂▂▁▁▁▁

0,1
eval_accuracy,0.90055
eval_blue_score,0.78589
eval_loss,0.19588
train_accuracy,0.82456
train_loss,0.31656


In [12]:
test_dataset = WikiDataset(
    exp.config.dataset_base_path,
    "test",
    exp.config.vocab_path,
)
test_dl = DataLoader(
    test_dataset,
    batch_size=exp.config.batch_size,
    shuffle=False,
)


sate_dict = torch.load(ckp_path.joinpath("model_best.pth.tar"))
model = GraphSeqAttn(
    emb_dim=exp.config.emb_dim,
    vocab_size=exp.config.vocab_size,
    pad_idx=exp.config.pad_idx,
    graph_conv_layers=exp.config.graph_conv_layers,
    rnn_decoder_layers=exp.config.rnn_layers,
    rnn_dropout=exp.config.rnn_dropout,
)
model.load_state_dict(sate_dict)
model = model.to(device)

scores = eval_fn(
    model=model,
    dataloader=test_dl,
    device=device,
)

scores

eval batch : 50
eval batch : 100
eval batch : 150
eval batch : 200


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


{'eval_loss': 1.0348684235988372,
 'eval_accuracy': 0.7476343373596317,
 'eval_blue_score': 0.4024570087621595}