# The AI

## Imports

In [35]:
# Complete Training and Evaluation System
import os
import sys
import json
import torch
import argparse
import logging
from pathlib import Path
from typing import List, Dict
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from datetime import datetime
# Advanced Data Augmentation and Preprocessing System
import json
import base64
import cv2
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import io
import random
from typing import List, Dict, Tuple, Optional
import albumentations as A
from pathlib import Path
import copy
import math


## DataSet and Data Loaders

In [36]:

class AdvancedUIDataAugmenter:
    """Advanced augmentation specifically for UI interaction data"""
    
    def __init__(self, preserve_ui_elements: bool = True):
        self.preserve_ui_elements = preserve_ui_elements
        self.setup_augmentations()
    
    def setup_augmentations(self):
        """Setup different types of augmentations"""
        
        # Visual augmentations (safe for UI)
        self.visual_transforms = A.Compose([
            A.RandomBrightnessContrast(
                brightness_limit=0.15, 
                contrast_limit=0.15, 
                p=0.4
            ),
            A.HueSaturationValue(
                hue_shift_limit=8, 
                sat_shift_limit=15, 
                val_shift_limit=15, 
                p=0.3
            ),
            A.GaussNoise(var_limit=(5.0, 25.0), p=0.2),
            A.RandomGamma(gamma_limit=(85, 115), p=0.3),
            A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.2),
        ])
        
        # Geometric augmentations (minimal to preserve click coordinates)
        self.geometric_transforms = A.Compose([
            A.ShiftScaleRotate(
                shift_limit=0.02,  # Very small shifts
                scale_limit=0.03,  # Very small scaling
                rotate_limit=1,    # Minimal rotation
                border_mode=cv2.BORDER_CONSTANT,
                value=0,
                p=0.3
            ),
        ])
        
        # Screen simulation augmentations
        self.screen_simulation = [
            self._simulate_different_screen_sizes,
            self._simulate_different_browsers,
            self._simulate_zoom_levels,
            self._add_cursor_variations,
        ]
    
    def augment_sequence(self, sequence_data: Dict, num_augmentations: int = 5) -> List[Dict]:
        """Generate multiple augmented versions of a sequence"""
        augmented_sequences = []
        
        for i in range(num_augmentations):
            # Create a copy of the original sequence
            aug_sequence = copy.deepcopy(sequence_data)
            
            # Apply random augmentations
            aug_sequence = self._apply_random_augmentations(aug_sequence)
            
            # Add variation identifier
            aug_sequence['augmentation_id'] = i
            aug_sequence['original_task'] = sequence_data.get('task_label', 'unknown')
            aug_sequence['task_label'] = f"{sequence_data.get('task_label', 'unknown')}_aug_{i}"
            
            augmented_sequences.append(aug_sequence)
        
        return augmented_sequences
    
    def _apply_random_augmentations(self, sequence: Dict) -> Dict:
        """Apply random augmentations to a sequence"""
        
        # Augment screenshots
        if 'screenshots' in sequence:
            for screenshot in sequence['screenshots']:
                screenshot['image_base64'] = self._augment_screenshot(
                    screenshot['image_base64']
                )
        
        # Apply action variations
        if 'actions' in sequence:
            sequence['actions'] = self._augment_actions(sequence['actions'])
        
        # Apply timing variations
        sequence = self._add_timing_variations(sequence)
        
        return sequence
    
    def _augment_screenshot(self, base64_image: str) -> str:
        """Augment a single screenshot"""
        # Decode image
        img_data = base64.b64decode(base64_image)
        img = Image.open(io.BytesIO(img_data))
        img_array = np.array(img)
        
        # Apply visual augmentations
        if random.random() < 0.7:
            augmented = self.visual_transforms(image=img_array)
            img_array = augmented['image']
        
        # Apply geometric augmentations (carefully)
        if random.random() < 0.3:
            augmented = self.geometric_transforms(image=img_array)
            img_array = augmented['image']
        
        # Apply screen simulation
        if random.random() < 0.4:
            aug_func = random.choice(self.screen_simulation)
            img_array = aug_func(img_array)
        
        # Convert back to base64
        img_pil = Image.fromarray(img_array)
        buffer = io.BytesIO()
        img_pil.save(buffer, format='PNG')
        augmented_base64 = base64.b64encode(buffer.getvalue()).decode()
        
        return augmented_base64
    
    def _augment_actions(self, actions: List[Dict]) -> List[Dict]:
        """Add variations to actions"""
        augmented_actions = []
        
        for action in actions:
            aug_action = copy.deepcopy(action)
            
            # Add small coordinate variations for mouse actions
            if action.get('type') == 'mouse' and 'coordinates' in action:
                coords = action['coordinates']
                if coords:
                    # Add small random offset (±5 pixels)
                    offset_x = random.randint(-5, 5)
                    offset_y = random.randint(-5, 5)
                    
                    aug_action['coordinates'] = {
                        'x': max(0, coords['x'] + offset_x),
                        'y': max(0, coords['y'] + offset_y)
                    }
            
            # Add timing variations
            if random.random() < 0.3:
                time_offset = random.randint(-100, 100)  # ±100ms
                aug_action['timestamp_ms'] = max(0, action['timestamp_ms'] + time_offset)
            
            augmented_actions.append(aug_action)
        
        return augmented_actions
    
    def _add_timing_variations(self, sequence: Dict) -> Dict:
        """Add realistic timing variations"""
        if 'duration_ms' in sequence:
            # Add ±10% variation to total duration
            variation = int(sequence['duration_ms'] * 0.1)
            offset = random.randint(-variation, variation)
            sequence['duration_ms'] = max(1000, sequence['duration_ms'] + offset)
        
        return sequence
    
    def _simulate_different_screen_sizes(self, image: np.ndarray) -> np.ndarray:
        """Simulate different screen sizes"""
        original_size = image.shape[:2]
        
        # Common screen sizes
        screen_sizes = [
            (1366, 768),   # HD
            (1920, 1080),  # Full HD
            (2560, 1440),  # QHD
            (1440, 900),   # MacBook
            (1280, 720),   # HD Ready
        ]
        
        target_size = random.choice(screen_sizes)
        
        # Resize and pad/crop as needed
        resized = cv2.resize(image, target_size)
        
        # If original was larger, we might need to adjust back
        if original_size != target_size:
            resized = cv2.resize(resized, (original_size[1], original_size[0]))
        
        return resized
    
    def _simulate_different_browsers(self, image: np.ndarray) -> np.ndarray:
        """Simulate different browser interfaces"""
        # Add browser-specific UI elements (simplified)
        h, w = image.shape[:2]
        
        # Simulate different browser header heights
        header_heights = [60, 80, 100, 120]
        header_height = random.choice(header_heights)
        
        # Add colored header
        header_colors = [
            [240, 240, 240],  # Light gray (Chrome-like)
            [230, 230, 230],  # Slightly darker gray
            [250, 250, 250],  # Very light gray
            [220, 220, 220],  # Medium gray
        ]
        
        header_color = random.choice(header_colors)
        
        # Create modified image with browser header simulation
        modified_image = image.copy()
        if random.random() < 0.3:  # Apply browser simulation occasionally
            # Shift content down to simulate browser header
            shifted_content = np.zeros_like(modified_image)
            shifted_content[header_height:, :] = modified_image[:-header_height, :]
            shifted_content[:header_height, :] = header_color
            modified_image = shifted_content
        
        return modified_image
    
    def _simulate_zoom_levels(self, image: np.ndarray) -> np.ndarray:
        """Simulate different browser zoom levels"""
        zoom_factors = [0.9, 0.95, 1.0, 1.05, 1.1, 1.25]
        zoom = random.choice(zoom_factors)
        
        if zoom != 1.0:
            h, w = image.shape[:2]
            new_h, new_w = int(h * zoom), int(w * zoom)
            
            # Resize image
            resized = cv2.resize(image, (new_w, new_h))
            
            # Crop or pad to original size
            if zoom > 1.0:  # Crop center
                start_y = (new_h - h) // 2
                start_x = (new_w - w) // 2
                cropped = resized[start_y:start_y+h, start_x:start_x+w]
                return cropped
            else:  # Pad with background color
                padded = np.full((h, w, 3), fill_value=240, dtype=np.uint8)
                start_y = (h - new_h) // 2
                start_x = (w - new_w) // 2
                padded[start_y:start_y+new_h, start_x:start_x+new_w] = resized
                return padded
        
        return image
    
    def _add_cursor_variations(self, image: np.ndarray) -> np.ndarray:
        """Add different cursor styles (simulated)"""
        # This is a simplified version - in reality, cursors aren't part of screenshots
        # But we can simulate their presence for training variety
        if random.random() < 0.2:  # Rarely add cursor simulation
            h, w = image.shape[:2]
            
            # Random cursor position
            cursor_x = random.randint(50, w - 50)
            cursor_y = random.randint(50, h - 50)
            
            # Add small cursor-like mark (just a few pixels)
            cursor_color = [0, 0, 0]  # Black cursor
            image[cursor_y:cursor_y+2, cursor_x:cursor_x+2] = cursor_color
        
        return image


