# ECE 4194 Final Project. Spring 2019. The Ohio State University
## Authors: Michael Wharton.124, Alex Whitman.97, Benji Justice.251

This notebook is to train a resnet model to classify the actions specified in the dataset linked below.

Dataset: http://archive.ics.uci.edu/ml/datasets/Smartphone-Based+Recognition+of+Human+Activities+and+Postural+Transitions



### Define packages to autoreload

In [1]:
%load_ext autoreload
%autoreload 1

In [2]:
%aimport models.resnet
%aimport utils.data_helpers

### Import necessary modules

In [3]:
# torch modules
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# import model
from models.resnet import resnet18

# data functions
from utils.data_helpers import load_data
from utils.data_helpers import har_dataset

# classics
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import time

### Define model hyper parameters

In [4]:
# what percentage of data should be used for training 
split = 0.7

num_epoch = 800
lr = 0.5e-6
batch_size = 10
num_workers = 1

### Create Model

### Handle GPU

In [11]:
use_gpu=True
if torch.cuda.is_available() and use_gpu:
    dtype = torch.cuda.FloatTensor
    ltype = torch.cuda.LongTensor
    
    device = torch.device('cuda:0')
    
else:
    dtype = torch.FloatTensor
    ltype = torch.LongTensor
    device = torch.device('cpu')
    
print('device {} dtype {}'.format(device, dtype))

device cuda:0 dtype <class 'torch.cuda.FloatTensor'>


### Handle loss function and optimizer

### Track stats

### Load data

In [12]:
# ...
train_data, train_labels, test_data, test_labels = load_data()


In [13]:
# convert to torch Tensor objects
train_data = torch.Tensor(train_data)
train_labels = torch.Tensor(train_labels)

test_data = torch.Tensor(test_data)
test_labels = torch.Tensor(test_labels)

In [14]:
print(train_data.shape)
print(test_data.shape)

print(train_labels.shape)
print(test_labels.shape)

torch.Size([873, 2048, 6])
torch.Size([341, 2048, 6])
torch.Size([873])
torch.Size([341])


In [15]:
best_acc_list = []

drop_prob_list = [0.05]#, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.0]
num_tests = len(drop_prob_list)

results = np.zeros(shape=(num_tests, num_epoch, 2, 2))

### Train !

In [16]:
for test_num, drop_prob in enumerate(drop_prob_list):
    print()
    model = resnet18(drop_prob=drop_prob)
    model = model.to(device)
    
    crit = torch.nn.CrossEntropyLoss()
    opt  = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0)

    stats = np.zeros(shape=(num_epoch, 2, 2))

    # indices
    train = 0
    test  = 1
    acc   = 0
    loss  = 1


    train_ds = har_dataset(train_data, train_labels)
    train_dl = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers,shuffle=True)
    train_batches = train_dl.__len__()

    test_ds = har_dataset(test_data, test_labels)
    test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers=num_workers)
    test_batches = test_dl.__len__()

    best_acc = 0
    tstart = time.time()
    for epoch in range(num_epoch):

        # switch to train mode
        model.train()

        # track accuracy
        total = 0
        correct = 0

        # track loss
        tr_loss = []
        for batch, data in enumerate(train_dl):
            print('Training Iteration: {0:4d} of {1:4d}'.format(batch+1, train_batches), end='\r')

            # extract signal and labels
            signals, labels = data

            ## Normalize?

            # transpose signals (batch_size, length, channels) -> (batch_size, channels, length)
            signals = signals.transpose(1,2)

            # move data to device
            signals = signals.to(device)
            labels = labels.type(ltype).to(device)

            # process signals
            out = model(signals)

            # hard decision for classification
            _, pred = torch.max(out.data, 1)

    #         print(pred)
    #         print(labels)
    #         print()

            # compute accuracy
            total   += labels.size(0)
            correct += (pred == labels).sum().item()

            # compute loss
            b_loss = crit(out, labels)
            opt.zero_grad()
            b_loss.backward()
            opt.step()

            tr_loss.append(b_loss.item())

        # epoch training stats
        tr_accuracy = 100. * correct / total
        tr_loss = np.mean(tr_loss)

        # evaluate model
        model.eval()

        ts_loss = []
        total   = 0
        correct = 0
        with torch.no_grad():
            for batch, data in enumerate(test_dl):
                print('Testing  Iteration:: {0:4d} or {1:4d}'.format(batch+1, test_batches), end='\r')

                # extract signals and labels
                signals, labels = data

                # transpose signals (batch_size, length, channels) -> (batch_size, channels, length)
                signals = signals.transpose(1,2)

                # move data to device
                signals = signals.to(device)
                labels = labels.type(ltype).to(device)

                # process signals
                out = model(signals)

                # hard decision for classification
                _, pred = torch.max(out.data, 1)

                # compute accuracy
                total   += labels.size(0)
                correct += (pred == labels).sum().item()

                # compute loss
                b_loss = crit(out, labels)

                ts_loss.append(b_loss.item())
    
        # epoch testing stats
        ts_accuracy = 100. * correct / total
        ts_loss = np.mean(ts_loss)

        if ts_accuracy > best_acc:
            best_acc = ts_accuracy
            
        # save stats to plot later
        stats[epoch, train, loss ] = tr_loss
        stats[epoch, train, acc  ] = tr_accuracy
        stats[epoch, test,  loss ] = ts_loss
        stats[epoch, test,  acc  ] = ts_accuracy

        t1 = time.time() - tstart
        print('Epoch: {0:4d} Tr loss: {1:.6f} Ts loss: {2:.3f} Tr Accuracy: {3:3.2f}% Ts Accuracy: {4:3.2f}% Time: {5:4.2f}s Dprob {6:.3f}'.format(
                epoch+1,      tr_loss,         ts_loss,         tr_accuracy,           ts_accuracy, t1, drop_prob))
    
    # save results of each model to plot
    results[test_num,:,:,:] = stats
    best_acc_list.append(best_acc)
    
