In [None]:
import numpy as np
from tensorflow.keras import layers
from tensorflow.keras import models
from sklearn.preprocessing import Normalizer
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

In [None]:
def create_feature_extractor(model, feature_layer, last_layer):
  '''
  Creates a global average pooling feature extractor from the model at layer feature_layer
  '''
  features_extractor = models.Sequential(name = model.name + "_feature_extractor")
  for layer in model.layers[:feature_layer]:
      features_extractor.add(layer)
  features_extractor.add(last_layer)
  return features_extractor

print('"create_feature_extractor" function loaded' )

In [None]:
def extract_features(features_extractor, train_images, val_images, test_images):
  '''
  Extract train, validation and test features using the feature extractor 
  and normalize them
  '''
  # Extract features
  train_features = features_extractor.predict(train_images)
  val_features = features_extractor.predict(val_images)
  test_features = features_extractor.predict(test_images)

  # SVM works better if features are normalized
  Normalizer().fit(train_features).transform(train_features)
  Normalizer().fit(val_features).transform(val_features)
  Normalizer().fit(test_features).transform(test_features)

  print(f"Features extracted using {features_extractor.name}")

  return train_features, val_features, test_features

print('"extract_features" function loaded' )

In [None]:
def train_svm_classifier(train_features, train_labels, val_features, val_labels, ker = "linear", gam='auto', c = 3):
  '''
  Train svm classifier and print train and validation accuracy
  '''
  svm_classifier = make_pipeline(SVC(kernel = ker, gamma = gam, C = c, probability= True))
  svm_classifier.fit(train_features, train_labels)
  print(f"Accuracy on training set: {round(svm_classifier.score(train_features, train_labels)*100,2)}%")
  print(f"Accuracy on validation set: {round(svm_classifier.score(val_features, val_labels)*100,2)}%")
  return svm_classifier

print('"train_svm_classifier" function loaded' )