# Neural Network for CLA Project

### Import statements

In [3]:
from sklearn import preprocessing
from sklearn import model_selection
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as utils
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import errno
import os
import sys
import Constants

### Hyperparameters

In [6]:
# data processing
sample_bias = 0     # adjust the difference in the number of the two types of samples (no algae vs algae)
test_size = 0.2
batch_size = 100    # batch size for the DataLoaders

# NN model
num_features = 17
input_size = num_features     # size of input layer
multiplier = 100                # multiplied by num_features to determine the size of each hidden layer
hidden_size = multiplier * input_size
output_size = 1
learning_rate = 0.01         # learning rate of optimizer
num_epochs = 100                # number of epochs

### Read in data

In [3]:
np.set_printoptions(threshold=np.inf)  # prints a full matrix rather than an abbreviated matrix

# define data and destination paths
dest_path = "/Users/Alliot/Documents/CLA-Project/Data/all-data-no-na/neural-network/"
data_path = "/Users/Alliot/Documents/CLA-Project/Data/data-sets/"
data_set = "data_2017_summer"

# if dest_path does not exist, create it
if not os.path.exists(dest_path):
    try:
        os.makedirs(dest_path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

# load data sets
X = np.load(data_path + data_set + ".npy")
y = np.load(data_path + data_set + "_labels.npy")

# manipulate data set. labels are converted to -1, +1 for binary classification; samples are removed uniformly 
# from the data set so that the disproportionately large number of negative samples (no algae) does 
# not bias the model.

num_alg = 0  # count the number of algae instances
num_no_alg = 0  # count the number of no algae instances

# Convert labels to binary: -1 for no algae and 1 for algae
for i in range(0, len(y)):
    if y[i] == 0:
        num_no_alg += 1
    if y[i] == 1 or y[i] == 2:
        y[i] = 1
        num_alg += 1

# oversample the data set by randomly adding occurences of algae until the difference between the number of algae
# samples and no algae samples equals sample_bias (defined below)
idx = 0
sample_bias = 0
length_y = len(y)
while num_alg != (num_no_alg + sample_bias):
    # circle through the data sets until the difference of num_no_alg and num_alg equals
    # the value specified by sample_bias
    if idx == (length_y - 1):
        idx = 0

    if y[idx] == 1:
        if np.random.rand() >= 0.5:  # add this sample with some probability
            y = np.append(y, y[idx])
            X = np.append(X, np.reshape(X[idx, :], newshape=(1, num_features)), axis=0)
            num_alg += 1
        else:
            idx += 1
    else:
        idx += 1

### Process and split data set

In [5]:
# standardize data: remove the mean and variance in each sample
num_splits = 2   # do not change
sss = model_selection.StratifiedShuffleSplit(n_splits=num_splits, test_size=test_size)

idx, _ = sss.split(X, y);
train_idx = idx[0]
test_idx = idx[1]
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

X_train = preprocessing.scale(X_train, axis=1, with_mean=True, with_std=True)
X_test = preprocessing.scale(X_test, axis=1, with_mean=True, with_std=True)

# convert numpy arrays to pytorch tensors
train_set_size = X_train.shape
test_set_size = X_test.shape
X_train, X_test = torch.from_numpy(X_train), torch.from_numpy(X_test)
y_train, y_test = torch.from_numpy(y_train), torch.from_numpy(y_test)

# convert pytorch tensors to pytorch TensorDataset
train_set = utils.TensorDataset(X_train, y_train)
test_set = utils.TensorDataset(X_test, y_test)

# create DataLoaders
train_loader = utils.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = utils.DataLoader(test_set, batch_size=test_set_size[0], shuffle=True)

### Define neural network model

In [4]:
class CLANet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CLANet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(hidden_size, hidden_size)
        self.relu4 = nn.ReLU()
        self.fc5 = nn.Linear(hidden_size, hidden_size)
        self.relu5 = nn.ReLU()
        self.fc6 = nn.Linear(hidden_size, output_size)
#         self.relu6 = nn.ReLU()
#         self.fc7 = nn.Linear(hidden_size, hidden_size)
#         self.relu7 = nn.ReLU()
#         self.fc8 = nn.Linear(hidden_size, hidden_size)
#         self.relu8 = nn.ReLU()
#         self.fc9 = nn.Linear(hidden_size, hidden_size)
#         self.relu9 = nn.ReLU()
#         self.fc10 = nn.Linear(hidden_size, hidden_size)
#         self.relu10 = nn.ReLU()
#         self.fc11 = nn.Linear(hidden_size, hidden_size)
#         self.relu11 = nn.ReLU()
#         self.fc12 = nn.Linear(hidden_size, output_size)
        self.sig1 = nn.Sigmoid()
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.relu3(out)
        out = self.fc4(out)
        out = self.relu4(out)
        out = self.fc5(out)
        out = self.relu5(out)
        out = self.fc6(out)
#         out = self.relu6(out)
#         out = self.fc7(out)
#         out = self.relu7(out)
#         out = self.fc8(out)
#         out = self.relu8(out)
#         out = self.fc9(out)
#         out = self.relu9(out)
#         out = self.fc10(out)
#         out = self.relu10(out)
#         out = self.fc11(out)
#         out = self.relu11(out)
#         out = self.fc12(out)
        out = self.sig1(out)
        return out

### Instantiate the neural network

In [6]:
model = CLANet(input_size, hidden_size, output_size)
criterion = nn.BCELoss()

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, nesterov=True, momentum=1, dampening=0)
model.double();     # cast model parameters to double

