In [None]:
conda install -c conda-forge pytorch-lightning --yes


In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pickle
import numpy as np
import pandas as pd
import random
from Utils.pretrainedGloVe import pretrainedWordEmeddings
from DataLoader.swde_dataLoader import swde_data_test, collate_fn_test
from Model.SimpDOM_model import SeqModel
from Prediction.test_step import main as get_predictions


In [None]:
# Model Configurations

datapath = './data'
random.seed(7)
device = 'cpu'

n_workers=2
n_gpus=0
char_emb_dim = 16
char_hid_dim = 100
char_emb_dropout = 0.1

tag_emb_dim = 16
tag_hid_dim = 30

leaf_emb_dim = 30
pos_emb_dim = 20
word_emb_filename= '{}/glove.6B.100d.txt'.format(datapath)

train_websites = ['auto-aol','auto-yahoo','auto-motortrend','auto-autobytel', 'auto-msn', ]
val_websites = ['auto-aol','auto-yahoo']
attributes = ['model', 'price', 'engine', 'fuel_economy']
n_classes = len(attributes)+1
class_weights = [1,100,100,100,100]

In [None]:
#Data Loading

charDict = pickle.load(open('{}/English_charDict.pkl'.format(datapath),'rb'))
tagDict = pickle.load(open('{}/HTMLTagDict.pkl'.format(datapath),'rb'))
print(len(charDict), len(tagDict))

WordEmeddings = pretrainedWordEmeddings('{}/glove.6B.100d.txt'.format(datapath))
test_dataset = DataLoader(dataset = swde_data_test(val_websites, datapath, charDict, \
                                  tagDict, n_gpus, WordEmeddings), num_workers=n_workers, \
                                  batch_size=32, shuffle=False, collate_fn = collate_fn_test)

In [None]:
checkpoint_callback = ModelCheckpoint(
    filename='./data/weights',
    save_top_k=1,
    save_last = True,
    verbose=True,
    monitor='val_loss',
    mode='min'
)

config = {
    'out_dim': n_classes,
    'train_websites': train_websites,
    'val_websites': val_websites,
    'datapath': datapath,
    'n_workers': n_workers,
    'charDict' : charDict,
    'char_emb_dim' : char_emb_dim,
    'char_hid_dim' : char_hid_dim,
    'char_emb_dropout' : char_emb_dropout,
    'tagDict': tagDict,
    'tag_emb_dim': tag_emb_dim,
    'tag_hid_dim': tag_hid_dim,
    'leaf_emb_dim': leaf_emb_dim,
    'pos_emb_dim': pos_emb_dim,
    'attributes': attributes,
    'n_gpus' : n_gpus,
    'class_weights':class_weights,
    'word_emb_filename': word_emb_filename
}

In [None]:
model = SeqModel.load_from_checkpoint('{}/weights.ckpt'.format(datapath), config=config)
model = model.eval()


In [None]:
model = model.to(device)
df = get_predictions(test_dataset, model,device, 0.6)



In [None]:
from Prediction.PRSummary import cal_PR_summary
avg_prf1_dict = cal_PR_summary(df, n_classes)

from Prediction.WebsiteLevel_PR_Generator import cal_PR_summary as websiteLevel_cal_PR_summary
pr_summary_df, pr_results_df = websiteLevel_cal_PR_summary(df, n_classes)
print(pr_results_df)