## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F

from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import recall_score

from src.models.model import GraphClassifier
from src.graphs.graph_loader import GraphDataset
from src.elastic_search_utils.elastic_utils import save_json, load_json

## Params

In [2]:
TEST_FILE_NAME = 'test_original_10b-testset4.json'

In [3]:
LOAD_FOLDER = '/datasets/johan_tests_original_format_graphs/similarity_shape_100_20__score_threshold_006__similarity_relevance_07/testset4/'

In [4]:
LOAD_QUESTIONS_FOLDER = '/datasets/johan_tests_original_format_centroid/merged_training_docs'
LOAD_QUESTIONS = f'{LOAD_QUESTIONS_FOLDER}/{TEST_FILE_NAME}'

In [5]:
LOAD_MODEL_PATH = '/datasets/johan_tests_models/v1/model.pth'

In [6]:
SAVING_FOLDER = '/datasets/johan_tests_original_format_graphs/similarity_shape_100_20__score_threshold_006__similarity_relevance_07'
SAVING_PATH = f'{SAVING_FOLDER}/{TEST_FILE_NAME}'

## Dataset params

In [7]:
BATCH_SIZE = 64

In [8]:
VAL_PERCENTAGE = 0.15
TEST_PERCENTAGE = 0.15

In [9]:
RELEVANCE_THRESHOLD = 0.04

In [10]:
RANDOM_STATE = 42

In [11]:
torch.manual_seed(RANDOM_STATE)

<torch._C.Generator at 0x7f2adc13e950>

## Model params

In [12]:
INPUT_DIM = 20
NODES_PER_GRAPH = 100
HIDDEN_CHANNELS = 16
OUTPUT_DIM = 2  # N_CLASSES
DROPOUT = 0.5

In [13]:
DEVICE = torch.device('cuda')

In [14]:
EPOCHS = 100

## Constants

In [15]:
questions = load_json(LOAD_QUESTIONS)

In [16]:
dataset = GraphDataset(
    dataset_path=LOAD_FOLDER,
    batch_size=BATCH_SIZE,
    val_percentage=VAL_PERCENTAGE,
    test_percentage=TEST_PERCENTAGE,
    random_state=RANDOM_STATE,
    score_threshold=RELEVANCE_THRESHOLD
)

In [17]:
model = GraphClassifier(
    input_dims=INPUT_DIM,
    nodes_per_graph=NODES_PER_GRAPH,
    hidden_channels=HIDDEN_CHANNELS,
    output_dim=OUTPUT_DIM,
    dropout=DROPOUT
)

model = model.to(DEVICE)

In [18]:
model.load_state_dict(torch.load(LOAD_MODEL_PATH))

<All keys matched successfully>

In [19]:
flat_labels = dataset.metadata['label'].tolist()
class_weights = compute_class_weight(
    'balanced', classes=[0,1], y=flat_labels
)
class_weights = torch.FloatTensor(class_weights)
class_weights = class_weights.to(DEVICE)

In [20]:
criterion = torch.nn.CrossEntropyLoss(
    weight=class_weights
)
criterion = criterion.to(DEVICE)

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-3)

## Prediction loop

In [22]:
y_pred = []
batch_pbar_val = tqdm(
    dataset.get_batch(),
    desc='loss = inf'
)
for test_batch in batch_pbar_val:
    test_batch = test_batch.to(DEVICE)
    out_test = model(test_batch)
    y_pred.append(out_test.tolist())
    test_batch.to('cpu')
    del out_test
    torch.cuda.empty_cache()
y_pred = np.concatenate(y_pred)

loss = inf: 119it [00:09, 12.35it/s]


In [23]:
dataset.metadata['is_relevant'] = y_pred.argmax(axis=1)

### Dictionary for fast removal of documents

In [24]:
relevant_documents = {}
for question, relevant_documents in dataset.metadata[
    dataset.metadata['is_relevant'] == 1
].groupby('question_id'):
    relevant_documents[question] = relevant_documents['document_id'].unique()

## Removing irrelevant documents from question dict

In [25]:
for question in questions['questions']:
    if question['id'] in relevant_documents.keys():
        relevant_documents = [
            document for document in question['documents']
            if document['id'] in relevant_documents[
                question['id']
            ]
        ]
        question['documents'] = relevant_documents
    else:
        continue

## Saving to disk

In [26]:
save_json(questions, SAVING_PATH)