### Train the neural network

In [7]:
model.train()     # training mode
training_loss = []
avg_error = 0

for epoch in range(num_epochs):
    print("Epoch: %d/%d" % (epoch+1, num_epochs))

    for i, (samples, labels) in enumerate(train_loader):
        samples = Variable(samples)
        labels = Variable(labels)
        output = model(samples)                # forward pass
        output = torch.flatten(output)         # resize predicted labels
        labels = labels.type(torch.DoubleTensor)
        
        loss = criterion(output, labels)  # calculate loss
        optimizer.zero_grad()     # clear gradient
        loss.backward()           # calculate gradients
        optimizer.step()          # update weights
        
        # calculate and print error
        out = output

        for j in range(0, out.size()[0]):
            if out[j] < 0.5:
                out[j] = 0
            else:
                out[j] = 1
        error = 1 - torch.sum(output == labels).item() / labels.size()[0]
        avg_error += error
        training_loss.append(loss.data.numpy())
        print("  Iteration: %d/%d, Loss: %g, Error: %0.4f" % 
              (i+1, np.ceil(X_train.size()[0] / batch_size).astype(int), loss.item(), error))
    
    print("Average Error for this Epoch: %0.4f" % (avg_error / np.ceil(X_train.size()[0] / batch_size)))
    avg_error = 0

Epoch: 1/100
  Iteration: 1/17, Loss: 0.694089, Error: 0.5700
  Iteration: 2/17, Loss: 0.693173, Error: 0.5000
  Iteration: 3/17, Loss: 0.693388, Error: 0.5100
  Iteration: 4/17, Loss: 0.69343, Error: 0.5100
  Iteration: 5/17, Loss: 0.693001, Error: 0.4200
  Iteration: 6/17, Loss: 0.693229, Error: 0.4700
  Iteration: 7/17, Loss: 0.692949, Error: 0.5300
  Iteration: 8/17, Loss: 0.693075, Error: 0.5100
  Iteration: 9/17, Loss: 0.693114, Error: 0.5000
  Iteration: 10/17, Loss: 0.692257, Error: 0.4600
  Iteration: 11/17, Loss: 0.695201, Error: 0.6000
  Iteration: 12/17, Loss: 0.693186, Error: 0.5000
  Iteration: 13/17, Loss: 0.693667, Error: 0.5300
  Iteration: 14/17, Loss: 0.692578, Error: 0.4800
  Iteration: 15/17, Loss: 0.694258, Error: 0.5500
  Iteration: 16/17, Loss: 0.694094, Error: 0.5400
  Iteration: 17/17, Loss: 0.69246, Error: 0.4677
