In [None]:
import pandas as pd
import numpy as np
import re
import json
import os
from datetime import datetime, timedelta
import random
from cryptography.fernet import Fernet
import string
import base64
import hashlib
import getpass
import uuid


In [None]:
def setup_directories():
    """Create all required directories"""
    directories = [
        '../data', 
        '../results', 
        '../results/reversed_datasets', 
        '../config'
    ]
    for directory in directories:
        if not os.path.exists(directory):
            os.makedirs(directory)
            print(f"Created directory: {directory}")
        else:
            print(f"Directory exists: {directory}")

print("Current working directory:", os.getcwd())
setup_directories()

In [None]:
SESSION_ID = str(uuid.uuid4())[:8]
print(f"Session ID: {SESSION_ID}")

In [None]:
PII_COLUMNS = {
    'names': ['full_name', 'doctor_name'],
    'addresses': ['address_street', 'address_city', 'address_zip'],
    'dates': ['date_of_birth', 'admission_date', 'discharge_date'],
    'contact': ['phone_number', 'email_address'],
    'identifiers': ['ssn', 'license_number', 'vehicle_id', 'device_serial_number'],
    'network': ['ip_address'],
    'organizations': ['hospital_name', 'insurance_provider']
}

print("PII Column Categories:")
for category, columns in PII_COLUMNS.items():
    print(f"{category}: {columns}")

In [None]:
class PIIDetector:
    def __init__(self):
        self.patterns = {
            'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
            'phone': r'(\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}(\s*x\d+)?)',
            'ssn': r'\b\d{3}-?\d{2}-?\d{4}\b',
            'ip_v4': r'\b(?:\d{1,3}\.){3}\d{1,3}\b',
            'ip_v6': r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b',
            'date': r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}',
            'zip_code': r'\b\d{5}(-\d{4})?\b',
        }
        
        self.pii_keywords = {
            'name': ['name', 'first', 'last', 'full_name', 'doctor', 'patient'],
            'address': ['address', 'street', 'city', 'state', 'zip', 'postal'],
            'contact': ['phone', 'email', 'telephone', 'mobile'],
            'id': ['id', 'ssn', 'social', 'license', 'serial', 'number'],
            'date': ['date', 'birth', 'dob', 'admission', 'discharge'],
            'organization': ['hospital', 'insurance', 'provider', 'company']
        }
    
    def detect_pii_columns(self, df):
        """Detect PII columns based on column names and content patterns"""
        pii_detected = {}
        
        for column in df.columns:
            column_lower = column.lower()
            pii_type = None
            confidence = 0
            
            for pii_category, keywords in self.pii_keywords.items():
                for keyword in keywords:
                    if keyword in column_lower:
                        pii_type = pii_category
                        confidence += 0.7
                        break
            
            sample_data = df[column].dropna().head(100).astype(str)
            
            for pattern_name, pattern in self.patterns.items():
                matches = sample_data.str.contains(pattern, regex=True, na=False).sum()
                match_ratio = matches / len(sample_data) if len(sample_data) > 0 else 0
                
                if match_ratio > 0.5:
                    if pattern_name == 'email':
                        pii_type = 'contact'
                        confidence += 0.8
                    elif pattern_name == 'phone':
                        pii_type = 'contact'
                        confidence += 0.8
                    elif pattern_name in ['ip_v4', 'ip_v6']:
                        pii_type = 'network'
                        confidence += 0.9
                    elif pattern_name == 'date':
                        pii_type = 'date'
                        confidence += 0.6
            
            if pii_type and confidence > 0.5:
                pii_detected[column] = {
                    'type': pii_type,
                    'confidence': min(confidence, 1.0)
                }
        
        return pii_detected
    
    def validate_pii_detection(self, df, detected_pii):
        """Display detection results for validation"""
        print("PII Detection Results:")
        print("-" * 50)
        
        for column, info in detected_pii.items():
            sample_values = df[column].dropna().head(3).tolist()
            print(f"Column: {column}")
            print(f"  Type: {info['type']}")
            print(f"  Confidence: {info['confidence']:.2f}")
            print(f"  Sample values: {sample_values}")
            print()

detector = PIIDetector()
print("PII Detector initialized.")