In [37]:

class TaskVariationGenerator:
    """Generate variations of tasks to increase dataset diversity"""
    
    def __init__(self):
        self.task_templates = {
            'login': {
                'variations': [
                    'login_gmail', 'login_facebook', 'login_twitter', 
                    'login_github', 'login_linkedin', 'login_microsoft', "L"
                ],
                'common_elements': ['email_field', 'password_field', 'login_button']
            },
            'search': {
                'variations': [
                    'search_google', 'search_youtube', 'search_amazon',
                    'search_stackoverflow', 'search_github'
                ],
                'common_elements': ['search_box', 'search_button', 'results']
            },
            'form_fill': {
                'variations': [
                    'contact_form', 'registration_form', 'survey_form',
                    'checkout_form', 'profile_form'
                ],
                'common_elements': ['text_fields', 'dropdown', 'submit_button']
            }
        }
    
    def generate_task_variations(self, original_sequence: Dict) -> List[Dict]:
        """Generate task variations based on the original task"""
        base_task = original_sequence.get('task_label', 'unknown')
        variations = []
        
        # Find matching task template
        task_type = self._identify_task_type(base_task)
        
        if task_type in self.task_templates:
            template = self.task_templates[task_type]
            
            for variation_name in template['variations']:
                if variation_name != base_task:  # Don't duplicate original
                    varied_sequence = copy.deepcopy(original_sequence)
                    varied_sequence['task_label'] = variation_name
                    varied_sequence['task_variation'] = True
                    varied_sequence['base_task'] = base_task
                    variations.append(varied_sequence)
        
        return variations
    
    def _identify_task_type(self, task_label: str) -> str:
        """Identify the general type of task"""
        task_lower = task_label.lower()
        
        if 'login' in task_lower or 'signin' in task_lower:
            return 'login'
        elif 'search' in task_lower:
            return 'search'
        elif 'form' in task_lower or 'fill' in task_lower:
            return 'form_fill'
        
        return 'unknown'


