In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import wandb
import torch
import numpy as np
import pickle

from torch import nn, Tensor, optim
from typing import Optional
from torch_geometric.loader import DataLoader


from src.datapipe import WikiDataset
from src.utils.common import PAD
from src.utils.training import train_fn, eval_fn
from src.modules.graph_encoder import GraphEncoder
from src.modules.seq_decoder import DecoderRNN
from src.modules.graph_seq import GraphSeq

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
wandb.login(key="fd8e6949c75375b623a566795f8460842fee1e14")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mandompesta[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ando_cavallari/.netrc


True

# Experiment 1
*Goal* overfit a single batch to verify code correctnes 

In [15]:
project = "astrazeneca"
experiment_name = "single-batch"

trials = [
    # trial setup
    dict(
        job_type="train",
        project=project,
        group=experiment_name,
        notes="test training pipeline with a single batch on simple model",
        config=dict(
            dataset_base_path="data/wiki",
            dataset_name="dev",
            vocab_path="data/wiki/entity_2_id.bin",
            batch_size=5,
            learning_rate=0.003,
            device="cuda",
            accumulation_steps=1,
            max_grad_norm=20.,
            epochs=500,
            steps_per_epoch=5,
            pad_idx=0,
            emb_dim=6,
            graph_conv_layers=1,
            rnn_decoder_layers=1,
            rnn_dropout=0.,
        ),
    )
]

In [17]:
for trial in trials:
    with wandb.init(**trial) as exp:
        dev_dataset = WikiDataset(
            exp.config.dataset_base_path,
            exp.config.dataset_name,
            exp.config.vocab_path,
        )
        exp.config["vocab_size"] = len(dev_dataset.entity_2_id.data)

        dev_dl = DataLoader(
            dev_dataset,
            batch_size=exp.config.batch_size,
            shuffle=False,
        )

        # as it is a single batch experiment
        batch_data = next(iter(dev_dl))
        dev_dl = [batch_data] * exp.config.steps_per_epoch

        # create model
        model = GraphSeq(
            emb_dim=exp.config.emb_dim,
            vocab_size=exp.config.vocab_size,
            pad_idx=exp.config.pad_idx,
            graph_conv_layers=exp.config.graph_conv_layers,
            rnn_decoder_layers=exp.config.rnn_decoder_layers,
            rnn_dropout=exp.config.rnn_dropout,
        )

        # create optimizer
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=exp.config.learning_rate,
        )

        # setup for training
        device = exp.config.device
        optimizer.zero_grad()
        model.train()
        model = model.to(device)

        for epoch in range(exp.config.epochs):

            metrics = train_fn(
                model=model,
                dataloader=dev_dl,
                optimizer=optimizer,
                steps_per_epoch=exp.config.steps_per_epoch,
                device=device,
                gradient_accumulation_steps=exp.config.accumulation_steps,
                pad_idx=exp.config.pad_idx,
                max_grad_norm=exp.config.max_grad_norm,
            )

            print("epoch:{epoch}\tacc:{acc} \t loss:{loss}".format(
                epoch=epoch,
                acc=metrics["train_accuracy"],
                loss=metrics["train_loss"],
            ))
            exp.log(metrics, step=epoch)

epoch:0	acc:0.12941176470588237 	 loss:4.322681713104248
epoch:1	acc:0.17647058823529413 	 loss:4.307934951782227
epoch:2	acc:0.17647058823529413 	 loss:4.28772554397583
epoch:3	acc:0.17647058823529413 	 loss:4.258470153808593
epoch:4	acc:0.14705882352941177 	 loss:4.214743137359619
epoch:5	acc:0.14705882352941177 	 loss:4.1482008934021
epoch:6	acc:0.17058823529411765 	 loss:4.04694504737854
epoch:7	acc:0.17647058823529413 	 loss:3.8974583625793455
epoch:8	acc:0.13529411764705881 	 loss:3.6893190860748293
epoch:9	acc:0.1588235294117647 	 loss:3.415463399887085
epoch:10	acc:0.10588235294117647 	 loss:3.066850519180298
epoch:11	acc:0.07647058823529412 	 loss:2.6263970375061034
epoch:12	acc:0.03529411764705882 	 loss:2.0766141414642334
epoch:13	acc:0.029411764705882353 	 loss:1.4917224884033202
epoch:14	acc:0.2411764705882353 	 loss:1.1506112813949585
epoch:15	acc:0.29411764705882354 	 loss:1.0390278100967407
epoch:16	acc:0.38823529411764707 	 loss:1.0054585456848144
epoch:17	acc:0.3 	 lo