In [None]:
class PIIEncryptionSystem:
    """Secure encryption system for PII mappings with user-defined encryption keys"""
    
    def __init__(self):
        self.key = None
        self.fernet = None
        self.encrypted_mappings = None
        self.user_password = None
        
    def generate_encryption_key(self, user_password=None):
        """Generate encryption key from user password"""
        if user_password is None:
            user_password = getpass.getpass("Enter encryption password for this dataset: ")
        
        self.user_password = user_password
        
        # Generate key from password using PBKDF2
        password_bytes = user_password.encode('utf-8')
        salt = b'pii_masking_salt_2024'  # Fixed salt for consistency
        key = hashlib.pbkdf2_hmac('sha256', password_bytes, salt, 100000)
        
        # Fernet requires base64 encoded key
        self.key = base64.urlsafe_b64encode(key)
        self.fernet = Fernet(self.key)
        
        print("Encryption key generated from password")
        return True
    
    def save_key_info(self, filepath):
        """Save key generation info (not the actual key)"""
        key_info = {
            'session_id': SESSION_ID,
            'timestamp': datetime.now().isoformat(),
            'key_generated': True,
            'note': 'Key generated from user password'
        }
        
        with open(filepath, 'w') as f:
            json.dump(key_info, f, indent=2)
        
        print(f"Key info saved to: {filepath}")
        return True
    
    def encrypt_mappings(self, mappings_dict):
        """Encrypt the mappings dictionary"""
        if self.fernet is None:
            print("No encryption key available!")
            return False
        
        try:
            # Convert mappings to JSON string
            mappings_json = json.dumps(mappings_dict, indent=2)
            mappings_bytes = mappings_json.encode('utf-8')
            
            # Encrypt the mappings
            encrypted_data = self.fernet.encrypt(mappings_bytes)
            
            self.encrypted_mappings = {
                'session_id': SESSION_ID,
                'timestamp': datetime.now().isoformat(),
                'encrypted_data': base64.b64encode(encrypted_data).decode('utf-8')
            }
            
            print("Mappings encrypted successfully")
            return True
            
        except Exception as e:
            print(f"Encryption error: {str(e)}")
            return False
    
    def save_encrypted_mappings(self, filepath):
        """Save encrypted mappings to file"""
        if self.encrypted_mappings is None:
            print("No encrypted mappings to save!")
            return False
        
        try:
            with open(filepath, 'w') as f:
                json.dump(self.encrypted_mappings, f, indent=2)
            
            print(f"Encrypted mappings saved to: {filepath}")
            return True
            
        except Exception as e:
            print(f"Save error: {str(e)}")
            return False
    
    def load_encrypted_mappings(self, filepath):
        """Load encrypted mappings from file"""
        try:
            with open(filepath, 'r') as f:
                self.encrypted_mappings = json.load(f)
            
            print(f"Encrypted mappings loaded from: {filepath}")
            return True
            
        except FileNotFoundError:
            print(f"Encrypted mappings file not found: {filepath}")
            return False
        except Exception as e:
            print(f"Load error: {str(e)}")
            return False
    
    def decrypt_mappings_with_password(self, password):
        """Decrypt mappings using provided password"""
        if self.encrypted_mappings is None:
            print("No encrypted mappings loaded!")
            return None
        
        try:
            # Regenerate key from password
            password_bytes = password.encode('utf-8')
            salt = b'pii_masking_salt_2024'
            key = hashlib.pbkdf2_hmac('sha256', password_bytes, salt, 100000)
            fernet_key = base64.urlsafe_b64encode(key)
            temp_fernet = Fernet(fernet_key)
            
            # Decrypt the data
            encrypted_data = base64.b64decode(self.encrypted_mappings['encrypted_data'])
            decrypted_bytes = temp_fernet.decrypt(encrypted_data)
            decrypted_json = decrypted_bytes.decode('utf-8')
            
            mappings = json.loads(decrypted_json)
            print("Mappings decrypted successfully")
            return mappings
            
        except Exception as e:
            print(f"Decryption failed: {str(e)}")
            return None
    
    def reverse_mappings(self, mappings_dict):
        """Create reverse mappings for data recovery"""
        reverse_mappings = {}
        
        for category, mapping in mappings_dict.items():
            reverse_mappings[category] = {}
            for original, masked in mapping.items():
                reverse_mappings[category][str(masked)] = original
        
        return reverse_mappings

# Initialize encryption system
encryption_system = PIIEncryptionSystem()
print("Encryption system initialized")

