## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F

from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import recall_score, precision_score, accuracy_score

from src.models.model import GraphClassifier
from src.graphs.graph_loader import GraphDataset
from src.elastic_search_utils.elastic_utils import save_json

## Params

In [2]:
LOAD_FOLDER = '/datasets/johan_tests_original_format_graphs_bm25/similarity_shape_100_20__score_threshold_006__similarity_relevance_07/training'

In [3]:
SAVING_FOLDER = '/datasets/johan_tests_models/v1'

In [4]:
SAVING_MODEL_PATH = f'{SAVING_FOLDER}/model.pth'
SAVING_METRICS_PATH = f'{SAVING_FOLDER}/metrics.json'

In [5]:
DEBUG = False

## Dataset params

In [6]:
BATCH_SIZE = 256

In [7]:
VAL_PERCENTAGE = 0.15
TEST_PERCENTAGE = 0.15

In [8]:
RELEVANCE_THRESHOLD = 10  # 0.32 to Keep 20% of elastic most relevant

In [9]:
RANDOM_STATE = 42

In [10]:
torch.manual_seed(RANDOM_STATE)

<torch._C.Generator at 0x7f34340bb9d0>

## Model params

In [11]:
INPUT_DIM = 20
NODES_PER_GRAPH = 100
HIDDEN_CHANNELS = 16
OUTPUT_DIM = 2  # N_CLASSES
DROPOUT = 0.5

In [12]:
DEVICE = torch.device('cuda')

In [13]:
EPOCHS = 5

## Constants

In [14]:
dataset = GraphDataset(
    dataset_path=LOAD_FOLDER,
    batch_size=BATCH_SIZE,
    val_percentage=VAL_PERCENTAGE,
    test_percentage=TEST_PERCENTAGE,
    random_state=RANDOM_STATE,
    score_threshold=RELEVANCE_THRESHOLD,
    debug=DEBUG
)

In [15]:
model = GraphClassifier(
    input_dims=INPUT_DIM,
    nodes_per_graph=NODES_PER_GRAPH,
    hidden_channels=HIDDEN_CHANNELS,
    output_dim=OUTPUT_DIM,
    dropout=DROPOUT
)

model = model.to(DEVICE)

In [16]:
flat_labels = dataset.metadata['label'].tolist()

class_weights = compute_class_weight(
    'balanced', classes=[0,1], y=flat_labels
)


# MU = 0.1
# class_weights = torch.FloatTensor([  # GDOT WEIGHTS
#     np.log((MU*7615.0/5815.0) + 1),
#     np.log((MU*7615.0/1800.0) + 1)
# ])

class_weights = torch.FloatTensor(class_weights)
class_weights = class_weights.to(DEVICE)
class_weights

tensor([0.5881, 3.3372], device='cuda:0')

In [17]:
criterion = torch.nn.NLLLoss(
    weight=class_weights
)
criterion = criterion.to(DEVICE)

In [18]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)#, weight_decay=5e-3)

In [19]:
dataset.metadata['label'].value_counts()

0.0    332845
1.0     58658
Name: label, dtype: int64

## Training loop

