In [1]:
#including all the libraries
import torch
from torch.utils.data import Dataset,DataLoader
from tqdm import tqdm
from torchvision import transforms
import numpy as np
import pandas as pd
from PIL import Image
import argparse
import os
import copy
import torch
import cv2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#performing pre processing
LABELS_Severity = {35: 0, 43: 0, 47: 1, 53: 1, 61: 2, 65: 2, 71: 2, 85: 2}
mean = (.1706)
std = (.2112)
normalize = transforms.Normalize(mean=mean, std=std)

transform = transforms.Compose([transforms.Resize(size=(224,224)),transforms.ToTensor(),normalize,])

In [3]:
#
class OCTDataset(Dataset):
    def __init__(self, args, subset='train', transform=None,):
        if subset == 'train':
            self.annot = pd.read_csv(args.annot_train_prime)
        elif subset == 'test':
            self.annot = pd.read_csv(args.annot_test_prime)
            
        self.annot['Severity_Label'] = [LABELS_Severity[drss] for drss in copy.deepcopy(self.annot['DRSS'].values)] 
        # print(self.annot)
        self.root = os.path.expanduser(args.data_root)
        self.transform = transform
        # self.subset = subset
        self.nb_classes=len(np.unique(list(LABELS_Severity.values())))
        self.path_list = self.annot['File_Path'].values
        self._labels = self.annot['Severity_Label'].values
        assert len(self.path_list) == len(self._labels)
        # idx_each_class = [[] for i in range(self.nb_classes)]

    def __getitem__(self, index):
        img, target = Image.open(self.root+self.path_list[index]).convert("L"), self._labels[index]

        if self.transform is not None:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self._labels)         

In [4]:
class NotebookArgs:
    def __init__(self, annot_train_prime = 'df_prime_train.csv', annot_test_prime = 'df_prime_test.csv', data_root = '/storage/home/hpaceice1/shared-classes/materials/ece8803fml/'):
        self.annot_train_prime = annot_train_prime
        self.annot_test_prime = annot_test_prime
        self.data_root = data_root
args = NotebookArgs()

In [5]:
#loading the dataset
trainset = OCTDataset(args, 'train', transform=transform)
testset = OCTDataset(args, 'test', transform=transform)

batch_size = 32

train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

In [6]:
def get_X_y_from_loader(loader):
    X, y = [], []
    i=0
    for sample in tqdm(loader, total=len(loader)):
        i = i+1
        images, labels = sample[0], sample[1]
        X.extend([a.numpy()[0] for a in images])
        y.extend([a.numpy().flatten() for a in labels])
        #break
        #if(i == 1000):
           #break
    return X,y
X_train, y_train = get_X_y_from_loader(train_loader)
print(len(X_train))

100%|██████████| 758/758 [01:34<00:00,  7.99it/s]

24252





In [7]:
#performing sift on train set
X_train_sift = []
i=0
for i in tqdm(range(len(X_train))):   
    #X_hog = X_train[i].reshape((224,224))
    img_uint8 = cv2.convertScaleAbs(X_train[i])
    sift = cv2.SIFT_create(nfeatures=50, contrastThreshold=0.001, edgeThreshold=5)
    # Detect keypoints and compute descriptors
    keypoints, descriptors = sift.detectAndCompute(img_uint8, None)
    descriptors=descriptors.flatten()
    descriptors= descriptors[:6400] 
    X_train_sift.append(descriptors)
    i =i+1
    if i==10000:
        break
print(X_train_sift[0].shape)
print(len(X_train_sift))

 41%|████      | 9999/24252 [01:38<02:19, 102.02it/s]

(6400,)
10000





In [8]:
print(len(y_train))
X_subset = X_train_sift
y_subset = y_train[0:10000]
print(len(y_subset))
#for arr in (X_subset):
    #print(arr.shape)


24252
10000


In [9]:
#naive bayes implementation
from sklearn.naive_bayes import GaussianNB
# Train the Naive Bayes classifier
clf = GaussianNB()
clf.fit(X_subset, y_subset)

  y = column_or_1d(y, warn=True)


In [10]:
X_test, y_test = get_X_y_from_loader(test_loader)

100%|██████████| 250/250 [00:31<00:00,  7.91it/s]


In [11]:
X_test_sift = []
for i in tqdm(range(len(X_test))):
    #X_hog = X_train[i].reshape((224,224))
    img_uint8 = cv2.convertScaleAbs(X_test[i])
    sift = cv2.SIFT_create(nfeatures=50, contrastThreshold=0.001, edgeThreshold=5)
    # Detect keypoints and compute descriptors
    keypoints, descriptors = sift.detectAndCompute(img_uint8, None)
    descriptors = descriptors.flatten()
    descriptors = descriptors[:6400]
    X_test_sift.append(descriptors)
print(X_test_sift[0].shape)
print(len(X_test_sift))

100%|██████████| 7987/7987 [01:18<00:00, 102.07it/s]

(6400,)
7987





In [12]:
# Make predictions on the testing data
y_pred = clf.predict(X_test_sift)

# Calculate the accuracy of the model
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy:', accuracy)

Accuracy: 0.3483160135219732


In [13]:
from sklearn.metrics import balanced_accuracy_score, precision_score, recall_score, f1_score

# Calculate balanced accuracy score
balanced_acc = balanced_accuracy_score(y_test, y_pred)
print("Balanced accuracy score:", balanced_acc)

# Calculate precision score
precision = precision_score(y_test, y_pred,average='weighted')
print("Precision score:", precision)

# Calculate recall score
recall = recall_score(y_test, y_pred,average='weighted')
print("Recall score:", recall)

# Calculate f1 score
f1 = f1_score(y_test, y_pred,average='weighted')
print("F1 score weighted:", f1)

# Calculate f1 score
f1 = f1_score(y_test, y_pred,average='micro')
print("F1 score Micro:", f1)

# Calculate f1 score
f1 = f1_score(y_test, y_pred,average='macro')
print("F1 score Macro:", f1)

Balanced accuracy score: 0.3210219695818774
Precision score: 0.387969692127719
Recall score: 0.3483160135219732
F1 score weighted: 0.36363330858040116
F1 score Micro: 0.3483160135219732
F1 score Macro: 0.3226210353298844


In [14]:
from sklearn.metrics import confusion_matrix

labels = [0, 1, 2]  # label names

cm = confusion_matrix(y_test, y_pred, labels=labels)

cm_df = pd.DataFrame(cm, index=labels, columns=labels)

print(cm_df)

      0     1     2
0   914   802   832
1  1092  1551  1277
2   424   778   317
