In [None]:
import os
import threading
import warnings
from datetime import datetime
import traceback

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image, ImageTk, ImageFile
from sklearn.cluster import DBSCAN, KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset, random_split, Subset
from torchvision import models, transforms

import tkinter as tk
from tkinter import filedialog, messagebox, ttk

ImageFile.LOAD_TRUNCATED_IMAGES = True

warnings.filterwarnings('ignore')

try:
    RESAMPLE_LANCZOS = Image.Resampling.LANCZOS
except Exception:
    RESAMPLE_LANCZOS = Image.LANCZOS



class Config:
    
    DATASET_DIR = r"C:\Users\koust\AnaKonda\FOOT_PLANTAR_CLASSIFICATION\Dataset"
    LABELS_CSV = "plantar_labels.csv"
    MODEL_SAVE_PATH = "models/plantar_model.pth"
    RESULTS_DIR = "results"
    FEATURES_DIR = "features"
    
    IMAGE_SIZE = (224, 224)  
    BATCH_SIZE = 16
    EPOCHS = 100
    LEARNING_RATE = 0.0001
    EARLY_STOPPING_PATIENCE = 15
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    FEATURE_METHODS = ['pressure_stats', 'pressure_distribution', 'asymmetry', 'cog']
    
    N_CLUSTERS = 3  
    
    @classmethod
    def create_directories(cls):
        for dir_path in [cls.RESULTS_DIR, cls.FEATURES_DIR, 'models', 'logs']:
            os.makedirs(dir_path, exist_ok=True)



class PressureMapFeatureExtractor:
    
    def __init__(self):
        self.feature_names = []
    
    def extract_all_features(self, image_path):
        try:
            image = cv2.imread(image_path)
            if image is None:
                raise ValueError(f"Could not load image: {image_path}")
            
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            
            features = {}
            features.update(self._extract_pressure_statistics(gray))
            features.update(self._extract_pressure_distribution(gray))
            features.update(self._extract_spatial_features(gray))
            features.update(self._extract_asymmetry_features(image))
            
            return features
        except Exception as e:
            print(f"Error extracting features from {image_path}: {e}")
            traceback.print_exc()
            return None
    
    def _extract_pressure_statistics(self, gray_image):
        features = {
            'mean_pressure': float(np.mean(gray_image)),
            'std_pressure': float(np.std(gray_image)),
            'max_pressure': float(np.max(gray_image)),
            'min_pressure': float(np.min(gray_image)),
            'median_pressure': float(np.median(gray_image)),
            'pressure_range': float(np.max(gray_image) - np.min(gray_image)),
        }
        return features
    
    def _extract_pressure_distribution(self, gray_image):
        hist, _ = np.histogram(gray_image, bins=10, range=(0, 255))
        hist = hist.astype(float)
        if hist.sum() > 0:
            hist = hist / hist.sum()  
        else:
            hist = np.zeros_like(hist, dtype=float)
        
        features = {f'pressure_bin_{i}': float(val) for i, val in enumerate(hist)}
        
        high_pressure_threshold = np.percentile(gray_image, 75)
        high_pressure_area = float(np.sum(gray_image > high_pressure_threshold) / gray_image.size)
        features['high_pressure_area_ratio'] = high_pressure_area
        
        return features
    
    def _extract_spatial_features(self, gray_image):
        y_coords, x_coords = np.indices(gray_image.shape)
        total_pressure = float(np.sum(gray_image))
        
        if total_pressure > 0:
            cog_x = float(np.sum(x_coords * gray_image) / total_pressure)
            cog_y = float(np.sum(y_coords * gray_image) / total_pressure)
        else:
            cog_x, cog_y = float(gray_image.shape[1] / 2), float(gray_image.shape[0] / 2)
        
        features = {
            'cog_x_normalized': float(cog_x / gray_image.shape[1]),
            'cog_y_normalized': float(cog_y / gray_image.shape[0]),
        }
        
        contact_threshold = float(np.mean(gray_image) + np.std(gray_image))
        contact_area = float(np.sum(gray_image > contact_threshold) / gray_image.size)
        features['contact_area_ratio'] = contact_area
        
        return features
    
    def _extract_asymmetry_features(self, image):
        h, w = image.shape[:2]
        left_half = image[:, :w//2]
        right_half = image[:, w//2:]
        
        right_flipped = cv2.flip(right_half, 1)
        
        min_width = min(left_half.shape[1], right_flipped.shape[1])
        left_half = left_half[:, :min_width]
        right_flipped = right_flipped[:, :min_width]
        
        diff = cv2.absdiff(left_half, right_flipped)
        asymmetry_score = float(np.mean(diff) / 255.0) if diff.size else 0.0
        
        features = {
            'lr_asymmetry': asymmetry_score,
            'left_mean_pressure': float(np.mean(left_half)) if left_half.size else 0.0,
            'right_mean_pressure': float(np.mean(right_half)) if right_half.size else 0.0,
        }
        
        return features



class ImprovedPlantarCNN(nn.Module):
    
    def __init__(self, num_classes=2, pretrained=True):
        super(ImprovedPlantarCNN, self).__init__()
        
        self.backbone = models.resnet18(pretrained=pretrained)
        
        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, 
                                        padding=3, bias=False)
        
        num_features = self.backbone.fc.in_features
        
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        return self.backbone(x)


class PlantarDataset(Dataset):
    
    def __init__(self, csv_file, transform=None, augment=False):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
        self.augment = augment
        self.label_map = {'healthy': 0, 'unhealthy': 1, 'normal': 0, 'abnormal': 1}
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['image_path']
        label_str = self.df.iloc[idx]['label']
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"[Dataset] Error loading image at idx={idx}, path={img_path}: {e}")
            traceback.print_exc()
            image = Image.new('RGB', Config.IMAGE_SIZE, color='black')
            label_str = 'healthy'
        
        label = self.label_map.get(str(label_str).lower(), 0)
        
        if self.transform:
            try:
                image = self.transform(image)
            except Exception as e:
                print(f"[Dataset] Transform failed for idx={idx}, path={img_path}: {e}")
                traceback.print_exc()
                # Default fallback tensor
                image = transforms.ToTensor()(Image.new('RGB', Config.IMAGE_SIZE, color='black'))
        
        return image, label, img_path


