# Neural Network for CLA Project

### Import statements

In [1]:
from sklearn import preprocessing
from sklearn import model_selection
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as utils
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import errno
import os
import sys
import Constants

### Hyperparameters

In [2]:
# data processing
sample_bias = 0     # adjust the difference in the number of the two types of samples (no algae vs algae)
test_size = 0.2
batch_size = 100    # batch size for the DataLoaders

# NN model
num_features = 17
input_size = num_features     # size of input layer
multiplier = 100                # multiplied by num_features to determine the size of each hidden layer
hidden_size = multiplier * input_size
output_size = 1
learning_rate = 0.01         # learning rate of optimizer
num_epochs = 100                # number of epochs

### Read in data

In [3]:
np.set_printoptions(threshold=np.inf)  # prints a full matrix rather than an abbreviated matrix

# define data and destination paths
dest_path = "/Users/Alliot/Documents/CLA-Project/Data/all-data-no-na/neural-network/"
data_path = "/Users/Alliot/Documents/CLA-Project/Data/data-sets/"
data_set = "data_2017_summer"

# if dest_path does not exist, create it
if not os.path.exists(dest_path):
    try:
        os.makedirs(dest_path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

# load data sets
X = np.load(data_path + data_set + ".npy")
y = np.load(data_path + data_set + "_labels.npy")

# manipulate data set. labels are converted to -1, +1 for binary classification; samples are removed uniformly 
# from the data set so that the disproportionately large number of negative samples (no algae) does 
# not bias the model.

num_alg = 0  # count the number of algae instances
num_no_alg = 0  # count the number of no algae instances

# Convert labels to binary: -1 for no algae and 1 for algae
for i in range(0, len(y)):
    if y[i] == 0:
        num_no_alg += 1
    if y[i] == 1 or y[i] == 2:
        y[i] = 1
        num_alg += 1

# oversample the data set by randomly adding occurences of algae until the difference between the number of algae
# samples and no algae samples equals sample_bias (defined below)
idx = 0
sample_bias = 0
length_y = len(y)
while num_alg != (num_no_alg + sample_bias):
    # circle through the data sets until the difference of num_no_alg and num_alg equals
    # the value specified by sample_bias
    if idx == (length_y - 1):
        idx = 0

    if y[idx] == 1:
        if np.random.rand() >= 0.5:  # add this sample with some probability
            y = np.append(y, y[idx])
            X = np.append(X, np.reshape(X[idx, :], newshape=(1, num_features)), axis=0)
            num_alg += 1
        else:
            idx += 1
    else:
        idx += 1

### Process and split data set

In [4]:
# standardize data: remove the mean and variance in each sample
num_splits = 2   # do not change
sss = model_selection.StratifiedShuffleSplit(n_splits=num_splits, test_size=test_size)

idx, _ = sss.split(X, y);
train_idx = idx[0]
test_idx = idx[1]
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

X_train = preprocessing.scale(X_train, axis=1, with_mean=True, with_std=True)
X_test = preprocessing.scale(X_test, axis=1, with_mean=True, with_std=True)
# X_train = preprocessing.normalize(X_train, norm="l2", axis=1) # Attempted normalization rather than standardization
# X_test = preprocessing.normalize(X_test, norm="l2", axis=1)

# convert numpy arrays to pytorch tensors
train_set_size = X_train.shape
test_set_size = X_test.shape
X_train, X_test = torch.from_numpy(X_train), torch.from_numpy(X_test)
y_train, y_test = torch.from_numpy(y_train), torch.from_numpy(y_test)

# convert pytorch tensors to pytorch TensorDataset
train_set = utils.TensorDataset(X_train, y_train)
test_set = utils.TensorDataset(X_test, y_test)

# create DataLoaders
train_loader = utils.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = utils.DataLoader(test_set, batch_size=test_set_size[0], shuffle=True)

### Define neural network model

In [94]:
class CLANet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CLANet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(hidden_size, hidden_size)
        self.relu4 = nn.ReLU()
        self.fc5 = nn.Linear(hidden_size, hidden_size)
        self.relu5 = nn.ReLU()
        self.fc6 = nn.Linear(hidden_size, output_size)
#         self.relu6 = nn.ReLU()
#         self.fc7 = nn.Linear(hidden_size, hidden_size)
#         self.relu7 = nn.ReLU()
#         self.fc8 = nn.Linear(hidden_size, hidden_size)
#         self.relu8 = nn.ReLU()
#         self.fc9 = nn.Linear(hidden_size, hidden_size)
#         self.relu9 = nn.ReLU()
#         self.fc10 = nn.Linear(hidden_size, hidden_size)
#         self.relu10 = nn.ReLU()
#         self.fc11 = nn.Linear(hidden_size, hidden_size)
#         self.relu11 = nn.ReLU()
#         self.fc12 = nn.Linear(hidden_size, output_size)
        self.sig1 = nn.Sigmoid()
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.relu3(out)
        out = self.fc4(out)
        out = self.relu4(out)
        out = self.fc5(out)
        out = self.relu5(out)
        out = self.fc6(out)
#         out = self.relu6(out)
#         out = self.fc7(out)
#         out = self.relu7(out)
#         out = self.fc8(out)
#         out = self.relu8(out)
#         out = self.fc9(out)
#         out = self.relu9(out)
#         out = self.fc10(out)
#         out = self.relu10(out)
#         out = self.fc11(out)
#         out = self.relu11(out)
#         out = self.fc12(out)
        out = self.sig1(out)
        return out

### Instantiate the neural network

In [95]:
model = CLANet(input_size, hidden_size, output_size)
criterion = nn.BCELoss()

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, nesterov=True, momentum=1, dampening=0)
model.double();     # cast model parameters to double

