### Single-input Network - PyTorch

Predicts the action given a single observation (checks whether the Dataset is 'difficult' to learn)

### Imports

In [None]:
"""
Imports external and own libraries
"""

import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader

from prettytable import PrettyTable

# own
import common.action as action
import common.world as world
import common.plot as plot
import common.preprocess as preprocess
import common.nets as nets
import common.train as train
import common.tools as tools

### Load datasets

In [None]:
with open("datasets/oracle_data.pickle", "rb") as handle:
    oracle_data = pickle.load(handle)

with open("datasets/oracle_reversed_data.pickle", "rb") as handle:
    oracle_reversed_data = pickle.load(handle)

with open("datasets/oracle_random_data.pickle", "rb") as handle:
    oracle_random_data = pickle.load(handle)

with open("datasets/oracle_reversed_random_data.pickle", "rb") as handle:
    oracle_reversed_random_data = pickle.load(handle)

with open("datasets/random_data.pickle", "rb") as handle:
    random_data = pickle.load(handle)
    
with open("datasets/oracle_reversed_random_data_small.pickle", "rb") as handle:
    oracle_reversed_random_data_small = pickle.load(handle)

with open("datasets/tmaze_random_reverse_data.pickle", "rb") as handle:
    tmaze_random_reverse_data = pickle.load(handle)

### Preprocess data

In [None]:
data = oracle_reversed_random_data_small

# split data
train_data, test_data = preprocess.split_data(data, 0.8)

# preprocess trainingset
oracle_train_data = preprocess.ObtainDataset(train_data, "observations", "actions")
oracle_test_data = preprocess.ObtainDataset(test_data, "observations", "actions")

# build dataloader (tensor format)
batch_size = 128
dataset_loader_train_data = DataLoader(
    oracle_train_data, batch_size=batch_size, shuffle=True
)
dataset_loader_test_data = DataLoader(
    oracle_test_data, batch_size=batch_size, shuffle=True
)

In [None]:
n = 4
# change seq. length, recode actions, split dataset
dataset = preprocess.split_n_steps_between(oracle_reversed_random_data_small, n=n)
dataet, counter, translation_dict = preprocess.recode_actions(dataset, n)
train_data, test_data = preprocess.split_data(dataset, 0.8)

# preprocess trainingset
oracle_train_data = preprocess.ObtainDualDataset(
    train_data, "observationsA", "observationsB", "actions"
)
oracle_test_data = preprocess.ObtainDualDataset(
    test_data, "observationsA", "observationsB", "actions"
)

# build dataloader (tensor format)
batch_size = 128
dataset_loader_train_data = DataLoader(
    oracle_train_data, batch_size=batch_size, shuffle=True
)
dataset_loader_test_data = DataLoader(
    oracle_test_data, batch_size=batch_size, shuffle=True
)

### Visualize a batch of data (size 64)

In [None]:
dataiter = iter(dataset_loader_train_data)
images, labels = dataiter.next()
plot.plot_64_observations(images)

### Initialize the model

In [None]:
forward = nets.Forward(4)
summary(forward, (3, 32, 32))

### Train model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    forward.parameters(), lr=0.001
)  # , momentum=0.9) # for small one 0.01 works well

episodes = 500
forward, train_loss, test_loss, train_acc, test_acc = train.train_Feedforward(
    dataset_loader_train_data,
    dataset_loader_test_data,
    forward,
    criterion,
    optimizer,
    episodes,
)

### Plot loss and accuracy curves for training and test set

In [None]:
plot.plot_losses(train_loss, test_loss)
plot.plot_acc(train_acc, test_acc, smooth=True)

### Plot example classifications and plot confusion matrix

In [None]:
plot.show_example_classificataions(dataset_loader_train_data, forward, amount=8)
plot.plot_confusion_matrix(dataset_loader_train_data, forward)

### Save and load models

In [None]:
# Save
torch.save(model, "models/Feedforward_Step1.pt")

# Load
# model = torch.load('models/Feedforward.pt')
# model.eval()