np.savez('dropout_test',res=results)


Epoch:    1 Tr loss: 2.903089 Ts loss: 2.570 Tr Accuracy: 7.67% Ts Accuracy: 10.56% Time: 8.23s Dprob 0.050
Epoch:    2 Tr loss: 2.655081 Ts loss: 2.417 Tr Accuracy: 9.74% Ts Accuracy: 12.90% Time: 9.97s Dprob 0.050
Epoch:    3 Tr loss: 2.522965 Ts loss: 2.310 Tr Accuracy: 10.77% Ts Accuracy: 16.13% Time: 11.82s Dprob 0.050
Epoch:    4 Tr loss: 2.405934 Ts loss: 2.222 Tr Accuracy: 12.03% Ts Accuracy: 19.65% Time: 13.72s Dprob 0.050
Epoch:    5 Tr loss: 2.286001 Ts loss: 2.167 Tr Accuracy: 16.95% Ts Accuracy: 21.70% Time: 15.58s Dprob 0.050
Epoch:    6 Tr loss: 2.222984 Ts loss: 2.102 Tr Accuracy: 14.32% Ts Accuracy: 25.81% Time: 17.50s Dprob 0.050
Epoch:    7 Tr loss: 2.146294 Ts loss: 2.017 Tr Accuracy: 19.47% Ts Accuracy: 29.03% Time: 19.38s Dprob 0.050
Epoch:    8 Tr loss: 2.081505 Ts loss: 1.987 Tr Accuracy: 20.96% Ts Accuracy: 29.33% Time: 21.25s Dprob 0.050
Epoch:    9 Tr loss: 2.045160 Ts loss: 1.961 Tr Accuracy: 23.94% Ts Accuracy: 31.38% Time: 23.09s Dprob 0.050
Epoch:   10 T

Epoch:   75 Tr loss: 0.784305 Ts loss: 0.821 Tr Accuracy: 76.86% Ts Accuracy: 77.71% Time: 144.95s Dprob 0.050
Epoch:   76 Tr loss: 0.810989 Ts loss: 0.813 Tr Accuracy: 74.23% Ts Accuracy: 79.47% Time: 146.78s Dprob 0.050
Epoch:   77 Tr loss: 0.788327 Ts loss: 0.787 Tr Accuracy: 76.52% Ts Accuracy: 77.13% Time: 148.60s Dprob 0.050
Epoch:   78 Tr loss: 0.770665 Ts loss: 0.799 Tr Accuracy: 76.17% Ts Accuracy: 78.59% Time: 150.38s Dprob 0.050
Epoch:   79 Tr loss: 0.780754 Ts loss: 0.796 Tr Accuracy: 75.60% Ts Accuracy: 80.06% Time: 152.09s Dprob 0.050
Epoch:   80 Tr loss: 0.758721 Ts loss: 0.797 Tr Accuracy: 77.21% Ts Accuracy: 77.71% Time: 153.91s Dprob 0.050
Epoch:   81 Tr loss: 0.744637 Ts loss: 0.752 Tr Accuracy: 78.01% Ts Accuracy: 80.65% Time: 155.74s Dprob 0.050
Epoch:   82 Tr loss: 0.738897 Ts loss: 0.762 Tr Accuracy: 79.04% Ts Accuracy: 81.52% Time: 157.50s Dprob 0.050
Epoch:   83 Tr loss: 0.756951 Ts loss: 0.739 Tr Accuracy: 77.09% Ts Accuracy: 79.47% Time: 159.30s Dprob 0.050
E