### Train the neural network

In [96]:
model.train()     # training mode
training_loss = []
avg_error = 0

for epoch in range(num_epochs):
    print("Epoch: %d/%d" % (epoch+1, num_epochs))

    for i, (samples, labels) in enumerate(train_loader):
        samples = Variable(samples)
        labels = Variable(labels)
        output = model(samples)                # forward pass
        output = torch.flatten(output)         # resize predicted labels
        labels = labels.type(torch.DoubleTensor)
        
        loss = criterion(output, labels)  # calculate loss
        optimizer.zero_grad()     # clear gradient
        loss.backward()           # calculate gradients
        optimizer.step()          # update weights
        
        # calculate and print error
        out = output

        for j in range(0, out.size()[0]):
            if out[j] < 0.5:
                out[j] = 0
            else:
                out[j] = 1
        error = 1 - torch.sum(output == labels).item() / labels.size()[0]
        avg_error += error
        training_loss.append(loss.data.numpy())
        print("  Iteration: %d/%d, Loss: %g, Error: %0.4f" % 
              (i+1, np.ceil(X_train.size()[0] / batch_size).astype(int), loss.item(), error))
    
    print("Average Error for this Epoch: %0.4f" % (avg_error / np.ceil(X_train.size()[0] / batch_size)))
    avg_error = 0

Epoch: 1/100
  Iteration: 1/17, Loss: 0.693813, Error: 0.5300
  Iteration: 2/17, Loss: 0.69355, Error: 0.5300
  Iteration: 3/17, Loss: 0.69278, Error: 0.4900
  Iteration: 4/17, Loss: 0.692035, Error: 0.4200
  Iteration: 5/17, Loss: 0.693409, Error: 0.5100
  Iteration: 6/17, Loss: 0.692825, Error: 0.4900
  Iteration: 7/17, Loss: 0.692656, Error: 0.4700
  Iteration: 8/17, Loss: 0.694561, Error: 0.5600
  Iteration: 9/17, Loss: 0.691966, Error: 0.4500
  Iteration: 10/17, Loss: 0.69322, Error: 0.5000
  Iteration: 11/17, Loss: 0.694165, Error: 0.5700
  Iteration: 12/17, Loss: 0.691771, Error: 0.4300
  Iteration: 13/17, Loss: 0.692568, Error: 0.4800
  Iteration: 14/17, Loss: 0.694196, Error: 0.5600
  Iteration: 15/17, Loss: 0.69433, Error: 0.5600
  Iteration: 16/17, Loss: 0.692121, Error: 0.4100
  Iteration: 17/17, Loss: 0.693578, Error: 0.5645