In [None]:
def load_and_analyze_dataset(file_path='../data/generated_data.csv'):
    """Load dataset and perform initial analysis"""
    try:
        df = pd.read_csv(file_path)
        
        print(f"Dataset loaded successfully!")
        print(f"Shape: {df.shape}")
        print(f"Columns: {list(df.columns)}")
        print("\nFirst 3 rows preview:")
        print("-" * 80)
        
        for i, row in df.head(3).iterrows():
            print(f"\nRow {i+1}:")
            for col in df.columns:
                print(f"  {col}: {row[col]}")
        
        print(f"\nMissing values per column:")
        missing = df.isnull().sum()
        for col, count in missing.items():
            if count > 0:
                print(f"  {col}: {count}")
        
        return df
        
    except FileNotFoundError:
        print(f"Error: File {file_path} not found.")
        print("Please make sure your dataset is in the correct location.")
        return None
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None

In [None]:
# Load and analyze the dataset
df = load_and_analyze_dataset()

if df is not None:
    # Detect PII columns
    detected_pii = detector.detect_pii_columns(df)
    detector.validate_pii_detection(df, detected_pii)
    
    # Save detected PII info
    pii_config_path = f'../config/detected_pii_{SESSION_ID}.json'
    with open(pii_config_path, 'w') as f:
        json.dump(detected_pii, f, indent=2)
    
    print(f"\nDetected {len(detected_pii)} PII columns out of {len(df.columns)} total columns.")
    print(f"PII detection saved to: {pii_config_path}")
else:
    print("Please check your dataset file and try again.")

