In [None]:
#################################################################################################
#     TASK 1C. Building SVM classifiers based on Pre-trained Alexnet network
#     For different SVM kernels: 
#              + Linear, 
#              + Quadratic,
#              + Polynomial
#################################################################################################

In [7]:
# 1. Mount the data from Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
# 2. Change working directory
%cd drive/My\ Drive/Colab\ Notebooks/

/content/drive/My Drive/Colab Notebooks


In [9]:
# Import necessary packages for the task
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from numpy import where
import pandas as pd
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from torch.autograd import Variable
from random import randrange
from matplotlib.pyplot import figure
from sklearn.feature_selection import VarianceThreshold
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from pylab import scatter, show, legend, xlabel, ylabel

In [11]:
# Data augmentation and normalization for training data

# transforms.RandomResizedCrop(224),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'flowers_alexnet_dataset'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader( image_datasets[x], batch_size=4,
                                            shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [12]:
# Create pre-trained Alexnet network
model = models.alexnet(pretrained=True)

num_fts = model.classifier[6].out_features

# add full-connected layer to the net
model.fc = nn.Linear(num_fts, num_fts)

# init the weights for the fc layer
torch.nn.init.eye_(model.fc.weight)

# keep all weights fixed
for param in model.parameters():
    param.requires_grad = False

# add the model to the device
model = model.to(device)

In [13]:
# Create pre-trained Alexnet network by removing all classifiers and
# replace by 1 Linear classifier to extract features for transfer learning
model = models.alexnet(pretrained=True)

num_fts = model.classifier[1].in_features

# add a new classifier layer to the net
model.classifier = nn.Linear(num_fts, num_fts)

# init the weights for the fc layer
torch.nn.init.eye_(model.classifier.weight)

# keep all weights fixed
for param in model.parameters():
    param.requires_grad = False

# add the model to the device
model = model.to(device)

In [14]:
model.eval

<bound method Module.eval of AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Linear(in_features=9216, out_features=9216, bias=True)
)>

In [15]:
# Get features from CNN Alexnet network for all training and validation data
def get_train_and_val_features():
    extracted_features = {"train": [], "val": []}
    # Training and validation phase
    for phase in ['train', 'val']:
        # Iterate over data.
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # get features outputs
            features = model(inputs).cpu()

            batch_number = features.shape[0] 
            
            for i in range(0, batch_number):  
                featr = np.array(features[i].cpu())
                label = int(labels[i].cpu().numpy())                
                extracted_features[phase].append(np.append(label, featr))
    
    return (extracted_features)

# Extract features from Alexnet
features = get_train_and_val_features()

print (len(features))

2


In [16]:
# Prepare training and validating data
train_dataset = pd.DataFrame()
val_dataset = pd.DataFrame()
for typ, data in features.items():
    if typ == 'train':
        train_dataset = pd.DataFrame(data)
    if typ == 'val':
        val_dataset = pd.DataFrame(data)

# training data
Y = train_dataset.iloc[:, 0]
X = train_dataset.iloc[:, 1:]

# validating data
Y_val = val_dataset.iloc[:, 0]
X_val = val_dataset.iloc[:, 1:]