Average Error for this Epoch: 0.5014
Epoch: 2/100
  Iteration: 1/17, Loss: 0.692333, Error: 0.4400
  Iteration: 2/17, Loss: 0.692697, Error: 0.4800
  

  Iteration: 5/17, Loss: 0.680236, Error: 0.5000
  Iteration: 6/17, Loss: 0.684963, Error: 0.4500
  Iteration: 7/17, Loss: 0.687991, Error: 0.3900
  Iteration: 8/17, Loss: 0.691435, Error: 0.5000
  Iteration: 9/17, Loss: 0.682706, Error: 0.4400
  Iteration: 10/17, Loss: 0.691881, Error: 0.4500
  Iteration: 11/17, Loss: 0.691447, Error: 0.4400
  Iteration: 12/17, Loss: 0.684308, Error: 0.4000
  Iteration: 13/17, Loss: 0.684504, Error: 0.4500
  Iteration: 14/17, Loss: 0.684257, Error: 0.4000
  Iteration: 15/17, Loss: 0.688283, Error: 0.5000
  Iteration: 16/17, Loss: 0.679223, Error: 0.3500
  Iteration: 17/17, Loss: 0.686241, Error: 0.4355
Average Error for this Epoch: 0.4468
Epoch: 11/100
  Iteration: 1/17, Loss: 0.678038, Error: 0.4100
  Iteration: 2/17, Loss: 0.675851, Error: 0.4100
  Iteration: 3/17, Loss: 0.691015, Error: 0.4900
  Iteration: 4/17, Loss: 0.680186, Error: 0.3900
  Iteration: 5/17, Loss: 0.694188, Error: 0.4900
  Iteration: 6/17, Loss: 0.676723, Error: 0.4300
  Iteratio

  Iteration: 9/17, Loss: 0.658651, Error: 0.3300
  Iteration: 10/17, Loss: 0.627043, Error: 0.3400
  Iteration: 11/17, Loss: 0.597907, Error: 0.3000
  Iteration: 12/17, Loss: 0.609641, Error: 0.3100
  Iteration: 13/17, Loss: 0.689738, Error: 0.3600
  Iteration: 14/17, Loss: 0.626413, Error: 0.3400
  Iteration: 15/17, Loss: 0.647213, Error: 0.3500
  Iteration: 16/17, Loss: 0.606941, Error: 0.3400
  Iteration: 17/17, Loss: 0.622965, Error: 0.3548
Average Error for this Epoch: 0.3362
Epoch: 20/100
  Iteration: 1/17, Loss: 0.591702, Error: 0.3100
  Iteration: 2/17, Loss: 0.543498, Error: 0.2900
  Iteration: 3/17, Loss: 0.609542, Error: 0.3300
  Iteration: 4/17, Loss: 0.619013, Error: 0.3300
  Iteration: 5/17, Loss: 0.6673, Error: 0.4000
  Iteration: 6/17, Loss: 0.60594, Error: 0.2600
  Iteration: 7/17, Loss: 0.648164, Error: 0.3500
  Iteration: 8/17, Loss: 0.64359, Error: 0.4000
  Iteration: 9/17, Loss: 0.650016, Error: 0.3500
  Iteration: 10/17, Loss: 0.669317, Error: 0.4100
  Iteration: 

  Iteration: 13/17, Loss: 0.62507, Error: 0.3400
  Iteration: 14/17, Loss: 0.63963, Error: 0.3200
  Iteration: 15/17, Loss: 0.5901, Error: 0.3200
  Iteration: 16/17, Loss: 0.613476, Error: 0.2900
  Iteration: 17/17, Loss: 0.581046, Error: 0.2903
Average Error for this Epoch: 0.3124
Epoch: 29/100
  Iteration: 1/17, Loss: 0.554993, Error: 0.2600
  Iteration: 2/17, Loss: 0.570855, Error: 0.2700
  Iteration: 3/17, Loss: 0.580716, Error: 0.2700
  Iteration: 4/17, Loss: 0.643869, Error: 0.3500
  Iteration: 5/17, Loss: 0.645922, Error: 0.3700
  Iteration: 6/17, Loss: 0.626005, Error: 0.2700
  Iteration: 7/17, Loss: 0.557418, Error: 0.2600
  Iteration: 8/17, Loss: 0.585274, Error: 0.2800
  Iteration: 9/17, Loss: 0.674442, Error: 0.4000
  Iteration: 10/17, Loss: 0.584374, Error: 0.3100
  Iteration: 11/17, Loss: 0.482565, Error: 0.2000
  Iteration: 12/17, Loss: 0.604672, Error: 0.3200
  Iteration: 13/17, Loss: 0.606602, Error: 0.3400
  Iteration: 14/17, Loss: 0.553397, Error: 0.2900
  Iteration:

  Iteration: 17/17, Loss: 0.540085, Error: 0.2903