In [None]:
class EnhancedBulletproofPIIMasker:
    """Enhanced masker with full reversibility for all PII types"""
    
    def __init__(self):
        self.mappings = {
            'names': {},
            'addresses': {},
            'emails': {},
            'phones': {},
            'organizations': {},
            'phone_numbers': {},
            'ssn_numbers': {},
            'license_numbers': {},
            'vehicle_ids': {},
            'device_serials': {},
            'medical_records': {},
            'ip_addresses': {},
            'dates': {}
        }
        self.fake_domains = ["gmail.com", "yahoo.com", "hotmail.com", "outlook.com", "example.com"]
    
    def mask_name_bulletproof(self, original_name):
        """BULLETPROOF name masking - character by character"""
        if pd.isna(original_name) or str(original_name).strip() == '':
            return original_name
            
        original_str = str(original_name).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['names']:
            return self.mappings['names'][original_str]
        
        fake_name = ''
        for i in range(target_length):
            char = original_str[i]
            if char == ' ':
                fake_name += ' '
            elif i == 0 or original_str[i-1] == ' ':
                fake_name += random.choice(string.ascii_uppercase)
            else:
                fake_name += random.choice(string.ascii_lowercase)
        
        assert len(fake_name) == target_length, f"Name length ERROR: {len(fake_name)} != {target_length}"
        
        self.mappings['names'][original_str] = fake_name
        return fake_name
    
    def mask_address_bulletproof(self, original_address):
        """BULLETPROOF address masking - character by character"""
        if pd.isna(original_address) or str(original_address).strip() == '':
            return original_address
            
        original_str = str(original_address).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['addresses']:
            return self.mappings['addresses'][original_str]
        
        fake_address = ''
        for i in range(target_length):
            char = original_str[i]
            if char == ' ':
                fake_address += ' '
            elif char.isdigit():
                fake_address += random.choice(string.digits)
            elif char.isalpha():
                fake_address += random.choice(string.ascii_letters)
            else:
                fake_address += char
        
        assert len(fake_address) == target_length, f"Address length ERROR: {len(fake_address)} != {target_length}"
        
        self.mappings['addresses'][original_str] = fake_address
        return fake_address
    
    def mask_email_bulletproof(self, original_email):
        """BULLETPROOF email masking - preserving @ and . positions"""
        if pd.isna(original_email) or str(original_email).strip() == '':
            return original_email
            
        original_str = str(original_email).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['emails']:
            return self.mappings['emails'][original_str]
        
        fake_email = ''
        for i in range(target_length):
            char = original_str[i]
            if char in '@.':
                fake_email += char
            elif char.isalpha():
                fake_email += random.choice(string.ascii_lowercase)
            elif char.isdigit():
                fake_email += random.choice(string.digits)
            else:
                fake_email += char
        
        assert len(fake_email) == target_length, f"Email length ERROR: {len(fake_email)} != {target_length}"
        
        self.mappings['emails'][original_str] = fake_email
        return fake_email
    
    def mask_phone_bulletproof(self, original_phone):
        """BULLETPROOF phone masking with format preservation and mapping"""
        if pd.isna(original_phone) or str(original_phone).strip() == '':
            return original_phone
            
        original_str = str(original_phone).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['phone_numbers']:
            return self.mappings['phone_numbers'][original_str]
        
        fake_phone = ''
        for i in range(target_length):
            char = original_str[i]
            if char.isdigit():
                seed_value = hash(original_str + str(i)) % 10
                fake_phone += str(seed_value)
            else:
                fake_phone += char
        
        assert len(fake_phone) == target_length, f"Phone length ERROR: {len(fake_phone)} != {target_length}"
        
        self.mappings['phone_numbers'][original_str] = fake_phone
        return fake_phone
    
    def mask_ssn_bulletproof(self, original_ssn):
        """BULLETPROOF SSN masking with format preservation and mapping"""
        if pd.isna(original_ssn) or str(original_ssn).strip() == '':
            return original_ssn
            
        original_str = str(original_ssn).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['ssn_numbers']:
            return self.mappings['ssn_numbers'][original_str]
        
        fake_ssn = ''
        for i in range(target_length):
            char = original_str[i]
            if char.isdigit():
                seed_value = hash(original_str + str(i)) % 10
                fake_ssn += str(seed_value)
            elif char == '-':
                fake_ssn += '-'
            else:
                fake_ssn += random.choice(string.ascii_lowercase)
        
        assert len(fake_ssn) == target_length, f"SSN length ERROR: {len(fake_ssn)} != {target_length}"
        
        self.mappings['ssn_numbers'][original_str] = fake_ssn
        return fake_ssn
    
    def mask_license_bulletproof(self, original_license):
        """BULLETPROOF license number masking"""
        if pd.isna(original_license) or str(original_license).strip() == '':
            return original_license
            
        original_str = str(original_license).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['license_numbers']:
            return self.mappings['license_numbers'][original_str]
        
        fake_license = ''
        for i in range(target_length):
            char = original_str[i]
            if char.isdigit():
                seed_value = hash(original_str + str(i)) % 10
                fake_license += str(seed_value)
            elif char.isalpha():
                if char.isupper():
                    fake_license += random.choice(string.ascii_uppercase)
                else:
                    fake_license += random.choice(string.ascii_lowercase)
            else:
                fake_license += char
        
        assert len(fake_license) == target_length, f"License length ERROR: {len(fake_license)} != {target_length}"
        
        self.mappings['license_numbers'][original_str] = fake_license
        return fake_license
    
    def mask_vehicle_id_bulletproof(self, original_vehicle_id):
        """BULLETPROOF vehicle ID masking"""
        if pd.isna(original_vehicle_id) or str(original_vehicle_id).strip() == '':
            return original_vehicle_id
            
        original_str = str(original_vehicle_id).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['vehicle_ids']:
            return self.mappings['vehicle_ids'][original_str]
        
        fake_vehicle_id = ''
        for i in range(target_length):
            char = original_str[i]
            if char.isdigit():
                seed_value = hash(original_str + str(i)) % 10
                fake_vehicle_id += str(seed_value)
            elif char.isalpha():
                if char.isupper():
                    fake_vehicle_id += random.choice(string.ascii_uppercase)
                else:
                    fake_vehicle_id += random.choice(string.ascii_lowercase)
            else:
                fake_vehicle_id += char
        
        assert len(fake_vehicle_id) == target_length, f"Vehicle ID length ERROR: {len(fake_vehicle_id)} != {target_length}"
        
        self.mappings['vehicle_ids'][original_str] = fake_vehicle_id
        return fake_vehicle_id
    
    def mask_device_serial_bulletproof(self, original_serial):
        """BULLETPROOF device serial masking"""
        if pd.isna(original_serial) or str(original_serial).strip() == '':
            return original_serial
            
        original_str = str(original_serial).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['device_serials']:
            return self.mappings['device_serials'][original_str]
        
        fake_serial = ''
        for i in range(target_length):
            char = original_str[i]
            if char.isdigit():
                seed_value = hash(original_str + str(i)) % 10
                fake_serial += str(seed_value)
            elif char.isalpha():
                if char.isupper():
                    fake_serial += random.choice(string.ascii_uppercase)
                else:
                    fake_serial += random.choice(string.ascii_lowercase)
            else:
                fake_serial += char
        
        assert len(fake_serial) == target_length, f"Serial length ERROR: {len(fake_serial)} != {target_length}"
        
        self.mappings['device_serials'][original_str] = fake_serial
        return fake_serial
    
    def mask_ip_address_bulletproof(self, original_ip):
        """BULLETPROOF IP address masking"""
        if pd.isna(original_ip) or str(original_ip).strip() == '':
            return original_ip
            
        original_str = str(original_ip).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['ip_addresses']:
            return self.mappings['ip_addresses'][original_str]
        
        if '.' in original_str and ':' not in original_str:
            # IPv4 format
            parts = original_str.split('.')
            fake_parts = []
            for part in parts:
                fake_part = str(random.randint(1, 254))
                fake_parts.append(fake_part)
            fake_ip = '.'.join(fake_parts)
            
        elif ':' in original_str:
            # IPv6 format - preserve structure
            fake_ip = ''
            for i in range(target_length):
                char = original_str[i]
                if char in '0123456789abcdefABCDEF':
                    fake_ip += random.choice('0123456789abcdef')
                else:
                    fake_ip += char
        else:
            # Unknown format - character by character
            fake_ip = ''
            for i in range(target_length):
                char = original_str[i]
                if char.isdigit():
                    fake_ip += str(random.randint(0, 9))
                elif char.isalpha():
                    fake_ip += random.choice(string.ascii_lowercase)
                else:
                    fake_ip += char
        
        # Ensure exact length for non-standard formats
        if len(fake_ip) != target_length:
            fake_ip = fake_ip[:target_length] if len(fake_ip) > target_length else fake_ip.ljust(target_length, '0')
        
        assert len(fake_ip) == target_length, f"IP length ERROR: {len(fake_ip)} != {target_length}"
        
        self.mappings['ip_addresses'][original_str] = fake_ip
        return fake_ip
    
    def mask_date_bulletproof(self, original_date):
        """BULLETPROOF date masking preserving exact format"""
        if pd.isna(original_date) or str(original_date).strip() == '':
            return original_date
            
        original_str = str(original_date).strip()
        target_length = len(original_str)
        
        if original_str in self.mappings['dates']:
            return self.mappings['dates'][original_str]
        
        fake_date = ''
        for i in range(target_length):
            char = original_str[i]
            if char.isdigit():
                fake_date += random.choice(string.digits)
            else:
                fake_date += char
        
        assert len(fake_date) == target_length, f"Date length ERROR: {len(fake_date)} != {target_length}"
        
        self.mappings['dates'][original_str] = fake_date
        return fake_date

