In [1]:
%load_ext autoreload
%autoreload 2

In [32]:
!pip install nltk

Collecting nltk
  Downloading nltk-3.8.1-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m41.5 MB/s[0m eta [36m0:00:00[0m00:01[0m
Collecting regex>=2021.8.3
  Downloading regex-2022.10.31-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (769 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m770.0/770.0 kB[0m [31m82.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: regex, nltk
Successfully installed nltk-3.8.1 regex-2022.10.31


In [50]:
import wandb
import torch
from torch import nn, Tensor, optim
import numpy as np
import pickle
from typing import Optional

from src.datapipe import WikiDataset
from src.utils.common import PAD
from src.modules.graph_encoder import GraphEncoder
from src.modules.seq_decoder import DecoderRNN
from src.modules.graph_seq import GraphSeq
from torch_geometric.loader import DataLoader

In [5]:
wandb.login(key="fd8e6949c75375b623a566795f8460842fee1e14")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ando_cavallari/.netrc


True

In [54]:
from nltk.translate import bleu_score
from sklearn.metrics import accuracy_score


def compute_correct(
    logits: torch.Tensor,
    labels: torch.Tensor,
    pad_idx: int,
) -> tuple[int, int]:
    mask = (labels != pad_idx)
    preds = logits.softmax(-1).argmax(-1).view(-1)
    correct = ((preds == labels) * mask).sum().item()
    total = mask.sum().item()
    return correct, total


def train(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: optim.Optimizer,
    steps_per_epoch: int,
    scheduler: Optional[optim.lr_scheduler.LambdaLR] = None,
    pad_idx: int = 0,
    device: str = "cuda",
    gradient_accumulation_steps: int = 1,
    max_grad_norm: float = 20.,
) -> dict[str, float]:
    # setup
    model = model.train()
    optimizer.zero_grad()
    # got loss
    loss_fn = torch.nn.CrossEntropyLoss(
        reduction="none",
        ignore_index=pad_idx,
    )
    loss_fn = loss_fn.to(device)

    # metrics
    total_loss = 0
    n_pred_total = 0
    n_pred_correct = 0
    steps = 0

    data_iter = iter(dataloader)

    while (steps / gradient_accumulation_steps) < steps_per_epoch:
        try:
            batch_data = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            batch_data = next(data_iter)

        batch_data = batch_data.to(device)

        with torch.set_grad_enabled(True):
            trg_logits = model(
                batch_data.x,
                batch_data.src_seq,
                batch_data.edge_index,
                batch_data.bw_edge_index,
                batch_data.batch,
            )

            trg_lable_t = batch_data.trg_seq.view(-1)
            loss_t = loss_fn(
                trg_logits.view(-1, model.vocab_size),
                trg_lable_t,
            )

            loss_t = loss_t.mean(-1)

            # accumulate the gradients
            if gradient_accumulation_steps > 1:
                # scale the loss if gradient accumulation is used
                loss_t = loss_t / gradient_accumulation_steps

            loss_t.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                max_grad_norm,
            )

            if steps % gradient_accumulation_steps == 0:
                # apply the accumulated gradients
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
                optimizer.zero_grad()

        # update metrics
        steps += 1
        correct, total = compute_correct(
            logits=trg_logits,
            labels=trg_lable_t,
            pad_idx=pad_idx,
        )

        total_loss += loss_t.item()
        n_pred_total += total
        n_pred_correct += correct

        # clrea GPU memory
        if steps % 50 == 0:
            torch.cuda.empty_cache()
            print(f"batch : {steps}")

    steps /= gradient_accumulation_steps
    total_loss = total_loss / steps
    accuracy = n_pred_correct / n_pred_total

    return dict(
        train_loss=total_loss,
        train_accuracy=accuracy,
    )


def eval(
    model: nn.Module,
    dataloader: DataLoader,
    pad_idx: int = 0,
    device: str = "cuda",
) -> dict[str, float]:
    # setup
    model = model.eval()

    # got loss
    loss_fn = torch.nn.CrossEntropyLoss(
        reduction="none",
        ignore_index=pad_idx,
    )
    loss_fn = loss_fn.to(device)

    # metrics
    total_loss = 0
    steps = 0
    preds = []
    labels = []

    data_iter = iter(dataloader)

    for batch_data in data_iter:
        batch_data = batch_data.to(device)

        with torch.set_grad_enabled(False):
            trg_logits = model(
                batch_data.x,
                batch_data.src_seq,
                batch_data.edge_index,
                batch_data.bw_edge_index,
                batch_data.batch,
            )

            trg_lable_t = batch_data.trg_seq.view(-1)
            loss_t = loss_fn(
                trg_logits.view(-1, model.vocab_size),
                trg_lable_t,
            )

            loss_t = loss_t.mean(-1)

        # update metrics
        steps += 1
        total_loss += loss_t.item()

        # update predictions
        preds_t = trg_logits.softmax(-1).argmax(-1).detach_().cpu().numpy()
        labels_t = batch_data.trg_seq.detach_().cpu().numpy()

        for pred_i, label_i in zip(preds_t, labels_t):
            # this is wrong, but will give a reference of the predictions
            preds.append(pred_i[label_i != pad_idx].tolist())
            labels.append(label_i[label_i != pad_idx].tolist())

        # clrea GPU memory
        if steps % 50 == 0:
            torch.cuda.empty_cache()
            print(f"eval batch : {steps}")

    total_loss = total_loss / steps
    accuracy = accuracy_score(
        np.concatenate(labels),
        np.concatenate(preds),
    )
    blue_score = np.array([
        bleu_score.sentence_bleu([label], pred)
        for pred, label in zip(preds, labels)
    ]).mean()

    return dict(
        eval_loss=total_loss,
        eval_accuracy=accuracy,
        eval_blue_score=blue_score,
    )


# Experiment 1
*Goal* overfit a single batch to verify code correctnes 

In [25]:
DATASET_PATH = "data/wiki"
DATASET_NAME = "dev"
VOCAB_PATH = "data/wiki/entity_2_id.bin"
BATCH_SIZE = 5
SHUFFLE = False
LR = 0.001
DEVICE = "cuda"
ACCUMULATION_STEPS = 1
MAX_GRAD_NORM = 20.
STEPS_PER_EPOCH = 4
EPOCHS = 500
PAD_IDX = 0

EMB_DIM = 6
GRAPH_CONV_LAYERS = 1
RNN_LAYERS = 1
RNN_DROPOUT = 0.

In [26]:
project = "astrazeneca"
experiment_name = "single-batch"

trials = [
    # trial setup
    dict(
        job_type="train",
        project=project,
        group=experiment_name,
        notes="test training pipeline with a single batch on simple model",
        config=dict(
            dataset_base_path=DATASET_PATH,
            dataset_name=DATASET_NAME,
            vocab_path=VOCAB_PATH,
            batch_size=BATCH_SIZE,
            shuffle=SHUFFLE,
            learning_rate=LR,
            device=DEVICE,
            accumulation_steps=ACCUMULATION_STEPS,
            max_grad_norm=MAX_GRAD_NORM,
            epochs=EPOCHS,
            steps_per_epoch=STEPS_PER_EPOCH,
            pad_idx=PAD_IDX,
            emb_dim=EMB_DIM,
            graph_conv_layers=GRAPH_CONV_LAYERS,
            rnn_layers=RNN_LAYERS,
            rnn_dropout=RNN_DROPOUT,
        ),
    )
]

In [27]:
for trial in trials:
    with wandb.init(**trial) as exp:
        dev_dataset = WikiDataset(
            DATASET_PATH,
            DATASET_NAME,
            VOCAB_PATH,
        )

        dev_dl = DataLoader(
            dev_dataset,
            batch_size=BATCH_SIZE,
            shuffle=SHUFFLE,
        )

        # as it is a single batch experiment
        batch_data = next(iter(dev_dl))
        dev_dl = [batch_data] * STEPS_PER_EPOCH

        # create model
        model = GraphSeq(
            emb_dim=EMB_DIM,
            vocab_size=len(dev_dataset.entity_2_id.data),
            pad_idx=PAD_IDX,
            graph_conv_layers=GRAPH_CONV_LAYERS,
            rnn_decoder_layers=RNN_LAYERS,
            rnn_dropout=RNN_DROPOUT,
        )

        # create optimizer
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=LR,
        )

        # setup for training
        optimizer.zero_grad()
        model.train()
        model = model.to(DEVICE)

        for epoch in range(EPOCHS):

            metrics = train(
                model=model,
                dataloader=dev_dl,
                optimizer=optimizer,
                steps_per_epoch=STEPS_PER_EPOCH,
                device=DEVICE,
                gradient_accumulation_steps=ACCUMULATION_STEPS,
                pad_idx=PAD_IDX,
                max_grad_norm=MAX_GRAD_NORM,
            )

            print("epoch:{epoch}\tacc:{acc} \t loss:{loss}".format(
                epoch=epoch,
                acc=metrics["train_accuracy"],
                loss=metrics["train_loss"],
            ))
            exp.log(metrics, step=epoch)

[34m[1mwandb[0m: Currently logged in as: [33mandompesta[0m. Use [1m`wandb login --relogin`[0m to force relogin


epoch:0	acc:0.14102564102564102 	 loss:4.962750554084778
epoch:1	acc:0.17307692307692307 	 loss:4.958779573440552
epoch:2	acc:0.10897435897435898 	 loss:4.954402565956116
epoch:3	acc:0.07692307692307693 	 loss:4.9494359493255615
epoch:4	acc:0.07692307692307693 	 loss:4.943694114685059
epoch:5	acc:0.07692307692307693 	 loss:4.936980724334717
epoch:6	acc:0.07692307692307693 	 loss:4.929075241088867
epoch:7	acc:0.07692307692307693 	 loss:4.919732332229614
epoch:8	acc:0.07692307692307693 	 loss:4.908666014671326
epoch:9	acc:0.07692307692307693 	 loss:4.8955559730529785
epoch:10	acc:0.07692307692307693 	 loss:4.880047798156738
epoch:11	acc:0.07692307692307693 	 loss:4.861753702163696
epoch:12	acc:0.07692307692307693 	 loss:4.840230107307434
epoch:13	acc:0.07692307692307693 	 loss:4.815027475357056
epoch:14	acc:0.07692307692307693 	 loss:4.785659074783325
epoch:15	acc:0.07692307692307693 	 loss:4.751682639122009
epoch:16	acc:0.07692307692307693 	 loss:4.712716579437256
epoch:17	acc:0.0769230

0,1
train_accuracy,▁▁▁▁▁▂▁▁▁▂▂▂▃▃▃▄▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇███████
train_loss,██▇▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_accuracy,1.0
train_loss,0.04951


# Experiment 2
As we are able to overfit a single batch, we can move to the next step

*Goal*: train and eval on full dataset with simple model 

In [52]:
experiment_name = "graph-seq train and eval"

trials = [
    # trial setup
    dict(
        job_type="train",
        project=project,
        group=experiment_name,
        notes=
        "test training and validation pipeline on the entire dataset with a simple model",
        config=dict(
            dataset_base_path="data/wiki",
            train_dataset_name="train",
            dev_dataset_name="dev",
            vocab_path="data/wiki/entity_2_id.bin",
            batch_size=64,
            learning_rate=0.001,
            device="cuda",
            accumulation_steps=1,
            max_grad_norm=20.,
            epochs=10,
            pad_idx=0,
            emb_dim=60,
            graph_conv_layers=3,
            rnn_layers=1,
            rnn_dropout=0.5,
        ),
    )
]

In [55]:
for trial in trials:
    with wandb.init(**trial) as exp:
        train_dataset = WikiDataset(
            exp.config["dataset_base_path"],
            exp.config["train_dataset_name"],
            exp.config["vocab_path"],
        )
        train_dl = DataLoader(
            train_dataset,
            batch_size=exp.config["batch_size"],
            shuffle=True,
        )
        exp.config["steps_per_epoch"] = len(train_dl)

        dev_dataset = WikiDataset(
            exp.config["dataset_base_path"],
            exp.config["dev_dataset_name"],
            exp.config["vocab_path"],
        )
        dev_dl = DataLoader(
            dev_dataset,
            batch_size=exp.config["batch_size"],
            shuffle=False,
        )

        # create model
        model = GraphSeq(
            emb_dim=exp.config["emb_dim"],
            vocab_size=len(dev_dataset.entity_2_id.data),
            pad_idx=exp.config["pad_idx"],
            graph_conv_layers=exp.config["graph_conv_layers"],
            rnn_decoder_layers=exp.config["rnn_layers"],
            rnn_dropout=exp.config["rnn_dropout"],
        )

        # create optimizer
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=exp.config["learning_rate"],
        )

        device = exp.config["device"]
        # setup for training
        optimizer.zero_grad()
        model.train()
        model = model.to(device)

        for epoch in range(exp.config["epochs"]):

            metrics = train(
                model=model,
                dataloader=dev_dl,
                optimizer=optimizer,
                steps_per_epoch=len(train_dl),
                device=device,
                gradient_accumulation_steps=exp.config["accumulation_steps"],
                pad_idx=exp.config["pad_idx"],
                max_grad_norm=exp.config["max_grad_norm"],
            )

            print("epoch:{epoch}\tacc:{acc} \t loss:{loss}".format(
                epoch=epoch,
                acc=metrics["train_accuracy"],
                loss=metrics["train_loss"],
            ))
            exp.log(metrics, step=epoch)

            if epoch % 1 == 0:
                # eval every 1 epochs
                is_best = False
                scores = eval(
                    model=model,
                    dataloader=dev_dl,
                    device=device,
                )

                print(epoch, scores)
                print()
                exp.log(scores, step=epoch)




batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:0	acc:0.5454857950386077 	 loss:1.539299227728611
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


0 {'eval_loss': 1.034965844316916, 'eval_accuracy': 0.6678873916547657, 'eval_blue_score': 0.15293316888491967}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:1	acc:0.707031291443073 	 loss:0.8456747814691568
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


1 {'eval_loss': 0.7425178856109128, 'eval_accuracy': 0.7442981773963209, 'eval_blue_score': 0.3130078731901078}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:2	acc:0.7852906664629656 	 loss:0.5814311626313358
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


2 {'eval_loss': 0.46895650616197876, 'eval_accuracy': 0.8221370699772351, 'eval_blue_score': 0.49503647616001184}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:3	acc:0.8574517430227105 	 loss:0.3713625632945309
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


3 {'eval_loss': 0.31700430234724825, 'eval_accuracy': 0.8869816042871485, 'eval_blue_score': 0.6625987538994469}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:4	acc:0.9210361590480374 	 loss:0.2205903691902469
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


4 {'eval_loss': 0.23028418664453607, 'eval_accuracy': 0.9112876998996083, 'eval_blue_score': 0.7245261139390683}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:5	acc:0.9675500075326929 	 loss:0.12293462592698252
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


5 {'eval_loss': 0.09515526396871517, 'eval_accuracy': 0.9709429747041274, 'eval_blue_score': 0.9055122469083374}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:6	acc:0.9907103859921618 	 loss:0.05964791111195635
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


6 {'eval_loss': 0.06218583085997538, 'eval_accuracy': 0.9823961087623545, 'eval_blue_score': 0.9372030192327453}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:7	acc:0.9975025409576918 	 loss:0.028717726043020624
eval batch : 50
eval batch : 100


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


7 {'eval_loss': 0.03615186017740405, 'eval_accuracy': 0.990059810811193, 'eval_blue_score': 0.9657315452845179}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:8	acc:0.9991342707822755 	 loss:0.01386546492116752
eval batch : 50
eval batch : 100
8 {'eval_loss': 0.011347157714860232, 'eval_accuracy': 0.9982749600554275, 'eval_blue_score': 0.9935185956970694}

batch : 50
batch : 100
batch : 150
batch : 200
batch : 250
batch : 300
batch : 350
batch : 400
batch : 450
batch : 500
batch : 550
batch : 600
batch : 650
batch : 700
batch : 750
batch : 800
batch : 850
epoch:9	acc:0.999522575799049 	 loss:0.006710440323320212
eval batch : 50
eval batch : 100
9 {'eval_loss': 0.005654219320135642, 'eval_accuracy': 0.9994202734612502, 'eval_blue_score': 0.997852666221114}



The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


0,1
eval_accuracy,▁▃▄▆▆▇████
eval_blue_score,▁▂▄▅▆▇▇███
eval_loss,█▆▄▃▃▂▁▁▁▁
train_accuracy,▁▃▅▆▇█████
train_loss,█▅▄▃▂▂▁▁▁▁

0,1
eval_accuracy,0.99942
eval_blue_score,0.99785
eval_loss,0.00565
train_accuracy,0.99952
train_loss,0.00671