Average Error for this Epoch: 0.2824
Epoch: 38/100
  Iteration: 1/17, Loss: 0.56621, Error: 0.3400
  Iteration: 2/17, Loss: 0.545155, Error: 0.2600
  Iteration: 3/17, Loss: 0.633865, Error: 0.3700
  Iteration: 4/17, Loss: 0.560138, Error: 0.2500
  Iteration: 5/17, Loss: 0.51814, Error: 0.2400
  Iteration: 6/17, Loss: 0.528979, Error: 0.2400
  Iteration: 7/17, Loss: 0.53762, Error: 0.2500
  Iteration: 8/17, Loss: 0.512958, Error: 0.2300
  Iteration: 9/17, Loss: 0.573457, Error: 0.3000
  Iteration: 10/17, Loss: 0.590333, Error: 0.3200
  Iteration: 11/17, Loss: 0.545952, Error: 0.2700
  Iteration: 12/17, Loss: 0.601458, Error: 0.2900
  Iteration: 13/17, Loss: 0.680226, Error: 0.3600
  Iteration: 14/17, Loss: 0.561647, Error: 0.2400
  Iteration: 15/17, Loss: 0.545659, Error: 0.3100
  Iteration: 16/17, Loss: 0.684605, Error: 0.4000
  Iteration: 17/17, Loss: 0.510153, Error: 0.2581
Average Error for this Epoch: 0.2899
Epoch: 39/100
  Iteratio

  Iteration: 3/17, Loss: 0.57602, Error: 0.3100
  Iteration: 4/17, Loss: 0.606615, Error: 0.3200
  Iteration: 5/17, Loss: 0.540779, Error: 0.2800
  Iteration: 6/17, Loss: 0.494654, Error: 0.2200
  Iteration: 7/17, Loss: 0.55033, Error: 0.2800
  Iteration: 8/17, Loss: 0.465843, Error: 0.1800
  Iteration: 9/17, Loss: 0.517816, Error: 0.2600
  Iteration: 10/17, Loss: 0.63337, Error: 0.3000
  Iteration: 11/17, Loss: 0.562905, Error: 0.3000
  Iteration: 12/17, Loss: 0.696611, Error: 0.3700
  Iteration: 13/17, Loss: 0.599304, Error: 0.2900
  Iteration: 14/17, Loss: 0.572099, Error: 0.3200
  Iteration: 15/17, Loss: 0.593779, Error: 0.2900
  Iteration: 16/17, Loss: 0.499191, Error: 0.2400
  Iteration: 17/17, Loss: 0.587616, Error: 0.2097
Average Error for this Epoch: 0.2812
Epoch: 48/100
  Iteration: 1/17, Loss: 0.552853, Error: 0.3100
  Iteration: 2/17, Loss: 0.544484, Error: 0.2600
  Iteration: 3/17, Loss: 0.57318, Error: 0.3300
  Iteration: 4/17, Loss: 0.465394, Error: 0.1700
  Iteration: 5

  Iteration: 7/17, Loss: 0.541268, Error: 0.2900
  Iteration: 8/17, Loss: 0.447377, Error: 0.1900
  Iteration: 9/17, Loss: 0.610098, Error: 0.3000
  Iteration: 10/17, Loss: 0.434426, Error: 0.1800
  Iteration: 11/17, Loss: 0.560078, Error: 0.3100
  Iteration: 12/17, Loss: 0.489008, Error: 0.2400
  Iteration: 13/17, Loss: 0.584909, Error: 0.3000
  Iteration: 14/17, Loss: 0.497222, Error: 0.2500
  Iteration: 15/17, Loss: 0.479106, Error: 0.2200
  Iteration: 16/17, Loss: 0.563383, Error: 0.3000
  Iteration: 17/17, Loss: 0.601617, Error: 0.3387