class TrainingEngine:
    
    def __init__(self, model, train_loader, val_loader, device):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        
        self.optimizer = optim.AdamW(
            self.model.parameters(), 
            lr=Config.LEARNING_RATE,
            weight_decay=0.01
        )
        
        self.criterion = nn.CrossEntropyLoss()
        
        try:
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
            )
            self._scheduler_supports_verbose = True
        except TypeError:
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode='min', factor=0.5, patience=5
            )
            self._scheduler_supports_verbose = False
        
        self._last_lr = float(self.optimizer.param_groups[0]['lr'])
        
        self.best_val_loss = float('inf')
        self.patience_counter = 0
        
        self.history = {
            'train_loss': [], 'val_loss': [], 
            'train_acc': [], 'val_acc': [],
            'learning_rates': []
        }

    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels, _ in self.train_loader:
            images, labels = images.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(self.train_loader.dataset) if len(self.train_loader.dataset)>0 else 0.0
        epoch_acc = correct / total if total > 0 else 0.0
        
        return epoch_loss, epoch_acc

    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels, _ in self.val_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_loss = running_loss / len(self.val_loader.dataset) if len(self.val_loader.dataset)>0 else float('inf')
        val_acc = correct / total if total>0 else 0.0
        
        return val_loss, val_acc, all_preds, all_labels

    def train(self, num_epochs):
        print(f"\nTraining on {self.device}")
        print(f"Train size: {len(self.train_loader.dataset)}")
        print(f"Val size: {len(self.val_loader.dataset)}")
        print("="*60)
        
        for epoch in range(num_epochs):
            train_loss, train_acc = self.train_epoch()
            val_loss, val_acc, _, _ = self.validate()
            
            self.scheduler.step(val_loss)
            current_lr = float(self.optimizer.param_groups[0]['lr'])
            
            if not self._scheduler_supports_verbose and current_lr < self._last_lr:
                print(f"[LR Scheduler] Reduced LR: {self._last_lr:.6f} -> {current_lr:.6f}")
            self._last_lr = current_lr
            
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['learning_rates'].append(current_lr)
            
            print(f'Epoch {epoch+1}/{num_epochs}')
            print(f'  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}')
            print(f'  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}')
            print(f'  LR: {current_lr:.6f}')
            
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                }, Config.MODEL_SAVE_PATH)
                print(f'  ✓ Model saved (Val Loss improved)')
            else:
                self.patience_counter += 1
                if self.patience_counter >= Config.EARLY_STOPPING_PATIENCE:
                    print(f'\nEarly stopping triggered after {epoch+1} epochs')
                    break
            
            print()
        
        return self.history