In [17]:
train_dataset.head(10)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,9177,9178,9179,9180,9181,9182,9183,9184,9185,9186,9187,9188,9189,9190,9191,9192,9193,9194,9195,9196,9197,9198,9199,9200,9201,9202,9203,9204,9205,9206,9207,9208,9209,9210,9211,9212,9213,9214,9215,9216
0,3.0,-0.006222,0.001637,-0.008066,0.002481,0.00047,-0.006442,-0.007088,0.005668,-0.004833,-0.001299,0.000618,-0.004086,0.004436,-0.003153,-0.007243,-0.002162,0.004506,0.002253,0.007099,0.010352,0.005348,0.008442,-0.008331,0.008585,-0.006346,0.003492,-0.002589,-0.008925,0.009793,-0.002807,0.005423,-0.008794,-0.008008,0.001326,0.640235,1.584166,-0.003758,-0.003128,0.00026,...,1.459623,0.783452,0.33382,2.53383,0.005767,-0.006269,-0.006344,2.56391,2.562145,0.001465,0.002379,0.007557,-0.00787,0.000155,-0.007163,0.010229,0.005678,-0.009664,0.00184,0.0065,0.2227,0.232347,0.001158,-0.009661,0.004417,0.001246,0.240048,0.232227,0.004275,0.009444,0.005743,-0.006441,-0.004304,0.009567,0.007632,-0.008584,-0.00568,0.010111,0.007561,0.00375
1,4.0,0.212278,0.220136,-0.008066,0.002481,0.00047,4.40583,-0.007088,0.005668,-0.004833,-0.001299,0.000618,1.348367,0.004436,-0.003153,-0.007243,0.862369,0.421491,0.002253,0.007099,0.010352,0.005348,0.615314,0.408655,0.008585,-0.006346,0.003492,-0.002589,-0.008925,0.009793,-0.002807,3.899294,0.984962,-0.008008,0.001326,-0.006496,0.003065,-0.003758,-0.003128,0.00026,...,-0.005313,-0.004838,0.006316,-0.010068,0.005767,-0.006269,-0.006344,-0.004685,-0.00645,0.001465,0.095536,0.031233,-0.00787,0.000155,-0.007163,0.010229,0.005678,0.218456,0.00184,0.0065,2.578001,2.687706,0.001158,0.167992,0.004417,0.163077,2.665993,2.903237,0.004275,0.009444,0.005743,-0.006441,2.543015,2.911643,0.662859,-0.008584,-0.00568,0.010111,0.007561,0.00375
2,2.0,-0.006222,0.001637,-0.008066,0.002481,0.00047,-0.006442,-0.007088,0.005668,-0.004833,-0.001299,0.000618,-0.004086,0.004436,-0.003153,-0.007243,-0.002162,0.312962,0.002253,0.007099,0.010352,0.005348,0.008442,-0.008331,0.008585,-0.006346,0.003492,-0.002589,-0.008925,0.009793,-0.002807,0.005423,1.479007,-0.008008,0.001326,-0.006496,0.003065,6.426361,6.426991,5.035393,...,-0.005313,-0.004838,0.006316,-0.010068,0.005767,-0.006269,-0.006344,-0.004685,-0.00645,0.001465,0.002379,0.007557,-0.00787,0.000155,-0.007163,0.010229,0.005678,-0.009664,0.00184,0.0065,-0.008367,0.00128,0.001158,-0.009661,0.004417,0.001246,0.008982,0.00116,0.004275,0.009444,0.005743,-0.006441,-0.004304,0.009567,0.007632,-0.008584,-0.00568,0.010111,0.007561,0.00375
3,2.0,-0.006222,1.23559,1.225888,1.955507,2.047491,0.677871,-0.007088,1.273652,1.229121,1.625985,1.6126,-0.004086,0.004436,0.473693,5.039421,5.044501,1.616488,0.002253,0.007099,0.400709,5.052011,5.055105,-0.008331,0.008585,-0.006346,0.003492,1.75257,1.746235,0.009793,-0.002807,0.005423,-0.008794,0.47154,0.001326,-0.006496,0.003065,-0.003758,-0.003128,0.00026,...,-0.005313,-0.004838,0.006316,-0.010068,0.005767,-0.006269,-0.006344,-0.004685,-0.00645,0.001465,0.002379,0.007557,-0.00787,0.000155,-0.007163,0.010229,0.005678,-0.009664,0.00184,0.0065,-0.008367,0.00128,0.001158,-0.009661,0.004417,0.001246,0.008982,0.00116,0.004275,0.009444,0.005743,-0.006441,-0.004304,0.009567,0.007632,-0.008584,-0.00568,0.010111,0.007561,0.00375
4,0.0,-0.006222,0.001637,-0.008066,0.002481,0.00047,-0.006442,-0.007088,0.005668,-0.004833,-0.001299,0.000618,-0.004086,0.180681,0.703967,0.699877,-0.002162,0.004506,0.002253,0.007099,3.483867,3.478863,0.008442,-0.008331,0.008585,-0.006346,3.477007,3.470925,-0.008925,0.009793,-0.002807,0.005423,4.284713,5.114036,0.769906,0.556541,0.566102,-0.003758,0.999728,0.488465,...,0.21545,0.215925,0.006316,-0.010068,0.005767,-0.006269,-0.006344,-0.004685,-0.00645,0.001465,4.366529,4.371707,2.440177,0.000155,-0.007163,0.010229,4.369828,4.354486,2.449887,0.0065,-0.008367,0.00128,0.001158,-0.009661,3.374954,0.464683,0.008982,0.00116,0.004275,0.009444,3.37628,0.456996,-0.004304,0.009567,0.007632,-0.008584,1.439395,1.844512,0.007561,0.00375
5,2.0,-0.006222,4.48081,4.471107,0.002481,0.00047,-0.006442,-0.007088,4.484841,4.47434,-0.001299,0.000618,-0.004086,0.004436,-0.003153,-0.007243,-0.002162,0.004506,0.002253,0.007099,0.010352,0.005348,0.008442,-0.008331,0.008585,-0.006346,0.003492,-0.002589,-0.008925,0.009793,-0.002807,0.005423,-0.008794,-0.008008,0.001326,-0.006496,0.003065,-0.003758,-0.003128,0.00026,...,1.85513,0.103618,0.006316,-0.010068,0.005767,-0.006269,-0.006344,-0.004685,-0.00645,0.001465,0.002379,0.007557,-0.00787,0.000155,-0.007163,1.033851,0.005678,-0.009664,0.00184,0.0065,0.279909,1.024903,0.001158,-0.009661,0.004417,0.001246,0.297257,0.833168,0.004275,0.009444,0.005743,-0.006441,-0.004304,0.009567,0.007632,-0.008584,-0.00568,0.010111,0.007561,0.00375
6,3.0,-0.006222,0.001637,-0.008066,0.002481,0.00047,-0.006442,-0.007088,0.005668,-0.004833,-0.001299,0.000618,-0.004086,0.004436,-0.003153,0.239105,0.244186,0.004506,0.002253,0.007099,0.010352,0.754745,0.757839,-0.008331,0.008585,-0.006346,0.003492,-0.002589,-0.008925,0.009793,-0.002807,0.005423,-0.008794,-0.008008,0.001326,-0.006496,0.003065,-0.003758,-0.003128,0.00026,...,-0.005313,-0.004838,0.561043,0.232108,0.005767,-0.006269,-0.006344,-0.004685,-0.00645,0.001465,0.002379,0.007557,-0.00787,0.000155,-0.007163,0.010229,0.005678,-0.009664,0.00184,0.0065,-0.008367,0.00128,0.001158,-0.009661,0.004417,0.001246,0.008982,0.00116,0.004275,0.009444,0.005743,-0.006441,-0.004304,0.009567,0.007632,-0.008584,-0.00568,0.010111,0.007561,0.00375
7,0.0,0.521082,4.718698,4.708996,0.806908,0.00047,-0.006442,2.761354,6.709153,6.698652,0.803128,0.000618,-0.004086,0.004436,5.632624,5.628534,-0.002162,0.004506,0.002253,0.007099,0.010352,0.005348,0.008442,-0.008331,0.008585,0.670131,0.003492,-0.002589,-0.008925,0.009793,-0.002807,0.681901,-0.008794,-0.008008,0.001326,-0.006496,0.122656,-0.003758,-0.003128,0.00026,...,-0.005313,-0.004838,0.006316,-0.010068,0.005767,-0.006269,-0.006344,-0.004685,-0.00645,0.001465,0.002379,0.007557,4.153725,2.884594,-0.007163,0.010229,0.005678,-0.009664,2.33735,2.34201,-0.008367,0.00128,0.001158,-0.009661,0.004417,0.001246,0.008982,0.00116,0.049496,0.009444,0.005743,-0.006441,-0.004304,0.009567,0.052854,-0.008584,-0.00568,0.010111,0.007561,0.00375
8,4.0,-0.006222,0.001637,-0.008066,0.002481,0.00047,-0.006442,-0.007088,0.005668,-0.004833,-0.001299,0.000618,-0.004086,0.004436,-0.003153,-0.007243,-0.002162,0.004506,0.002253,0.369135,0.372388,0.597696,0.008442,-0.008331,0.008585,-0.006346,0.003492,0.589758,-0.008925,0.009793,-0.002807,0.005423,1.321921,1.545321,0.001326,-0.006496,0.003065,-0.003758,-0.003128,0.00026,...,0.345608,-0.004838,0.006316,2.127479,0.005767,1.072123,-0.006344,-0.004685,-0.00645,0.001465,0.002379,0.007557,-0.00787,0.000155,-0.007163,0.010229,0.494782,0.47944,0.00184,0.0065,-0.008367,0.00128,0.001158,-0.009661,0.004417,0.001246,0.008982,0.00116,0.004275,0.009444,0.005743,-0.006441,-0.004304,0.009567,0.007632,-0.008584,-0.00568,0.010111,0.007561,0.00375
9,3.0,-0.006222,0.001637,-0.008066,0.002481,0.00047,-0.006442,-0.007088,0.005668,-0.004833,-0.001299,0.000618,-0.004086,0.004436,-0.003153,-0.007243,-0.002162,0.004506,0.002253,0.007099,0.010352,0.005348,0.008442,-0.008331,0.008585,-0.006346,0.003492,-0.002589,-0.008925,0.009793,-0.002807,0.005423,-0.008794,-0.008008,0.001326,-0.006496,0.003065,-0.003758,-0.003128,0.00026,...,-0.005313,-0.004838,0.006316,-0.010068,0.005767,-0.006269,-0.006344,-0.004685,-0.00645,0.001465,0.002379,0.007557,-0.00787,0.000155,-0.007163,0.010229,0.005678,-0.009664,0.00184,0.0065,-0.008367,0.00128,0.001158,-0.009661,0.004417,0.001246,0.008982,0.00116,0.004275,0.009444,0.005743,-0.006441,-0.004304,0.009567,0.007632,-0.008584,-0.00568,0.010111,0.007561,0.00375