Average Error for this Epoch: 0.5087
Epoch: 2/100
  Iteration: 1/17, Loss: 0.693867, Error: 0.5300
  Iteration: 2/17, Loss: 0.692359, Error: 0.4500


  Iteration: 5/17, Loss: 0.681162, Error: 0.4300
  Iteration: 6/17, Loss: 0.679254, Error: 0.4400
  Iteration: 7/17, Loss: 0.672279, Error: 0.4200
  Iteration: 8/17, Loss: 0.668323, Error: 0.3800
  Iteration: 9/17, Loss: 0.675795, Error: 0.4800
  Iteration: 10/17, Loss: 0.676638, Error: 0.4700
  Iteration: 11/17, Loss: 0.677124, Error: 0.4500
  Iteration: 12/17, Loss: 0.693609, Error: 0.4900
  Iteration: 13/17, Loss: 0.684081, Error: 0.4600
  Iteration: 14/17, Loss: 0.698402, Error: 0.4600
  Iteration: 15/17, Loss: 0.656563, Error: 0.3900
  Iteration: 16/17, Loss: 0.65774, Error: 0.4100
  Iteration: 17/17, Loss: 0.690987, Error: 0.5000
Average Error for this Epoch: 0.4424
Epoch: 11/100
  Iteration: 1/17, Loss: 0.663666, Error: 0.4500
  Iteration: 2/17, Loss: 0.709408, Error: 0.5100
  Iteration: 3/17, Loss: 0.649115, Error: 0.4000
  Iteration: 4/17, Loss: 0.664814, Error: 0.4300
  Iteration: 5/17, Loss: 0.677336, Error: 0.4600
  Iteration: 6/17, Loss: 0.680232, Error: 0.4500
  Iteration

  Iteration: 9/17, Loss: 0.612484, Error: 0.3800
  Iteration: 10/17, Loss: 0.62849, Error: 0.3900
  Iteration: 11/17, Loss: 0.632317, Error: 0.3300
  Iteration: 12/17, Loss: 0.624141, Error: 0.3600
  Iteration: 13/17, Loss: 0.642478, Error: 0.3400
  Iteration: 14/17, Loss: 0.631271, Error: 0.3500
  Iteration: 15/17, Loss: 0.59018, Error: 0.3600
  Iteration: 16/17, Loss: 0.636982, Error: 0.3500
  Iteration: 17/17, Loss: 0.588214, Error: 0.3226
Average Error for this Epoch: 0.3537
Epoch: 20/100
  Iteration: 1/17, Loss: 0.66545, Error: 0.4000
  Iteration: 2/17, Loss: 0.600554, Error: 0.3200
  Iteration: 3/17, Loss: 0.639912, Error: 0.3800
  Iteration: 4/17, Loss: 0.619547, Error: 0.3200
  Iteration: 5/17, Loss: 0.582456, Error: 0.3200
  Iteration: 6/17, Loss: 0.682587, Error: 0.4500
  Iteration: 7/17, Loss: 0.592994, Error: 0.3100
  Iteration: 8/17, Loss: 0.704534, Error: 0.4200
  Iteration: 9/17, Loss: 0.632918, Error: 0.3200
  Iteration: 10/17, Loss: 0.661331, Error: 0.3800
  Iteration:

  Iteration: 13/17, Loss: 0.652723, Error: 0.3900
  Iteration: 14/17, Loss: 0.636646, Error: 0.4100
  Iteration: 15/17, Loss: 0.563528, Error: 0.2700
  Iteration: 16/17, Loss: 0.572191, Error: 0.3300
  Iteration: 17/17, Loss: 0.556872, Error: 0.2903
Average Error for this Epoch: 0.3353
Epoch: 29/100
  Iteration: 1/17, Loss: 0.651697, Error: 0.3700
  Iteration: 2/17, Loss: 0.559136, Error: 0.3300
  Iteration: 3/17, Loss: 0.561257, Error: 0.2700
  Iteration: 4/17, Loss: 0.542695, Error: 0.2800
  Iteration: 5/17, Loss: 0.631738, Error: 0.3200
  Iteration: 6/17, Loss: 0.58831, Error: 0.3500
  Iteration: 7/17, Loss: 0.605717, Error: 0.3500
  Iteration: 8/17, Loss: 0.644227, Error: 0.4000
  Iteration: 9/17, Loss: 0.549128, Error: 0.3400
  Iteration: 10/17, Loss: 0.552546, Error: 0.2900
  Iteration: 11/17, Loss: 0.552059, Error: 0.2600
  Iteration: 12/17, Loss: 0.698357, Error: 0.4000
  Iteration: 13/17, Loss: 0.540155, Error: 0.2600
  Iteration: 14/17, Loss: 0.640516, Error: 0.3700
  Iterati

  Iteration: 17/17, Loss: 0.549045, Error: 0.3226