class UnsupervisedAnalyzer:
    
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.feature_extractor = PressureMapFeatureExtractor()
        self.features_df = None
        self.scaler = StandardScaler()
        
    def extract_features_from_dataset(self):
        image_files = [f for f in os.listdir(self.image_dir) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.tiff', '.bmp'))]
        
        features_list = []
        valid_files = []
        
        print(f"Extracting features from {len(image_files)} images...")
        for img_file in image_files:
            img_path = os.path.join(self.image_dir, img_file)
            features = self.feature_extractor.extract_all_features(img_path)
            if features is not None:
                features['filename'] = img_file
                features_list.append(features)
                valid_files.append(img_file)
        
        self.features_df = pd.DataFrame(features_list)
        print(f"Extracted features from {len(features_list)} images")
        
        features_path = os.path.join(Config.FEATURES_DIR, 'extracted_features.csv')
        try:
            self.features_df.to_csv(features_path, index=False, encoding='utf-8-sig')
        except Exception:
            self.features_df.to_csv(features_path, index=False)
        print(f"Features saved to {features_path}")
        
        return self.features_df
    
    def perform_clustering(self, n_clusters=3, method='kmeans'):
        if self.features_df is None:
            print("No features available. Extract features first.")
            return None
        
        feature_cols = [col for col in self.features_df.columns if col != 'filename']
        X = self.features_df[feature_cols].values
        X_scaled = self.scaler.fit_transform(X)
        
        if method == 'kmeans':
            clusterer = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        elif method == 'dbscan':
            clusterer = DBSCAN(eps=0.5, min_samples=5)
        else:
            raise ValueError(f"Unknown method: {method}")
        
        labels = clusterer.fit_predict(X_scaled)
        
        self.features_df['cluster'] = labels
        
        if len(set(labels)) > 1:
            try:
                sil_score = silhouette_score(X_scaled, labels)
                print(f"Silhouette Score: {sil_score:.3f}")
            except Exception:
                pass
        
        results_path = os.path.join(Config.RESULTS_DIR, f'clustering_{method}_results.csv')
        try:
            self.features_df.to_csv(results_path, index=False, encoding='utf-8-sig')
        except Exception:
            self.features_df.to_csv(results_path, index=False)
        print(f"Clustering results saved to {results_path}")
        
        return labels, X_scaled
    
    def visualize_clusters(self, X_scaled, labels):
        pca = PCA(n_components=2)
        X_pca = pca.fit_transform(X_scaled)
        
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        scatter = plt.scatter(X_pca[:, 0], X_pca[:, 1], c=labels, 
                            cmap='viridis', alpha=0.6, edgecolors='k')
        plt.colorbar(scatter, label='Cluster')
        plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
        plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
        plt.title('Cluster Visualization (PCA)')
        plt.grid(True, alpha=0.3)
        
        plt.subplot(1, 2, 2)
        unique, counts = np.unique(labels, return_counts=True)
        plt.bar(unique, counts, color='skyblue', edgecolor='black')
        plt.xlabel('Cluster')
        plt.ylabel('Count')
        plt.title('Cluster Distribution')
        plt.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        plot_path = os.path.join(Config.RESULTS_DIR, 'clustering_visualization.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Visualization saved to {plot_path}")
        
        return X_pca


class GradCAM:
    
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        target_layer.register_forward_hook(self.save_activation)
        target_layer.register_backward_hook(self.save_gradient)
    
    def save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate_cam(self, input_image, target_class=None):
        self.model.eval()
        
        output = self.model(input_image)
        
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        self.model.zero_grad()
        target = output[0, target_class]
        target.backward()
        
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = torch.relu(cam)
        
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        
        return cam.squeeze().cpu().numpy(), target_class


class LabelingTool:
    
    def __init__(self, root, image_dir):
        self.root = root
        self.image_dir = image_dir
        self.current_idx = 0
        self.labels = {}
        self.current_image = None   
        self.current_photo = None  
        
        self.image_files = [f for f in os.listdir(image_dir)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png', '.tiff', '.bmp'))]
        self.image_files.sort()
        
        if os.path.exists(Config.LABELS_CSV):
            try:
                df = pd.read_csv(Config.LABELS_CSV)
                self.labels = dict(zip(df['image_path'], df['label']))
                print(f"Loaded {len(self.labels)} existing labels")
            except Exception:
                pass
        
        self.create_widgets()
        self.root.after(100, lambda: self.load_image(self.current_idx))
    
    def create_widgets(self):
        self.root.title("Plantar Pressure Image Labeling Tool")
        self.root.geometry("900x700")
        
        info_frame = ttk.Frame(self.root)
        info_frame.pack(fill=tk.X, padx=10, pady=5)
        
        self.progress_label = ttk.Label(info_frame, 
                                       text=f"Image 0/{len(self.image_files)}", 
                                       font=('Arial', 12, 'bold'))
        self.progress_label.pack(side=tk.LEFT)
        
        labeled_count = len([v for v in self.labels.values() if v])
        self.labeled_label = ttk.Label(info_frame, 
                                       text=f"Labeled: {labeled_count}", 
                                       font=('Arial', 12))
        self.labeled_label.pack(side=tk.RIGHT)
        
        
        img_frame = ttk.LabelFrame(self.root, text="Plantar Pressure Map")
        img_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5)
        
        self.canvas = tk.Canvas(img_frame, bg='black', highlightthickness=1, highlightbackground='#cccccc')
        self.canvas.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        self.canvas.bind("<Configure>", self._on_canvas_configure)
        
        self.filename_label = ttk.Label(self.root, text="", 
                                       font=('Arial', 10))
        self.filename_label.pack(pady=2)
        
        button_frame = ttk.LabelFrame(self.root, text="Classification")
        button_frame.pack(fill=tk.X, padx=10, pady=5)
        
        btn_frame_inner = ttk.Frame(button_frame)
        btn_frame_inner.pack(pady=10)
        
        ttk.Button(btn_frame_inner, text="Normal", command=lambda: self.label_current('healthy')).pack(side=tk.LEFT, padx=10)
        ttk.Button(btn_frame_inner, text="Abnormal", command=lambda: self.label_current('unhealthy')).pack(side=tk.LEFT, padx=10)
        ttk.Button(btn_frame_inner, text="Skip", command=self.skip_image).pack(side=tk.LEFT, padx=10)
        
        nav_frame = ttk.Frame(self.root)
        nav_frame.pack(fill=tk.X, padx=10, pady=5)
        
        ttk.Button(nav_frame, text="◄ Previous", command=self.prev_image).pack(side=tk.LEFT, padx=5)
        ttk.Button(nav_frame, text="Next ►", command=self.next_image).pack(side=tk.LEFT, padx=5)
        ttk.Button(nav_frame, text="Save & Exit", command=self.save_and_exit).pack(side=tk.RIGHT, padx=5)
        
        self.root.bind('1', lambda e: self.label_current('healthy'))
        self.root.bind('2', lambda e: self.label_current('unhealthy'))
        self.root.bind('<space>', lambda e: self.skip_image())
        self.root.bind('<Left>', lambda e: self.prev_image())
        self.root.bind('<Right>', lambda e: self.next_image())
    
    def _on_canvas_configure(self, event):
        if self.current_image is not None:
            self._display_image_on_canvas(self.current_image)
    
    def load_image(self, idx):
        if idx < 0 or idx >= len(self.image_files):
            return
        
        self.current_idx = idx
        filename = self.image_files[idx]
        img_path = os.path.join(self.image_dir, filename)
        
        try:
            pil_image = Image.open(img_path).convert('RGB')
            self.current_image = pil_image  
            
            self._display_image_on_canvas(pil_image)
            
            self.progress_label.config(text=f"Image {idx+1}/{len(self.image_files)}")
            
            current_label = self.labels.get(img_path, "Not labeled")
            self.filename_label.config(text=f"{filename} | Current: {current_label}")
            
        except Exception as e:
            messagebox.showerror("Error", f"Failed to load image: {e}")
            self.current_image = None
            self.canvas.delete("all")
    
    def _display_image_on_canvas(self, pil_image):
        canvas_width = max(1, self.canvas.winfo_width())
        canvas_height = max(1, self.canvas.winfo_height())
        
        if canvas_width < 10 or canvas_height < 10:
            self.root.after(100, lambda: self._display_image_on_canvas(pil_image))
            return
        
        max_w = max(1, canvas_width - 10)
        max_h = max(1, canvas_height - 10)
        
        image_copy = pil_image.copy()
        image_copy.thumbnail((max_w, max_h), RESAMPLE_LANCZOS)
        
        self.current_photo = ImageTk.PhotoImage(image_copy)
        
        self.canvas.delete("all")
        self.canvas.create_image(0, 0, image=self.current_photo, anchor='nw')
    
    def label_current(self, label):
        filename = self.image_files[self.current_idx]
        img_path = os.path.join(self.image_dir, filename)
        self.labels[img_path] = label
        
        labeled_count = len([v for v in self.labels.values() if v])
        self.labeled_label.config(text=f"Labeled: {labeled_count}")
        
        self.next_image()
    
    def skip_image(self):
        self.next_image()
    
    def next_image(self):
        if self.current_idx < len(self.image_files) - 1:
            self.load_image(self.current_idx + 1)
    
    def prev_image(self):
        if self.current_idx > 0:
            self.load_image(self.current_idx - 1)
    
    def save_and_exit(self):
        if not self.labels:
            messagebox.showwarning("Warning", "No labels to save!")
            return
        
        data = []
        for img_path, label in self.labels.items():
            data.append({'image_path': img_path, 'label': label})
        
        df = pd.DataFrame(data)
        try:
            df.to_csv(Config.LABELS_CSV, index=False, encoding='utf-8-sig')
        except Exception:
            df.to_csv(Config.LABELS_CSV, index=False)
        
        messagebox.showinfo("Success", f"Saved {len(self.labels)} labels to {Config.LABELS_CSV}")
        self.root.destroy()


class ProfessionalPlantarApp:
    
    def __init__(self, root):
        self.root = root
        self.root.title("Professional Plantar Pressure Analysis System")
        self.root.geometry("1200x800")
        
        Config.create_directories()
        
        self.model = None
        self.unsupervised_analyzer = UnsupervisedAnalyzer(Config.DATASET_DIR)
        self.feature_extractor = PressureMapFeatureExtractor()
        
        self.transform = transforms.Compose([
            transforms.Resize(Config.IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        self.create_gui()
        self.load_model_if_exists()
    
    def create_gui(self):
        menubar = tk.Menu(self.root)
        self.root.config(menu=menubar)
        
        file_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="File", menu=file_menu)
        file_menu.add_command(label="Label Images", command=self.open_labeling_tool)
        file_menu.add_separator()
        file_menu.add_command(label="Exit", command=self.root.quit)
        
        analysis_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="Analysis", menu=analysis_menu)
        analysis_menu.add_command(label="Extract Features", 
                                 command=self.extract_features)
        analysis_menu.add_command(label="Cluster Analysis", 
                                 command=self.run_clustering)
        analysis_menu.add_command(label="View Results", 
                                 command=self.view_results)
        
        model_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="Model", menu=model_menu)
        model_menu.add_command(label="Train Model", command=self.train_model)
        model_menu.add_command(label="Load Model", command=self.load_model_if_exists)
        
        self.notebook = ttk.Notebook(self.root)
        self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        self.create_prediction_tab()
        
        self.create_analysis_tab()
        
        self.status_var = tk.StringVar(value="Ready")
        status_bar = ttk.Label(self.root, textvariable=self.status_var, 
                              relief=tk.SUNKEN, anchor=tk.W)
        status_bar.pack(side=tk.BOTTOM, fill=tk.X)
    
    def create_prediction_tab(self):
        pred_frame = ttk.Frame(self.notebook)
        self.notebook.add(pred_frame, text="Image Prediction")
        
        control_frame = ttk.LabelFrame(pred_frame, text="Controls")
        control_frame.pack(fill=tk.X, padx=10, pady=5)
        
        ttk.Button(control_frame, text="Load Image", 
                  command=self.load_and_predict).pack(side=tk.LEFT, padx=5, pady=5)
        ttk.Button(control_frame, text="Batch Predict", 
                  command=self.batch_predict).pack(side=tk.LEFT, padx=5, pady=5)
        
        content_frame = ttk.Frame(pred_frame)
        content_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5)
        
        left_frame = ttk.LabelFrame(content_frame, text="Image")
        left_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5)
        
        self.image_canvas = tk.Canvas(left_frame, bg='black')
        self.image_canvas.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        self.image_canvas.bind("<Configure>", lambda e: None)  # keep to allow future resizing behavior
        
        right_frame = ttk.LabelFrame(content_frame, text="Results")
        right_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=5)
        
        result_frame = ttk.Frame(right_frame)
        result_frame.pack(fill=tk.X, pady=10, padx=10)
        
        ttk.Label(result_frame, text="Classification:", 
                 font=('Arial', 11, 'bold')).pack(anchor=tk.W)
        self.pred_result_var = tk.StringVar(value="No prediction")
        self.pred_result_label = ttk.Label(result_frame, 
                                          textvariable=self.pred_result_var,
                                          font=('Arial', 14, 'bold'))
        self.pred_result_label.pack(anchor=tk.W, pady=5)
        
        conf_frame = ttk.LabelFrame(right_frame, text="Confidence")
        conf_frame.pack(fill=tk.X, pady=10, padx=10)
        
        self.conf_text = tk.Text(conf_frame, height=3, font=('Consolas', 10))
        self.conf_text.pack(fill=tk.X, padx=5, pady=5)
        
        feat_frame = ttk.LabelFrame(right_frame, text="Extracted Features")
        feat_frame.pack(fill=tk.BOTH, expand=True, pady=10, padx=10)
        
        self.feat_text = tk.Text(feat_frame, font=('Consolas', 9))
        feat_scrollbar = ttk.Scrollbar(feat_frame, command=self.feat_text.yview)
        self.feat_text.configure(yscrollcommand=feat_scrollbar.set)
        feat_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
        self.feat_text.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
    
    def create_analysis_tab(self):
        analysis_frame = ttk.Frame(self.notebook)
        self.notebook.add(analysis_frame, text="Dataset Analysis")
        
        self.analysis_text = tk.Text(analysis_frame, font=('Consolas', 10))
        scrollbar = ttk.Scrollbar(analysis_frame, command=self.analysis_text.yview)
        self.analysis_text.configure(yscrollcommand=scrollbar.set)
        
        scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
        self.analysis_text.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
    
    def open_labeling_tool(self):
        labeling_window = tk.Toplevel(self.root)
        LabelingTool(labeling_window, Config.DATASET_DIR)
    
    def extract_features(self):
        self.status_var.set("Extracting features...")
        self.root.update()
        
        def extract_thread():
            try:
                features_df = self.unsupervised_analyzer.extract_features_from_dataset()
                
                summary = f"Feature Extraction Complete\n"
                summary += f"{'='*50}\n\n"
                summary += f"Total images processed: {len(features_df)}\n"
                summary += f"Total features: {len(features_df.columns)-1}\n\n"
                summary += f"Feature Statistics:\n"
                summary += f"{'-'*50}\n"
                summary += features_df.describe().to_string()
                
                self.analysis_text.delete('1.0', tk.END)
                self.analysis_text.insert('1.0', summary)
                self.status_var.set("Feature extraction complete")
            except Exception as e:
                self.status_var.set(f"Error: {e}")
                messagebox.showerror("Error", str(e))
        
        threading.Thread(target=extract_thread, daemon=True).start()
    
    def run_clustering(self):
        self.status_var.set("Running clustering...")
        self.root.update()
        
        def cluster_thread():
            try:
                labels, X_scaled = self.unsupervised_analyzer.perform_clustering(
                    n_clusters=Config.N_CLUSTERS
                )
                
                X_pca = self.unsupervised_analyzer.visualize_clusters(X_scaled, labels)
                
                summary = f"Clustering Analysis Complete\n"
                summary += f"{'='*50}\n\n"
                summary += f"Method: K-Means\n"
                summary += f"Number of clusters: {Config.N_CLUSTERS}\n\n"
                summary += f"Cluster Distribution:\n"
                summary += f"{'-'*50}\n"
                
                for cluster_id in range(Config.N_CLUSTERS):
                    count = np.sum(labels == cluster_id)
                    percentage = (count / len(labels)) * 100
                    summary += f"Cluster {cluster_id}: {count} images ({percentage:.1f}%)\n"
                
                summary += f"\n\nInterpretation:\n"
                summary += f"{'-'*50}\n"
                summary += f"Images have been grouped into {Config.N_CLUSTERS} distinct patterns.\n"
                summary += f"Review 'clustering_kmeans_results.csv' to see which images\n"
                summary += f"belong to each cluster. This can help identify:\n"
                summary += f"  - Normal vs abnormal patterns\n"
                summary += f"  - Different types of gait abnormalities\n"
                summary += f"  - Outliers or unusual cases\n"
                
                self.analysis_text.delete('1.0', tk.END)
                self.analysis_text.insert('1.0', summary)
                self.status_var.set("Clustering complete - check results folder")
                
                messagebox.showinfo("Success", 
                                  "Clustering complete! Check 'results' folder for visualizations.")
            except Exception as e:
                self.status_var.set(f"Error: {e}")
                messagebox.showerror("Error", str(e))
        
        threading.Thread(target=cluster_thread, daemon=True).start()
    
    def view_results(self):
        results_dir = Config.RESULTS_DIR
        if os.path.exists(results_dir):
            try:
                os.startfile(results_dir)  # Windows
            except Exception:
                messagebox.showinfo("Info", f"Results path: {os.path.abspath(results_dir)}")
        else:
            messagebox.showinfo("Info", "No results available yet")
    
    def load_model_if_exists(self):
        if os.path.exists(Config.MODEL_SAVE_PATH):
            try:
                self.model = ImprovedPlantarCNN().to(Config.DEVICE)
                checkpoint = torch.load(Config.MODEL_SAVE_PATH, map_location=Config.DEVICE)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.model.eval()
                self.status_var.set(f"Model loaded successfully (Val Acc: {checkpoint.get('val_acc', 0):.2%})")
                return True
            except Exception as e:
                self.status_var.set(f"Error loading model: {e}")
                return False
        else:
            self.status_var.set("No trained model found")
            return False
    
    def train_model(self):
        if not os.path.exists(Config.LABELS_CSV):
            messagebox.showerror("Error", 
                               "No labels file found! Please label images first.")
            return
        
        response = messagebox.askyesno("Train Model", 
                                      "This will start training. Continue?")
        if not response:
            return
        
        self.status_var.set("Training started...")
        
        def train_thread():
            try:
                train_transform = transforms.Compose([
                    transforms.Resize(Config.IMAGE_SIZE),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomRotation(15),
                    transforms.ColorJitter(brightness=0.2, contrast=0.2),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                       std=[0.229, 0.224, 0.225])
                ])
                
                val_transform = transforms.Compose([
                    transforms.Resize(Config.IMAGE_SIZE),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                       std=[0.229, 0.224, 0.225])
                ])
                
                df = pd.read_csv(Config.LABELS_CSV)
                total_len = len(df)
                if total_len < 2:
                    messagebox.showerror("Error", "Not enough labeled examples to train.")
                    self.status_var.set("Training aborted: insufficient data")
                    return
                
                train_size = int(0.8 * total_len)
                indices = list(range(total_len))
                np.random.seed(42)
                np.random.shuffle(indices)
                train_indices = indices[:train_size]
                val_indices = indices[train_size:]
                
                train_dataset = Subset(PlantarDataset(Config.LABELS_CSV, transform=train_transform), train_indices)
                val_dataset = Subset(PlantarDataset(Config.LABELS_CSV, transform=val_transform), val_indices)
                
                train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, 
                                        shuffle=True, num_workers=0, pin_memory=False)
                val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, 
                                      num_workers=0, pin_memory=False)
                
                model = ImprovedPlantarCNN(pretrained=True)
                
                engine = TrainingEngine(model, train_loader, val_loader, Config.DEVICE)
                history = engine.train(Config.EPOCHS)
                
                self.status_var.set("Training complete!")
                messagebox.showinfo("Success", "Model training complete!")
                
                self.load_model_if_exists()
                
            except Exception as e:
                print("Training error:", e)
                traceback.print_exc()
                self.status_var.set(f"Training error: {e}")
                messagebox.showerror("Error", str(e))
        
        threading.Thread(target=train_thread, daemon=True).start()
    
    def load_and_predict(self):
        """Load image and make prediction"""
        if self.model is None:
            messagebox.showerror("Error", "No model loaded!")
            return
        
        file_path = filedialog.askopenfilename(
            filetypes=[("Image files", "*.jpg;*.jpeg;*.png;*.tiff;*.bmp")]
        )
        if not file_path:
            return
        
        try:
            image = Image.open(file_path).convert('RGB')
            display_img = image.copy()
            canvas_width = max(1, self.image_canvas.winfo_width())
            canvas_height = max(1, self.image_canvas.winfo_height())
            if canvas_width > 10 and canvas_height > 10:
                display_img.thumbnail((canvas_width-10, canvas_height-10), RESAMPLE_LANCZOS)
            
            photo = ImageTk.PhotoImage(display_img)
            
            self.image_canvas.delete("all")
            self.image_canvas.image = photo
            self.image_canvas.create_image(0, 0, image=photo, anchor='nw')
            
            input_tensor = self.transform(image).unsqueeze(0).to(Config.DEVICE)
            
            with torch.no_grad():
                outputs = self.model(input_tensor)
                probabilities = torch.softmax(outputs, dim=1)[0]
            
            classes = ['Normal (Healthy)', 'Abnormal (Unhealthy)']
            pred_class = int(probabilities.argmax().item())
            confidence = float(probabilities[pred_class].item() * 100)
            
            self.pred_result_var.set(classes[pred_class])
            color = "green" if pred_class == 0 else "red"
            self.pred_result_label.config(foreground=color)
            
            conf_text = f"Normal:   {probabilities[0].item()*100:.2f}%\n"
            conf_text += f"Abnormal: {probabilities[1].item()*100:.2f}%\n"
            conf_text += f"Confidence: {confidence:.2f}%"
            
            self.conf_text.delete('1.0', tk.END)
            self.conf_text.insert('1.0', conf_text)
            
            features = self.feature_extractor.extract_all_features(file_path)
            if features:
                feat_text = "Key Features:\n" + "="*40 + "\n\n"
                for key, value in sorted(features.items()):
                    if key != 'filename':
                        try:
                            feat_text += f"{key:30s}: {float(value):.4f}\n"
                        except Exception:
                            feat_text += f"{key:30s}: {value}\n"
                
                self.feat_text.delete('1.0', tk.END)
                self.feat_text.insert('1.0', feat_text)
            
            self.status_var.set(f"Prediction complete: {classes[pred_class]}")
            
        except Exception as e:
            self.status_var.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))
    
    def batch_predict(self):
        if self.model is None:
            messagebox.showerror("Error", "No model loaded!")
            return
        
        folder_path = filedialog.askdirectory(title="Select folder with images")
        if not folder_path:
            return
        
        messagebox.showinfo("Info", "Batch prediction feature - coming soon!")