In [None]:
###############################################################################
# visualize train data
daisy = where(Y == 0.0)
dandelion = where(Y ==1.0)
rose = where(Y ==2.0)
sunflower = where(Y ==3.0)
tulip = where(Y ==4.0)

# 'b', 'g', 'r', 'c', 'm', 'y', 'k', 'w' - they are the single
#   character short-hand notations for blue, green, red, cyan, 
# magenta, yellow, black, and white.
ran_feat_1 = [randrange(1,9216) for i in range(10)]
ran_feat_2 = [randrange(1,9216) for i in range(10)]

for i in ran_feat_1:
  for j in ran_feat_2:
    scatter(X.iloc[daisy[0], i], X.iloc[daisy[0], j], marker='1', c='r')
    scatter(X.iloc[dandelion[0], i], X.iloc[dandelion[0], j], marker='o', c='g')
    scatter(X.iloc[rose[0], i], X.iloc[rose[0], j], marker='s', c='b')
    scatter(X.iloc[sunflower[0], i], X.iloc[sunflower[0], j], marker='p', c='y')
    scatter(X.iloc[tulip[0], i], X.iloc[tulip[0], j], marker='*', c='c')

    xlabel('Training set - Exam the feature %s' % str(i))
    ylabel('Training set - Exam the feature %s' % str(j))
    legend(['Daisy', 'Dandelion', 'Rose', 'Sunflower', 'Tulip'])
    show()

