In [None]:
# you'll need to uncomment this 2 lines if you run notebooks in vscode
# import sys
# sys.path.append('/home/tincho/dev/recordar_ia/')
from src.load_data import create_data_block

import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from typing import Iterator, Union, Iterable
from more_itertools import flatten, unique_everseen
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import numpy as np
from copy import deepcopy
from torch_geometric.loader import DataLoader

In [None]:
from pytorch_lightning.utilities.warnings import PossibleUserWarning
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=PossibleUserWarning)

In [None]:
OUTPUT_DATA = "../../out_data/"
INPUT_DATA = "../../input_data/"

In [None]:
data_block = create_data_block(INPUT_DATA, OUTPUT_DATA, debug = True)

In [None]:
def split_dataset(data_block, train_size: float = 0.6, val_size: float = 0.2, test_size: float = 0.2) -> tuple:
    """
    Split the data_block into train, val and test sets. The train set will be used to train the model, the val set will be used to validate the model during training and the test set will be used to test the model after training.

    Args:
        train_size (float, optional): Percentage of the data_block that will be used for training. Defaults to 0.6.
        val_size (float, optional): Percentage of the data_block that will be used for validation. Defaults to 0.2.
        test_size (float, optional): Percentage of the data_block that will be used for testing. Defaults to 0.2.

    Returns:
        tuple: Tuple of three lists containing the train, val and test sets.
    """
    # check that the sum of the sizes is 1
    assert train_size + val_size + test_size == 1, "The sum of the sizes must be 1"

    # get the number of samples7
    n_samples = len(data_block)

    # get the number of samples for each set
    n_train = int(n_samples * train_size)
    n_val = int(n_samples * val_size)
    n_test = int(n_samples * test_size)

    # get the indices for each set
    indices = np.arange(n_samples)
    train_indices = indices[:n_train]
    val_indices = indices[n_train:n_train + n_val]
    test_indices = indices[n_train + n_val:]

    # get the samples for each set
    train = [data_block[i] for i in train_indices]
    val = [data_block[i] for i in val_indices]
    test = [data_block[i] for i in test_indices]

    return train, val, test

train, val, test = split_dataset(data_block)

print(f'Train: {len(train)} \nVal: {len(val)}\nTest: {len(test)}\nData_block: {len(data_block)}')

In [None]:
NODE_FEATURES = [
    "n_line",
    "x_position",
    "y_position",
    "box_area",
]

def get_node_features(token_box: dict):
    node_features = [
        v if isinstance(v, Iterable) else [v]
        for k, v in token_box.items()
        if k in NODE_FEATURES
    ]

    assert len(NODE_FEATURES) == len(node_features), (
        "mismatch in the number of node features ",
        f"expected => {len(NODE_FEATURES)} ",
        f"current  => {len(node_features)}",
    )

    node_features = list(flatten(node_features))
    return node_features


def get_labels(datablock) -> Iterator[str]:
    labels = (
        (token["label"] for token in data_item["token_boxes"])
        for data_item in datablock
    )

    labels = flatten(labels)
    return labels


def set_label_map(datablock):
    labels = unique_everseen(get_labels(datablock))
    
    label_map = {
        label : idx
        for idx, label in enumerate(labels)
    }

    inv_label_map = {v: k for k, v in label_map.items()}
    
    return label_map, inv_label_map


def get_doc_graph(data_item, label_map) -> nx.DiGraph:
    data_map = {
        token_box["id"]: {
            "node_features": get_node_features(token_box),
            "label": label_map[token_box["label"]],
        }
        for token_box in data_item["token_boxes"]
    }

    doc_graph = data_item["doc_graph"]
    node_attributes = {
        node: {
            "x": data_map[node]["node_features"],
            "y": data_map[node]["label"],
        }
        for node in doc_graph.nodes
    }

    nx.set_node_attributes(doc_graph, node_attributes)
    return doc_graph

def get_pg_graph(doc_graph: nx.DiGraph) -> Data:
    pg_graph = from_networkx(doc_graph)
    pg_graph.x = pg_graph.x.float()
    pg_graph.y = pg_graph.y.long()
    return pg_graph

def get_pg_graphs(data_block, label_map) -> list[Data]:
    doc_graphs = [get_doc_graph(data_item,label_map) for data_item in data_block]
    pg_graph = [get_pg_graph(doc_graph) for doc_graph in doc_graphs]

    return pg_graph

In [None]:
label_map, inv_label_map = set_label_map(train)

pg_graph_train = get_pg_graphs(train, label_map)
pg_graph_val = get_pg_graphs(val, label_map)
pg_graph_test = get_pg_graphs(test, label_map)

n_features = pg_graph_train[0].x.shape[1]
n_classes = len(label_map)

hidden_channels = 512
batch_size = 32
learning_rate = 0.0001
max_epochs = 5000

train_monitor = "loss"
es_patience = 100

In [None]:
MONITOR_MAP = {
    "f1":  {
        "monitor": "val_f1",
        "mode": "max",
    },
    "loss":  {
        "monitor": "val_loss",
        "mode": "min",
    }
}

monitor = MONITOR_MAP[train_monitor]["monitor"]
mode = MONITOR_MAP[train_monitor]["mode"]

early_stop_callback = EarlyStopping(
    monitor=monitor,
    mode=mode,
    min_delta=0.00,
    patience= es_patience,
    verbose=False,
)

In [None]:
train_loader = DataLoader(
    pg_graph_train, batch_size=batch_size, shuffle=False, num_workers = 16
)

val_loader = DataLoader(
    pg_graph_val, batch_size=batch_size, shuffle=False, num_workers = 16
)

test_loader = DataLoader(
    pg_graph_test, batch_size=batch_size, shuffle=False, num_workers = 16
)