Epoch:  149 Tr loss: 0.497979 Ts loss: 0.538 Tr Accuracy: 84.08% Ts Accuracy: 84.16% Time: 280.30s Dprob 0.050
Epoch:  150 Tr loss: 0.482915 Ts loss: 0.488 Tr Accuracy: 84.77% Ts Accuracy: 85.92% Time: 282.14s Dprob 0.050
Epoch:  151 Tr loss: 0.483788 Ts loss: 0.535 Tr Accuracy: 85.45% Ts Accuracy: 85.04% Time: 283.98s Dprob 0.050
Epoch:  152 Tr loss: 0.488446 Ts loss: 0.467 Tr Accuracy: 84.99% Ts Accuracy: 87.39% Time: 285.89s Dprob 0.050
Epoch:  153 Tr loss: 0.480142 Ts loss: 0.483 Tr Accuracy: 84.42% Ts Accuracy: 86.51% Time: 287.72s Dprob 0.050
Epoch:  154 Tr loss: 0.465062 Ts loss: 0.451 Tr Accuracy: 86.71% Ts Accuracy: 86.80% Time: 289.57s Dprob 0.050
Epoch:  155 Tr loss: 0.480129 Ts loss: 0.483 Tr Accuracy: 84.88% Ts Accuracy: 85.92% Time: 291.44s Dprob 0.050
Epoch:  156 Tr loss: 0.470604 Ts loss: 0.526 Tr Accuracy: 84.88% Ts Accuracy: 83.87% Time: 293.27s Dprob 0.050
Epoch:  157 Tr loss: 0.470565 Ts loss: 0.504 Tr Accuracy: 86.37% Ts Accuracy: 84.75% Time: 295.08s Dprob 0.050
E

Epoch:  223 Tr loss: 0.343134 Ts loss: 0.385 Tr Accuracy: 89.69% Ts Accuracy: 87.98% Time: 414.12s Dprob 0.050
Epoch:  224 Tr loss: 0.347843 Ts loss: 0.434 Tr Accuracy: 88.55% Ts Accuracy: 86.51% Time: 415.67s Dprob 0.050
Epoch:  225 Tr loss: 0.351369 Ts loss: 0.399 Tr Accuracy: 88.89% Ts Accuracy: 87.98% Time: 417.33s Dprob 0.050
Epoch:  226 Tr loss: 0.336347 Ts loss: 0.403 Tr Accuracy: 89.35% Ts Accuracy: 86.80% Time: 419.03s Dprob 0.050
Epoch:  227 Tr loss: 0.335829 Ts loss: 0.394 Tr Accuracy: 89.69% Ts Accuracy: 87.68% Time: 420.81s Dprob 0.050
Epoch:  228 Tr loss: 0.360113 Ts loss: 0.387 Tr Accuracy: 88.89% Ts Accuracy: 87.39% Time: 422.66s Dprob 0.050
Epoch:  229 Tr loss: 0.356488 Ts loss: 0.394 Tr Accuracy: 88.32% Ts Accuracy: 87.68% Time: 424.47s Dprob 0.050
Epoch:  230 Tr loss: 0.362486 Ts loss: 0.397 Tr Accuracy: 89.12% Ts Accuracy: 87.68% Time: 426.06s Dprob 0.050
Epoch:  231 Tr loss: 0.333017 Ts loss: 0.378 Tr Accuracy: 89.69% Ts Accuracy: 87.39% Time: 427.84s Dprob 0.050
E