In [None]:
print (X.iloc[daisy[0], 0])

In [19]:
# Train SVM Classifier - Kernel: Linear
SVM_Linear_Classifier = SVC(kernel='linear')
linear_pipeline = Pipeline([('low_variance_filter', VarianceThreshold()), ('model', SVM_Linear_Classifier)])
linear_pipeline.fit(X, Y)

Pipeline(memory=None,
         steps=[('low_variance_filter', VarianceThreshold(threshold=0.0)),
                ('model',
                 SVC(C=1.0, break_ties=False, cache_size=200, class_weight=None,
                     coef0=0.0, decision_function_shape='ovr', degree=3,
                     gamma='scale', kernel='linear', max_iter=-1,
                     probability=False, random_state=None, shrinking=True,
                     tol=0.001, verbose=False))],
         verbose=False)

In [20]:
# Train SVM Classifier - Kernel: Quadratic
SVM_Quadratic_Classifier = SVC(kernel='rbf')
quad_pipeline = Pipeline([('low_variance_filter', VarianceThreshold()), ('model', SVM_Quadratic_Classifier)])
quad_pipeline.fit(X, Y)

Pipeline(memory=None,
         steps=[('low_variance_filter', VarianceThreshold(threshold=0.0)),
                ('model',
                 SVC(C=1.0, break_ties=False, cache_size=200, class_weight=None,
                     coef0=0.0, decision_function_shape='ovr', degree=3,
                     gamma='scale', kernel='rbf', max_iter=-1,
                     probability=False, random_state=None, shrinking=True,
                     tol=0.001, verbose=False))],
         verbose=False)

