In [1]:
import pandas as pd
import numpy as np

from itertools import product

import torch
from torch.utils.data import DataLoader

from tqdm.auto import tqdm

from models.Wide_Deep import Wide_Deep
from models.WD_Dataset import WD_Dataset


from utils.utils import (load_test_data, target_features, numerical_features,
                         categorical_features, features, get_results_df)

In [2]:
DATA_PATH = "./data/"
CHECKPOINT_PATH = "./checkpoints/"

BATCH_SIZE = 128

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [3]:
test_df = load_test_data(DATA_PATH)
ground_truths = test_df.loc[:, target_features].values

In [4]:
model = Wide_Deep(dim_features=len(numerical_features + features),
                                dim_hidden=[768,512,256,128,64,32])

PATH = CHECKPOINT_PATH+"WD_epoch_1_end"
model.load_state_dict(torch.load(PATH))
model = model.to(DEVICE)

In [5]:
def create_dataset(df, numerical_features=numerical_features,
                   features=features, targets=target_features):
    all_features = numerical_features + features
    
    feats = df.loc[:,all_features].values
    target_values = df.loc[:, targets].values
    
    return WD_Dataset(feats, target_values)

dataset = create_dataset(test_df)

In [6]:
model.eval()

predictions = []

dataloader = DataLoader(dataset, batch_size = BATCH_SIZE,
                        shuffle=False, drop_last=False)

for data in tqdm(dataloader):
    features = data['features'].to(DEVICE)

    labels = data['labels'].to(DEVICE)
        
    logits = model(features)
    predictions.append(torch.sigmoid(logits).detach().cpu())

  0%|          | 0/10881 [00:00<?, ?it/s]

In [7]:
predictions = [prediction.numpy() for prediction in predictions]

In [8]:
prediction_arr = np.concatenate(predictions, axis=0)
print(prediction_arr.shape)

(1392727, 4)


In [9]:
results_df = get_results_df(prediction_arr, ground_truths)

results_df

Unnamed: 0,reply,retweet,retweet_comment,like
rce,16.394132,23.055234,5.656965,10.009639
avg_prec,0.201447,0.455484,0.044813,0.663035
