In [1]:
import pickle

from model import ClassifierLSTM

In [2]:
model = ClassifierLSTM(embedding_dim=64, hidden_dim=64, device='cuda:0', gapped=False, fixed_len=False)
print('Model initialized.')

Model initialized.


In [3]:
# data files
trn_humans_fn = './data/sample/my_data/humans_train.txt'
trn_mice_fn = './data/sample/my_data/mice_train.txt'
val_humans_fn = './data/sample/my_data/humans_val.txt'
val_mice_fn = './data/sample/my_data/mice_val.txt'
test_humans_fn = './data/sample/my_data/humans_test.txt'
test_mice_fn = './data/sample/my_data/mice_test.txt'

In [4]:
# fit model (it takes a while, run next cell to get trained model)
# model.fit(
#     trn_human_fn=trn_humans_fn,
#     trn_mouse_fn=trn_mice_fn,
#     val_human_fn=val_humans_fn,
#     val_mouse_fn=val_mice_fn,
#     n_epoch=10,
#     trn_batch_size=128,
#     vld_batch_size=128,
#     lr=.001,
#     save_fp='results'
# )

In [5]:
model = ClassifierLSTM(embedding_dim=64, hidden_dim=64, device='cuda:0', gapped=False, fixed_len=False)
model.load('./results/lstm_0.000095.npy')
with open('./results/metrics.pickle', 'rb') as fb:
    metrics = pickle.load(fb)
print('Model and metrics loaded.')

Model and metrics loaded.


In [6]:
import plotly.express as px
import pandas as pd

metrics_to_draw = {'epoch': [], 'loss': [], 'train_val': []}
for epoch_metrics in metrics:
  metrics_to_draw['epoch'].append(epoch_metrics['epoch'])
  metrics_to_draw['epoch'].append(epoch_metrics['epoch'])
  metrics_to_draw['loss'].append(epoch_metrics['train_loss'])
  metrics_to_draw['loss'].append(epoch_metrics['val_loss'])
  metrics_to_draw['train_val'].append('train')
  metrics_to_draw['train_val'].append('val')
metrics_df = pd.DataFrame.from_dict(metrics_to_draw)

fig = px.line(metrics_df, x="epoch", y="loss", color='train_val')

# learning curves
fig.show()

In [7]:
val_y_pred, val_y_true = model.eval(val_humans_fn, val_mice_fn, batch_size=512)
# random sanity check
print(val_y_pred[:3], val_y_pred[-3:])
print(val_y_true[:3], val_y_true[-3:])

test_y_pred, test_y_true = model.eval(test_humans_fn, test_mice_fn, batch_size=512)
# random sanity check
print(test_y_pred[:3], test_y_pred[-3:])
print(test_y_true[:3], test_y_true[-3:])


nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.

100%|| 101010/101010 [00:12<00:00, 8076.03seq/s]


[1. 1. 1.] [1.9307168e-07 1.1103709e-08 6.0984717e-06]
[1. 1. 1.] [0. 0. 0.]


100%|| 101011/101011 [00:12<00:00, 8050.23seq/s]

[1. 1. 1.] [1.3788005e-09 9.4594277e-10 1.9365145e-07]
[1. 1. 1.] [0. 0. 0.]





In [8]:
from sklearn.metrics import confusion_matrix
import numpy as np
conf_mat = confusion_matrix(val_y_true, val_y_pred.round())
print(conf_mat)
fig = px.imshow(
    conf_mat/np.sum(conf_mat, axis=1)[np.newaxis].T,
    text_auto=True,
    labels={'x': "Predicted", 'y': "Ground truth", 'color': "Ratio in ground truth"},
    x=[0, 1],
    y=[0, 1],
    title="Confusion matrix on validation set"
)
# confusion matrix on val set
fig.show()

[[49989    11]
 [    1 51009]]


In [9]:
conf_mat = confusion_matrix(test_y_true, test_y_pred.round())
print(conf_mat)
fig = px.imshow(
    conf_mat/np.sum(conf_mat, axis=1)[np.newaxis].T,
    text_auto=True,
    labels={'x': "Predicted", 'y': "Ground truth", 'color': "Ratio in ground truth"},
    x=[0, 1],
    y=[0, 1],
    title="Confusion matrix on test set"
)
# confusion matrix on test set
fig.show()

[[49994     6]
 [    3 51008]]


In [10]:
# The histogram of scores compared to true labels
fig_hist = px.histogram(
    x=test_y_pred, color=test_y_true, nbins=50,
    labels=dict(color='True Labels', x='Score'),
    title="Result scores on histogram from test set"
)
fig_hist.show()

In [11]:
from sklearn.metrics import roc_curve, auc

fpr, tpr, thresholds = roc_curve(test_y_true, test_y_pred)

fig = px.area(
    x=fpr, y=tpr,
    title=f'ROC Curve on test set (AUC={auc(fpr, tpr):.4f})',
    labels=dict(x='False Positive Rate', y='True Positive Rate'),
    width=700, height=500
)
fig.add_shape(
    type='line', line=dict(dash='dash'),
    x0=0, x1=1, y0=0, y1=1
)

fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(constrain='domain')
fig.show()

### Summary
Above results looks good. Model is handling given task properly. 