In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report, roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import hashlib
import json
import warnings
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

class FlowerDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

def load_flower_dataset(base_path):
    classes = ['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']
    
    all_image_paths = []
    all_labels = []
    class_counts = {}
    
    for class_idx, class_name in enumerate(classes):
        class_path = os.path.join(base_path, class_name)
        
        if not os.path.exists(class_path):
            print(f"Warning: {class_path} does not exist!")
            continue
            
        images = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        class_counts[class_name] = len(images)
        
        for img_name in images:
            img_path = os.path.join(class_path, img_name)
            all_image_paths.append(img_path)
            all_labels.append(class_idx)
    
    print("\nDataset Distribution:")
    print("-" * 40)
    for class_name, count in class_counts.items():
        print(f"  {class_name:12s}: {count:4d} images")
    print("-" * 40)
    print(f"  Total:        {len(all_image_paths):4d} images")
    print("-" * 40)
    
    return all_image_paths, all_labels, classes

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(identity)
        out = self.relu(out)
        
        return out

class FlowerCNN(nn.Module):
    def __init__(self, num_classes=5):
        super(FlowerCNN, self).__init__()
        
        self.initial_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.layer1 = nn.Sequential(
            ResidualBlock(64, 128, stride=1),
            ResidualBlock(128, 128, stride=1)
        )
        
        self.layer2 = nn.Sequential(
            ResidualBlock(128, 256, stride=2),
            ResidualBlock(256, 256, stride=1)
        )
        
        self.layer3 = nn.Sequential(
            ResidualBlock(256, 512, stride=2),
            ResidualBlock(512, 512, stride=1)
        )
        
        self.layer4 = nn.Sequential(
            ResidualBlock(512, 1024, stride=2),
            ResidualBlock(1024, 1024, stride=1)
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        x = self.initial_conv(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class FixedWeightWatermark:
    def __init__(self, model, fingerprint, freeze_ratio=0.002):
        self.model = model
        self.fingerprint = fingerprint
        self.freeze_ratio = freeze_ratio
        self.frozen_params = {}
        self.frozen_values = {}
        
    def generate_frozen_pattern(self):
        hash_object = hashlib.sha256(self.fingerprint.encode())
        hex_dig = hash_object.hexdigest()
        seed = int(hex_dig[:8], 16)
        rng = np.random.RandomState(seed)
        
        all_params = []
        for name, param in self.model.named_parameters():
            if 'weight' in name and ('layer' in name or 'classifier' in name or 'initial_conv' in name):
                all_params.append((name, param))
        
        total_frozen = 0
        for name, param in all_params:
            num_weights = param.numel()
            num_freeze = max(1, int(num_weights * self.freeze_ratio))
            
            indices = rng.choice(num_weights, size=num_freeze, replace=False)
            values = rng.randn(num_freeze) * 0.1
            
            self.frozen_params[name] = indices
            self.frozen_values[name] = values
            total_frozen += num_freeze
        
        print(f"\nWatermark Configuration:")
        print(f"  Fingerprint: {self.fingerprint}")
        print(f"  Freeze Ratio: {self.freeze_ratio:.4f}")
        print(f"  Total Frozen Weights: {total_frozen}")
        print(f"  Frozen Layers: {len(self.frozen_params)}")
    
    def apply_frozen_weights(self):
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.frozen_params:
                    indices = self.frozen_params[name]
                    values = self.frozen_values[name]
                    
                    param_flat = param.data.view(-1)
                    param_flat[indices] = torch.FloatTensor(values).to(param.device)
    
    def freeze_gradient_hook(self):
        def hook_fn(name):
            def hook(grad):
                if name in self.frozen_params:
                    indices = self.frozen_params[name]
                    grad_flat = grad.view(-1)
                    grad_flat[indices] = 0.0
                return grad
            return hook
        
        for name, param in self.model.named_parameters():
            if name in self.frozen_params:
                param.register_hook(hook_fn(name))
    
    def verify_frozen_weights(self):
        match_score = 0.0
        total_frozen = 0
        layer_matches = {}
        
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.frozen_params:
                    indices = self.frozen_params[name]
                    expected_values = torch.FloatTensor(self.frozen_values[name]).to(param.device)
                    
                    param_flat = param.data.view(-1)
                    actual_values = param_flat[indices]
                    
                    differences = torch.abs(actual_values - expected_values)
                    matches = (differences < 0.01).sum().item()
                    
                    layer_matches[name] = matches / len(indices)
                    match_score += matches
                    total_frozen += len(indices)
        
        match_ratio = match_score / total_frozen if total_frozen > 0 else 0.0
        
        return match_ratio, total_frozen, layer_matches
    
    def get_watermark_statistics(self):
        stats = {
            'total_layers': len(self.frozen_params),
            'total_frozen_weights': sum(len(indices) for indices in self.frozen_params.values()),
            'layer_details': {}
        }
        
        for name, indices in self.frozen_params.items():
            stats['layer_details'][name] = {
                'num_frozen': len(indices),
                'frozen_values_mean': float(np.mean(self.frozen_values[name])),
                'frozen_values_std': float(np.std(self.frozen_values[name]))
            }
        
        return stats

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 
                num_epochs, watermark_handler=None, model_name="model"):
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': [],
        'learning_rates': []
    }
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [{model_name}]')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            if watermark_handler is not None:
                watermark_handler.apply_frozen_weights()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({'loss': running_loss/len(pbar), 'acc': 100.*correct/total})
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100. * correct / total
        
        current_lr = optimizer.param_groups[0]['lr']
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['learning_rates'].append(current_lr)
        
        if scheduler is not None:
            scheduler.step(val_loss)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f'{model_name}_best.pth')
        
        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, LR: {current_lr:.6f}')
    
    return history