In [None]:
# Initialize the enhanced masker
enhanced_masker = EnhancedBulletproofPIIMasker()
print("Enhanced Bulletproof PII Masker initialized successfully!")

In [None]:
class FullReversibleMaskingPipeline:
    """Enhanced pipeline with full reversibility for all PII types"""
    
    def __init__(self, enhanced_masker, detected_pii):
        self.masker = enhanced_masker
        self.detected_pii = detected_pii
        
        # Column maskers mapping
        self.column_maskers = {
            'full_name': self.masker.mask_name_bulletproof,
            'doctor_name': self.masker.mask_name_bulletproof,
            'address_street': self.masker.mask_address_bulletproof,
            'address_city': self.masker.mask_address_bulletproof,
            'address_zip': self.masker.mask_address_bulletproof,
            'date_of_birth': self.masker.mask_date_bulletproof,
            'admission_date': self.masker.mask_date_bulletproof,
            'discharge_date': self.masker.mask_date_bulletproof,
            'phone_number': self.masker.mask_phone_bulletproof,
            'email_address': self.masker.mask_email_bulletproof,
            'ssn': self.masker.mask_ssn_bulletproof,
            'license_number': self.masker.mask_license_bulletproof,
            'vehicle_id': self.masker.mask_vehicle_id_bulletproof,
            'device_serial_number': self.masker.mask_device_serial_bulletproof,
            'ip_address': self.masker.mask_ip_address_bulletproof,
            'hospital_name': self.masker.mask_name_bulletproof,
            'insurance_provider': self.masker.mask_name_bulletproof
        }
    
    def mask_dataframe(self, df):
        """Apply bulletproof masking to the entire dataframe"""
        masked_df = df.copy()
        masking_report = {}
        
        print("Starting BULLETPROOF masking process...")
        print("-" * 50)
        
        for column in df.columns:
            if column in self.column_maskers:
                print(f"Masking column: {column}")
                
                mask_function = self.column_maskers[column]
                
                try:
                    masked_values = []
                    original_values = df[column].tolist()
                    
                    for value in original_values:
                        masked_value = mask_function(value)
                        masked_values.append(masked_value)
                    
                    masked_df[column] = masked_values
                    
                    length_errors = 0
                    for orig, masked in zip(original_values, masked_values):
                        if not pd.isna(orig) and not pd.isna(masked):
                            if len(str(orig).strip()) != len(str(masked).strip()):
                                length_errors += 1
                    
                    non_null_original = df[column].dropna().head(3).tolist()
                    non_null_masked = pd.Series(masked_values).dropna().head(3).tolist()
                    
                    masking_report[column] = {
                        'status': 'success',
                        'total_rows': len(df),
                        'length_errors': length_errors,
                        'sample_original': non_null_original,
                        'sample_masked': non_null_masked
                    }
                    
                    print(f"  Successfully masked {len(masked_values)} values")
                    if length_errors > 0:
                        print(f"  {length_errors} length mismatches detected")
                    
                except Exception as e:
                    print(f"  Error masking column {column}: {str(e)}")
                    masking_report[column] = {
                        'status': 'error',
                        'error': str(e)
                    }
            else:
                print(f"Keeping original: {column} (not PII)")
                masking_report[column] = {'status': 'unchanged'}
        
        print(f"\nBulletproof masking completed!")
        return masked_df, masking_report
    
    def display_masking_report(self, masking_report):
        """Display detailed masking report"""
        print("\nBulletproof Masking Report:")
        print("=" * 70)
        
        total_length_errors = 0
        
        for column, report in masking_report.items():
            print(f"\nColumn: {column}")
            print(f"Status: {report['status']}")
            
            if report['status'] == 'success':
                length_errors = report.get('length_errors', 0)
                total_length_errors += length_errors
                
                print(f"Length errors: {length_errors}")
                print(f"Original samples: {report['sample_original']}")
                print(f"Masked samples:   {report['sample_masked']}")
            elif report['status'] == 'error':
                print(f"Error: {report['error']}")
        
        print(f"\n{'='*70}")
        if total_length_errors == 0:
            print("PERFECT! No length preservation errors!")
        else:
            print(f"Total length errors across all columns: {total_length_errors}")