Average Error for this Epoch: 0.2925
Epoch: 38/100
  Iteration: 1/17, Loss: 0.565705, Error: 0.2900
  Iteration: 2/17, Loss: 0.562534, Error: 0.3200
  Iteration: 3/17, Loss: 0.474977, Error: 0.2500
  Iteration: 4/17, Loss: 0.517753, Error: 0.2500
  Iteration: 5/17, Loss: 0.61451, Error: 0.3400
  Iteration: 6/17, Loss: 0.540035, Error: 0.3000
  Iteration: 7/17, Loss: 0.666033, Error: 0.3800
  Iteration: 8/17, Loss: 0.627883, Error: 0.3200
  Iteration: 9/17, Loss: 0.538749, Error: 0.3000
  Iteration: 10/17, Loss: 0.488441, Error: 0.2500
  Iteration: 11/17, Loss: 0.540721, Error: 0.2900
  Iteration: 12/17, Loss: 0.603551, Error: 0.3900
  Iteration: 13/17, Loss: 0.55659, Error: 0.2700
  Iteration: 14/17, Loss: 0.566019, Error: 0.3100
  Iteration: 15/17, Loss: 0.534009, Error: 0.2700
  Iteration: 16/17, Loss: 0.476064, Error: 0.2200
  Iteration: 17/17, Loss: 0.540299, Error: 0.3065
Average Error for this Epoch: 0.2974
Epoch: 39/100
  Iterati

  Iteration: 3/17, Loss: 0.526438, Error: 0.2200
  Iteration: 4/17, Loss: 0.540823, Error: 0.2700
  Iteration: 5/17, Loss: 0.527743, Error: 0.3000
  Iteration: 6/17, Loss: 0.460448, Error: 0.2200
  Iteration: 7/17, Loss: 0.535753, Error: 0.2800
  Iteration: 8/17, Loss: 0.581209, Error: 0.3200
  Iteration: 9/17, Loss: 0.575239, Error: 0.3200
  Iteration: 10/17, Loss: 0.509207, Error: 0.2100
  Iteration: 11/17, Loss: 0.533205, Error: 0.2800
  Iteration: 12/17, Loss: 0.512418, Error: 0.2700
  Iteration: 13/17, Loss: 0.453094, Error: 0.2700
  Iteration: 14/17, Loss: 0.481233, Error: 0.2200
  Iteration: 15/17, Loss: 0.514393, Error: 0.2700
  Iteration: 16/17, Loss: 0.613548, Error: 0.3200
  Iteration: 17/17, Loss: 0.598965, Error: 0.3226
Average Error for this Epoch: 0.2713
Epoch: 48/100
  Iteration: 1/17, Loss: 0.431679, Error: 0.1800
  Iteration: 2/17, Loss: 0.520492, Error: 0.2900
  Iteration: 3/17, Loss: 0.500555, Error: 0.2700
  Iteration: 4/17, Loss: 0.521749, Error: 0.2600
  Iteratio

  Iteration: 7/17, Loss: 0.650081, Error: 0.3800
  Iteration: 8/17, Loss: 0.671739, Error: 0.4300
  Iteration: 9/17, Loss: 0.665847, Error: 0.4700
  Iteration: 10/17, Loss: 0.678731, Error: 0.4100
  Iteration: 11/17, Loss: 0.6243, Error: 0.3600
  Iteration: 12/17, Loss: 0.673903, Error: 0.4300
  Iteration: 13/17, Loss: 0.638406, Error: 0.4100
  Iteration: 14/17, Loss: 0.679025, Error: 0.4900
  Iteration: 15/17, Loss: 0.739614, Error: 0.5000
  Iteration: 16/17, Loss: 0.682588, Error: 0.4600
  Iteration: 17/17, Loss: 0.638919, Error: 0.3871