In [38]:

class SequentialActionGenerator:
    """Generate realistic sequential actions for training"""
    
    def __init__(self):
        self.action_sequences = {
            'typing': self._generate_typing_sequence,
            'navigation': self._generate_navigation_sequence,
            'form_interaction': self._generate_form_sequence,
        }
    
    def generate_intermediate_actions(self, sequence: Dict) -> Dict:
        """Generate intermediate actions to make sequences more realistic"""
        actions = sequence.get('actions', [])
        if len(actions) < 2:
            return sequence
        
        enhanced_actions = []
        
        for i, action in enumerate(actions):
            enhanced_actions.append(action)
            
            # Add intermediate actions between major actions
            if i < len(actions) - 1:
                next_action = actions[i + 1]
                intermediate = self._generate_intermediate_action(action, next_action)
                if intermediate:
                    enhanced_actions.extend(intermediate)
        
        sequence['actions'] = enhanced_actions
        return sequence
    
    def _generate_intermediate_action(self, current: Dict, next_action: Dict) -> List[Dict]:
        """Generate realistic intermediate actions"""
        intermediate_actions = []
        
        # Add mouse movement before clicks
        if (next_action.get('type') == 'mouse' and 
            next_action.get('action') == 'click' and 
            current.get('type') != 'mouse'):
            
            mouse_move = {
                'timestamp_ms': current['timestamp_ms'] + 50,
                'type': 'mouse',
                'action': 'move',
                'coordinates': next_action.get('coordinates', {'x': 0, 'y': 0})
            }
            intermediate_actions.append(mouse_move)
        
        # Add pauses between rapid actions
        time_diff = next_action['timestamp_ms'] - current['timestamp_ms']
        if time_diff < 100:  # Very fast actions
            pause_action = {
                'timestamp_ms': current['timestamp_ms'] + time_diff // 2,
                'type': 'system',
                'action': 'pause',
                'duration': 50
            }
            intermediate_actions.append(pause_action)
        
        return intermediate_actions
    
    def _generate_typing_sequence(self, text: str, start_time: int) -> List[Dict]:
        """Generate realistic typing sequence"""
        actions = []
        current_time = start_time
        
        for char in text:
            # Random typing speed (50-200ms per character)
            char_delay = random.randint(50, 200)
            
            # Key press
            actions.append({
                'timestamp_ms': current_time,
                'type': 'keyboard',
                'action': 'press',
                'key': char,
                'key_code': ord(char)
            })
            
            # Key release
            actions.append({
                'timestamp_ms': current_time + random.randint(20, 80),
                'type': 'keyboard',
                'action': 'release',
                'key': char
            })
            
            current_time += char_delay
        
        return actions