Epoch:  297 Tr loss: 0.267022 Ts loss: 0.313 Tr Accuracy: 91.87% Ts Accuracy: 89.44% Time: 543.38s Dprob 0.050
Epoch:  298 Tr loss: 0.267236 Ts loss: 0.299 Tr Accuracy: 91.64% Ts Accuracy: 90.32% Time: 545.33s Dprob 0.050
Epoch:  299 Tr loss: 0.261657 Ts loss: 0.325 Tr Accuracy: 92.10% Ts Accuracy: 88.56% Time: 547.23s Dprob 0.050
Epoch:  300 Tr loss: 0.267168 Ts loss: 0.319 Tr Accuracy: 91.52% Ts Accuracy: 87.68% Time: 549.09s Dprob 0.050
Epoch:  301 Tr loss: 0.265378 Ts loss: 0.328 Tr Accuracy: 92.21% Ts Accuracy: 88.27% Time: 550.94s Dprob 0.050
Epoch:  302 Tr loss: 0.261141 Ts loss: 0.329 Tr Accuracy: 92.21% Ts Accuracy: 88.27% Time: 552.70s Dprob 0.050
Epoch:  303 Tr loss: 0.264724 Ts loss: 0.311 Tr Accuracy: 92.90% Ts Accuracy: 88.56% Time: 554.62s Dprob 0.050
Epoch:  304 Tr loss: 0.259895 Ts loss: 0.320 Tr Accuracy: 91.75% Ts Accuracy: 87.98% Time: 556.43s Dprob 0.050
Epoch:  305 Tr loss: 0.266613 Ts loss: 0.299 Tr Accuracy: 91.07% Ts Accuracy: 90.32% Time: 558.27s Dprob 0.050
E

Epoch:  371 Tr loss: 0.219481 Ts loss: 0.292 Tr Accuracy: 92.67% Ts Accuracy: 89.74% Time: 676.98s Dprob 0.050
Epoch:  372 Tr loss: 0.207001 Ts loss: 0.270 Tr Accuracy: 93.81% Ts Accuracy: 90.32% Time: 678.78s Dprob 0.050
Epoch:  373 Tr loss: 0.217364 Ts loss: 0.287 Tr Accuracy: 92.90% Ts Accuracy: 88.56% Time: 680.56s Dprob 0.050
Epoch:  374 Tr loss: 0.207878 Ts loss: 0.263 Tr Accuracy: 92.67% Ts Accuracy: 89.44% Time: 682.35s Dprob 0.050
Epoch:  375 Tr loss: 0.213965 Ts loss: 0.283 Tr Accuracy: 93.13% Ts Accuracy: 90.03% Time: 684.13s Dprob 0.050
Epoch:  376 Tr loss: 0.217173 Ts loss: 0.285 Tr Accuracy: 93.36% Ts Accuracy: 89.44% Time: 685.93s Dprob 0.050
Epoch:  377 Tr loss: 0.222123 Ts loss: 0.279 Tr Accuracy: 93.01% Ts Accuracy: 89.74% Time: 687.72s Dprob 0.050
Epoch:  378 Tr loss: 0.203872 Ts loss: 0.275 Tr Accuracy: 94.16% Ts Accuracy: 89.15% Time: 689.47s Dprob 0.050
Epoch:  379 Tr loss: 0.231974 Ts loss: 0.271 Tr Accuracy: 91.98% Ts Accuracy: 89.74% Time: 691.27s Dprob 0.050
E

Epoch:  445 Tr loss: 0.165040 Ts loss: 0.251 Tr Accuracy: 95.30% Ts Accuracy: 90.03% Time: 811.44s Dprob 0.050
Epoch:  446 Tr loss: 0.184851 Ts loss: 0.250 Tr Accuracy: 93.70% Ts Accuracy: 90.62% Time: 813.29s Dprob 0.050
Epoch:  447 Tr loss: 0.183321 Ts loss: 0.261 Tr Accuracy: 94.16% Ts Accuracy: 89.44% Time: 815.13s Dprob 0.050
Epoch:  448 Tr loss: 0.167677 Ts loss: 0.244 Tr Accuracy: 94.62% Ts Accuracy: 90.62% Time: 817.01s Dprob 0.050
Epoch:  449 Tr loss: 0.172673 Ts loss: 0.284 Tr Accuracy: 94.50% Ts Accuracy: 88.86% Time: 818.85s Dprob 0.050
Epoch:  450 Tr loss: 0.173693 Ts loss: 0.260 Tr Accuracy: 94.39% Ts Accuracy: 89.15% Time: 820.64s Dprob 0.050
Epoch:  451 Tr loss: 0.179659 Ts loss: 0.244 Tr Accuracy: 94.04% Ts Accuracy: 90.32% Time: 822.46s Dprob 0.050
Epoch:  452 Tr loss: 0.179116 Ts loss: 0.263 Tr Accuracy: 94.62% Ts Accuracy: 89.44% Time: 824.28s Dprob 0.050
Epoch:  453 Tr loss: 0.179708 Ts loss: 0.255 Tr Accuracy: 95.19% Ts Accuracy: 90.62% Time: 826.04s Dprob 0.050
E