In [None]:
from torch import nn
from torch_geometric.nn import SAGEConv
from torch.nn import CrossEntropyLoss
from torchmetrics import F1Score
from torch.optim import Adam

from pytorch_lightning import LightningModule
import torch
from torch.optim import Optimizer

class Model(LightningModule):
    def __init__(
        self,
        train_loader: DataLoader,
        val_loader: DataLoader,
        hidden_channels: int,
        n_features: int,
        n_classes: int,
    ):
        super().__init__()

        self.train_loader = train_loader
        self.val_loader = val_loader

        self.lin1 = nn.Linear(512, n_features // 2)
        self.sig1 = nn.Sigmoid()

        self.sage_conv1 = SAGEConv(
            n_features,
            hidden_channels,
            aggr="mean",
        )

        self.sage_conv2 = SAGEConv(hidden_channels, n_classes, aggr="mean")

        self.ce_loss = CrossEntropyLoss()
        self.f1 = F1Score('multiclass', num_classes = n_classes, top_k=1, average="macro")

    
    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
    ) -> torch.Tensor:

        x = self.sage_conv1(x, edge_index)
        x = x.relu()

        x = self.sage_conv2(x, edge_index)
        return x

    def training_step(
        self, batch: torch.Tensor, batch_index: torch.Tensor
    ) -> torch.Tensor:

        x, edge_index = (
            batch.x,
            batch.edge_index,
        )

        x_out = self.forward(x, edge_index)
        loss = self.ce_loss(x_out, batch.y)

        preds = x_out.argmax(dim=1)
        self.f1(preds, batch.y)

        self.log("train_loss", loss, prog_bar=True)
        self.log("train_f1", self.f1, prog_bar=True)

        return loss

    def validation_step(self, batch: torch.Tensor, batch_idx: torch.Tensor):
        """
        When the validation_step() is called,
        the model has been put in eval mode
        and PyTorch gradients have been disabled.
        At the end of validation, the model goes back to training mode
        and gradients are enabled.
        """

        x, edge_index = (
            batch.x,
            batch.edge_index,
        )

        x_out = self.forward(x,edge_index)
        loss = self.ce_loss(x_out, batch.y)

        preds = x_out.argmax(dim=1)
        self.f1(preds, batch.y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_f1", self.f1, prog_bar=True)

    def predict_step(
        self, batch: torch.Tensor, batch_idx: torch.Tensor
    ) -> list:

        x, edge_index = (
            batch.x,
            batch.edge_index,
        )
        pred = self(x, edge_index)

        pred = pred.softmax(dim=1)
        confidences = pred.max(dim=1)
        pred = pred.argmax(dim=1)

        return pred, confidences

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    def configure_optimizers(self) -> Optimizer:
        optimizer = Adam(self.parameters(), lr=learning_rate)
        return optimizer

In [None]:
model = Model(
    train_loader, 
    val_loader,
    hidden_channels= hidden_channels,
    n_features= n_features,
    n_classes= n_classes,
)

In [None]:
trainer = Trainer(
    max_epochs= max_epochs,
    callbacks=[
        early_stop_callback,
    ]
)

In [None]:
trainer.fit(model)

## Predict

In [None]:
def data_item_predict(data_item, pred_map: dict):

    data_item = deepcopy(data_item)
    data_item["token_boxes"] = [
        token_box | pred_map[token_box["id"]]
        for token_box in data_item["token_boxes"]
    ]
    return data_item

def predict(data_block, label_map, inv_label_map):
    pg_graphs = get_pg_graphs(data_block, label_map)

    loader = DataLoader(
        pg_graphs, batch_size=5, shuffle=False
    )

    pred_tuples = trainer.predict(model, loader)
    preds = [pred[0].cpu().numpy() for pred in pred_tuples]
    confidences = [pred[1][0].cpu().numpy() for pred in pred_tuples]

    preds = np.hstack(preds)
    confidences = np.hstack(confidences)

    pred_labels = (inv_label_map[label_idx] for label_idx in preds)
    node_ids = (
        (token_box["id"] for token_box in data["token_boxes"])
        for data in data_block
    )

    node_ids = flatten(node_ids)
    pred_map = {
        idx: {"pred_label": pred_label, "cls_conf": conf}
        for idx, pred_label, conf in zip(
            node_ids, pred_labels, confidences
        )
    }

    data_block = [
        data_item_predict(data_item, pred_map)
        for data_item in data_block
    ]

    return data_block

In [None]:
predict_data_block = predict(test, label_map, inv_label_map)

In [None]:
y_true = []
y_pred = []
for data_item in predict_data_block
    for token in data_item['token_boxes']:
        y_true.append(token['label'])
        y_pred.append(token['pred_label'])
        
len(y_true), len(y_pred)

In [None]:
y_true = [str(i) for i in y_true]
y_pred = [str(i) for i in y_pred]

In [None]:
for label in label_map.keys():
    print(f'{label} : true {len([i for i in y_true if i == label])} || pred {len([i for i in y_pred if i == label])}')

In [None]:
#TODO de acá hacia abajo nada fue agregado en el codigo de neural_network
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

from sklearn.metrics import classification_report, confusion_matrix


def show_metrics(
    y_true,
    y_pred,
    inv_label_map: dict,
):

    target_names = [str(i) for i in list(inv_label_map.values())]
    
    _, ax = plt.subplots(1, 1, figsize=(10, 5))
    cm = confusion_matrix(y_true, y_pred)

    sns.heatmap(
        cm,
        ax=ax,
        robust=True,
        annot=True,
        square=False,
        xticklabels=target_names,
        yticklabels=target_names,
    )

    ax.set_xlabel("prediction")
    ax.set_ylabel("true")

    plt.show()

In [None]:
show_metrics(y_true,
    y_pred,
    inv_label_map)