In [21]:
# Train SVM Classifier - Kernel: Polynomial
SVM_Polynomial_Classifier = SVC(kernel='poly', degree=1)
poly_pipeline = Pipeline([('low_variance_filter', VarianceThreshold()), ('model', SVM_Polynomial_Classifier)])
poly_pipeline.fit(X, Y)

Pipeline(memory=None,
         steps=[('low_variance_filter', VarianceThreshold(threshold=0.0)),
                ('model',
                 SVC(C=1.0, break_ties=False, cache_size=200, class_weight=None,
                     coef0=0.0, decision_function_shape='ovr', degree=1,
                     gamma='scale', kernel='poly', max_iter=-1,
                     probability=False, random_state=None, shrinking=True,
                     tol=0.001, verbose=False))],
         verbose=False)

In [22]:
def validation_statistics():
    pipelines = [linear_pipeline, quad_pipeline, poly_pipeline]
    
    for pipeline in pipelines:

        preds = pipeline.predict(X_val)

        correct = 0
        correct_daisy = 0
        correct_dandelion = 0
        correct_rose = 0
        correct_sunflower = 0
        correct_tulip = 0
        # calculate number of each type of flowers
        num_daisy = len(where(Y_val == 0)[0])
        num_dandelion = len(where(Y_val == 1)[0])
        num_rose = len(where(Y_val == 2)[0])
        num_sunflower = len(where(Y_val == 3)[0])
        num_tulip = len(where(Y_val == 4)[0])
        for i in range(0,len(Y_val)):
            actual_val = int(Y_val[i])
            validate_val = int(preds[i])
            if (actual_val == validate_val):
              correct += 1
              if(actual_val == 0):
                  correct_daisy += 1
              if(actual_val == 1):
                  correct_dandelion += 1
              if(actual_val == 2):
                  correct_rose += 1
              if(actual_val == 3):
                  correct_sunflower += 1
              if(actual_val == 4):
                  correct_tulip += 1

        acc = float(correct/len(Y_val))
        acc_daisy = float(correct_daisy/num_daisy)
        acc_dandelion = float(correct_dandelion/num_dandelion)
        acc_rose = float(correct_rose/num_rose)
        acc_sunflower = float(correct_sunflower/num_sunflower)
        acc_tulip = float(correct_tulip/num_tulip)

        # save the statistics
        with open('validation_statistics.txt', 'a+') as f:
            f.write('%.3f %.3f %.3f  %.3f  %.3f  %.3f\n' % 
                    (acc, acc_daisy, acc_dandelion, acc_rose, acc_sunflower, acc_tulip))