0,1
train_accuracy,▁▂▃▃▄▅▆▆▇▇▇▇███████▇▇▇█▇█████▇█████████▇
train_loss,█▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_accuracy,0.87647
train_loss,0.11233


In [18]:
logits = model(
    batch_data.x,
    batch_data.src_seq,
    batch_data.edge_index,
    batch_data.bw_edge_index,
    batch_data.batch,
)

preds = logits.softmax(-1).argmax(-1)
preds, batch_data.trg_seq, batch_data.src_seq

(tensor([[    4,    40,     6,    34,     3,     2,     2,     2,     2,     2,
              2,     2,     2,     2,     2,     2,     2],
         [    4,    17,    34,     6,    35,   227,     2,     2,     2,     2,
              2,     2,     2,     2,     2,     2,     2],
         [    4,    34,     6,    35,   681,     2,     2,     2,     2,     2,
              2,     2,     2,     2,     2,     2,     2],
         [    4,    42,     6,    35,   227,     2,     2,     2,     2,     2,
              2,     2,     2,     2,     2,     2,     2],
         [    4,    42,     6,     6,   681,     5,     5, 14416,     2,     2,
              2,     2,     2,     2,     2,     2,     2]], device='cuda:0'),
 tensor([[    4,    40,     6,    34,     3,     2,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
         [    4,    17,    34,     6,    35,   227,     2,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0],
 

# Experiment 2
As we are able to overfit a single batch, we can move to the next step

*Goal*: train and eval on full dataset with simple model 

In [19]:
experiment_name = "graph-seq train and eval"

trials = [
    # trial setup
    dict(
        job_type="train",
        project=project,
        group=experiment_name,
        notes=
        "test training and validation pipeline on the entire dataset with a simple model",
        config=dict(
            dataset_base_path="data/wiki",
            train_dataset_name="train",
            dev_dataset_name="dev",
            vocab_path="data/wiki/entity_2_id.bin",
            batch_size=64,
            learning_rate=0.003,
            device="cuda",
            accumulation_steps=1,
            max_grad_norm=20.,
            epochs=10,
            pad_idx=0,
            emb_dim=60,
            graph_conv_layers=3,
            rnn_layers=2,
            rnn_dropout=0.5,
        ),
    )
]

In [20]:
for trial in trials:
    with wandb.init(**trial) as exp:
        train_dataset = WikiDataset(
            exp.config.dataset_base_path,
            exp.config.train_dataset_name,
            exp.config.vocab_path,
        )
        train_dl = DataLoader(
            train_dataset,
            batch_size=exp.config.batch_size,
            shuffle=True,
        )
        exp.config.steps_per_epoch = len(train_dl)
        exp.config.vocab_size = len(dev_dataset.entity_2_id.data)

        dev_dataset = WikiDataset(
            exp.config.dataset_base_path,
            exp.config.dev_dataset_name,
            exp.config.vocab_path,
        )
        dev_dl = DataLoader(
            dev_dataset,
            batch_size=exp.config.batch_size,
            shuffle=False,
        )

        # create model
        model = GraphSeq(
            emb_dim=exp.config.emb_dim,
            vocab_size=exp.config.vocab_size,
            pad_idx=exp.config.pad_idx,
            graph_conv_layers=exp.config.graph_conv_layers,
            rnn_decoder_layers=exp.config.rnn_layers,
            rnn_dropout=exp.config.rnn_dropout,
        )

        # create optimizer
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=exp.config.learning_rate,
        )

        device = exp.config.device
        # setup for training
        optimizer.zero_grad()
        model.train()
        model = model.to(device)

        for epoch in range(exp.config.epochs):

            metrics = train_fn(
                model=model,
                dataloader=dev_dl,
                optimizer=optimizer,
                steps_per_epoch=exp.config.steps_per_epoch,
                device=device,
                gradient_accumulation_steps=exp.config.accumulation_steps,
                pad_idx=exp.config.pad_idx,
                max_grad_norm=exp.config.max_grad_norm,
            )

            print("epoch:{epoch}\tacc:{acc} \t loss:{loss}".format(
                epoch=epoch,
                acc=metrics["train_accuracy"],
                loss=metrics["train_loss"],
            ))
            exp.log(metrics, step=epoch)

            if epoch % 1 == 0:
                # eval every 1 epochs
                is_best = False
                scores = eval_fn(
                    model=model,
                    dataloader=dev_dl,
                    device=device,
                )

                print(epoch, scores)
                print()
                exp.log(scores, step=epoch)


Processing...
Done!


batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:0	acc:0.46205942798218075 	 loss:1.4920014627404705
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


0 {'eval_loss': 0.9521422697739168, 'eval_accuracy': 0.5877499919745754, 'eval_blue_score': 0.10205085511420839}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:1	acc:0.5806792801952503 	 loss:0.9575620880191903
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


1 {'eval_loss': 0.7148003117604689, 'eval_accuracy': 0.6526435748451093, 'eval_blue_score': 0.1839454358009507}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:2	acc:0.6296894161524803 	 loss:0.7839474743327272
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


2 {'eval_loss': 0.5599505885532408, 'eval_accuracy': 0.7111810214760361, 'eval_blue_score': 0.2717121205156288}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:3	acc:0.6715583642728589 	 loss:0.6657895584517792
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


3 {'eval_loss': 0.4461886298024293, 'eval_accuracy': 0.7669416712144073, 'eval_blue_score': 0.38070674969269325}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:4	acc:0.703831040074014 	 loss:0.5802082945418277
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


4 {'eval_loss': 0.3654428055566369, 'eval_accuracy': 0.8064588616737826, 'eval_blue_score': 0.4638796137146781}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:5	acc:0.7286444994301987 	 loss:0.5173888485223291
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


5 {'eval_loss': 0.30989974037264334, 'eval_accuracy': 0.8331835254084942, 'eval_blue_score': 0.5333742155627581}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:6	acc:0.7467239439402299 	 loss:0.4691592994633381
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


6 {'eval_loss': 0.25836929357187316, 'eval_accuracy': 0.858688324612372, 'eval_blue_score': 0.6103597191551461}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:7	acc:0.7627602955738609 	 loss:0.42965020543803695
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


7 {'eval_loss': 0.22046721980653025, 'eval_accuracy': 0.8805335302237488, 'eval_blue_score': 0.6779827867448665}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:8	acc:0.7749248898344083 	 loss:0.4011964034141125
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


8 {'eval_loss': 0.19530869359997186, 'eval_accuracy': 0.8921222432666688, 'eval_blue_score': 0.7122706705837628}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:9	acc:0.7855499365147438 	 loss:0.37670996739496304
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


9 {'eval_loss': 0.17914718208890973, 'eval_accuracy': 0.9009020577188533, 'eval_blue_score': 0.7389551826714676}



0,1
eval_accuracy,▁▂▄▅▆▆▇███
eval_blue_score,▁▂▃▄▅▆▇▇██
eval_loss,█▆▄▃▃▂▂▁▁▁
train_accuracy,▁▄▅▆▆▇▇███
train_loss,█▅▄▃▂▂▂▁▁▁

0,1
eval_accuracy,0.9009
eval_blue_score,0.73896
eval_loss,0.17915
train_accuracy,0.78555
train_loss,0.37671
