In [None]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from pathlib import Path
import pandas as pd
import pickle
import numpy as np
import shutil
from tqdm import tqdm

import torch.utils.data
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, Engine
from ignite.metrics import Accuracy, Loss

from src.models import InsiderClassifier, LSTM_Encoder, LSTM_Encoder_Topics
from src.params import get_params
from src.dataset import CertDataset, create_data_loaders
from src.lstm_trainer import *

In [4]:
params = get_params()
params['model']['lstm_encoder']['embedding_size'] = None
params['model']['lstm_encoder']['use_content_topics'] = False


In [14]:
device = 'cpu'


if params['model']['lstm_encoder']['use_content_topics']:
    LSTM_model = LSTM_Encoder_Topics
else:
    LSTM_model = LSTM_Encoder
lstm_encoder = LSTM_model(params['model']['lstm_encoder'])
criterion = nn.NLLLoss() #softmax
optimizer = optim.Adam(lstm_encoder.parameters())

#train
train_engine = create_supervised_trainer_lstm(
                                        lstm_encoder, optimizer, criterion, device=device,
                                        prepare_batch=prepare_batch_lstm,
                                        checkpoint_dir=checkpoint_dir,
                                       )
#validate
val_engine = create_supervised_evaluator_lstm(
        lstm_encoder, device=device,
        prepare_batch=prepare_batch_lstm,
        metrics={},
        criterion=criterion,
        checkpoint_dir=checkpoint_dir,
)


In [16]:
@train_engine.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    val_engine.train_epoch = train_engine.state.epoch
    val_engine.run(val_loader)

    if val_engine.state.metrics['accuracy'] > wandb.run.summary["best_accuracy"]:
        wandb.save(str(checkpoint_dir) + '/best_model*')

In [17]:
df = pd.read_pickle('./df.pkl')

df.action_id = CertDataset.pad_to_length(df.action_id)
actions = np.vstack(df.action_id.values)

content = None
if params['model']:
    content_values = df.content.values
    content_values = CertDataset.pad_topic_matricies(content_values, max_length=200)
    content = df.content

cert_dataset = CertDataset(actions, df.malicious, content_topics=content)
train_loader, val_loader = create_data_loaders(cert_dataset, validation_split=0.3, random_seed=0, batch_size=1024)



TypeError: 'int' object is not subscriptable