def evaluate_model(model, test_loader, class_names):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Evaluating'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, support = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    
    per_class_precision, per_class_recall, per_class_f1, per_class_support = precision_recall_fscore_support(
        all_labels, all_preds, average=None
    )
    
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)
    
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'per_class_precision': per_class_precision,
        'per_class_recall': per_class_recall,
        'per_class_f1': per_class_f1,
        'per_class_support': per_class_support,
        'confusion_matrix': cm,
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs,
        'classification_report': report
    }
    
    return metrics

def plot_training_history(history, model_name):
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    axes[0].plot(epochs, history['train_loss'], label='Train Loss', marker='o', 
                 linewidth=2.5, markersize=6, color='#E74C3C')
    axes[0].plot(epochs, history['val_loss'], label='Validation Loss', marker='s', 
                 linewidth=2.5, markersize=6, color='#3498DB')
    axes[0].set_xlabel('Epoch', fontsize=13, fontweight='bold')
    axes[0].set_ylabel('Loss', fontsize=13, fontweight='bold')
    axes[0].set_title(f'Loss Curves - {model_name}', fontsize=15, fontweight='bold', pad=15)
    axes[0].legend(fontsize=11, frameon=True, shadow=True)
    axes[0].grid(True, alpha=0.3, linestyle='--')
    
    axes[1].plot(epochs, history['train_acc'], label='Train Accuracy', marker='o', 
                 linewidth=2.5, markersize=6, color='#2ECC71')
    axes[1].plot(epochs, history['val_acc'], label='Validation Accuracy', marker='s', 
                 linewidth=2.5, markersize=6, color='#F39C12')
    axes[1].set_xlabel('Epoch', fontsize=13, fontweight='bold')
    axes[1].set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold')
    axes[1].set_title(f'Accuracy Curves - {model_name}', fontsize=15, fontweight='bold', pad=15)
    axes[1].legend(fontsize=11, frameon=True, shadow=True)
    axes[1].grid(True, alpha=0.3, linestyle='--')
    
    axes[2].plot(epochs, history['learning_rates'], marker='o', 
                 linewidth=2.5, markersize=6, color='#9B59B6')
    axes[2].set_xlabel('Epoch', fontsize=13, fontweight='bold')
    axes[2].set_ylabel('Learning Rate', fontsize=13, fontweight='bold')
    axes[2].set_title(f'Learning Rate Schedule - {model_name}', fontsize=15, fontweight='bold', pad=15)
    axes[2].grid(True, alpha=0.3, linestyle='--')
    axes[2].set_yscale('log')
    
    plt.tight_layout()
    plt.savefig(f'{model_name}_training_history.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

def plot_confusion_matrix(cm, class_names, model_name):
    plt.figure(figsize=(10, 8))
    
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='YlOrRd', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'}, linewidths=0.5, linecolor='gray')
    
    plt.xlabel('Predicted Label', fontsize=13, fontweight='bold')
    plt.ylabel('True Label', fontsize=13, fontweight='bold')
    plt.title(f'Confusion Matrix - {model_name}', fontsize=15, fontweight='bold', pad=15)
    plt.xticks(rotation=45, ha='right', fontsize=11)
    plt.yticks(rotation=0, fontsize=11)
    plt.tight_layout()
    plt.savefig(f'{model_name}_confusion_matrix.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Percentage'}, linewidths=0.5, linecolor='gray')
    
    plt.xlabel('Predicted Label', fontsize=13, fontweight='bold')
    plt.ylabel('True Label', fontsize=13, fontweight='bold')
    plt.title(f'Normalized Confusion Matrix - {model_name}', fontsize=15, fontweight='bold', pad=15)
    plt.xticks(rotation=45, ha='right', fontsize=11)
    plt.yticks(rotation=0, fontsize=11)
    plt.tight_layout()
    plt.savefig(f'{model_name}_confusion_matrix_normalized.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

def plot_per_class_metrics(metrics, class_names, model_name):
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    x = np.arange(len(class_names))
    width = 0.6
    
    bars = axes[0, 0].bar(x, metrics['per_class_precision'], width, 
                          color='#3498DB', edgecolor='black', linewidth=1.2, alpha=0.85)
    axes[0, 0].set_ylabel('Precision', fontsize=12, fontweight='bold')
    axes[0, 0].set_title('Precision per Class', fontsize=14, fontweight='bold', pad=12)
    axes[0, 0].set_xticks(x)
    axes[0, 0].set_xticklabels(class_names, rotation=45, ha='right')
    axes[0, 0].set_ylim([0, 1.1])
    axes[0, 0].grid(True, alpha=0.3, axis='y', linestyle='--')
    for bar in bars:
        height = bar.get_height()
        axes[0, 0].text(bar.get_x() + bar.get_width()/2., height,
                        f'{height:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    bars = axes[0, 1].bar(x, metrics['per_class_recall'], width, 
                          color='#2ECC71', edgecolor='black', linewidth=1.2, alpha=0.85)
    axes[0, 1].set_ylabel('Recall', fontsize=12, fontweight='bold')
    axes[0, 1].set_title('Recall per Class', fontsize=14, fontweight='bold', pad=12)
    axes[0, 1].set_xticks(x)
    axes[0, 1].set_xticklabels(class_names, rotation=45, ha='right')
    axes[0, 1].set_ylim([0, 1.1])
    axes[0, 1].grid(True, alpha=0.3, axis='y', linestyle='--')
    for bar in bars:
        height = bar.get_height()
        axes[0, 1].text(bar.get_x() + bar.get_width()/2., height,
                        f'{height:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    bars = axes[1, 0].bar(x, metrics['per_class_f1'], width, 
                          color='#F39C12', edgecolor='black', linewidth=1.2, alpha=0.85)
    axes[1, 0].set_ylabel('F1-Score', fontsize=12, fontweight='bold')
    axes[1, 0].set_title('F1-Score per Class', fontsize=14, fontweight='bold', pad=12)
    axes[1, 0].set_xticks(x)
    axes[1, 0].set_xticklabels(class_names, rotation=45, ha='right')
    axes[1, 0].set_ylim([0, 1.1])
    axes[1, 0].grid(True, alpha=0.3, axis='y', linestyle='--')
    for bar in bars:
        height = bar.get_height()
        axes[1, 0].text(bar.get_x() + bar.get_width()/2., height,
                        f'{height:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    bars = axes[1, 1].bar(x, metrics['per_class_support'], width, 
                          color='#E74C3C', edgecolor='black', linewidth=1.2, alpha=0.85)
    axes[1, 1].set_ylabel('Support (# samples)', fontsize=12, fontweight='bold')
    axes[1, 1].set_title('Support per Class', fontsize=14, fontweight='bold', pad=12)
    axes[1, 1].set_xticks(x)
    axes[1, 1].set_xticklabels(class_names, rotation=45, ha='right')
    axes[1, 1].grid(True, alpha=0.3, axis='y', linestyle='--')
    for bar in bars:
        height = bar.get_height()
        axes[1, 1].text(bar.get_x() + bar.get_width()/2., height,
                        f'{int(height)}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    plt.suptitle(f'Per-Class Performance Metrics - {model_name}', 
                 fontsize=16, fontweight='bold', y=1.00)
    plt.tight_layout()
    plt.savefig(f'{model_name}_per_class_metrics.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

def plot_overall_metrics(metrics, model_name):
    metric_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    metric_values = [metrics['accuracy'], metrics['precision'], 
                     metrics['recall'], metrics['f1_score']]
    colors = ['#2ECC71', '#3498DB', '#E74C3C', '#F39C12']
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    bars = ax1.bar(metric_names, metric_values, color=colors,
                   edgecolor='black', linewidth=1.5, alpha=0.85)
    
    for bar in bars:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.4f}',
                ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    ax1.set_ylabel('Score', fontsize=13, fontweight='bold')
    ax1.set_title(f'Overall Performance Metrics - {model_name}', fontsize=15, fontweight='bold', pad=15)
    ax1.set_ylim([0, 1.1])
    ax1.grid(True, alpha=0.3, axis='y', linestyle='--')
    ax1.tick_params(axis='x', labelsize=11)
    
    sizes = metric_values
    explode = (0.05, 0.05, 0.05, 0.05)
    
    wedges, texts, autotexts = ax2.pie(sizes, explode=explode, labels=metric_names,
                                         colors=colors, autopct='%1.2f%%',
                                         shadow=True, startangle=90,
                                         textprops={'fontsize': 11, 'fontweight': 'bold'})
    
    for autotext in autotexts:
        autotext.set_color('white')
        autotext.set_fontsize(11)
        autotext.set_fontweight('bold')
    
    ax2.set_title(f'Metrics Distribution - {model_name}', fontsize=15, fontweight='bold', pad=15)
    
    plt.tight_layout()
    plt.savefig(f'{model_name}_overall_metrics.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

def plot_roc_curves(metrics, class_names, model_name):
    n_classes = len(class_names)
    y_true_bin = label_binarize(metrics['labels'], classes=range(n_classes))
    y_score = metrics['probabilities']
    
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    
    plt.figure(figsize=(12, 9))
    
    colors = ['#E74C3C', '#3498DB', '#2ECC71', '#F39C12', '#9B59B6']
    
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2.5,
                 label=f'{class_names[i]} (AUC = {roc_auc[i]:.3f})')
    
    plt.plot(fpr["micro"], tpr["micro"],
             label=f'Micro-average (AUC = {roc_auc["micro"]:.3f})',
             color='navy', linestyle='--', linewidth=3)
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=13, fontweight='bold')
    plt.ylabel('True Positive Rate', fontsize=13, fontweight='bold')
    plt.title(f'ROC Curves - {model_name}', fontsize=15, fontweight='bold', pad=15)
    plt.legend(loc="lower right", fontsize=11, frameon=True, shadow=True)
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.tight_layout()
    plt.savefig(f'{model_name}_roc_curves.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

def plot_model_comparison(baseline_metrics, watermarked_metrics, class_names):
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    metrics_overall = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    baseline_vals = [baseline_metrics['accuracy'], baseline_metrics['precision'],
                     baseline_metrics['recall'], baseline_metrics['f1_score']]
    watermarked_vals = [watermarked_metrics['accuracy'], watermarked_metrics['precision'],
                        watermarked_metrics['recall'], watermarked_metrics['f1_score']]
    
    x = np.arange(len(metrics_overall))
    width = 0.35
    
    bars1 = axes[0, 0].bar(x - width/2, baseline_vals, width, label='Baseline',
                           color='#95A5A6', edgecolor='black', linewidth=1.2, alpha=0.85)
    bars2 = axes[0, 0].bar(x + width/2, watermarked_vals, width, label='Watermarked',
                           color='#E74C3C', edgecolor='black', linewidth=1.2, alpha=0.85)
    
    axes[0, 0].set_ylabel('Score', fontsize=12, fontweight='bold')
    axes[0, 0].set_title('Overall Performance Comparison', fontsize=14, fontweight='bold', pad=12)
    axes[0, 0].set_xticks(x)
    axes[0, 0].set_xticklabels(metrics_overall, fontsize=11)
    axes[0, 0].legend(fontsize=11, frameon=True, shadow=True)
    axes[0, 0].set_ylim([0, 1.1])
    axes[0, 0].grid(True, alpha=0.3, axis='y', linestyle='--')
    
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            axes[0, 0].text(bar.get_x() + bar.get_width()/2., height,
                           f'{height:.3f}', ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    performance_drops = [(baseline_vals[i] - watermarked_vals[i]) * 100 for i in range(len(baseline_vals))]
    
    bars = axes[0, 1].bar(metrics_overall, performance_drops,
                          color=['#E74C3C' if x > 0 else '#2ECC71' for x in performance_drops],
                          edgecolor='black', linewidth=1.2, alpha=0.85)
    
    axes[0, 1].axhline(y=0, color='black', linestyle='-', linewidth=1)
    axes[0, 1].axhline(y=1, color='green', linestyle='--', linewidth=2, alpha=0.7, label='Target (1%)')
    axes[0, 1].set_ylabel('Performance Drop (%)', fontsize=12, fontweight='bold')
    axes[0, 1].set_title('Watermarking Impact', fontsize=14, fontweight='bold', pad=12)
    axes[0, 1].legend(fontsize=11)
    axes[0, 1].grid(True, alpha=0.3, axis='y', linestyle='--')
    
    for bar in bars:
        height = bar.get_height()
        axes[0, 1].text(bar.get_x() + bar.get_width()/2., height,
                       f'{height:.2f}%', ha='center', 
                       va='bottom' if height >= 0 else 'top', fontsize=10, fontweight='bold')
    
    x_classes = np.arange(len(class_names))
    width = 0.35
    
    bars1 = axes[1, 0].bar(x_classes - width/2, baseline_metrics['per_class_f1'], width,
                           label='Baseline', color='#95A5A6', edgecolor='black', linewidth=1.2, alpha=0.85)
    bars2 = axes[1, 0].bar(x_classes + width/2, watermarked_metrics['per_class_f1'], width,
                           label='Watermarked', color='#E74C3C', edgecolor='black', linewidth=1.2, alpha=0.85)
    
    axes[1, 0].set_ylabel('F1-Score', fontsize=12, fontweight='bold')
    axes[1, 0].set_title('Per-Class F1-Score Comparison', fontsize=14, fontweight='bold', pad=12)
    axes[1, 0].set_xticks(x_classes)
    axes[1, 0].set_xticklabels(class_names, rotation=45, ha='right', fontsize=11)
    axes[1, 0].legend(fontsize=11, frameon=True, shadow=True)
    axes[1, 0].set_ylim([0, 1.1])
    axes[1, 0].grid(True, alpha=0.3, axis='y', linestyle='--')
    
    class_drops = [(baseline_metrics['per_class_f1'][i] - watermarked_metrics['per_class_f1'][i]) * 100 
                   for i in range(len(class_names))]
    
    bars = axes[1, 1].bar(class_names, class_drops,
                          color=['#E74C3C' if x > 0 else '#2ECC71' for x in class_drops],
                          edgecolor='black', linewidth=1.2, alpha=0.85)
    
    axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=1)
    axes[1, 1].set_ylabel('F1-Score Drop (%)', fontsize=12, fontweight='bold')
    axes[1, 1].set_title('Per-Class Performance Impact', fontsize=14, fontweight='bold', pad=12)
    axes[1, 1].set_xticklabels(class_names, rotation=45, ha='right', fontsize=11)
    axes[1, 1].grid(True, alpha=0.3, axis='y', linestyle='--')
    
    for bar in bars:
        height = bar.get_height()
        axes[1, 1].text(bar.get_x() + bar.get_width()/2., height,
                       f'{height:.1f}%', ha='center', 
                       va='bottom' if height >= 0 else 'top', fontsize=9, fontweight='bold')
    
    plt.suptitle('Baseline vs Watermarked Model Comparison', 
                 fontsize=16, fontweight='bold', y=1.00)
    plt.tight_layout()
    plt.savefig('model_comparison_comprehensive.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

def plot_watermark_verification(baseline_verification, watermarked_verification):
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    models = ['Baseline', 'Watermarked']
    match_ratios = [0, watermarked_verification['match_ratio']]
    
    bars = axes[0].bar(models, match_ratios, 
                       color=['#95A5A6', '#E74C3C'],
                       edgecolor='black', linewidth=1.5, alpha=0.85, width=0.5)
    
    axes[0].axhline(y=0.95, color='green', linestyle='--', linewidth=2.5, 
                    label='Detection Threshold (0.95)', alpha=0.8)
    
    axes[0].set_ylabel('Weight Match Ratio', fontsize=13, fontweight='bold')
    axes[0].set_title('Fixed Weight Watermark Verification', fontsize=15, fontweight='bold', pad=15)
    axes[0].set_ylim([0, 1.1])
    axes[0].legend(fontsize=12, frameon=True, shadow=True)
    axes[0].grid(True, alpha=0.3, axis='y', linestyle='--')
    
    for bar in bars:
        height = bar.get_height()
        if height > 0:
            axes[0].text(bar.get_x() + bar.get_width()/2., height,
                        f'{height:.4f}',
                        ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    detected_status = ['Not Detected', 'Detected']
    colors_status = ['#95A5A6', '#2ECC71']
    counts = [1, 1]
    
    wedges, texts, autotexts = axes[1].pie(counts, labels=detected_status,
                                            colors=colors_status, autopct='',
                                            shadow=True, startangle=90,
                                            explode=(0.05, 0.05),
                                            textprops={'fontsize': 13, 'fontweight': 'bold'})
    
    axes[1].set_title('Watermark Detection Status', fontsize=15, fontweight='bold', pad=15)
    
    detection_text = f"Baseline: Not Detected\nWatermarked: {'Detected ✓' if watermarked_verification['detected'] else 'Not Detected ✗'}"
    axes[1].text(0, -1.4, detection_text, ha='center', fontsize=12, 
                 bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig('watermark_verification.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

def plot_watermark_statistics(watermark_stats):
    layer_names = list(watermark_stats['layer_details'].keys())
    num_frozen = [watermark_stats['layer_details'][name]['num_frozen'] for name in layer_names]
    
    short_names = [name.split('.')[-2] + '.' + name.split('.')[-1] if '.' in name else name 
                   for name in layer_names]
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    bars = axes[0].barh(short_names, num_frozen, color='#3498DB', 
                        edgecolor='black', linewidth=1.2, alpha=0.85)
    
    axes[0].set_xlabel('Number of Frozen Weights', fontsize=13, fontweight='bold')
    axes[0].set_ylabel('Layer', fontsize=13, fontweight='bold')
    axes[0].set_title('Frozen Weights Distribution Across Layers', fontsize=15, fontweight='bold', pad=15)
    axes[0].grid(True, alpha=0.3, axis='x', linestyle='--')
    
    for i, bar in enumerate(bars):
        width = bar.get_width()
        axes[0].text(width, bar.get_y() + bar.get_height()/2.,
                    f' {int(width)}', ha='left', va='center', fontsize=10, fontweight='bold')
    
    total_frozen = watermark_stats['total_frozen_weights']
    total_layers = watermark_stats['total_layers']
    
    info_text = f"Total Frozen Weights: {total_frozen}\n"
    info_text += f"Total Layers: {total_layers}\n"
    info_text += f"Average per Layer: {total_frozen/total_layers:.1f}"
    
    axes[1].text(0.5, 0.6, info_text, ha='center', va='center',
                fontsize=14, fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='#ECF0F1', alpha=0.8, pad=1.5),
                transform=axes[1].transAxes)
    
    axes[1].text(0.5, 0.3, f"Watermark Configuration", ha='center', va='center',
                fontsize=16, fontweight='bold', style='italic',
                transform=axes[1].transAxes)
    
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.savefig('watermark_statistics.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()

def save_comprehensive_results(baseline_metrics, watermarked_metrics, 
                                baseline_verification, watermarked_verification,
                                watermark_stats, class_names):
    
    results = {
        'Dataset': {
            'Name': 'Flowers Recognition',
            'Classes': class_names,
            'Num_Classes': len(class_names)
        },
        'Baseline_Model': {
            'Performance': {
                'Accuracy': float(baseline_metrics['accuracy']),
                'Precision': float(baseline_metrics['precision']),
                'Recall': float(baseline_metrics['recall']),
                'F1_Score': float(baseline_metrics['f1_score'])
            },
            'Per_Class_Performance': {
                class_names[i]: {
                    'Precision': float(baseline_metrics['per_class_precision'][i]),
                    'Recall': float(baseline_metrics['per_class_recall'][i]),
                    'F1_Score': float(baseline_metrics['per_class_f1'][i]),
                    'Support': int(baseline_metrics['per_class_support'][i])
                }
                for i in range(len(class_names))
            },
            'Verification': {
                'Match_Ratio': float(baseline_verification['match_ratio']),
                'Detected': bool(baseline_verification['detected'])
            }
        },
        'Watermarked_Model': {
            'Performance': {
                'Accuracy': float(watermarked_metrics['accuracy']),
                'Precision': float(watermarked_metrics['precision']),
                'Recall': float(watermarked_metrics['recall']),
                'F1_Score': float(watermarked_metrics['f1_score'])
            },
            'Per_Class_Performance': {
                class_names[i]: {
                    'Precision': float(watermarked_metrics['per_class_precision'][i]),
                    'Recall': float(watermarked_metrics['per_class_recall'][i]),
                    'F1_Score': float(watermarked_metrics['per_class_f1'][i]),
                    'Support': int(watermarked_metrics['per_class_support'][i])
                }
                for i in range(len(class_names))
            },
            'Verification': {
                'Match_Ratio': float(watermarked_verification['match_ratio']),
                'Total_Frozen_Weights': int(watermarked_verification['total_frozen']),
                'Detected': bool(watermarked_verification['detected'])
            },
            'Watermark_Configuration': watermark_stats
        },
        'Performance_Impact': {
            'Overall': {
                'Accuracy_Drop': float(baseline_metrics['accuracy'] - watermarked_metrics['accuracy']),
                'Precision_Drop': float(baseline_metrics['precision'] - watermarked_metrics['precision']),
                'Recall_Drop': float(baseline_metrics['recall'] - watermarked_metrics['recall']),
                'F1_Score_Drop': float(baseline_metrics['f1_score'] - watermarked_metrics['f1_score'])
            },
            'Per_Class_F1_Drop': {
                class_names[i]: float(baseline_metrics['per_class_f1'][i] - watermarked_metrics['per_class_f1'][i])
                for i in range(len(class_names))
            }
        }
    }
    
    with open('fixed_watermark_results.json', 'w') as f:
        json.dump(results, f, indent=4)
    
    print("\n" + "="*80)
    print("COMPREHENSIVE EXPERIMENTAL RESULTS")
    print("="*80)
    
    print("\n" + "-"*80)
    print("BASELINE MODEL (No Watermark)")
    print("-"*80)
    print(f"  Overall Performance:")
    print(f"    • Accuracy:  {baseline_metrics['accuracy']:.4f}")
    print(f"    • Precision: {baseline_metrics['precision']:.4f}")
    print(f"    • Recall:    {baseline_metrics['recall']:.4f}")
    print(f"    • F1-Score:  {baseline_metrics['f1_score']:.4f}")
    print(f"  Watermark Status: {'Detected (Unexpected!)' if baseline_verification['detected'] else 'Not Detected ✓'}")
    
    print("\n" + "-"*80)
    print("WATERMARKED MODEL (Fixed Weight Watermark)")
    print("-"*80)
    print(f"  Overall Performance:")
    print(f"    • Accuracy:  {watermarked_metrics['accuracy']:.4f}")
    print(f"    • Precision: {watermarked_metrics['precision']:.4f}")
    print(f"    • Recall:    {watermarked_metrics['recall']:.4f}")
    print(f"    • F1-Score:  {watermarked_metrics['f1_score']:.4f}")
    print(f"\n  Watermark Configuration:")
    print(f"    • Total Frozen Weights: {watermark_stats['total_frozen_weights']}")
    print(f"    • Frozen Layers: {watermark_stats['total_layers']}")
    print(f"    • Match Ratio: {watermarked_verification['match_ratio']:.4f}")
    print(f"    • Detection Status: {'Detected ✓' if watermarked_verification['detected'] else 'Not Detected ✗'}")
    
    print("\n" + "-"*80)
    print("PERFORMANCE IMPACT ANALYSIS")
    print("-"*80)
    acc_drop = (baseline_metrics['accuracy'] - watermarked_metrics['accuracy']) * 100
    print(f"  • Accuracy Drop:  {acc_drop:+.2f}%")
    print(f"  • Precision Drop: {(baseline_metrics['precision'] - watermarked_metrics['precision']) * 100:+.2f}%")
    print(f"  • Recall Drop:    {(baseline_metrics['recall'] - watermarked_metrics['recall']) * 100:+.2f}%")
    print(f"  • F1-Score Drop:  {(baseline_metrics['f1_score'] - watermarked_metrics['f1_score']) * 100:+.2f}%")
    
    if abs(acc_drop) <= 1.0:
        print(f"\n  ✓ Performance impact is within acceptable range (≤1%)")
    else:
        print(f"\n  ⚠ Performance impact exceeds target threshold (>1%)")
    
    print("\n" + "-"*80)
    print("PER-CLASS PERFORMANCE (F1-Score)")
    print("-"*80)
    for i, class_name in enumerate(class_names):
        baseline_f1 = baseline_metrics['per_class_f1'][i]
        watermarked_f1 = watermarked_metrics['per_class_f1'][i]
        drop = (baseline_f1 - watermarked_f1) * 100
        print(f"  {class_name:12s}: Baseline={baseline_f1:.4f}, Watermarked={watermarked_f1:.4f}, Drop={drop:+.2f}%")
    
    print("\n" + "="*80)
    print("CONCLUSION")
    print("="*80)
    
    if watermarked_verification['detected'] and not baseline_verification['detected']:
        print("  ✓ Watermark successfully embedded and detected")
        print(f"  ✓ Match ratio: {watermarked_verification['match_ratio']:.4f} (above threshold)")
    else:
        print("  ⚠ Watermark detection issue")
    
    if abs(acc_drop) <= 1.0:
        print("  ✓ Minimal performance degradation achieved")
    else:
        print("  ⚠ Performance degradation exceeds target")
    
    print("="*80 + "\n")

def main():
    print("\n" + "="*80)
    print("FIXED WEIGHT WATERMARKING FOR MODEL ATTRIBUTION")
    print("Flower Recognition Dataset - Enhanced ResNet Architecture")
    print("="*80 + "\n")
    
    BASE_PATH = '/kaggle/input/flowers-recognition/flowers'
    FINGERPRINT = "FlowerCNN_FixedWatermark_2025_SecureOwnership"
    
    IMG_SIZE = 224
    BATCH_SIZE = 32
    NUM_EPOCHS = 10
    LEARNING_RATE = 0.001
    FREEZE_RATIO = 0.002
    
    print("Loading Flower Dataset...")
    all_paths, all_labels, class_names = load_flower_dataset(BASE_PATH)
    
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        all_paths, all_labels, test_size=0.3, stratify=all_labels, random_state=RANDOM_SEED
    )
    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=0.5, stratify=temp_labels, random_state=RANDOM_SEED
    )
    
    print(f"\nData Split:")
    print(f"  Train:      {len(train_paths)} images")
    print(f"  Validation: {len(val_paths)} images")
    print(f"  Test:       {len(test_paths)} images")
    
    transform_train = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    transform_test = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    train_dataset = FlowerDataset(train_paths, train_labels, transform=transform_train)
    val_dataset = FlowerDataset(val_paths, val_labels, transform=transform_test)
    test_dataset = FlowerDataset(test_paths, test_labels, transform=transform_test)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    print("\n" + "="*80)
    print("EXPERIMENT 1: BASELINE MODEL (No Watermark)")
    print("="*80)
    
    baseline_model = FlowerCNN(num_classes=len(class_names)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(baseline_model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 
                                                       patience=3, verbose=True)
    
    baseline_history = train_model(
        baseline_model, train_loader, val_loader, criterion, optimizer, scheduler,
        NUM_EPOCHS, watermark_handler=None, model_name="baseline"
    )
    
    baseline_model.load_state_dict(torch.load('baseline_best.pth'))
    
    print("\nEvaluating Baseline Model...")
    baseline_metrics = evaluate_model(baseline_model, test_loader, class_names)
    print("\n" + baseline_metrics['classification_report'])
    
    print("\nGenerating Baseline Model Visualizations...")
    plot_training_history(baseline_history, "Baseline")
    plot_confusion_matrix(baseline_metrics['confusion_matrix'], class_names, "Baseline")
    plot_per_class_metrics(baseline_metrics, class_names, "Baseline")
    plot_overall_metrics(baseline_metrics, "Baseline")
    plot_roc_curves(baseline_metrics, class_names, "Baseline")
    
    print("\nVerifying Baseline Model (Should Not Detect Watermark)...")
    dummy_watermark_baseline = FixedWeightWatermark(baseline_model, FINGERPRINT, FREEZE_RATIO)
    dummy_watermark_baseline.generate_frozen_pattern()
    match_ratio, total_frozen, _ = dummy_watermark_baseline.verify_frozen_weights()
    
    baseline_verification = {
        'match_ratio': match_ratio,
        'total_frozen': total_frozen,
        'detected': match_ratio > 0.95
    }
    
    print(f"  Match Ratio: {match_ratio:.4f}")
    print(f"  Detected: {'YES (Unexpected!)' if baseline_verification['detected'] else 'NO ✓'}")
    
    print("\n" + "="*80)
    print("EXPERIMENT 2: WATERMARKED MODEL (Fixed Weight Watermark)")
    print("="*80)
    print(f"  Fingerprint: {FINGERPRINT}")
    print(f"  Freeze Ratio: {FREEZE_RATIO:.4f}")
    
    watermarked_model = FlowerCNN(num_classes=len(class_names)).to(device)
    watermark_handler = FixedWeightWatermark(watermarked_model, FINGERPRINT, FREEZE_RATIO)
    watermark_handler.generate_frozen_pattern()
    watermark_handler.apply_frozen_weights()
    watermark_handler.freeze_gradient_hook()
    
    watermark_stats = watermark_handler.get_watermark_statistics()
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(watermarked_model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 
                                                       patience=3, verbose=True)
    
    watermarked_history = train_model(
        watermarked_model, train_loader, val_loader, criterion, optimizer, scheduler,
        NUM_EPOCHS, watermark_handler=watermark_handler, model_name="watermarked"
    )
    
    watermarked_model.load_state_dict(torch.load('watermarked_best.pth'))
    
    print("\nEvaluating Watermarked Model...")
    watermarked_metrics = evaluate_model(watermarked_model, test_loader, class_names)
    print("\n" + watermarked_metrics['classification_report'])
    
    print("\nGenerating Watermarked Model Visualizations...")
    plot_training_history(watermarked_history, "Watermarked")
    plot_confusion_matrix(watermarked_metrics['confusion_matrix'], class_names, "Watermarked")
    plot_per_class_metrics(watermarked_metrics, class_names, "Watermarked")
    plot_overall_metrics(watermarked_metrics, "Watermarked")
    plot_roc_curves(watermarked_metrics, class_names, "Watermarked")
    
    print("\nVerifying Watermarked Model...")
    match_ratio, total_frozen, layer_matches = watermark_handler.verify_frozen_weights()
    
    watermarked_verification = {
        'match_ratio': match_ratio,
        'total_frozen': total_frozen,
        'layer_matches': layer_matches,
        'detected': match_ratio > 0.95
    }
    
    print(f"  Match Ratio: {match_ratio:.4f}")
    print(f"  Total Frozen Weights: {total_frozen}")
    print(f"  Detected: {'YES ✓' if watermarked_verification['detected'] else 'NO ✗'}")
    
    print("\n" + "="*80)
    print("GENERATING COMPARISON VISUALIZATIONS")
    print("="*80)
    
    plot_model_comparison(baseline_metrics, watermarked_metrics, class_names)
    plot_watermark_verification(baseline_verification, watermarked_verification)
    plot_watermark_statistics(watermark_stats)
    
    save_comprehensive_results(baseline_metrics, watermarked_metrics,
                               baseline_verification, watermarked_verification,
                               watermark_stats, class_names)
    
    print("\n" + "="*80)
    print("EXPERIMENT COMPLETED SUCCESSFULLY")
    print("="*80)
    print("\nGenerated Files:")
    print("\n  Trained Models:")
    print("    • baseline_best.pth")
    print("    • watermarked_best.pth")
    print("\n  Baseline Model Plots:")
    print("    • baseline_training_history.png")
    print("    • baseline_confusion_matrix.png")
    print("    • baseline_confusion_matrix_normalized.png")
    print("    • baseline_per_class_metrics.png")
    print("    • baseline_overall_metrics.png")
    print("    • baseline_roc_curves.png")
    print("\n  Watermarked Model Plots:")
    print("    • watermarked_training_history.png")
    print("    • watermarked_confusion_matrix.png")
    print("    • watermarked_confusion_matrix_normalized.png")
    print("    • watermarked_per_class_metrics.png")
    print("    • watermarked_overall_metrics.png")
    print("    • watermarked_roc_curves.png")
    print("\n  Comparison & Analysis Plots:")
    print("    • model_comparison_comprehensive.png")
    print("    • watermark_verification.png")
    print("    • watermark_statistics.png")
    print("\n  Results Data:")
    print("    • fixed_watermark_results.json")
    print("\n  Total: 17 PNG files + 1 JSON file + 2 model files")
    print("="*80 + "\n")

if __name__ == "__main__":
    main()

Using device: cuda

FIXED WEIGHT WATERMARKING FOR MODEL ATTRIBUTION
Flower Recognition Dataset - Enhanced ResNet Architecture

Loading Flower Dataset...

Dataset Distribution:
----------------------------------------
  daisy       :  764 images
  dandelion   : 1052 images
  rose        :  784 images
  sunflower   :  733 images
  tulip       :  984 images
----------------------------------------
  Total:        4317 images
----------------------------------------

Data Split:
  Train:      3021 images
  Validation: 648 images
  Test:       648 images

EXPERIMENT 1: BASELINE MODEL (No Watermark)


Epoch 1/10 [baseline]: 100%|██████████| 95/95 [00:27<00:00,  3.39it/s, loss=1.46, acc=35]   


Epoch 1: Train Loss: 1.4577, Train Acc: 34.96%, Val Loss: 1.4902, Val Acc: 42.75%, LR: 0.001000


Epoch 2/10 [baseline]: 100%|██████████| 95/95 [00:22<00:00,  4.14it/s, loss=1.39, acc=40.1] 


Epoch 2: Train Loss: 1.3875, Train Acc: 40.09%, Val Loss: 1.3553, Val Acc: 41.36%, LR: 0.001000


Epoch 3/10 [baseline]: 100%|██████████| 95/95 [00:23<00:00,  4.07it/s, loss=1.34, acc=42.7] 


Epoch 3: Train Loss: 1.3370, Train Acc: 42.67%, Val Loss: 1.3165, Val Acc: 40.43%, LR: 0.001000


Epoch 4/10 [baseline]: 100%|██████████| 95/95 [00:24<00:00,  3.96it/s, loss=1.31, acc=44.5] 


Epoch 4: Train Loss: 1.3117, Train Acc: 44.49%, Val Loss: 1.1838, Val Acc: 45.22%, LR: 0.001000


Epoch 5/10 [baseline]: 100%|██████████| 95/95 [00:24<00:00,  3.88it/s, loss=1.26, acc=45.2] 


Epoch 5: Train Loss: 1.2611, Train Acc: 45.15%, Val Loss: 1.1700, Val Acc: 52.16%, LR: 0.001000


Epoch 6/10 [baseline]: 100%|██████████| 95/95 [00:25<00:00,  3.78it/s, loss=1.23, acc=49.5] 


Epoch 6: Train Loss: 1.2341, Train Acc: 49.49%, Val Loss: 1.1738, Val Acc: 53.70%, LR: 0.001000


Epoch 7/10 [baseline]: 100%|██████████| 95/95 [00:25<00:00,  3.75it/s, loss=1.22, acc=48.6] 


Epoch 7: Train Loss: 1.2246, Train Acc: 48.59%, Val Loss: 1.1557, Val Acc: 53.40%, LR: 0.001000


Epoch 8/10 [baseline]: 100%|██████████| 95/95 [00:26<00:00,  3.64it/s, loss=1.18, acc=52.1] 


Epoch 8: Train Loss: 1.1798, Train Acc: 52.07%, Val Loss: 1.0650, Val Acc: 55.86%, LR: 0.001000


Epoch 9/10 [baseline]: 100%|██████████| 95/95 [00:26<00:00,  3.57it/s, loss=1.15, acc=54.7] 


Epoch 9: Train Loss: 1.1515, Train Acc: 54.65%, Val Loss: 1.0291, Val Acc: 58.95%, LR: 0.001000


Epoch 10/10 [baseline]: 100%|██████████| 95/95 [00:26<00:00,  3.53it/s, loss=1.11, acc=55.8]  


Epoch 10: Train Loss: 1.1055, Train Acc: 55.84%, Val Loss: 1.0950, Val Acc: 54.78%, LR: 0.001000

Evaluating Baseline Model...


Evaluating: 100%|██████████| 21/21 [00:03<00:00,  5.46it/s]



              precision    recall  f1-score   support

       daisy     0.4811    0.7739    0.5933       115
   dandelion     0.6707    0.6962    0.6832       158
        rose     0.4286    0.5128    0.4669       117
   sunflower     0.6917    0.7545    0.7217       110
       tulip     0.7436    0.1959    0.3102       148

    accuracy                         0.5725       648
   macro avg     0.6031    0.5867    0.5551       648
weighted avg     0.6135    0.5725    0.5496       648


Generating Baseline Model Visualizations...

Verifying Baseline Model (Should Not Detect Watermark)...

Watermark Configuration:
  Fingerprint: FlowerCNN_FixedWatermark_2025_SecureOwnership
  Freeze Ratio: 0.0020
  Total Frozen Weights: 90474
  Frozen Layers: 45
  Match Ratio: 0.0795
  Detected: NO ✓

EXPERIMENT 2: WATERMARKED MODEL (Fixed Weight Watermark)
  Fingerprint: FlowerCNN_FixedWatermark_2025_SecureOwnership
  Freeze Ratio: 0.0020

Watermark Configuration:
  Fingerprint: FlowerCNN_FixedWatermark

Epoch 1/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.51it/s, loss=1.49, acc=33.8] 


Epoch 1: Train Loss: 1.4928, Train Acc: 33.80%, Val Loss: 1.2358, Val Acc: 45.99%, LR: 0.001000


Epoch 2/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.43it/s, loss=1.37, acc=40.1] 


Epoch 2: Train Loss: 1.3667, Train Acc: 40.09%, Val Loss: 1.2615, Val Acc: 44.14%, LR: 0.001000


Epoch 3/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.49it/s, loss=1.33, acc=41.9] 


Epoch 3: Train Loss: 1.3335, Train Acc: 41.87%, Val Loss: 1.2700, Val Acc: 42.90%, LR: 0.001000


Epoch 4/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.47it/s, loss=1.29, acc=42.8] 


