In [8]:
import pandas as pd
import os
import re
import torch
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
from torch import nn, optim
import random
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from scipy.spatial import distance
import cv2
from matplotlib.colors import ListedColormap
import torch.nn.functional as F

In [3]:
labels_df = pd.read_csv("LERA_Dataset/labels.csv", names = ["patient_ID", "image_type", "label"])

In [5]:
ankle_df = labels_df[labels_df["image_type"]=="XR ANKLE"]
ankle_patients = list(ankle_df["patient_ID"])

image_paths = []

for patient in ankle_patients:
    dir = f"LERA_Dataset/{str(patient)}/ST-1"
    image_paths.extend(os.path.join(dir, file) for file in os.listdir(dir) if file.endswith(".png"))

In [6]:
len(image_paths)

321

In [11]:
class AnkleXrayDataset(Dataset):
    def __init__(self, image_paths):
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        image = Image.open(image_path).convert("L").resize((256, 256))
        
        image = ToTensor()(image)
        
        return image

In [9]:
def gaussian_kernel(size, sigma=1.0):
    x_coords = -(size // 2)
    y_coords = size // 2 + 1
    coordinates = np.arange(x_coords, y_coords)
    x, y = np.meshgrid(coordinates, coordinates)
    kernel = np.exp(-(x**2 + y**2) / (2 * sigma**2))
    output_kernel = kernel / kernel.sum()
    return output_kernel

In [22]:
def generate_pseudo_labels(image):
    
    image_np = image.squeeze().numpy() * 255
    
    kernel_size = 5
    kernel = gaussian_kernel(kernel_size, sigma=5.0)
    kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)

    padding = kernel_size//2
    image_padded = F.pad(image.unsqueeze(0), 
                              (padding, padding, padding, padding), 
                              mode='reflect')
    denoised_image = F.conv2d(image_padded, kernel)
    
    denoised_image = denoised_image.squeeze().numpy() * 255
    
    p2, p98 = np.percentile(denoised_image, (2, 98))
    denoised_image = np.clip((denoised_image - p2) / (p98 - p2) * 255, 0, 255)
    
    flattened_image = denoised_image.reshape(-1).astype(np.float32).reshape(-1, 1)
    clusters = 5
    kmeans = KMeans(n_clusters=clusters).fit(flattened_image)
    labels = kmeans.labels_.reshape(denoised_image.shape)
    
    if np.mean(labels == 0) > np.mean(labels == 1):
        labels = 1 - labels

    return torch.tensor(labels).unsqueeze(0).float() / 100

In [29]:
def segment():
    all_labels = []
    
    for images in dataloader:
        pseudo_label = [generate_pseudo_labels(img.cpu()) for img in images]
        all_labels.append(pseudo_label)

    return all_labels

In [27]:
full_dataset = AnkleXrayDataset(image_paths)
dataloader = DataLoader(full_dataset, batch_size=1)

In [None]:
segmented_images = segment()