In [1]:
import numpy as np
import pandas as pd
import torch

from classification.models.cnn import CNNClassifier
from classification.models.rnn import RNNClassifier
from classification.models.lstm import LSTMClassifier


In [2]:
model_fn = "./lstm.pth"

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
def load_data():
    test_df = pd.read_csv("./data/test.csv")
    test_X = torch.tensor(test_df.values) / 255
    
    return test_X

In [5]:
def load_model(model_fn, device):
    model = torch.load(model_fn, map_location=device)

    return model['model'], model['config']

In [6]:
def test(model, x, config):
    model.eval()

    test_pred = []
    with torch.no_grad():
        test_X_ = x.split(config.batch_size, dim=0)

        for x_i in test_X_:
            y_pred = model(x_i)
            test_pred.append(y_pred)

    test_pred = torch.cat(test_pred, dim=0)
    
    return test_pred

In [7]:
test_X = load_data().reshape(-1, 28, 28).to(device)
state_dict, config = load_model(model_fn, device)

model = LSTMClassifier(
    input_size=28,
    hidden_size=config.hidden_size,
    output_size=10,
    n_layers=config.n_layers,
    dropout_p=config.dropout_p,
).to(device)
model.load_state_dict(state_dict)

test_pred = test(model, test_X, config)

In [8]:
pred_df = pd.DataFrame({
        "ImageId": np.arange(test_X.size(0)) + 1,
        "Label": torch.argmax(test_pred, dim=-1).cpu().numpy()
    })
pred_df.to_csv('./submissions/lstm.csv', index=False)