validation_statistics()

In [121]:
num_daisy = len(where(Y_val == 0)[0])
print (num_daisy)

113


In [4]:
# Validating the trained SVM classifiers
def validation_trained_SVM_classifier(svm_kernel):
    if svm_kernel == 'linear':
        preds = linear_pipeline.predict(X_val)
    if svm_kernel == 'rbf':
        preds = quad_pipeline.predict(X_val)
    if svm_kernel == 'poly':
        preds = poly_pipeline.predict(X_val)

    correct = 0
    for i in range(0,len(Y_val)):
        actual_val = int(Y_val[i])
        validate_val = int(preds[i])
        if (actual_val == validate_val):
          correct += 1

    acc = float(correct/len(Y_val))

    print ('Accuracy: %.3f' % acc)
    
    return preds

In [5]:
# validate the model
preds = validation_trained_SVM_classifier('linear')

NameError: ignored

In [96]:
preds = validation_trained_SVM_classifier('rbf')

Accuracy: 0.862


In [103]:
preds = validation_trained_SVM_classifier('poly')

Accuracy: 0.866


In [112]:
def get_images_features_for_testing():

    # Iterate over vallidation data.
    # pick randomly 3x4 images for testing
    test_indices = [randrange(0, int(len(dataloaders['val'])/4)) for i in range(0,3)]

    test_img_data = []
    test_img_features = []
    
    for idx, (inputs, labels) in enumerate(dataloaders['val']):
        if idx in test_indices:

            inputs = inputs.to(device)
            labels = labels.to(device)

            # get features outputs
            features = model(inputs).cpu()

            batch_size = features.shape[0] 
            
            for i in range(0, batch_size):  
                img_data = inputs.cpu().data[i]
                featr = np.array(features[i].cpu())
                label = int(labels[i].cpu().numpy())

                # store data and label
                test_img_data.append(img_data)
                test_img_features.append(np.append(label, featr))
    
    return (test_img_data, test_img_features)

# extract features for testing the model
test_img_data, test_img_features = get_images_features_for_testing()

In [107]:
# Show pridictions result for testing data
def show_testing_result(pipeline):
    test_dataset = pd.DataFrame(test_img_features)

    Y_test = test_dataset.iloc[:, 0]
    X_test = test_dataset.iloc[:, 1:]

    preds_test = pipeline.predict(X_test)


    row = 3
    col = 4
    f, ax = plt.subplots(row, col, figsize=(10,10))

    for i in range(row):
      for j in range(col):
        inp = test_img_data[col*i+j]
        inp = inp.numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        inp = std * inp + mean
        inp = np.clip(inp, 0, 1)

        ax[i,j].imshow(inp)
        ax[i,j].text(5, -13, 'Predicted: %s' % class_names[int(preds_test[col*i+j])],
                    color='k', backgroundcolor='red', alpha=0.9)
        # ax[i,j].text(25, -5, 'Pred:   %s\nActual:%s' % (class_names[int(preds_test[col*i+j])], class_names[int(Y_test[col*i+j])]),
        #              color='k', backgroundcolor='red', alpha=0.8)

    plt.show()


In [None]:
show_testing_result(linear_pipeline)

In [None]:
show_testing_result(quad_pipeline)

In [None]:
show_testing_result(poly_pipeline)