print("Masking Pipeline class created successfully!")

In [None]:
if 'df' in locals() and df is not None and 'detected_pii' in locals():
    print("Applying Enhanced Masking with Encryption...")
    
    # Generate encryption key from user password
    print("\n" + "="*60)
    print("ENCRYPTION SETUP")
    print("="*60)
    encryption_system.generate_encryption_key()
    
    # Create masking pipeline
    pipeline = FullReversibleMaskingPipeline(enhanced_masker, detected_pii)
    
    # Step 3: Apply masking
    print("\n" + "="*60)
    print("APPLYING MASKING")
    print("="*60)
    masked_df, masking_report = pipeline.mask_dataframe(df)
    
    # Display masking report
    pipeline.display_masking_report(masking_report)
    
    # Step 5: Save masked dataset to results folder
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    masked_filename = f"masked_dataset_{SESSION_ID}_{timestamp}.csv"
    masked_filepath = os.path.join('../results', masked_filename)
    
    masked_df.to_csv(masked_filepath, index=False)
    print(f"\n✅ Masked dataset saved to: {masked_filepath}")
    
    # Encrypt and save mappings
    print("\n" + "="*60)
    print("ENCRYPTING MAPPINGS")
    print("="*60)
    
    # Get all mappings from the masker
    all_mappings = enhanced_masker.mappings
    
    # Display mapping statistics
    print("Mapping categories created:")
    for category, mappings in all_mappings.items():
        if mappings:
            print(f"  {category}: {len(mappings)} mappings")
    
    # Encrypt the mappings
    encryption_success = encryption_system.encrypt_mappings(all_mappings)
    
    if encryption_success:
        # Save encrypted mappings
        encrypted_mappings_path = f'../config/encrypted_mappings_{SESSION_ID}.json'
        encryption_system.save_encrypted_mappings(encrypted_mappings_path)
        
        # Save key info (not the actual key)
        key_info_path = f'../config/key_info_{SESSION_ID}.json'
        encryption_system.save_key_info(key_info_path)
        
        print(f"✅ Encrypted mappings saved to: {encrypted_mappings_path}")
        print(f"✅ Key info saved to: {key_info_path}")
        
        # Save session info for easy reversal
        session_info = {
            'session_id': SESSION_ID,
            'timestamp': timestamp,
            'original_dataset_shape': df.shape,
            'masked_dataset_path': masked_filepath,
            'encrypted_mappings_path': encrypted_mappings_path,
            'key_info_path': key_info_path,
            'pii_columns_detected': len(detected_pii),
            'total_columns': len(df.columns),
            'note': 'Use this session info to reverse the dataset with the correct encryption password'
        }
        
        session_info_path = f'../config/session_info_{SESSION_ID}.json'
        with open(session_info_path, 'w') as f:
            json.dump(session_info, f, indent=2)
        
        print(f"✅ Session info saved to: {session_info_path}")
        
        print(f"\n" + "="*60)
        print("MASKING COMPLETE!")
        print("="*60)
        print(f"📁 Session ID: {SESSION_ID}")
        print(f"📁 Masked dataset: {masked_filepath}")
        print(f"🔒 Encrypted mappings: {encrypted_mappings_path}")
        print(f"📋 Session info: {session_info_path}")
        print(f"\n💡 To reverse this dataset, use the reversal function with:")
        print(f"   - Session ID: {SESSION_ID}")
        print(f"   - Your encryption password")
        
    else:
        print("❌ Failed to encrypt mappings!")
        
