In [1]:
import torch
import os
import cv2
import json
from tqdm import tqdm
from PIL import Image
import numpy as np

exclusionInterval = 20

class CustomHipDataset(torch.utils.data.Dataset):
    def __init__(self, json_path, transform=None):
        """
        Args:
            json_path (string): Path to the JSON file with annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        with open(json_path, 'r') as file:
            data_info = json.load(file)

        self.image_paths = []
        self.labels = []

        for subj_info in data_info:
            subj_name = subj_info["subj"]
            for view_type in ["axial", "coronal", "sagittal"]:
                image_dir = os.path.join(subj_info["folder"], subj_name, view_type)

                # Determine if the view type differentiates between DX and SX
                sides = ["DX", "SX"] if view_type in ["axial", "coronal"] else [None]

                for side_idx, side in enumerate(sides):
                    positive_interval = subj_info[view_type][side_idx]
                    excluded_interval = [
                        max(0, positive_interval[0] - exclusionInterval),
                        positive_interval[1] + exclusionInterval
                    ]

                    for image_name in os.listdir(image_dir):
                        # Validate if the image matches the side for non-sagittal types
                        if side and not image_name.endswith(f"_{side}.png"):
                            continue

                        # Extract the image number
                        img_num = int(image_name.split('_')[1].split('.')[0])

                        # Exclude images within the excluded interval
                        if excluded_interval[0] <= img_num <= positive_interval[0] or positive_interval[1] <= img_num <= excluded_interval[1]:
                            continue

                        # Label as positive if within the positive interval
                        label = positive_interval[0] <= img_num <= positive_interval[1]

                        self.image_paths.append(os.path.join(image_dir, image_name))
                        self.labels.append(label)

        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        image = Image.open(img_path).convert("RGB").convert("L")

        image = np.array(image)

        if self.transform:
            image = self.transform(image)
        
        label = torch.tensor(self.labels[index], dtype=torch.float32)
        
        return image, label
    
    def load_all_dataset(self):
        self.__images__ = []
        self.__labels__ = []
        for i in tqdm(range(self.__len__()), "Load all dataset"):
            img, label = self.__getitem__(i)

            if self.transform:
                img = self.transform(img)
        
            self.__images__.append(img)
            self.__labels__.append(label)
        
        return
    
    def get_all_dataset(self):
        return self.__images__, self.__labels__

# Usage:
# dataset = CustomHipDataset("your_json_path.json")
# data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [2]:
import numpy as np
import joblib
from sklearn import svm
from sklearn.model_selection import train_test_split
from skimage.transform import resize
from skimage.feature import hog
from skimage import exposure

class SliceSVMClassifier:
    def __init__(self, kernel='linear', C=1.0):
        """
        Initialize the ImageSVMClassifierWithHOG.

        :param kernel: SVM kernel (default: 'linear')
        :param C: Regularization parameter (default: 1.0)
        """
        self.clf = svm.SVC(kernel=kernel, C=C)

    def _extract_hog_features(self, images):
        hog_features = []
        for image in tqdm(images, desc="Extracting HOG Features"):
            image = resize(image, (128, 128))
            # Extract HOG features
            fd = hog(image, pixels_per_cell=(8, 8), cells_per_block=(2, 2), block_norm='L2-Hys')

            hog_features.append(fd)
        return np.array(hog_features)

    def train(self, custom_dataset, test_size=0.2, random_state=None):
        """
        Train the SVM classifier with HOG features using a CustomHipDataset.

        :param custom_dataset: CustomHipDataset object
        :param test_size: Fraction of data to use for testing (default: 0.2)
        :param random_state: Random seed for reproducibility
        """
        X = custom_dataset[0]
        y = custom_dataset[1]

        hog_features = self._extract_hog_features(X)
        X_train, X_test, y_train, y_test = train_test_split(hog_features, y, test_size=test_size, random_state=random_state)
        
        # Train the classifier and update the progress bar
        self.clf.fit(X_train, y_train)

        accuracy = self.clf.score(X_test, y_test)
        return accuracy

    def predict(self, X):
        """
        Make predictions using the trained SVM classifier with HOG features using a CustomHipDataset.

        :param custom_dataset: CustomHipDataset object
        :return: List of predicted labels
        """
        hog_features = self._extract_hog_features(X)
        predictions = self.clf.predict(hog_features)
        return predictions
    
    def save_model(self, filename):
        """
        Save the trained SVM classifier model to a file using joblib.

        :param filename: Name of the file to save the model to
        """
        joblib.dump(self.clf, filename)
    
    def load_model(self, filename):
        """
        Load a trained SVM classifier model from a file using joblib.

        :param filename: Name of the file to load the model from
        :return: An instance of SliceSVMClassifier with the loaded model
        """
        self.clf = joblib.load(filename)
        return


In [3]:
# Load the dataset
hips = CustomHipDataset("TEST.json")
hips.load_all_dataset()

Load all dataset: 100%|██████████| 6472/6472 [02:35<00:00, 41.50it/s]


In [4]:
# Initialize the SVM classifier

svm_classifier = SliceSVMClassifier()

# Train the SVM classifier with HOG features
accuracy = svm_classifier.train(hips.get_all_dataset(), random_state=42)

print(f"Accuracy: {accuracy}")

# Save the trained model to a file
svm_classifier.save_model("svm_classifier_model.pkl")

Extracting HOG Features: 100%|██████████| 6472/6472 [07:47<00:00, 13.84it/s]


Accuracy: 1.0


In [None]:
# Load the model from a file
loaded_classifier = SliceSVMClassifier.load_model("svm_classifier_model.pkl")

# Make predictions using the loaded model
#predictions = loaded_classifier.predict(X)