Epoch:  519 Tr loss: 0.141894 Ts loss: 0.249 Tr Accuracy: 95.30% Ts Accuracy: 89.44% Time: 943.60s Dprob 0.050
Epoch:  520 Tr loss: 0.130858 Ts loss: 0.246 Tr Accuracy: 96.22% Ts Accuracy: 90.03% Time: 945.43s Dprob 0.050
Epoch:  521 Tr loss: 0.155735 Ts loss: 0.242 Tr Accuracy: 94.96% Ts Accuracy: 90.03% Time: 947.38s Dprob 0.050
Epoch:  522 Tr loss: 0.156936 Ts loss: 0.243 Tr Accuracy: 94.85% Ts Accuracy: 90.32% Time: 949.25s Dprob 0.050
Epoch:  523 Tr loss: 0.153693 Ts loss: 0.234 Tr Accuracy: 95.19% Ts Accuracy: 90.62% Time: 951.09s Dprob 0.050
Epoch:  524 Tr loss: 0.156533 Ts loss: 0.249 Tr Accuracy: 94.85% Ts Accuracy: 89.74% Time: 952.92s Dprob 0.050
Epoch:  525 Tr loss: 0.139542 Ts loss: 0.233 Tr Accuracy: 95.42% Ts Accuracy: 90.62% Time: 954.72s Dprob 0.050
Epoch:  526 Tr loss: 0.143741 Ts loss: 0.255 Tr Accuracy: 95.30% Ts Accuracy: 89.74% Time: 956.46s Dprob 0.050
Epoch:  527 Tr loss: 0.142430 Ts loss: 0.235 Tr Accuracy: 95.53% Ts Accuracy: 90.03% Time: 958.26s Dprob 0.050
E

Epoch:  593 Tr loss: 0.129673 Ts loss: 0.213 Tr Accuracy: 95.53% Ts Accuracy: 91.50% Time: 1077.91s Dprob 0.050
Epoch:  594 Tr loss: 0.118490 Ts loss: 0.228 Tr Accuracy: 95.99% Ts Accuracy: 90.32% Time: 1079.55s Dprob 0.050
Epoch:  595 Tr loss: 0.125124 Ts loss: 0.219 Tr Accuracy: 96.11% Ts Accuracy: 90.32% Time: 1081.23s Dprob 0.050
Epoch:  596 Tr loss: 0.125084 Ts loss: 0.231 Tr Accuracy: 95.76% Ts Accuracy: 90.91% Time: 1082.97s Dprob 0.050
Epoch:  597 Tr loss: 0.127560 Ts loss: 0.238 Tr Accuracy: 96.11% Ts Accuracy: 90.32% Time: 1084.67s Dprob 0.050
Epoch:  598 Tr loss: 0.141686 Ts loss: 0.225 Tr Accuracy: 95.19% Ts Accuracy: 90.32% Time: 1086.51s Dprob 0.050
Epoch:  599 Tr loss: 0.134564 Ts loss: 0.241 Tr Accuracy: 95.76% Ts Accuracy: 90.03% Time: 1088.19s Dprob 0.050
Epoch:  600 Tr loss: 0.131621 Ts loss: 0.229 Tr Accuracy: 95.76% Ts Accuracy: 90.91% Time: 1089.87s Dprob 0.050
Epoch:  601 Tr loss: 0.127730 Ts loss: 0.251 Tr Accuracy: 95.42% Ts Accuracy: 89.44% Time: 1091.41s Dpro

