In [43]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
import torchvision.models as models
from skimage.feature import local_binary_pattern
from skimage.color import rgb2gray
from sklearn.linear_model import LogisticRegression


In [44]:
dataset_dir = 'dataset_preprocessed'

In [45]:
means= [0.3337701, 0.35129565, 0.36801142]
stds= [0.16881385, 0.1562263, 0.16852096]

In [46]:
transform = transforms.Compose(
    [
     transforms.Resize((256,256)),
     transforms.ToTensor(),
     transforms.Normalize((means[0],means[1],means[2]), (stds[0],stds[1],stds[2]))])
full_data = ImageFolder(dataset_dir, transform=transform)

#Split the dataset
train_size = int(0.85 * len(full_data))
test_size = len(full_data) - train_size
trainDataset, testDataset = torch.utils.data.random_split(full_data, [train_size, test_size])

In [47]:
#Extract LBP Features for images
def extract_lbp_features(image, radius=4, n_points=64):

    image = image.permute(1,2,0)
    image = image.numpy()
    # transform rgb to grayscale
    image = rgb2gray(image)
    lbp = local_binary_pattern(image, n_points, radius, method="uniform")
    (hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, 65), range=(0, n_points + 2), density=True)
    hist = hist.astype(np.float32)
    return hist

#Get features for all images in full_data
def get_features(full_data):
    features = []
    labels = []
    for i in range(len(full_data)):
        image, label = full_data[i]
        lbp = extract_lbp_features(image)
        features.append(lbp)
        labels.append(label)
    return np.array(features), np.array(labels)

In [50]:
x_train, y_train = get_features(trainDataset)
x_test, y_test = get_features(testDataset)

In [51]:
logisticRegr = LogisticRegression()
logisticRegr.fit(x_train, y_train)

accuracy = logisticRegr.score(x_test, y_test)

print("Accuracy on Test Set: ", accuracy*100, "%")

Accuracy on Test Set:  84.17266187050359 %