Average Error for this Epoch: 0.4287
Epoch: 57/100
  Iteration: 1/17, Loss: 0.718271, Error: 0.4700
  Iteration: 2/17, Loss: 0.67453, Error: 0.4500
  Iteration: 3/17, Loss: 0.684192, Error: 0.4300
  Iteration: 4/17, Loss: 0.654062, Error: 0.4200
  Iteration: 5/17, Loss: 0.64237, Error: 0.4300
  Iteration: 6/17, Loss: 0.640085, Error: 0.3400
  Iteration: 7/17, Loss: 0.679691, Error: 0.4100
  Iteration: 8/17, Loss: 0.683353, Error: 0.4900
  Iteration: 9

  Iteration: 11/17, Loss: 1.13297, Error: 0.6100
  Iteration: 12/17, Loss: 1.46136, Error: 0.4600
  Iteration: 13/17, Loss: 0.664356, Error: 0.3700
  Iteration: 14/17, Loss: 0.647861, Error: 0.4200
  Iteration: 15/17, Loss: 0.85195, Error: 0.4800
  Iteration: 16/17, Loss: 1.27622, Error: 0.4600
  Iteration: 17/17, Loss: 3.69266, Error: 0.5645
Average Error for this Epoch: 0.4738
Epoch: 66/100
  Iteration: 1/17, Loss: 8.6609, Error: 0.4800
  Iteration: 2/17, Loss: 1.89164, Error: 0.4900
  Iteration: 3/17, Loss: 1.01334, Error: 0.3600
  Iteration: 4/17, Loss: 0.906533, Error: 0.4800
  Iteration: 5/17, Loss: 0.712862, Error: 0.4600
  Iteration: 6/17, Loss: 0.714574, Error: 0.4300
  Iteration: 7/17, Loss: 0.801663, Error: 0.4400
  Iteration: 8/17, Loss: 1.21645, Error: 0.5200
  Iteration: 9/17, Loss: 0.753046, Error: 0.5000
  Iteration: 10/17, Loss: 0.750248, Error: 0.4900
  Iteration: 11/17, Loss: 0.652113, Error: 0.4000
  Iteration: 12/17, Loss: 0.658736, Error: 0.3900
  Iteration: 13/17

  Iteration: 15/17, Loss: 0.728676, Error: 0.3700
  Iteration: 16/17, Loss: 0.645943, Error: 0.4100
  Iteration: 17/17, Loss: 0.623471, Error: 0.3548
Average Error for this Epoch: 0.4179
Epoch: 75/100
  Iteration: 1/17, Loss: 0.63737, Error: 0.3300
  Iteration: 2/17, Loss: 0.674541, Error: 0.3700
  Iteration: 3/17, Loss: 0.746499, Error: 0.4300
  Iteration: 4/17, Loss: 0.767026, Error: 0.4500
  Iteration: 5/17, Loss: 0.63954, Error: 0.3300
  Iteration: 6/17, Loss: 0.755268, Error: 0.4000
  Iteration: 7/17, Loss: 0.679886, Error: 0.4100
  Iteration: 8/17, Loss: 0.655111, Error: 0.4300
  Iteration: 9/17, Loss: 0.717701, Error: 0.4800
  Iteration: 10/17, Loss: 0.67252, Error: 0.5000
  Iteration: 11/17, Loss: 0.660093, Error: 0.4600
  Iteration: 12/17, Loss: 0.614199, Error: 0.2700
  Iteration: 13/17, Loss: 0.693346, Error: 0.4400
  Iteration: 14/17, Loss: 0.64987, Error: 0.4100
  Iteration: 15/17, Loss: 0.686331, Error: 0.4400
  Iteration: 16/17, Loss: 5.97903, Error: 0.5800
  Iteration: 

  Iteration: 4/17, Loss: 15.1971, Error: 0.5500
  Iteration: 5/17, Loss: 10.7761, Error: 0.3900
  Iteration: 6/17, Loss: 13.5392, Error: 0.4900
  Iteration: 7/17, Loss: 14.3681, Error: 0.5200
  Iteration: 8/17, Loss: 14.3681, Error: 0.5200
  Iteration: 9/17, Loss: 11.605, Error: 0.4200
  Iteration: 10/17, Loss: 15.4734, Error: 0.5600
  Iteration: 11/17, Loss: 15.1971, Error: 0.5500
  Iteration: 12/17, Loss: 16.3023, Error: 0.5900
  Iteration: 13/17, Loss: 15.4734, Error: 0.5600
  Iteration: 14/17, Loss: 14.6444, Error: 0.5300
  Iteration: 15/17, Loss: 11.3287, Error: 0.4100
  Iteration: 16/17, Loss: 12.434, Error: 0.4500
  Iteration: 17/17, Loss: 13.3698, Error: 0.4839