Average Error for this Epoch: 0.2582
Epoch: 57/100
  Iteration: 1/17, Loss: 0.50189, Error: 0.2600
  Iteration: 2/17, Loss: 0.466868, Error: 0.2200
  Iteration: 3/17, Loss: 0.505893, Error: 0.2500
  Iteration: 4/17, Loss: 0.447518, Error: 0.1600
  Iteration: 5/17, Loss: 0.503874, Error: 0.2400
  Iteration: 6/17, Loss: 0.61873, Error: 0.3600
  Iteration: 7/17, Loss: 0.541705, Error: 0.2700
  Iteration: 8/17, Loss: 0.537146, Error: 0.2500
  Iteration:

  Iteration: 11/17, Loss: 0.605363, Error: 0.3500
  Iteration: 12/17, Loss: 0.515061, Error: 0.2400
  Iteration: 13/17, Loss: 0.520283, Error: 0.2500
  Iteration: 14/17, Loss: 0.479226, Error: 0.2500
  Iteration: 15/17, Loss: 0.507314, Error: 0.2600
  Iteration: 16/17, Loss: 0.563163, Error: 0.1900
  Iteration: 17/17, Loss: 0.442618, Error: 0.1290
Average Error for this Epoch: 0.2382
Epoch: 66/100
  Iteration: 1/17, Loss: 0.516264, Error: 0.2400
  Iteration: 2/17, Loss: 0.463478, Error: 0.1900
  Iteration: 3/17, Loss: 0.476732, Error: 0.2400
  Iteration: 4/17, Loss: 0.477599, Error: 0.2200
  Iteration: 5/17, Loss: 0.437971, Error: 0.2100
  Iteration: 6/17, Loss: 0.559676, Error: 0.2400
  Iteration: 7/17, Loss: 0.526186, Error: 0.2500
  Iteration: 8/17, Loss: 0.488625, Error: 0.2500
  Iteration: 9/17, Loss: 0.486903, Error: 0.1900
  Iteration: 10/17, Loss: 0.506434, Error: 0.2700
  Iteration: 11/17, Loss: 0.409562, Error: 0.1700
  Iteration: 12/17, Loss: 0.51899, Error: 0.2500
  Iterati

  Iteration: 15/17, Loss: 0.503194, Error: 0.2500
  Iteration: 16/17, Loss: 0.518472, Error: 0.2200
  Iteration: 17/17, Loss: 0.501722, Error: 0.2742
Average Error for this Epoch: 0.2561
Epoch: 75/100
  Iteration: 1/17, Loss: 0.44349, Error: 0.2500
  Iteration: 2/17, Loss: 0.500707, Error: 0.2900
  Iteration: 3/17, Loss: 0.613365, Error: 0.3100
  Iteration: 4/17, Loss: 0.520904, Error: 0.2700
  Iteration: 5/17, Loss: 0.440775, Error: 0.2300
  Iteration: 6/17, Loss: 0.481174, Error: 0.2000
  Iteration: 7/17, Loss: 0.4948, Error: 0.2400
  Iteration: 8/17, Loss: 0.543445, Error: 0.3100
  Iteration: 9/17, Loss: 0.478958, Error: 0.2500
  Iteration: 10/17, Loss: 0.501005, Error: 0.2800
  Iteration: 11/17, Loss: 0.417973, Error: 0.2200
  Iteration: 12/17, Loss: 0.433113, Error: 0.2000
  Iteration: 13/17, Loss: 0.513846, Error: 0.2800
  Iteration: 14/17, Loss: 0.539725, Error: 0.2800
  Iteration: 15/17, Loss: 0.528417, Error: 0.2200
  Iteration: 16/17, Loss: 0.503469, Error: 0.2400
  Iteration

  Iteration: 1/17, Loss: 0.599247, Error: 0.3100
  Iteration: 2/17, Loss: 0.547933, Error: 0.2700
  Iteration: 3/17, Loss: 0.591132, Error: 0.3400
  Iteration: 4/17, Loss: 0.480837, Error: 0.2500
  Iteration: 5/17, Loss: 0.548085, Error: 0.2700
  Iteration: 6/17, Loss: 0.476773, Error: 0.2500
  Iteration: 7/17, Loss: 0.415687, Error: 0.2200
  Iteration: 8/17, Loss: 0.423192, Error: 0.1600
  Iteration: 9/17, Loss: 0.427607, Error: 0.1900
  Iteration: 10/17, Loss: 0.468899, Error: 0.1900
  Iteration: 11/17, Loss: 0.653418, Error: 0.3800
  Iteration: 12/17, Loss: 0.513571, Error: 0.2900
  Iteration: 13/17, Loss: 0.560954, Error: 0.3000
  Iteration: 14/17, Loss: 0.634632, Error: 0.2500
  Iteration: 15/17, Loss: 0.51334, Error: 0.2900
  Iteration: 16/17, Loss: 0.493079, Error: 0.2300
  Iteration: 17/17, Loss: 0.607221, Error: 0.3226