Epoch 4: Train Loss: 1.2862, Train Acc: 42.83%, Val Loss: 1.1347, Val Acc: 48.46%, LR: 0.001000


Epoch 5/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.45it/s, loss=1.29, acc=43.1] 


Epoch 5: Train Loss: 1.2945, Train Acc: 43.13%, Val Loss: 1.1642, Val Acc: 47.84%, LR: 0.001000


Epoch 6/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.46it/s, loss=1.26, acc=45]   


Epoch 6: Train Loss: 1.2572, Train Acc: 44.95%, Val Loss: 1.1389, Val Acc: 48.77%, LR: 0.001000


Epoch 7/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.46it/s, loss=1.24, acc=45.3] 


Epoch 7: Train Loss: 1.2383, Train Acc: 45.35%, Val Loss: 1.1933, Val Acc: 47.22%, LR: 0.001000


Epoch 8/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.45it/s, loss=1.22, acc=48]   


Epoch 8: Train Loss: 1.2229, Train Acc: 48.03%, Val Loss: 1.2506, Val Acc: 46.76%, LR: 0.001000


Epoch 9/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.46it/s, loss=1.16, acc=52.1] 


Epoch 9: Train Loss: 1.1630, Train Acc: 52.14%, Val Loss: 1.0389, Val Acc: 54.94%, LR: 0.000500