Average Error for this Epoch: 0.4996
Epoch: 85/100
  Iteration: 1/17, Loss: 14.3681, Error: 0.5200
  Iteration: 2/17, Loss: 14.6444, Error: 0.5300
  Iteration: 3/17, Loss: 11.8813, Error: 0.4300
  Iteration: 4/17, Loss: 15.1971, Error: 0.5500
  Iteration: 5/17, Loss: 14.3681, Error: 0.5200
  Iteration: 6/17, Loss: 13.262

  Iteration: 11/17, Loss: 13.5392, Error: 0.4900
  Iteration: 12/17, Loss: 14.9208, Error: 0.5400
  Iteration: 13/17, Loss: 13.5392, Error: 0.4900
  Iteration: 14/17, Loss: 12.9866, Error: 0.4700
  Iteration: 15/17, Loss: 14.6444, Error: 0.5300
  Iteration: 16/17, Loss: 15.1971, Error: 0.5500
  Iteration: 17/17, Loss: 14.7068, Error: 0.5323
Average Error for this Epoch: 0.5007
Epoch: 94/100
  Iteration: 1/17, Loss: 15.7497, Error: 0.5700
  Iteration: 2/17, Loss: 14.0918, Error: 0.5100
  Iteration: 3/17, Loss: 14.0918, Error: 0.5100
  Iteration: 4/17, Loss: 15.7497, Error: 0.5700
  Iteration: 5/17, Loss: 13.8155, Error: 0.5000
  Iteration: 6/17, Loss: 16.026, Error: 0.5800
  Iteration: 7/17, Loss: 12.7103, Error: 0.4600
  Iteration: 8/17, Loss: 12.434, Error: 0.4500
  Iteration: 9/17, Loss: 14.6444, Error: 0.5300
  Iteration: 10/17, Loss: 12.1576, Error: 0.4400
  Iteration: 11/17, Loss: 13.8155, Error: 0.5000
  Iteration: 12/17, Loss: 12.9866, Error: 0.4700
  Iteration: 13/17, Loss: 14.

### Evaluate Model on Testing Set

In [8]:
model.eval()

for i, (samples, labels) in enumerate(test_loader):
    samples = Variable(samples)
    labels = Variable(labels)
    predictions = model(samples)
    predictions = torch.flatten(predictions)
    labels = labels.type(torch.DoubleTensor)

    for j in range(0, predictions.size()[0]):
        if predictions[j] < 0.5:
            predictions[j] = 0
        else:
            predictions[j] = 1
    
    error = 1 - torch.sum(predictions == labels).item() / labels.size()[0]
    
    print("Testing set Error: %0.4f" % error)
    
model_path = "./torch_model_2_18_19_lr=" + str(learning_rate) + "_dict.pt"

Testing set Error: 0.5000


### Save Model

In [9]:
torch.save(model.state_dict(), model_path)

### Load and Evaluate previous models

In [8]:
model = CLANet(input_size, hidden_size, output_size)
model.load_state_dict(torch.load("torch_model_2_18_19_lr=0.01_dict.pt"))
model.double()     # cast model parameters to double
model.eval()

for i, (samples, labels) in enumerate(test_loader):
    samples = Variable(samples)
    labels = Variable(labels)
    predictions = model(samples)
    predictions = torch.flatten(predictions)
    labels = labels.type(torch.DoubleTensor)

    for j in range(0, predictions.size()[0]):
        if predictions[j] < 0.5:
            predictions[j] = 0
        else:
            predictions[j] = 1
    
    error = 1 - torch.sum(predictions == labels).item() / labels.size()[0]
    
    print("Testing set Error: %0.4f" % error)

  "type " + container_type.__name__ + ". It won't be checked "


RuntimeError: storage has wrong size: expected -8422397943574246465 got 850