Average Error for this Epoch: 0.2654
Epoch: 85/100
  Iteration: 1/17, Loss: 0.525949, Error: 0.3200
  Iteration: 2/17, Loss: 0.620022, Error: 0.3300
  Iteration

  Iteration: 5/17, Loss: 0.477175, Error: 0.2700
  Iteration: 6/17, Loss: 0.536114, Error: 0.2300
  Iteration: 7/17, Loss: 0.548975, Error: 0.3200
  Iteration: 8/17, Loss: 0.518829, Error: 0.2500
  Iteration: 9/17, Loss: 0.582777, Error: 0.2700
  Iteration: 10/17, Loss: 0.499977, Error: 0.2400
  Iteration: 11/17, Loss: 0.428713, Error: 0.2600
  Iteration: 12/17, Loss: 0.543405, Error: 0.2800
  Iteration: 13/17, Loss: 0.459673, Error: 0.2100
  Iteration: 14/17, Loss: 0.555832, Error: 0.2900
  Iteration: 15/17, Loss: 0.422198, Error: 0.1900
  Iteration: 16/17, Loss: 0.616565, Error: 0.3000
  Iteration: 17/17, Loss: 0.619661, Error: 0.2581
Average Error for this Epoch: 0.2487
Epoch: 94/100
  Iteration: 1/17, Loss: 0.438406, Error: 0.1900
  Iteration: 2/17, Loss: 0.524647, Error: 0.2600
  Iteration: 3/17, Loss: 0.54447, Error: 0.2300
  Iteration: 4/17, Loss: 0.427947, Error: 0.2000
  Iteration: 5/17, Loss: 0.508713, Error: 0.2400
  Iteration: 6/17, Loss: 0.48404, Error: 0.2600
  Iteration:

### Evaluate Model on Testing Set

In [104]:
model.eval()

for i, (samples, labels) in enumerate(test_loader):
    samples = Variable(samples)
    labels = Variable(labels)
    predictions = model(samples)
    predictions = torch.flatten(predictions)
    labels = labels.type(torch.DoubleTensor)

    for j in range(0, predictions.size()[0]):
        if predictions[j] < 0.5:
            predictions[j] = 0
        else:
            predictions[j] = 1
    
    error = 1 - torch.sum(predictions == labels).item() / labels.size()[0]
    
    print("Testing set Error: %0.4f" % error)
    
model_path = "./torch_model_2_10_19_test_error=" + str(error) + ".pt"

Testing set Error: 0.2500


### Save Model

In [None]:
torch.save(model, model_path)

### Load and Evaluate previous models

In [90]:
model = CLANet(input_size, hidden_size, output_size)
model.load_state_dict(torch.load("torch_model_2_10_19_test_error=0.2380.pt"))
model.double();     # cast model parameters to double
model.eval()

for i, (samples, labels) in enumerate(test_loader):
    samples = Variable(samples)
    labels = Variable(labels)
    predictions = model(samples)
    predictions = torch.flatten(predictions)
    labels = labels.type(torch.DoubleTensor)

    for j in range(0, predictions.size()[0]):
        if predictions[j] < 0.5:
            predictions[j] = 0
        else:
            predictions[j] = 1
    
    error = 1 - torch.sum(predictions == labels).item() / labels.size()[0]
    
    print("Testing set Error: %0.4f" % error)

Testing set Error: 0.2500