In [None]:
losses = []
accurracies = []
recalls = []
precisions = []
for epoch in range(EPOCHS):
    model.train()
    batch_pbar = tqdm(
        dataset.get_batch('train'),
        desc='loss = inf'
    )

    epoch_loss = 0
    y_train = []
    y_train_pred = []
    for n_batch, batch in enumerate(batch_pbar):
        batch = batch.to(DEVICE)
        optimizer.step()
        out = model(batch).float()
        batch_y_test = batch.y #F.one_hot(batch.y, OUTPUT_DIM).float()

        loss = criterion(out, batch_y_test)
        loss.backward()
        optimizer.zero_grad()
        batch.to('cpu')
        batch_pbar.set_description(f"loss = {loss}")
        epoch_loss += loss.to('cpu').tolist()
        y_train.append(batch_y_test.tolist())
        y_train_pred.append(out.tolist())
        # FIXME JUST FOR MEMORY CLEANING
        del out
        del batch_y_test
        torch.cuda.empty_cache()
    y_train = np.concatenate(y_train)
    y_train_pred = np.concatenate(y_train_pred)
    acc = accuracy_score(
        y_train, y_train_pred.argmax(axis=1)
    )
    recall = recall_score(
        y_train, y_train_pred.argmax(axis=1), average='binary'
    )
    precision = precision_score(
        y_train, y_train_pred.argmax(axis=1), average='binary'
    )
    
    print("Train value counts", pd.value_counts(np.array(y_train_pred.argmax(axis=1))))
    print("Train recall", recall)
    print("Train precision", precision)
    print(f'Train accuracy: {acc:.4f}')
    
        
    epoch_loss = epoch_loss/(n_batch + 1)
    losses.append(epoch_loss)
    print(f'EPOCH LOSS: {epoch_loss}')
    model.eval()
    y_val = []
    y_pred = []
    batch_pbar_val = tqdm(
        dataset.get_batch('val'),
        desc='loss = inf'
    )
    for test_batch in batch_pbar_val:
        test_batch = test_batch.to(DEVICE)
        out_test = model(test_batch)
        test_gold = test_batch.y  # F.one_hot(test_batch.y, OUTPUT_DIM).float()
        y_val.append(test_gold.tolist())
        y_pred.append(out_test.tolist())
        test_batch.to('cpu')
        del out_test
        del test_gold
        torch.cuda.empty_cache()
    y_val = np.concatenate(y_val)
    y_pred = np.concatenate(y_pred)

    acc = accuracy_score(
        y_val, y_pred.argmax(axis=1)
    )
    recall = recall_score(
        y_val, y_pred.argmax(axis=1), average='binary'
    )
    precision = precision_score(
        y_val, y_pred.argmax(axis=1), average='binary'
    )
    accurracies.append(acc)
    recalls.append(recall)
    precisions.append(precision)
    print("Value counts", pd.value_counts(np.array(y_pred.argmax(axis=1))))
    print("Recall", recall)
    print("Precision", precision)
    print(f'Accuracy: {acc:.4f}')

loss = 0.6931698322296143: : 192it [01:35,  1.99it/s]

In [45]:
time_series = {
    'loss': losses,
    'accuracy': accurracies,
    'recall': recalls,
    'precision': precisions
}

In [None]:
save_json(time_series, SAVING_METRICS_PATH)

In [None]:
torch.save(model.state_dict(), SAVING_MODEL_PATH)

## Test load model

In [None]:
model = GraphClassifier(
    input_dims=INPUT_DIM,
    nodes_per_graph=NODES_PER_GRAPH,
    hidden_channels=HIDDEN_CHANNELS,
    output_dim=OUTPUT_DIM,
    dropout=DROPOUT
)

model = model.to(DEVICE)

In [None]:
model.load_state_dict(torch.load(SAVING_MODEL_PATH))

In [None]:
model.eval()
y_test = []
y_pred = []
batch_pbar_test = tqdm(
    dataset.get_batch('test'),
    desc='loss = inf'
)
for test_batch in batch_pbar_test:
    test_batch = test_batch.to(DEVICE)
    out_test = model(test_batch)
    test_gold = test_batch.y  # F.one_hot(test_batch.y, OUTPUT_DIM).float()
    y_test.append(test_gold.tolist())
    y_pred.append(out_test.tolist())
    test_batch.to('cpu')
    del out_test
    del test_gold
    torch.cuda.empty_cache()
y_test = np.concatenate(y_test)
y_pred = np.concatenate(y_pred)

acc = accuracy_score(
    y_test, y_pred.argmax(axis=1)
)
recall = recall_score(
    y_test, y_pred.argmax(axis=1), average='binary'
)
precision = precision_score(
    y_test, y_pred.argmax(axis=1), average='binary'
)
accurracies.append(acc)
recalls.append(recall)
precisions.append(precision)
print("Value counts", pd.value_counts(np.array(y_pred.argmax(axis=1))))
print("Recall", recall)
print("Precision", precision)
print(f'Accuracy: {acc:.4f}')