In [39]:

class DatasetEnhancer:
    """Main class to enhance and expand the dataset"""
    
    def __init__(self, output_dir: str = "enhanced_dataset"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        
        self.augmenter = AdvancedUIDataAugmenter()
        self.task_generator = TaskVariationGenerator()
        self.action_generator = SequentialActionGenerator()
    
    def enhance_dataset(self, input_files: List[str], enhancement_factor: int = 5) -> List[str]:
        """Enhance dataset with multiple augmentation techniques"""
        enhanced_files = []
        
        for input_file in input_files:
            print(f"Processing {input_file}...")
            
            # Load original data
            with open(input_file, 'r') as f:
                original_data = json.load(f)
            
            enhanced_sequences = []
            
            # 1. Visual and timing augmentations
            visual_augmented = self.augmenter.augment_sequence(
                original_data, num_augmentations=enhancement_factor
            )
            enhanced_sequences.extend(visual_augmented)
            
            # 2. Task variations
            task_variations = self.task_generator.generate_task_variations(original_data)
            enhanced_sequences.extend(task_variations)
            
            # 3. Action sequence enhancements
            for seq in enhanced_sequences[:]:  # Copy list to avoid modifying during iteration
                enhanced_seq = self.action_generator.generate_intermediate_actions(seq)
                enhanced_sequences.append(enhanced_seq)
            
            # Save enhanced sequences
            for i, enhanced_seq in enumerate(enhanced_sequences):
                output_filename = f"{Path(input_file).stem}_enhanced_{i}.json"
                output_path = self.output_dir / output_filename
                
                with open(output_path, 'w') as f:
                    json.dump(enhanced_seq, f, indent=2)
                
                enhanced_files.append(str(output_path))
        
        print(f"Generated {len(enhanced_files)} enhanced data files")
        return enhanced_files
    
    def create_balanced_dataset(self, input_files: List[str]) -> str:
        """Create a balanced dataset with equal representation"""
        all_sequences = []
        task_counts = {}
        
        # Load all sequences and count tasks
        for file_path in input_files:
            with open(file_path, 'r') as f:
                data = json.load(f)
                task_label = data.get('task_label', 'unknown')
                
                if task_label not in task_counts:
                    task_counts[task_label] = []
                
                task_counts[task_label].append(data)
        
        # Balance dataset
        max_samples = max(len(sequences) for sequences in task_counts.values())
        balanced_sequences = []
        
        for task_label, sequences in task_counts.items():
            # Upsample if needed
            while len(sequences) < max_samples:
                # Randomly select and augment existing sequences
                base_seq = random.choice(sequences)
                augmented = self.augmenter.augment_sequence(base_seq, num_augmentations=1)[0]
                sequences.append(augmented)
            
            balanced_sequences.extend(sequences[:max_samples])
        
        # Save balanced dataset
        balanced_file = self.output_dir / "balanced_dataset.json"
        with open(balanced_file, 'w') as f:
            json.dump(balanced_sequences, f, indent=2)
        
        print(f"Created balanced dataset with {len(balanced_sequences)} sequences")
        print(f"Task distribution: {dict((k, len(v)) for k, v in task_counts.items())}")
        
        return str(balanced_file)


## Tracker

In [47]:
# Complete Training and Evaluation System
import os
import sys
import json
import torch
import argparse
import logging
from pathlib import Path
from typing import List, Dict
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from datetime import datetime


In [48]:

class ExperimentTracker:
    """Track and log experiments"""
    
    def __init__(self, experiment_name: str, use_wandb: bool = False):
        self.experiment_name = experiment_name
        self.use_wandb = use_wandb
        self.metrics_history = []
        
        # Setup logging
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(f'{experiment_name}_training.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
        
        # Initialize wandb if requested
        if self.use_wandb:
            try:
                wandb.init(project="ui-action-prediction", name=experiment_name)
            except Exception as e:
                self.logger.warning(f"Failed to initialize wandb: {e}")
                self.use_wandb = False
    
    def log_metrics(self, metrics: Dict, step: int):
        """Log metrics to file and wandb"""
        metrics['step'] = step
        metrics['timestamp'] = datetime.now().isoformat()
        self.metrics_history.append(metrics)
        
        # Log to console
        self.logger.info(f"Step {step}: {metrics}")
        
        # Log to wandb
        if self.use_wandb:
            wandb.log(metrics, step=step)
    
    def save_metrics(self, filepath: str):
        """Save metrics history to file"""
        with open(filepath, 'w') as f:
            json.dump(self.metrics_history, f, indent=2)


In [49]:

class ModelEvaluator:
    """Comprehensive model evaluation"""
    
    def __init__(self, model, dataset, action_encoder, device='cuda'):
        self.model = model
        self.dataset = dataset
        self.action_encoder = action_encoder
        self.device = device
    
    def evaluate_comprehensive(self, dataloader) -> Dict:
        """Comprehensive evaluation including per-class metrics"""
        self.model.eval()
        
        all_predictions = []
        all_targets = []
        all_coordinates_pred = []
        all_coordinates_true = []
        
        with torch.no_grad():
            for batch in dataloader:
                images = batch['image'].to(self.device)
                action_targets = batch['action_class'].squeeze().to(self.device)
                coord_targets = batch['coordinates'].to(self.device)
                
                action_logits, coord_pred = self.model(images)
                action_pred = torch.argmax(action_logits, dim=1)
                
                all_predictions.extend(action_pred.cpu().numpy())
                all_targets.extend(action_targets.cpu().numpy())
                all_coordinates_pred.extend(coord_pred.cpu().numpy())
                all_coordinates_true.extend(coord_targets.cpu().numpy())
        
        # Calculate metrics
        metrics = {}
        
        # Action classification metrics
        metrics['classification_report'] = classification_report(
            all_targets, all_predictions, 
            target_names=self.action_encoder.classes_,
            output_dict=True
        )
        
        # Coordinate regression metrics
        coord_pred = np.array(all_coordinates_pred)
        coord_true = np.array(all_coordinates_true)
        
        metrics['coordinate_mae'] = np.mean(np.abs(coord_pred - coord_true), axis=0)
        metrics['coordinate_rmse'] = np.sqrt(np.mean((coord_pred - coord_true)**2, axis=0))
        
        # Pixel accuracy (within N pixels)
        pixel_thresholds = [5, 10, 20, 50]
        for threshold in pixel_thresholds:
            pixel_distances = np.sqrt(np.sum((coord_pred - coord_true)**2, axis=1))
            accuracy = np.mean(pixel_distances <= threshold/1920.0)  # Normalized
            metrics[f'pixel_accuracy_{threshold}px'] = accuracy
        
        return metrics
    
    def plot_confusion_matrix(self, dataloader, save_path: str):
        """Generate and save confusion matrix"""
        self.model.eval()
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            for batch in dataloader:
                images = batch['image'].to(self.device)
                action_targets = batch['action_class'].squeeze().to(self.device)
                
                action_logits, _ = self.model(images)
                action_pred = torch.argmax(action_logits, dim=1)
                
                all_predictions.extend(action_pred.cpu().numpy())
                all_targets.extend(action_targets.cpu().numpy())
        
        # Create confusion matrix
        cm = confusion_matrix(all_targets, all_predictions)
        
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', 
                   xticklabels=self.action_encoder.classes_,
                   yticklabels=self.action_encoder.classes_)
        plt.title('Action Prediction Confusion Matrix')
        plt.ylabel('True Action')
        plt.xlabel('Predicted Action')
        plt.xticks(rotation=45)
        plt.yticks(rotation=45)
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    
    def analyze_failure_cases(self, dataloader, num_samples: int = 20) -> List[Dict]:
        """Analyze failure cases for debugging"""
        self.model.eval()
        failure_cases = []
        
        with torch.no_grad():
            for batch in dataloader:
                images = batch['image'].to(self.device)
                action_targets = batch['action_class'].squeeze().to(self.device)
                coord_targets = batch['coordinates'].to(self.device)
                
                action_logits, coord_pred = self.model(images)
                action_pred = torch.argmax(action_logits, dim=1)
                
                # Find incorrect predictions
                incorrect_mask = action_pred != action_targets
                
                if incorrect_mask.any():
                    for i in torch.where(incorrect_mask)[0]:
                        if len(failure_cases) >= num_samples:
                            break
                        
                        failure_case = {
                            'predicted_action': self.action_encoder.classes_[action_pred[i]],
                            'true_action': self.action_encoder.classes_[action_targets[i]],
                            'predicted_coords': coord_pred[i].cpu().numpy().tolist(),
                            'true_coords': coord_targets[i].cpu().numpy().tolist(),
                            'confidence': torch.softmax(action_logits[i], dim=0).max().item()
                        }
                        failure_cases.append(failure_case)
                
                if len(failure_cases) >= num_samples:
                    break
        
        return failure_cases


In [None]:

class AdvancedTrainingPipeline:
    """Advanced training pipeline with all features"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Initialize experiment tracker
        self.tracker = ExperimentTracker(
            config['experiment_name'])
        
        # Save config template
        config_path = f"{config['experiment_name']}_config.json"
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)
        print(f"Created config template: {config_path}")
    
    # # Run training pipeline
    # pipeline = AdvancedTrainingPipeline(config)
    # model_path = pipeline.train()
    
    # print(f"\nTraining completed!")
    # print(f"Best model saved to: {model_path}")
    
    # # Demonstrate real-time execution
    # print("\nStarting demonstration of trained model...")
    # demonstrate_execution(model_path, config)

def demonstrate_execution(model_path: str, config: Dict):
    """Demonstrate the trained model in action"""
    try:
        # Initialize execution agent
        encoder_path = str(Path(model_path).parent / 'action_encoder.pkl')
        agent = RealTimeExecutionAgent(model_path, encoder_path)
        
        print("Real-time execution agent initialized!")
        print("The agent can now predict and execute actions based on screen content.")
        print("To run automation, call: agent.run_automation('task_name')")
        
        # Demo: Just predict without executing
        screenshot = agent.capture_screen()
        predicted_action = agent.predict_action(screenshot)
        
        print(f"Demo prediction from current screen:")
        print(f"  Action: {predicted_action['type']} - {predicted_action['action']}")
        print(f"  Coordinates: {predicted_action['coordinates']}")
        print(f"  Confidence: {predicted_action['confidence']:.3f}")
        
        if predicted_action['key']:
            print(f"  Key: {predicted_action['key']}")
            
    except Exception as e:
        print(f"Demo execution failed: {e}")

class LiveTrainingMonitor:
    """Monitor training progress in real-time"""
    
    def __init__(self, log_file: str):
        self.log_file = log_file
        self.metrics = []
    
    def plot_training_progress(self, save_path: str = "training_progress.png"):
        """Plot training progress"""
        if not self.metrics:
            return
        
        epochs = [m['epoch'] for m in self.metrics]
        train_loss = [m['train_loss'] for m in self.metrics]
        val_loss = [m['val_loss'] for m in self.metrics]
        val_accuracy = [m['val_accuracy'] for m in self.metrics]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        # Loss plot
        ax1.plot(epochs, train_loss, label='Train Loss', color='blue')
        ax1.plot(epochs, val_loss, label='Validation Loss', color='red')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True)
        
        # Accuracy plot
        ax2.plot(epochs, val_accuracy, label='Validation Accuracy', color='green')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Validation Accuracy')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

# Additional utility functions

def validate_data_format(data_file: str) -> bool:
    """Validate that data file has correct format"""
    try:
        with open(data_file, 'r') as f:
            data = json.load(f)
        
        required_fields = ['task_label', 'screenshots', 'actions']
        for field in required_fields:
            if field not in data:
                print(f"Missing required field: {field}")
                return False
        
        # Validate screenshots
        if not data['screenshots']:
            print("No screenshots found")
            return False
            
        for screenshot in data['screenshots'][:1]:  # Check first screenshot
            if 'image_base64' not in screenshot:
                print("Screenshot missing image_base64")
                return False
        
        # Validate actions
        if not data['actions']:
            print("No actions found")
            return False
            
        for action in data['actions'][:1]:  # Check first action
            if 'type' not in action or 'action' not in action:
                print("Action missing required fields")
                return False
        
        print(f"Data file {data_file} is valid!")
        return True
        
    except Exception as e:
        print(f"Error validating {data_file}: {e}")
        return False

def setup_environment():
    """Setup the training environment"""
    print("Setting up UI Action Prediction environment...")
    
    # Create necessary directories
    directories = [
        'data', 'models', 'experiments', 'logs', 
        'enhanced_data', 'plots', 'configs'
    ]
    
    for dir_name in directories:
        Path(dir_name).mkdir(exist_ok=True)
    
    # Create sample config
    config = create_config_template()
    with open('configs/default_config.json', 'w') as f:
        json.dump(config, f, indent=2)
    
    # Create sample data structure
    sample_data = create_sample_data_structure()
    with open('data/sample_data_structure.json', 'w') as f:
        json.dump(sample_data, f, indent=2)
    
    print("Environment setup complete!")
    print("Created directories:", directories)
    print("Sample config: configs/default_config.json")
    print("Sample data structure: data/sample_data_structure.json")

if __name__ == "__main__":
    # Check if setup is needed
    if len(sys.argv) > 1 and sys.argv[1] == 'setup':
        setup_environment()
        sys.exit(0)
    
    # Run main training
    main()experiment_name'], 
            config.get('use_wandb', False)
        )
        
        # Create output directories
        self.output_dir = Path(config['output_dir'])
        self.output_dir.mkdir(exist_ok=True)
        (self.output_dir / 'models').mkdir(exist_ok=True)
        (self.output_dir / 'plots').mkdir(exist_ok=True)
        (self.output_dir / 'logs').mkdir(exist_ok=True)
    
    def prepare_data(self) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """Prepare training, validation, and test data"""
        self.tracker.logger.info("Preparing dataset...")
        
        # Enhance dataset if requested
        if self.config.get('enhance_data', True):
            enhancer = DatasetEnhancer(str(self.output_dir / 'enhanced_data'))
            enhanced_files = enhancer.enhance_dataset(
                self.config['data_files'], 
                enhancement_factor=self.config.get('enhancement_factor', 5)
            )
            data_files = enhanced_files
        else:
            data_files = self.config['data_files']
        
        # Create dataset
        from data_augmentation import AdvancedUIDataAugmenter
        augmenter = AdvancedUIDataAugmenter()
        
        dataset = UIActionDataset(
            data_files, 
            transform=augmenter.visual_transforms,
            augment=True
        )
        
        # Split dataset
        total_size = len(dataset)
        train_size = int(0.7 * total_size)
        val_size = int(0.15 * total_size)
        test_size = total_size - train_size - val_size
        
        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            dataset, [train_size, val_size, test_size]
        )
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset, 
            batch_size=self.config['batch_size'], 
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        val_loader = DataLoader(
            val_dataset, 
            batch_size=self.config['batch_size'], 
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        test_loader = DataLoader(
            test_dataset, 
            batch_size=self.config['batch_size'], 
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        self.action_encoder = dataset.action_encoder
        self.tracker.logger.info(f"Dataset prepared: {len(dataset)} total samples")
        self.tracker.logger.info(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
        
        return train_loader, val_loader, test_loader
    
    def create_model(self) -> ActionPredictionModel:
        """Create and initialize model"""
        num_classes = len(self.action_encoder.classes_)
        model = ActionPredictionModel(
            num_classes, 
            input_size=tuple(self.config.get('input_size', [224, 224]))
        )
        
        # Load pretrained weights if specified
        if self.config.get('pretrained_path'):
            checkpoint = torch.load(self.config['pretrained_path'])
            model.load_state_dict(checkpoint['model_state_dict'])
            self.tracker.logger.info(f"Loaded pretrained model from {self.config['pretrained_path']}")
        
        return model.to(self.device)
    
    def train(self):
        """Complete training pipeline"""
        self.tracker.logger.info("Starting training pipeline...")
        
        # Prepare data
        train_loader, val_loader, test_loader = self.prepare_data()
        
        # Create model
        model = self.create_model()
        
        # Create trainer
        trainer = UIActionTrainer(model, self.device)
        
        # Training parameters
        num_epochs = self.config.get('num_epochs', 100)
        early_stopping_patience = self.config.get('early_stopping_patience', 15)
        best_val_accuracy = 0
        patience_counter = 0
        
        # Training loop
        for epoch in range(num_epochs):
            self.tracker.logger.info(f"Epoch {epoch+1}/{num_epochs}")
            
            # Train
            train_metrics = trainer.train_epoch(train_loader)
            
            # Validate
            val_metrics = trainer.validate(val_loader)
            
            # Combine metrics
            combined_metrics = {
                'epoch': epoch + 1,
                'train_loss': train_metrics['total_loss'],
                'train_action_loss': train_metrics['action_loss'],
                'train_coord_loss': train_metrics['coord_loss'],
                'val_loss': val_metrics['total_loss'],
                'val_accuracy': val_metrics['action_accuracy'],
                'val_coord_mae': val_metrics['coord_mae']
            }
            
            # Log metrics
            self.tracker.log_metrics(combined_metrics, epoch + 1)
            
            # Save best model
            if val_metrics['action_accuracy'] > best_val_accuracy:
                best_val_accuracy = val_metrics['action_accuracy']
                patience_counter = 0
                
                # Save model
                model_path = self.output_dir / 'models' / 'best_model.pth'
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': trainer.optimizer.state_dict(),
                    'val_accuracy': best_val_accuracy,
                    'config': self.config
                }, model_path)
                
                # Save action encoder
                encoder_path = self.output_dir / 'models' / 'action_encoder.pkl'
                import pickle
                with open(encoder_path, 'wb') as f:
                    pickle.dump(self.action_encoder, f)
                
                self.tracker.logger.info(f"Saved best model with accuracy: {best_val_accuracy:.4f}")
            else:
                patience_counter += 1
            
            # Early stopping
            if patience_counter >= early_stopping_patience:
                self.tracker.logger.info(f"Early stopping at epoch {epoch+1}")
                break
        
        # Final evaluation
        self.tracker.logger.info("Starting final evaluation...")
        evaluator = ModelEvaluator(model, None, self.action_encoder, self.device)
        
        # Comprehensive evaluation
        test_metrics = evaluator.evaluate_comprehensive(test_loader)
        
        # Plot confusion matrix
        evaluator.plot_confusion_matrix(
            test_loader, 
            str(self.output_dir / 'plots' / 'confusion_matrix.png')
        )
        
        # Analyze failure cases
        failure_cases = evaluator.analyze_failure_cases(test_loader)
        
        # Save evaluation results
        evaluation_results = {
            'test_metrics': test_metrics,
            'failure_cases': failure_cases,
            'final_config': self.config
        }
        
        with open(self.output_dir / 'logs' / 'evaluation_results.json', 'w') as f:
            json.dump(evaluation_results, f, indent=2, default=str)
        
        # Save metrics history
        self.tracker.save_metrics(str(self.output_dir / 'logs' / 'training_metrics.json'))
        
        self.tracker.logger.info("Training pipeline completed!")
        self.tracker.logger.info(f"Best validation accuracy: {best_val_accuracy:.4f}")
        self.tracker.logger.info(f"Results saved to: {self.output_dir}")
        
        return str(self.output_dir / 'models' / 'best_model.pth')

def create_config_template() -> Dict:
    """Create a configuration template"""
    return {
        'experiment_name': 'ui_action_experiment_1',
        'data_files': ['login_task_data.json'],  # List of your data files
        'output_dir': 'experiments/exp_1',
        'batch_size': 16,
        'num_epochs': 100,
        'early_stopping_patience': 15,
        'input_size': [224, 224],
        'enhance_data': True,
        'enhancement_factor': 8,
        'use_wandb': False,  # Set to True if you want to use Weights & Biases
        'pretrained_path': None  # Path to pretrained model if available
    }


In [None]:

def main():
    parser = argparse.ArgumentParser(description='UI Action Prediction Training')
    parser.add_argument('--config', type=str, help='Path to config file')
    parser.add_argument('--data-dir', type=str, help='Directory containing data files')
    parser.add_argument('--experiment-name', type=str, default='ui_action_exp', help='Experiment name')
    
    args = parser.parse_args()
    
    # Load or create config
    if args.config and os.path.exists(args.config):
        with open(args.config, 'r') as f:
            config = json.load(f)
    else:
        config = create_config_template()
        
        # Update with command line arguments
        if args.data_dir:
            data_files = list(Path(args.data_dir).glob('*.json'))
            config['data_files'] = [str(f) for f in data_files]
        
        config['