In [None]:
import pickle
import requests
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
import lime
from lime import lime_image
from skimage.segmentation import mark_boundaries, slic, quickshift, felzenszwalb
import warnings
warnings.filterwarnings('ignore')

In [None]:
def optimized_slic_segmentation(image):
    """Optimized SLIC segmentation for better IoU with reasonable speed"""
    return slic(image, n_segments=100, compactness=15, sigma=1, start_label=1, max_num_iter=15)

def quick_felzenszwalb_segmentation(image):
    """Faster Felzenszwalb for complex scenes"""
    return felzenszwalb(image, scale=150, sigma=0.8, min_size=30)

class OptimizedLIMEExplainer:
    def __init__(self):
        """Initialize LIME explainer with ResNet50 model"""
        # Load pre-trained ResNet50 model
        self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.model.eval()
        
        # Define image transformations
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        # Initialize LIME explainer
        self.explainer = lime_image.LimeImageExplainer()
        
        # Load ImageNet class labels
        self.load_imagenet_labels()
        
        # Image URLs from the assignment
        self.image_urls = {
            'West_Highland_white_terrier': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n02098286_West_Highland_white_terrier.JPEG',
            'American_coot': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n02018207_American_coot.JPEG',
            'racer': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n04037443_racer.JPEG',
            'flamingo': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n02007558_flamingo.JPEG',
            'kite': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01608432_kite.JPEG',
            'goldfish': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01443537_goldfish.JPEG',
            'tiger_shark': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01491361_tiger_shark.JPEG',
            'vulture': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01616318_vulture.JPEG',
            'common_iguana': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01677366_common_iguana.JPEG',
            'orange': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n07747607_orange.JPEG'
        }
    
    def load_imagenet_labels(self):
        """Load ImageNet class labels"""
        try:
            url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
            response = requests.get(url)
            self.class_labels = response.text.strip().split('\n')
        except:
            self.class_labels = [f"class_{i}" for i in range(1000)]
    
    def load_image(self, url):
        """Load image from URL"""
        try:
            response = requests.get(url)
            image = Image.open(BytesIO(response.content)).convert('RGB')
            return image
        except Exception as e:
            print(f"Error loading image from {url}: {e}")
            return None
    
    def classifier_fn(self, images):
        # Use smaller batch size for better memory management and speed
        batch_size = min(32, len(images))
        all_probabilities = []
        
        for i in range(0, len(images), batch_size):
            batch_images = images[i:i + batch_size]
            batch = []
            
            for img in batch_images:
                pil_img = Image.fromarray(img.astype('uint8'))
                tensor_img = self.transform(pil_img).unsqueeze(0)
                batch.append(tensor_img)
            
            batch_tensor = torch.cat(batch, dim=0)
            
            with torch.no_grad():
                outputs = self.model(batch_tensor)
                probabilities = torch.nn.functional.softmax(outputs, dim=1)
                all_probabilities.append(probabilities.cpu().numpy())
        
        return np.vstack(all_probabilities)
    
    def get_prediction(self, image):
        """Get model prediction for an image"""
        input_tensor = self.transform(image).unsqueeze(0)
        
        with torch.no_grad():
            outputs = self.model(input_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0][predicted_class].item()
        
        return predicted_class, confidence
    
    def create_optimized_parameters(self):        
        key_names = [
            'West_Highland_white_terrier',
            'American_coot', 
            'racer',
            'flamingo',
            'kite',
            'goldfish',
            'tiger_shark',
            'vulture',
            'common_iguana',
            'orange'
        ]
        
        all_params = {}
        
        for name in key_names:
            # Base parameters optimized for speed-accuracy balance
            base_params = {
                "labels": (1,),
                "hide_color": 0,
                "top_labels": 1,
                "num_features": 10,      # Balanced - not too high
                "num_samples": 300,      # Balanced for speed-accuracy
                "batch_size": 32,        # Smaller for better speed
                "segmentation_fn": optimized_slic_segmentation,
                "distance_metric": "cosine",
                "model_regressor": None,
                "random_seed": 42
            }
            
            # Targeted optimization for problem images
            if name in ['goldfish', 'orange']:
                # Simple objects - can use fewer samples
                base_params.update({
                    "num_features": 8,
                    "num_samples": 250,
                    "segmentation_fn": optimized_slic_segmentation
                })
            elif name in ['West_Highland_white_terrier', 'American_coot', 'flamingo']:
                # Clear subjects - moderate parameters
                base_params.update({
                    "num_features": 10,
                    "num_samples": 300,
                    "segmentation_fn": optimized_slic_segmentation
                })
            elif name in ['vulture', 'tiger_shark', 'common_iguana']:
                # Problem images - need better segmentation and more features
                base_params.update({
                    "num_features": 12,       # More features for complex scenes
                    "num_samples": 350,       # More samples for accuracy
                    "segmentation_fn": quick_felzenszwalb_segmentation,  # Better segmentation
                    "hide_color": 0           # Keep color info for these complex images
                })
            elif name == 'kite':
                # Most problematic - special handling
                base_params.update({
                    "num_features": 14,       # Even more features
                    "num_samples": 400,       # More samples
                    "segmentation_fn": quick_felzenszwalb_segmentation,
                    "hide_color": 0,
                    "distance_metric": "euclidean"
                })
            elif name == 'racer':
                # Vehicle - moderate complexity
                base_params.update({
                    "num_features": 11,
                    "num_samples": 320,
                    "segmentation_fn": optimized_slic_segmentation
                })
            
            all_params[name] = base_params
        
        return all_params
    
    def create_pickle_safe_optimized_parameters(self):
        """Create optimized parameters without function references for pickle"""
        
        key_names = [
            'West_Highland_white_terrier',
            'American_coot', 
            'racer',
            'flamingo',
            'kite',
            'goldfish',
            'tiger_shark',
            'vulture',
            'common_iguana',
            'orange'
        ]
        
        all_params = {}
        
        for name in key_names:
            # Base parameters - balanced for speed and accuracy
            base_params = {
                "labels": (1,),
                "hide_color": 0,
                "top_labels": 1,
                "num_features": 10,
                "num_samples": 300,
                "batch_size": 32,
                "segmentation_fn": None,  # Use LIME default
                "distance_metric": "cosine",
                "model_regressor": None,
                "random_seed": 42
            }
            
            # Targeted optimization
            if name in ['goldfish', 'orange']:
                # Simple objects
                base_params.update({
                    "num_features": 8,
                    "num_samples": 250
                })
            elif name in ['West_Highland_white_terrier', 'American_coot', 'flamingo']:
                # Clear subjects
                base_params.update({
                    "num_features": 10,
                    "num_samples": 300
                })
            elif name in ['vulture', 'tiger_shark', 'common_iguana']:
                # Problem images - targeted improvement
                base_params.update({
                    "num_features": 12,
                    "num_samples": 350,
                    "hide_color": 0
                })
            elif name == 'kite':
                # Most problematic
                base_params.update({
                    "num_features": 14,
                    "num_samples": 400,
                    "hide_color": 0,
                    "distance_metric": "euclidean"
                })
            elif name == 'racer':
                # Vehicle
                base_params.update({
                    "num_features": 11,
                    "num_samples": 320
                })
            
            all_params[name] = base_params
        
        return all_params
    
    def generate_lime_explanation(self, image, image_name, params):
        """Generate LIME explanation with timing"""
        import time
        
        image_np = np.array(image)
        predicted_class, confidence = self.get_prediction(image)
        predicted_label = self.class_labels[predicted_class]
        
        print(f"Analyzing {image_name}:")
        print(f"Predicted class: {predicted_label} (confidence: {confidence:.3f})")
        
        start_time = time.time()
        
        # Generate explanation
        explanation = self.explainer.explain_instance(
            image_np, 
            self.classifier_fn,
            **params
        )
        
        end_time = time.time()
        execution_time = end_time - start_time
        
        print(f"Explanation generated in {execution_time:.3f} seconds")
        
        return explanation, predicted_class, predicted_label, confidence, execution_time
    
    def visualize_lime_explanation(self, image, explanation, image_name, predicted_class, predicted_label, confidence):
        """Visualize LIME explanation"""
        # Get image and mask from explanation
        temp, mask = explanation.get_image_and_mask(
            predicted_class, 
            positive_only=True, 
            num_features=8,  # Show reasonable number of features
            hide_rest=False
        )
        
        # Create visualization
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original image
        axes[0].imshow(image)
        axes[0].set_title(f'Original Image\n{image_name}', fontsize=12)
        axes[0].axis('off')
        
        # LIME explanation
        axes[1].imshow(mark_boundaries(temp / 255.0, mask))
        axes[1].set_title(f'LIME Explanation\n{predicted_label}\nConf: {confidence:.3f}', fontsize=12)
        axes[1].axis('off')
        
        # Important regions only
        temp_neg, mask_neg = explanation.get_image_and_mask(
            predicted_class, 
            positive_only=False, 
            num_features=8,
            hide_rest=True
        )
        axes[2].imshow(mark_boundaries(temp_neg / 255.0, mask_neg))
        axes[2].set_title('Important Regions Only', fontsize=12)
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.savefig(f'lime_explanation_{image_name}_optimized.png', dpi=150, bbox_inches='tight')
        plt.show()
        
        # Print feature importance
        features = explanation.local_exp[predicted_class]
        sorted_features = sorted(features, key=lambda x: abs(x[1]), reverse=True)[:6]
        print(f"Top 6 feature importances: {[(f[0], f'{f[1]:.4f}') for f in sorted_features]}")
        print()
        
        return temp, mask
    
    def analyze_all_images_optimized(self):
        """Analyze all images with optimized parameters"""
        all_params = self.create_optimized_parameters()
        
        results = {}
        total_time = 0
                
        for image_name, url in self.image_urls.items():
            print(f"Processing {image_name}...")
            
            image = self.load_image(url)
            if image is None:
                print(f"Failed to load {image_name}, skipping...")
                continue
            
            params = all_params[image_name].copy()
            # Remove parameters that shouldn't be passed to explain_instance
            params.pop('labels', None)
            params.pop('top_labels', None)
            
            try:
                explanation, predicted_class, predicted_label, confidence, execution_time = self.generate_lime_explanation(
                    image, image_name, params
                )
                
                total_time += execution_time
                
                # Visualize explanation
                temp, mask = self.visualize_lime_explanation(
                    image, explanation, image_name, predicted_class, predicted_label, confidence
                )
                
                results[image_name] = {
                    'explanation': explanation,
                    'predicted_class': predicted_class,
                    'predicted_label': predicted_label,
                    'confidence': confidence,
                    'image': image,
                    'visualization': (temp, mask),
                    'execution_time': execution_time
                }
                
            except Exception as e:
                print(f"Error processing {image_name}: {e}")
                continue
        
        avg_time = total_time / len(results) if results else 0
        print(f"\nAverage execution time per image: {avg_time:.3f} seconds")
        
        return results
    
    def save_optimized_parameters_pickle(self, filename="lime_parameters_optimized.pkl"):
        """Save optimized parameters to pickle file"""
        params = self.create_pickle_safe_optimized_parameters()
        
        with open(filename, "wb") as f:
            pickle.dump(params, f)
        
        print(f"Optimized parameters saved to {filename}")
        self.verify_pickle_file(filename)
        
        return params
    
    def verify_pickle_file(self, filename):
        """Verify the pickle file format"""
        try:
            with open(filename, 'rb') as f:
                loaded_params = pickle.load(f)
            
            print(f"\nVerification of {filename}:")
            print(f"Number of entries: {len(loaded_params)}")
            
            expected_keys = [
                'West_Highland_white_terrier', 'American_coot', 'racer', 'flamingo', 'kite',
                'goldfish', 'tiger_shark', 'vulture', 'common_iguana', 'orange'
            ]
            
            missing_keys = [key for key in expected_keys if key not in loaded_params]
            extra_keys = [key for key in loaded_params if key not in expected_keys]
            
            if not missing_keys and not extra_keys:
                print("✅ All required keys present, no extra keys")
            else:
                if missing_keys:
                    print(f"❌ Missing keys: {missing_keys}")
                if extra_keys:
                    print(f"❌ Extra keys: {extra_keys}")
            
            # Show parameters for problem images
            problem_images = ['vulture', 'tiger_shark', 'kite', 'common_iguana']
            print("\nParameters for problem images:")
            for img in problem_images:
                if img in loaded_params:
                    params = loaded_params[img]
                    print(f"{img}: samples={params.get('num_samples')}, features={params.get('num_features')}")
            
            print("✅ Optimized pickle file created successfully!")
            
        except Exception as e:
            print(f"Error verifying pickle file: {e}")
    
    def submit_to_server(self, pickle_file, token):
        """Submit pickle file to server"""
        try:
            url = "http://34.122.51.94:9091/lime"
            
            with open(pickle_file, "rb") as f:
                response = requests.post(
                    url,
                    files={"file": f},
                    headers={"token": token}
                )
            
            print(f"Submission response status: {response.status_code}")
            
            try:
                response_json = response.json()
                print(f"Response JSON: {response_json}")
                
                if response.status_code == 200:
                    print("✅ Submission successful!")
                    if 'avg_iou' in response_json:
                        iou = response_json['avg_iou']
                        print(f"Average IoU: {iou:.4f}")
                    if 'avg_time' in response_json:
                        time_val = response_json['avg_time']
                        print(f"Average Time: {time_val:.4f} seconds")
                else:
                    print(f"❌ Submission failed with status {response.status_code}")
                    
            except Exception as json_error:
                print(f"Could not parse JSON response: {json_error}")
                print(f"Raw response: {response.text}")
                
        except Exception as e:
            print(f"Error during submission: {e}")

def main_optimized():
    """Main function for optimized LIME analysis"""
    explainer = OptimizedLIMEExplainer()
    
    # 1. Analyze all images
    results = explainer.analyze_all_images_optimized()
    
    # 2. Create and save parameters
    params = explainer.save_optimized_parameters_pickle()
    
    # 3. Submit to server
    token = "96005201"  # Replace with your actual token
    explainer.submit_to_server("lime_parameters_optimized.pkl", token)
    
    print("\n=== OPTIMIZED LIME Analysis Summary ===")
    print(f"Successfully analyzed {len(results)} images")

    return results, params

In [None]:
if __name__ == "__main__":
    results, params = main_optimized()