else:
    print("❌ Please run the previous cells to load data and detect PII first!")

In [None]:
def reverse_masked_dataset_enhanced(session_id=None, masked_csv_path=None):
    """
    ENHANCED DATA REVERSAL with password protection and session management
    """
    
    print(f"\n" + "="*60)
    print("ENHANCED DATA REVERSAL SYSTEM")
    print("="*60)
    
    # Get session ID if not provided
    if session_id is None:
        print("Available sessions:")
        config_files = [f for f in os.listdir('../config') if f.startswith('session_info_')]
        if not config_files:
            print("No sessions found!")
            return None, None
        
        for i, file in enumerate(config_files):
            session_id_from_file = file.replace('session_info_', '').replace('.json', '')
            with open(f'../config/{file}', 'r') as f:
                info = json.load(f)
            print(f"{i+1}. Session ID: {session_id_from_file} ({info['timestamp']})")
        
        choice = input("\nEnter session number or session ID: ").strip()
        
        if choice.isdigit():
            choice_idx = int(choice) - 1
            if 0 <= choice_idx < len(config_files):
                session_id = config_files[choice_idx].replace('session_info_', '').replace('.json', '')
            else:
                print("Invalid choice!")
                return None, None
        else:
            session_id = choice
    
    print(f"\nUsing Session ID: {session_id}")
    
    # Load session info
    session_info_path = f'../config/session_info_{session_id}.json'
    try:
        with open(session_info_path, 'r') as f:
            session_info = json.load(f)
        print(f"✅ Session info loaded")
    except FileNotFoundError:
        print(f"❌ Session info not found: {session_info_path}")
        return None, None
    
    # Get masked dataset path
    if masked_csv_path is None:
        masked_csv_path = session_info['masked_dataset_path']
    
    print(f"📁 Masked dataset: {masked_csv_path}")
    
    # Load masked data
    try:
        masked_df = pd.read_csv(masked_csv_path)
        print(f"✅ Loaded masked dataset: {masked_df.shape}")
    except FileNotFoundError:
        print(f"❌ Masked dataset not found: {masked_csv_path}")
        return None, None
    
    # Load encrypted mappings
    encrypted_mappings_path = session_info['encrypted_mappings_path']
    temp_encryption = PIIEncryptionSystem()
    
    mappings_loaded = temp_encryption.load_encrypted_mappings(encrypted_mappings_path)
    if not mappings_loaded:
        print(f"❌ Encrypted mappings not found: {encrypted_mappings_path}")
        return None, None
    
    # Get password and decrypt
    print(f"\n🔒 DECRYPTION REQUIRED")
    print("-" * 30)
    
    max_attempts = 3
    for attempt in range(max_attempts):
        password = getpass.getpass(f"Enter encryption password (attempt {attempt+1}/{max_attempts}): ")
        
        decrypted_mappings = temp_encryption.decrypt_mappings_with_password(password)
        
        if decrypted_mappings is not None:
            print("✅ Password correct! Mappings decrypted successfully")
            break
        else:
            print(f"❌ Incorrect password! {max_attempts - attempt - 1} attempts remaining")
            if attempt == max_attempts - 1:
                print("❌ Maximum attempts reached. Reversal failed.")
                return None, None
    
    # Create reverse mappings
    reverse_mappings = temp_encryption.reverse_mappings(decrypted_mappings)
    
    print(f"\n📊 Available mapping categories:")
    for category, mappings in reverse_mappings.items():
        if mappings:
            print(f"  {category}: {len(mappings)} mappings")
    
    # Enhanced column to category mapping
    column_categories = {
        'full_name': 'names',
        'doctor_name': 'names',
        'address_street': 'addresses', 
        'address_city': 'addresses',
        'address_zip': 'addresses',
        'email_address': 'emails',
        'phone_number': 'phone_numbers',
        'ssn': 'ssn_numbers',
        'license_number': 'license_numbers',
        'vehicle_id': 'vehicle_ids',
        'device_serial_number': 'device_serials',
        'ip_address': 'ip_addresses',
        'hospital_name': 'names',
        'insurance_provider': 'names',
        'date_of_birth': 'dates',
        'admission_date': 'dates',
        'discharge_date': 'dates'
    }
    
    # Apply enhanced reversal
    print(f"\n🔄 REVERSING DATA")
    print("-" * 30)
    
    recovered_df = masked_df.copy()
    reversal_stats = {}
    
    for column in masked_df.columns:
        if column in column_categories:
            category = column_categories[column]
            
            if category in reverse_mappings and reverse_mappings[category]:
                print(f"  Reversing {column} ({category})...")
                
                recovered_values = []
                successful_reversals = 0
                
                for value in masked_df[column]:
                    if pd.isna(value):
                        recovered_values.append(value)
                    else:
                        value_str = str(value)
                        if value_str in reverse_mappings[category]:
                            original = reverse_mappings[category][value_str]
                            recovered_values.append(original)
                            successful_reversals += 1
                        else:
                            recovered_values.append(value)
                
                recovered_df[column] = recovered_values
                success_rate = (successful_reversals / len(masked_df)) * 100
                reversal_stats[column] = {
                    'successful': successful_reversals,
                    'total': len(masked_df),
                    'success_rate': success_rate
                }
                
                print(f"    ✅ Recovered {successful_reversals}/{len(masked_df)} values ({success_rate:.1f}%)")
            else:
                print(f"    ⚠️ No mappings found for {column}")
                reversal_stats[column] = {'status': 'no_mappings'}
        else:
            reversal_stats[column] = {'status': 'not_pii'}
    
    # Save recovered data to reversed_datasets folder
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    recovered_filename = f"recovered_dataset_{session_id}_{timestamp}.csv"
    recovered_filepath = os.path.join('../results/reversed_datasets', recovered_filename)
    
    recovered_df.to_csv(recovered_filepath, index=False)
    
    # Create recovery report
    recovery_report = {
        'session_id': session_id,
        'recovery_timestamp': timestamp,
        'original_masked_file': masked_csv_path,
        'recovered_file': recovered_filepath,
        'reversal_stats': reversal_stats,
        'total_columns': len(masked_df.columns),
        'pii_columns_processed': len([col for col in reversal_stats if 'success_rate' in reversal_stats[col]]),
        'perfect_reversals': len([col for col in reversal_stats if reversal_stats[col].get('success_rate') == 100])
    }
    
    recovery_report_path = os.path.join('../results/reversed_datasets', f'recovery_report_{session_id}_{timestamp}.json')
    with open(recovery_report_path, 'w') as f:
        json.dump(recovery_report, f, indent=2)
    
    total_pii_columns = recovery_report['pii_columns_processed']
    perfect_columns = recovery_report['perfect_reversals']
    
    print(f"\n" + "="*60)
    print("ENHANCED REVERSAL COMPLETE!")
    print("="*60)
    print(f"✅ Perfect reversals: {perfect_columns}/{total_pii_columns} PII columns")
    print(f"📁 Recovered dataset: {recovered_filepath}")
    print(f"📋 Recovery report: {recovery_report_path}")
    print(f"🔒 Session ID: {session_id}")
    
    return recovered_df, reversal_stats

print("Enhanced Reversal Function created successfully!")
print("\n💡 Usage:")
print("   reverse_masked_dataset_enhanced()  # Interactive mode")
print("   reverse_masked_dataset_enhanced('your_session_id')  # Direct mode")

In [None]:
# Test the reversal function
print("Testing Reversal Function...")
print("This will ask for your encryption password to reverse the masked data")

# Run the enhanced reversal function
recovered_df, reversal_stats = reverse_masked_dataset_enhanced()

if recovered_df is not None:
    print("\n" + "="*60)
    print("REVERSAL TEST RESULTS")
    print("="*60)
    
    # Display some comparison between original and recovered
    if 'df' in locals():
        print("Comparing original vs recovered data (first 3 rows):")
        print("\nOriginal data sample:")
        for col in df.columns[:3]:  # Show first 3 columns
            print(f"{col}: {df[col].head(3).tolist()}")
        
        print("\nRecovered data sample:")
        for col in recovered_df.columns[:3]:  # Show first 3 columns
            print(f"{col}: {recovered_df[col].head(3).tolist()}")
    
    print("\n✅ Reversal test completed successfully!")
else:
    print("❌ Reversal test failed!")

print("Reversal test cell ready (uncomment to run)")