Epoch 10/10 [watermarked]: 100%|██████████| 95/95 [00:27<00:00,  3.48it/s, loss=1.16, acc=51.4] 


Epoch 10: Train Loss: 1.1565, Train Acc: 51.37%, Val Loss: 1.0799, Val Acc: 55.40%, LR: 0.000500

Evaluating Watermarked Model...


Evaluating: 100%|██████████| 21/21 [00:02<00:00, 10.04it/s]



              precision    recall  f1-score   support

       daisy     0.5600    0.3652    0.4421       115
   dandelion     0.5888    0.7342    0.6535       158
        rose     0.4677    0.2479    0.3240       117
   sunflower     0.5163    0.8636    0.6463       110
       tulip     0.6231    0.5473    0.5827       148

    accuracy                         0.5602       648
   macro avg     0.5512    0.5516    0.5297       648
weighted avg     0.5574    0.5602    0.5391       648


Generating Watermarked Model Visualizations...

Verifying Watermarked Model...
  Match Ratio: 1.0000
  Total Frozen Weights: 90474
  Detected: YES ✓

GENERATING COMPARISON VISUALIZATIONS

COMPREHENSIVE EXPERIMENTAL RESULTS

--------------------------------------------------------------------------------
BASELINE MODEL (No Watermark)
--------------------------------------------------------------------------------
  Overall Performance:
    • Accuracy:  0.5725
    • Precision: 0.6135
    • Recall:    0.572