Epoch:  666 Tr loss: 0.102595 Ts loss: 0.232 Tr Accuracy: 96.79% Ts Accuracy: 90.32% Time: 1197.72s Dprob 0.050
Epoch:  667 Tr loss: 0.119636 Ts loss: 0.213 Tr Accuracy: 95.65% Ts Accuracy: 90.91% Time: 1199.40s Dprob 0.050
Epoch:  668 Tr loss: 0.113698 Ts loss: 0.215 Tr Accuracy: 96.45% Ts Accuracy: 90.32% Time: 1201.00s Dprob 0.050
Epoch:  669 Tr loss: 0.107952 Ts loss: 0.234 Tr Accuracy: 97.14% Ts Accuracy: 89.44% Time: 1202.59s Dprob 0.050
Epoch:  670 Tr loss: 0.094271 Ts loss: 0.224 Tr Accuracy: 97.25% Ts Accuracy: 90.03% Time: 1204.14s Dprob 0.050
Epoch:  671 Tr loss: 0.113571 Ts loss: 0.235 Tr Accuracy: 95.88% Ts Accuracy: 89.74% Time: 1205.76s Dprob 0.050
Epoch:  672 Tr loss: 0.112888 Ts loss: 0.226 Tr Accuracy: 95.53% Ts Accuracy: 90.03% Time: 1207.32s Dprob 0.050
Epoch:  673 Tr loss: 0.105497 Ts loss: 0.231 Tr Accuracy: 97.02% Ts Accuracy: 90.03% Time: 1208.88s Dprob 0.050
Epoch:  674 Tr loss: 0.111782 Ts loss: 0.230 Tr Accuracy: 96.11% Ts Accuracy: 89.74% Time: 1210.43s Dpro

Epoch:  739 Tr loss: 0.087513 Ts loss: 0.235 Tr Accuracy: 97.82% Ts Accuracy: 89.74% Time: 1324.92s Dprob 0.050
Epoch:  740 Tr loss: 0.088604 Ts loss: 0.210 Tr Accuracy: 97.37% Ts Accuracy: 90.91% Time: 1326.83s Dprob 0.050
Epoch:  741 Tr loss: 0.118514 Ts loss: 0.209 Tr Accuracy: 97.14% Ts Accuracy: 90.62% Time: 1328.67s Dprob 0.050
Epoch:  742 Tr loss: 0.095569 Ts loss: 0.215 Tr Accuracy: 96.56% Ts Accuracy: 90.03% Time: 1330.49s Dprob 0.050
Epoch:  743 Tr loss: 0.091610 Ts loss: 0.214 Tr Accuracy: 97.37% Ts Accuracy: 90.62% Time: 1332.31s Dprob 0.050
Epoch:  744 Tr loss: 0.096622 Ts loss: 0.228 Tr Accuracy: 96.33% Ts Accuracy: 89.74% Time: 1334.12s Dprob 0.050
Epoch:  745 Tr loss: 0.106310 Ts loss: 0.216 Tr Accuracy: 96.33% Ts Accuracy: 90.32% Time: 1336.02s Dprob 0.050
Epoch:  746 Tr loss: 0.090076 Ts loss: 0.219 Tr Accuracy: 97.37% Ts Accuracy: 90.32% Time: 1337.89s Dprob 0.050
Epoch:  747 Tr loss: 0.087744 Ts loss: 0.217 Tr Accuracy: 97.14% Ts Accuracy: 90.32% Time: 1339.73s Dpro

In [17]:
print(best_acc)

91.78885630498533


In [None]:
x = np.linspace(0,epoch,epoch+1)

tr_loss = stats[:,train,loss]
ts_loss = stats[:,test, loss]

tr_acc = stats[:, train, acc]
ts_acc = stats[:, test,  acc]

plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.plot(x,tr_loss,label='train loss')
plt.plot(x,ts_loss,label='test loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(x,tr_acc,label='train acc')
plt.plot(x,ts_acc,label='test  acc')
plt.legend()

plt.show()

In [None]:
test_file = np.load('dropout_test.npz')['res']
test_data=test_file

In [None]:
train = 0
test  = 1
acc   = 0
loss  = 1

x = np.linspace(0,epoch,epoch+1)
for i in range(len(drop_prob_list)):
    plt.figure(figsize=(15,5))
    plt.subplot(1,2,1)
    ts_a = test_file[i,:,test,acc]
    ts_l = test_file[i,:,test,loss]
    tr_a = test_file[i,:,train,acc]
    tr_l = test_file[i,:,train,loss]
    
    plt.plot(x,ts_a,label='ts acc')
    plt.plot(x,tr_a,label='tr acc')
    plt.title('epoch vs acuracy  p=' + str(drop_prob_list[i]))
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.legend()
    plt.grid()
    
    plt.subplot(1,2,2)
    plt.plot(x,ts_l,label='ts loss')
    plt.plot(x,tr_l,label='tr loss')
    plt.title('epoch vs loss  p=' + str(drop_prob_list[i]))
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.grid()
    