if __name__ == "__main__":
    try:
        import multiprocessing as mp
        mp.set_start_method('spawn', force=False)
    except Exception:
        pass

    print("="*60)
    print("Professional Plantar Pressure Analysis System")
    print("="*60)
    print(f"Device: {Config.DEVICE}")
    print(f"Dataset: {Config.DATASET_DIR}")
    print("="*60)
    
    Config.create_directories()
    
    root = tk.Tk()
    app = ProfessionalPlantarApp(root)
    root.mainloop()


Professional Plantar Pressure Analysis System
Device: cuda
Dataset: C:\Users\koust\AnaKonda\FOOT_PLANTAR_CLASSIFICATION\Dataset
Loaded 202 existing labels
Extracting features from 202 images...
Extracted features from 202 images
Features saved to features\extracted_features.csv
Silhouette Score: 0.160
Clustering results saved to results\clustering_kmeans_results.csv
Visualization saved to results\clustering_visualization.png

Training on cuda
Train size: 161
Val size: 41
Epoch 1/100
  Train Loss: 0.7161 | Train Acc: 0.4596
  Val Loss: 0.7031 | Val Acc: 0.4634
  LR: 0.000100
  ✓ Model saved (Val Loss improved)

Epoch 2/100
  Train Loss: 0.6995 | Train Acc: 0.5404
  Val Loss: 0.7065 | Val Acc: 0.4390
  LR: 0.000100

Epoch 3/100
  Train Loss: 0.6774 | Train Acc: 0.5404
  Val Loss: 0.6872 | Val Acc: 0.5122
  LR: 0.000100
  ✓ Model saved (Val Loss improved)

Epoch 4/100
  Train Loss: 0.6883 | Train Acc: 0.5714
  Val Loss: 0.6754 | Val Acc: 0.7073
  LR: 0.000100
  ✓ Model saved (Val Loss imp