In [2]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import hashlib
import json
from scipy.stats import pearsonr
from scipy.spatial.distance import cosine
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

class ButterflyDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.labels = sorted(df['label'].unique())
        self.label_to_idx = {label: idx for idx, label in enumerate(self.labels)}
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = self.df.loc[idx, 'filename']
        label = self.df.loc[idx, 'label']
        img_path = os.path.join(self.img_dir, img_name)
        
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        label_idx = self.label_to_idx[label]
        
        return image, label_idx

class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 14 * 14, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class FixedWeightWatermark:
    def __init__(self, model, fingerprint, freeze_ratio=0.002):
        self.model = model
        self.fingerprint = fingerprint
        self.freeze_ratio = freeze_ratio
        self.frozen_params = {}
        self.frozen_values = {}
        
    def generate_frozen_pattern(self):
        hash_object = hashlib.sha256(self.fingerprint.encode())
        hex_dig = hash_object.hexdigest()
        seed = int(hex_dig[:8], 16)
        rng = np.random.RandomState(seed)
        
        all_params = []
        for name, param in self.model.named_parameters():
            if 'weight' in name and 'features' in name:
                all_params.append((name, param))
        
        for name, param in all_params:
            num_weights = param.numel()
            num_freeze = max(1, int(num_weights * self.freeze_ratio))
            
            indices = rng.choice(num_weights, size=num_freeze, replace=False)
            
            values = rng.randn(num_freeze) * 0.1
            
            self.frozen_params[name] = indices
            self.frozen_values[name] = values
    
    def apply_frozen_weights(self):
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.frozen_params:
                    indices = self.frozen_params[name]
                    values = self.frozen_values[name]
                    
                    param_flat = param.data.view(-1)
                    param_flat[indices] = torch.FloatTensor(values).to(param.device)
    
    def freeze_gradient_hook(self):
        def hook_fn(name):
            def hook(grad):
                if name in self.frozen_params:
                    indices = self.frozen_params[name]
                    grad_flat = grad.view(-1)
                    grad_flat[indices] = 0.0
                return grad
            return hook
        
        for name, param in self.model.named_parameters():
            if name in self.frozen_params:
                param.register_hook(hook_fn(name))
    
    def verify_frozen_weights(self):
        match_score = 0.0
        total_frozen = 0
        
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.frozen_params:
                    indices = self.frozen_params[name]
                    expected_values = torch.FloatTensor(self.frozen_values[name]).to(param.device)
                    
                    param_flat = param.data.view(-1)
                    actual_values = param_flat[indices]
                    
                    differences = torch.abs(actual_values - expected_values)
                    matches = (differences < 0.01).sum().item()
                    
                    match_score += matches
                    total_frozen += len(indices)
        
        match_ratio = match_score / total_frozen if total_frozen > 0 else 0.0
        return match_ratio, total_frozen

class CausalWatermarkInjector:
    def __init__(self, fingerprint, target_layers, lambda_factor=0.001):
        self.fingerprint = fingerprint
        self.lambda_factor = lambda_factor
        self.target_layers = target_layers
        self.fingerprint_vector = self._generate_fingerprint_vector()
        
    def _generate_fingerprint_vector(self):
        hash_object = hashlib.sha256(self.fingerprint.encode())
        hex_dig = hash_object.hexdigest()
        
        random_gen = np.random.RandomState(int(hex_dig[:8], 16))
        vector = random_gen.randn(1000)
        vector = vector / np.linalg.norm(vector)
        
        return torch.FloatTensor(vector).to(device)
    
    def inject_perturbation(self, model):
        for name, param in model.named_parameters():
            if any(layer_name in name for layer_name in self.target_layers):
                if param.grad is not None:
                    grad_flat = param.grad.view(-1)
                    perturbation_size = min(len(grad_flat), len(self.fingerprint_vector))
                    
                    perturbation = self.fingerprint_vector[:perturbation_size]
                    perturbation = perturbation * self.lambda_factor * torch.norm(grad_flat) / torch.norm(perturbation)
                    
                    param.grad.view(-1)[:perturbation_size] += perturbation

class HybridWatermark:
    def __init__(self, model, fingerprint, freeze_ratio=0.002, lambda_factor=0.001):
        self.fixed_watermark = FixedWeightWatermark(model, fingerprint, freeze_ratio)
        self.causal_injector = CausalWatermarkInjector(
            fingerprint, 
            target_layers=['features.0', 'features.3', 'features.6'],
            lambda_factor=lambda_factor
        )
        
    def initialize(self):
        self.fixed_watermark.generate_frozen_pattern()
        self.fixed_watermark.apply_frozen_weights()
        self.fixed_watermark.freeze_gradient_hook()
    
    def inject_gradient_perturbation(self, model):
        self.causal_injector.inject_perturbation(model)
    
    def maintain_frozen_weights(self):
        self.fixed_watermark.apply_frozen_weights()
    
    def verify_fixed_weights(self):
        return self.fixed_watermark.verify_frozen_weights()

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, 
                watermark_type=None, watermark_handler=None, model_name="model"):
    
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [{model_name}]')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            
            if watermark_type == 'causal' and watermark_handler is not None:
                watermark_handler.inject_perturbation(model)
            elif watermark_type == 'hybrid' and watermark_handler is not None:
                watermark_handler.inject_gradient_perturbation(model)
            
            optimizer.step()
            
            if watermark_type == 'fixed' and watermark_handler is not None:
                watermark_handler.apply_frozen_weights()
            elif watermark_type == 'hybrid' and watermark_handler is not None:
                watermark_handler.maintain_frozen_weights()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({'loss': running_loss/len(pbar), 'acc': 100.*correct/total})
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100. * correct / total
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f'{model_name}_best.pth')
        
        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    history = {
        'train_loss': train_losses,
        'val_loss': val_losses,
        'train_acc': train_accs,
        'val_acc': val_accs
    }
    
    return history

def evaluate_model(model, test_loader, class_names):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Evaluating'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    
    cm = confusion_matrix(all_labels, all_preds)
    
    report = classification_report(all_labels, all_preds, target_names=class_names, digits=4)
    
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs,
        'classification_report': report
    }
    
    return metrics

def plot_training_history(history, model_name):
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o', linewidth=2)
    axes[0].plot(history['val_loss'], label='Validation Loss', marker='s', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title(f'Training and Validation Loss - {model_name}', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=11)
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(history['train_acc'], label='Train Accuracy', marker='o', linewidth=2)
    axes[1].plot(history['val_acc'], label='Validation Accuracy', marker='s', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Accuracy (%)', fontsize=12)
    axes[1].set_title(f'Training and Validation Accuracy - {model_name}', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=11)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{model_name}_training_history.png', dpi=300, bbox_inches='tight')
    plt.close()
    
def plot_confusion_matrix(cm, class_names, model_name):
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'})
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.title(f'Confusion Matrix - {model_name}', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(f'{model_name}_confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_metrics_comparison(metrics, model_name):
    metric_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    metric_values = [metrics['accuracy'], metrics['precision'], 
                     metrics['recall'], metrics['f1_score']]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    bars = ax.bar(metric_names, metric_values, color=['#2ecc71', '#3498db', '#e74c3c', '#f39c12'],
                   edgecolor='black', linewidth=1.5, alpha=0.8)
    
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.4f}',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    ax.set_ylabel('Score', fontsize=12)
    ax.set_title(f'Performance Metrics - {model_name}', fontsize=14, fontweight='bold')
    ax.set_ylim([0, 1.1])
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(f'{model_name}_metrics.png', dpi=300, bbox_inches='tight')
    plt.close()

class CausalWatermarkVerifier:
    def __init__(self, model, num_probe_samples=100):
        self.model = model
        self.num_probe_samples = num_probe_samples
        
    def compute_integrated_gradients(self, inputs, target_class, steps=50):
        self.model.eval()
        
        baseline = torch.zeros_like(inputs)
        
        scaled_inputs = torch.stack([baseline + (float(i) / steps) * (inputs - baseline) 
                                      for i in range(steps + 1)], dim=0)
        
        scaled_inputs = scaled_inputs.view(-1, *inputs.shape[1:])
        scaled_inputs.requires_grad = True
        
        outputs = self.model(scaled_inputs)
        target_outputs = outputs[:, target_class]
        
        gradients = torch.autograd.grad(outputs=target_outputs.sum(), 
                                         inputs=scaled_inputs,
                                         create_graph=False)[0]
        
        avg_gradients = gradients.mean(dim=0)
        integrated_grads = (inputs - baseline) * avg_gradients
        
        return integrated_grads
    
    def extract_attribution_signature(self, data_loader):
        self.model.eval()
        
        all_attributions = []
        samples_processed = 0
        
        for inputs, labels in data_loader:
            if samples_processed >= self.num_probe_samples:
                break
                
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = self.model(inputs)
            _, predicted = outputs.max(1)
            
            for i in range(inputs.size(0)):
                if samples_processed >= self.num_probe_samples:
                    break
                    
                single_input = inputs[i:i+1]
                target_class = predicted[i].item()
                
                attributions = self.compute_integrated_gradients(single_input, target_class)
                attribution_flat = attributions.cpu().detach().numpy().flatten()
                
                all_attributions.append(attribution_flat)
                samples_processed += 1
        
        all_attributions = np.array(all_attributions)
        
        mean_attribution = np.mean(all_attributions, axis=0)
        std_attribution = np.std(all_attributions, axis=0)
        
        top_k = 1000
        top_indices = np.argsort(np.abs(mean_attribution))[-top_k:]
        signature_vector = mean_attribution[top_indices]
        signature_vector = signature_vector / (np.linalg.norm(signature_vector) + 1e-8)
        
        signature = {
            'mean_attribution': mean_attribution,
            'std_attribution': std_attribution,
            'signature_vector': signature_vector,
            'top_indices': top_indices
        }
        
        return signature
    
    def verify_fingerprint(self, fingerprint, data_loader, fixed_watermark=None):
        print(f"\n{'='*60}")
        print(f"WATERMARK VERIFICATION")
        print(f"{'='*60}")
        print(f"Target Fingerprint: {fingerprint}")
        print(f"Probe Samples: {self.num_probe_samples}")
        print(f"{'-'*60}")
        
        injector = CausalWatermarkInjector(fingerprint, target_layers=['features'])
        expected_vector = injector.fingerprint_vector.cpu().numpy()
        
        signature = self.extract_attribution_signature(data_loader)
        observed_vector = signature['signature_vector']
        
        expected_normalized = expected_vector[:len(observed_vector)]
        expected_normalized = expected_normalized / (np.linalg.norm(expected_normalized) + 1e-8)
        
        cosine_sim = 1 - cosine(expected_normalized, observed_vector)
        pearson_corr, p_value = pearsonr(expected_normalized, observed_vector)
        
        l2_distance = np.linalg.norm(expected_normalized - observed_vector)
        
        dot_product = np.dot(expected_normalized, observed_vector)
        
        causal_threshold = 0.15
        causal_detected = cosine_sim > causal_threshold
        
        fixed_match_ratio = 0.0
        fixed_total = 0
        fixed_detected = False
        
        if fixed_watermark is not None:
            fixed_match_ratio, fixed_total = fixed_watermark.verify_frozen_weights()
            fixed_threshold = 0.95
            fixed_detected = fixed_match_ratio > fixed_threshold
            
            print(f"\nFixed Weight Verification:")
            print(f"  • Total Frozen Weights:     {fixed_total}")
            print(f"  • Match Ratio:              {fixed_match_ratio:.6f}")
            print(f"  • Detection Threshold:      {fixed_threshold:.6f}")
            print(f"  • Fixed Weights Detected:   {'YES ✓' if fixed_detected else 'NO ✗'}")
            print(f"{'-'*60}")
        
        print(f"\nCausal Attribution Verification:")
        print(f"  • Cosine Similarity:        {cosine_sim:.6f}")
        print(f"  • Pearson Correlation:      {pearson_corr:.6f} (p-value: {p_value:.2e})")
        print(f"  • L2 Distance:              {l2_distance:.6f}")
        print(f"  • Dot Product:              {dot_product:.6f}")
        print(f"  • Detection Threshold:      {causal_threshold:.6f}")
        print(f"  • Causal Pattern Detected:  {'YES ✓' if causal_detected else 'NO ✗'}")
        print(f"{'-'*60}")
        
        overall_detected = causal_detected or fixed_detected
        print(f"Overall Watermark Status: {'DETECTED ✓' if overall_detected else 'NOT DETECTED ✗'}")
        print(f"{'='*60}\n")
        
        verification_result = {
            'fingerprint': fingerprint,
            'causal_cosine_similarity': float(cosine_sim),
            'causal_pearson_correlation': float(pearson_corr),
            'causal_p_value': float(p_value),
            'causal_l2_distance': float(l2_distance),
            'causal_dot_product': float(dot_product),
            'causal_detected': bool(causal_detected),
            'fixed_match_ratio': float(fixed_match_ratio),
            'fixed_total_weights': int(fixed_total),
            'fixed_detected': bool(fixed_detected),
            'overall_detected': bool(overall_detected)
        }
        
        return verification_result

def plot_all_models_comparison(baseline_metrics, fixed_metrics, hybrid_metrics):
    models = ['Baseline', 'Fixed Weights', 'Hybrid']
    accuracy = [baseline_metrics['accuracy'], fixed_metrics['accuracy'], hybrid_metrics['accuracy']]
    precision = [baseline_metrics['precision'], fixed_metrics['precision'], hybrid_metrics['precision']]
    recall = [baseline_metrics['recall'], fixed_metrics['recall'], hybrid_metrics['recall']]
    f1 = [baseline_metrics['f1_score'], fixed_metrics['f1_score'], hybrid_metrics['f1_score']]
    
    x = np.arange(len(models))
    width = 0.2
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    bars1 = ax.bar(x - 1.5*width, accuracy, width, label='Accuracy', color='#2ecc71', edgecolor='black', alpha=0.8)
    bars2 = ax.bar(x - 0.5*width, precision, width, label='Precision', color='#3498db', edgecolor='black', alpha=0.8)
    bars3 = ax.bar(x + 0.5*width, recall, width, label='Recall', color='#e74c3c', edgecolor='black', alpha=0.8)
    bars4 = ax.bar(x + 1.5*width, f1, width, label='F1-Score', color='#f39c12', edgecolor='black', alpha=0.8)
    
    for bars in [bars1, bars2, bars3, bars4]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    ax.set_ylabel('Score', fontsize=12)
    ax.set_title('Performance Comparison Across All Models', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(models, fontsize=11)
    ax.legend(fontsize=11)
    ax.set_ylim([0, 1.1])
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('all_models_performance_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_verification_comparison(baseline_ver, fixed_ver, hybrid_ver):
    models = ['Baseline', 'Fixed\nWeights', 'Hybrid']
    causal_sim = [baseline_ver['causal_cosine_similarity'], 
                  fixed_ver['causal_cosine_similarity'], 
                  hybrid_ver['causal_cosine_similarity']]
    fixed_match = [0, fixed_ver['fixed_match_ratio'], hybrid_ver['fixed_match_ratio']]
    
    x = np.arange(len(models))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    bars1 = ax.bar(x - width/2, causal_sim, width, label='Causal Similarity', 
                   color='#3498db', edgecolor='black', linewidth=1.5, alpha=0.8)
    bars2 = ax.bar(x + width/2, fixed_match, width, label='Fixed Weight Match', 
                   color='#e74c3c', edgecolor='black', linewidth=1.5, alpha=0.8)
    
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            if height > 0:
                ax.text(bar.get_x() + bar.get_width()/2., height,
                        f'{height:.4f}',
                        ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    ax.axhline(y=0.15, color='green', linestyle='--', linewidth=2, 
               label='Causal Threshold (0.15)', alpha=0.7)
    ax.axhline(y=0.95, color='orange', linestyle='--', linewidth=2, 
               label='Fixed Threshold (0.95)', alpha=0.7)
    
    ax.set_ylabel('Score', fontsize=12)
    ax.set_title('Watermark Verification Comparison', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(models, fontsize=11)
    ax.legend(fontsize=10)
    ax.set_ylim([0, 1.1])
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('watermark_verification_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_performance_impact(baseline_metrics, fixed_metrics, hybrid_metrics):
    metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
    baseline_vals = [baseline_metrics['accuracy'], baseline_metrics['precision'], 
                     baseline_metrics['recall'], baseline_metrics['f1_score']]
    fixed_drops = [(baseline_metrics['accuracy'] - fixed_metrics['accuracy']) * 100,
                   (baseline_metrics['precision'] - fixed_metrics['precision']) * 100,
                   (baseline_metrics['recall'] - fixed_metrics['recall']) * 100,
                   (baseline_metrics['f1_score'] - fixed_metrics['f1_score']) * 100]
    hybrid_drops = [(baseline_metrics['accuracy'] - hybrid_metrics['accuracy']) * 100,
                    (baseline_metrics['precision'] - hybrid_metrics['precision']) * 100,
                    (baseline_metrics['recall'] - hybrid_metrics['recall']) * 100,
                    (baseline_metrics['f1_score'] - hybrid_metrics['f1_score']) * 100]
    
    x = np.arange(len(metrics))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    bars1 = ax.bar(x - width/2, fixed_drops, width, label='Fixed Weights Impact', 
                   color='#e74c3c', edgecolor='black', linewidth=1.5, alpha=0.8)
    bars2 = ax.bar(x + width/2, hybrid_drops, width, label='Hybrid Impact', 
                   color='#f39c12', edgecolor='black', linewidth=1.5, alpha=0.8)
    
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2f}%',
                    ha='center', va='bottom' if height >= 0 else 'top', 
                    fontsize=10, fontweight='bold')
    
    ax.axhline(y=0, color='black', linestyle='-', linewidth=1)
    ax.axhline(y=1, color='green', linestyle='--', linewidth=2, 
               label='Target Max Drop (1%)', alpha=0.7)
    
    ax.set_ylabel('Performance Drop (%)', fontsize=12)
    ax.set_title('Watermarking Performance Impact', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics, fontsize=11)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('watermarking_performance_impact.png', dpi=300, bbox_inches='tight')
    plt.close()

def save_results_to_file(baseline_metrics, fixed_metrics, hybrid_metrics,
                         baseline_ver, fixed_ver, hybrid_ver):
    results = {
        'Baseline_Model': {
            'Performance': {
                'Accuracy': float(baseline_metrics['accuracy']),
                'Precision': float(baseline_metrics['precision']),
                'Recall': float(baseline_metrics['recall']),
                'F1_Score': float(baseline_metrics['f1_score'])
            },
            'Verification': {
                'Causal_Cosine_Similarity': float(baseline_ver['causal_cosine_similarity']),
                'Causal_Detected': bool(baseline_ver['causal_detected']),
                'Fixed_Detected': bool(baseline_ver['fixed_detected']),
                'Overall_Detected': bool(baseline_ver['overall_detected'])
            }
        },
        'Fixed_Weight_Model': {
            'Performance': {
                'Accuracy': float(fixed_metrics['accuracy']),
                'Precision': float(fixed_metrics['precision']),
                'Recall': float(fixed_metrics['recall']),
                'F1_Score': float(fixed_metrics['f1_score'])
            },
            'Verification': {
                'Causal_Cosine_Similarity': float(fixed_ver['causal_cosine_similarity']),
                'Fixed_Match_Ratio': float(fixed_ver['fixed_match_ratio']),
                'Fixed_Total_Weights': int(fixed_ver['fixed_total_weights']),
                'Causal_Detected': bool(fixed_ver['causal_detected']),
                'Fixed_Detected': bool(fixed_ver['fixed_detected']),
                'Overall_Detected': bool(fixed_ver['overall_detected'])
            }
        },
        'Hybrid_Model': {
            'Performance': {
                'Accuracy': float(hybrid_metrics['accuracy']),
                'Precision': float(hybrid_metrics['precision']),
                'Recall': float(hybrid_metrics['recall']),
                'F1_Score': float(hybrid_metrics['f1_score'])
            },
            'Verification': {
                'Causal_Cosine_Similarity': float(hybrid_ver['causal_cosine_similarity']),
                'Fixed_Match_Ratio': float(hybrid_ver['fixed_match_ratio']),
                'Fixed_Total_Weights': int(hybrid_ver['fixed_total_weights']),
                'Causal_Detected': bool(hybrid_ver['causal_detected']),
                'Fixed_Detected': bool(hybrid_ver['fixed_detected']),
                'Overall_Detected': bool(hybrid_ver['overall_detected'])
            }
        },
        'Performance_Impact': {
            'Fixed_Weight_Model': {
                'Accuracy_Drop': float(baseline_metrics['accuracy'] - fixed_metrics['accuracy']),
                'Precision_Drop': float(baseline_metrics['precision'] - fixed_metrics['precision']),
                'Recall_Drop': float(baseline_metrics['recall'] - fixed_metrics['recall']),
                'F1_Score_Drop': float(baseline_metrics['f1_score'] - fixed_metrics['f1_score'])
            },
            'Hybrid_Model': {
                'Accuracy_Drop': float(baseline_metrics['accuracy'] - hybrid_metrics['accuracy']),
                'Precision_Drop': float(baseline_metrics['precision'] - hybrid_metrics['precision']),
                'Recall_Drop': float(baseline_metrics['recall'] - hybrid_metrics['recall']),
                'F1_Score_Drop': float(baseline_metrics['f1_score'] - hybrid_metrics['f1_score'])
            }
        }
    }
    
    with open('cwma_experiment_results.json', 'w') as f:
        json.dump(results, f, indent=4)
    
    print("\n" + "="*80)
    print("COMPREHENSIVE EXPERIMENT RESULTS")
    print("="*80)
    
    print("\n" + "-"*80)
    print("BASELINE MODEL (No Watermark)")
    print("-"*80)
    print(f"  Accuracy:  {baseline_metrics['accuracy']:.4f}")
    print(f"  Precision: {baseline_metrics['precision']:.4f}")
    print(f"  Recall:    {baseline_metrics['recall']:.4f}")
    print(f"  F1-Score:  {baseline_metrics['f1_score']:.4f}")
    print(f"  Watermark Detected: {baseline_ver['overall_detected']}")
    
    print("\n" + "-"*80)
    print("FIXED WEIGHT WATERMARKED MODEL")
    print("-"*80)
    print(f"  Accuracy:  {fixed_metrics['accuracy']:.4f}")
    print(f"  Precision: {fixed_metrics['precision']:.4f}")
    print(f"  Recall:    {fixed_metrics['recall']:.4f}")
    print(f"  F1-Score:  {fixed_metrics['f1_score']:.4f}")
    print(f"  Fixed Weight Match: {fixed_ver['fixed_match_ratio']:.4f} ({fixed_ver['fixed_total_weights']} weights)")
    print(f"  Watermark Detected: {fixed_ver['overall_detected']}")
    print(f"  Performance Drop: {(baseline_metrics['accuracy'] - fixed_metrics['accuracy'])*100:.2f}%")
    
    print("\n" + "-"*80)
    print("HYBRID WATERMARKED MODEL (Fixed + Causal)")
    print("-"*80)
    print(f"  Accuracy:  {hybrid_metrics['accuracy']:.4f}")
    print(f"  Precision: {hybrid_metrics['precision']:.4f}")
    print(f"  Recall:    {hybrid_metrics['recall']:.4f}")
    print(f"  F1-Score:  {hybrid_metrics['f1_score']:.4f}")
    print(f"  Fixed Weight Match: {hybrid_ver['fixed_match_ratio']:.4f} ({hybrid_ver['fixed_total_weights']} weights)")
    print(f"  Causal Similarity: {hybrid_ver['causal_cosine_similarity']:.4f}")
    print(f"  Watermark Detected: {hybrid_ver['overall_detected']}")
    print(f"  Performance Drop: {(baseline_metrics['accuracy'] - hybrid_metrics['accuracy'])*100:.2f}%")
    
    print("\n" + "="*80)
    print("KEY FINDINGS")
    print("="*80)
    print(f"  • Baseline Model: {'NOT watermarked' if not baseline_ver['overall_detected'] else 'Watermarked (unexpected)'}")
    print(f"  • Fixed Weight Model: {'Successfully watermarked' if fixed_ver['fixed_detected'] else 'Watermark not detected'}")
    print(f"  • Hybrid Model: {'Successfully watermarked' if hybrid_ver['overall_detected'] else 'Watermark not detected'}")
    print(f"  • Hybrid uses DUAL verification (Fixed={hybrid_ver['fixed_detected']}, Causal={hybrid_ver['causal_detected']})")
    print("="*80 + "\n")

def main():
    print("\n" + "="*80)
    print("CAUSAL WATERMARKING FOR MODEL ATTRIBUTION (CWMA)")
    print("Hybrid Approach: Fixed Weights + Gradient Perturbation")
    print("="*80 + "\n")
    
    IMG_DIR = '/kaggle/input/butterfly-image-classification/train'
    CSV_PATH = '/kaggle/input/butterfly-image-classification/Training_set.csv'
    
    FINGERPRINT = "CWMA2025_HybridWatermark_SecureOwnership"
    
    IMG_SIZE = 224
    BATCH_SIZE = 32
    NUM_EPOCHS = 15
    LEARNING_RATE = 0.001
    
    print("Loading and preparing data...")
    df = pd.read_csv(CSV_PATH)
    print(f"Total samples: {len(df)}")
    print(f"Number of classes: {df['label'].nunique()}")
    
    train_df, temp_df = train_test_split(df, test_size=0.3, stratify=df['label'], random_state=RANDOM_SEED)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['label'], random_state=RANDOM_SEED)
    
    print(f"Train samples: {len(train_df)}")
    print(f"Validation samples: {len(val_df)}")
    print(f"Test samples: {len(test_df)}")
    
    transform_train = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    transform_test = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    train_dataset = ButterflyDataset(train_df, IMG_DIR, transform=transform_train)
    val_dataset = ButterflyDataset(val_df, IMG_DIR, transform=transform_test)
    test_dataset = ButterflyDataset(test_df, IMG_DIR, transform=transform_test)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    num_classes = len(train_dataset.labels)
    class_names = train_dataset.labels
    
    print("\n" + "="*80)
    print("EXPERIMENT 1: BASELINE MODEL (No Watermark)")
    print("="*80)
    
    baseline_model = SimpleCNN(num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer_baseline = optim.Adam(baseline_model.parameters(), lr=LEARNING_RATE)
    
    baseline_history = train_model(
        baseline_model, train_loader, val_loader, criterion, 
        optimizer_baseline, NUM_EPOCHS, watermark_type=None, 
        watermark_handler=None, model_name="baseline"
    )
    
    baseline_model.load_state_dict(torch.load('baseline_best.pth'))
    
    print("\nEvaluating baseline model...")
    baseline_metrics = evaluate_model(baseline_model, test_loader, class_names)
    print(baseline_metrics['classification_report'])
    
    plot_training_history(baseline_history, "Baseline")
    plot_confusion_matrix(baseline_metrics['confusion_matrix'], class_names, "Baseline")
    plot_metrics_comparison(baseline_metrics, "Baseline")
    
    print("\nVerifying baseline model...")
    baseline_verifier = CausalWatermarkVerifier(baseline_model, num_probe_samples=100)
    baseline_verification = baseline_verifier.verify_fingerprint(FINGERPRINT, test_loader, fixed_watermark=None)
    
    print("\n" + "="*80)
    print("EXPERIMENT 2: FIXED WEIGHT WATERMARKED MODEL")
    print("="*80)
    print(f"Embedding Fingerprint: {FINGERPRINT}")
    print(f"Method: Freeze 0.2% of weights based on fingerprint hash\n")
    
    fixed_model = SimpleCNN(num_classes).to(device)
    fixed_watermark = FixedWeightWatermark(fixed_model, FINGERPRINT, freeze_ratio=0.002)
    fixed_watermark.generate_frozen_pattern()
    fixed_watermark.apply_frozen_weights()
    fixed_watermark.freeze_gradient_hook()
    
    criterion = nn.CrossEntropyLoss()
    optimizer_fixed = optim.Adam(fixed_model.parameters(), lr=LEARNING_RATE)
    
    fixed_history = train_model(
        fixed_model, train_loader, val_loader, criterion,
        optimizer_fixed, NUM_EPOCHS, watermark_type='fixed',
        watermark_handler=fixed_watermark, model_name="fixed_weight"
    )
    
    fixed_model.load_state_dict(torch.load('fixed_weight_best.pth'))
    
    print("\nEvaluating fixed weight model...")
    fixed_metrics = evaluate_model(fixed_model, test_loader, class_names)
    print(fixed_metrics['classification_report'])
    
    plot_training_history(fixed_history, "Fixed_Weight")
    plot_confusion_matrix(fixed_metrics['confusion_matrix'], class_names, "Fixed_Weight")
    plot_metrics_comparison(fixed_metrics, "Fixed_Weight")
    
    print("\nVerifying fixed weight model...")
    fixed_verifier = CausalWatermarkVerifier(fixed_model, num_probe_samples=100)
    fixed_verification = fixed_verifier.verify_fingerprint(FINGERPRINT, test_loader, fixed_watermark=fixed_watermark)
    
    print("\n" + "="*80)
    print("EXPERIMENT 3: HYBRID WATERMARKED MODEL (Fixed + Causal)")
    print("="*80)
    print(f"Embedding Fingerprint: {FINGERPRINT}")
    print(f"Method 1: Freeze 0.2% of weights")
    print(f"Method 2: Gradient perturbation for causal entanglement\n")
    
    hybrid_model = SimpleCNN(num_classes).to(device)
    hybrid_watermark = HybridWatermark(hybrid_model, FINGERPRINT, freeze_ratio=0.002, lambda_factor=0.001)
    hybrid_watermark.initialize()
    
    criterion = nn.CrossEntropyLoss()
    optimizer_hybrid = optim.Adam(hybrid_model.parameters(), lr=LEARNING_RATE)
    
    hybrid_history = train_model(
        hybrid_model, train_loader, val_loader, criterion,
        optimizer_hybrid, NUM_EPOCHS, watermark_type='hybrid',
        watermark_handler=hybrid_watermark, model_name="hybrid"
    )
    
    hybrid_model.load_state_dict(torch.load('hybrid_best.pth'))
    
    print("\nEvaluating hybrid model...")
    hybrid_metrics = evaluate_model(hybrid_model, test_loader, class_names)
    print(hybrid_metrics['classification_report'])
    
    plot_training_history(hybrid_history, "Hybrid")
    plot_confusion_matrix(hybrid_metrics['confusion_matrix'], class_names, "Hybrid")
    plot_metrics_comparison(hybrid_metrics, "Hybrid")
    
    print("\nVerifying hybrid model...")
    hybrid_verifier = CausalWatermarkVerifier(hybrid_model, num_probe_samples=100)
    hybrid_verification = hybrid_verifier.verify_fingerprint(FINGERPRINT, test_loader, 
                                                              fixed_watermark=hybrid_watermark.fixed_watermark)
    
    print("\n" + "="*80)
    print("GENERATING COMPARISON VISUALIZATIONS")
    print("="*80)
    
    plot_all_models_comparison(baseline_metrics, fixed_metrics, hybrid_metrics)
    plot_verification_comparison(baseline_verification, fixed_verification, hybrid_verification)
    plot_performance_impact(baseline_metrics, fixed_metrics, hybrid_metrics)
    
    save_results_to_file(baseline_metrics, fixed_metrics, hybrid_metrics,
                         baseline_verification, fixed_verification, hybrid_verification)
    
    print("\n" + "="*80)
    print("EXPERIMENT COMPLETED SUCCESSFULLY")
    print("="*80)
    print("\nGenerated Files:")
    print("  Models:")
    print("    • baseline_best.pth")
    print("    • fixed_weight_best.pth")
    print("    • hybrid_best.pth")
    print("\n  Individual Model Plots:")
    print("    • baseline_training_history.png")
    print("    • baseline_confusion_matrix.png")
    print("    • baseline_metrics.png")
    print("    • fixed_weight_training_history.png")
    print("    • fixed_weight_confusion_matrix.png")
    print("    • fixed_weight_metrics.png")
    print("    • hybrid_training_history.png")
    print("    • hybrid_confusion_matrix.png")
    print("    • hybrid_metrics.png")
    print("\n  Comparison Plots:")
    print("    • all_models_performance_comparison.png")
    print("    • watermark_verification_comparison.png")
    print("    • watermarking_performance_impact.png")
    print("\n  Results:")
    print("    • cwma_experiment_results.json")
    print("="*80 + "\n")

if __name__ == "__main__":
    main()

Using device: cuda

CAUSAL WATERMARKING FOR MODEL ATTRIBUTION (CWMA)
Hybrid Approach: Fixed Weights + Gradient Perturbation

Loading and preparing data...
Total samples: 6499
Number of classes: 75
Train samples: 4549
Validation samples: 975
Test samples: 975

EXPERIMENT 1: BASELINE MODEL (No Watermark)


Epoch 1/15 [baseline]: 100%|██████████| 143/143 [00:12<00:00, 11.19it/s, loss=4.08, acc=4.07] 


Epoch 1: Train Loss: 4.0830, Train Acc: 4.07%, Val Loss: 3.6363, Val Acc: 9.95%


Epoch 2/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.52it/s, loss=3.45, acc=13.4]


Epoch 2: Train Loss: 3.4547, Train Acc: 13.43%, Val Loss: 3.0936, Val Acc: 20.31%


Epoch 3/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.57it/s, loss=2.88, acc=24.2]


Epoch 3: Train Loss: 2.8777, Train Acc: 24.20%, Val Loss: 2.3461, Val Acc: 35.90%


Epoch 4/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.13it/s, loss=2.58, acc=30.1]


Epoch 4: Train Loss: 2.5802, Train Acc: 30.05%, Val Loss: 2.1728, Val Acc: 42.56%


Epoch 5/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.40it/s, loss=2.32, acc=35.3]


Epoch 5: Train Loss: 2.3242, Train Acc: 35.28%, Val Loss: 1.9684, Val Acc: 45.03%


Epoch 6/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.46it/s, loss=2.12, acc=41.5]


Epoch 6: Train Loss: 2.1181, Train Acc: 41.50%, Val Loss: 1.9291, Val Acc: 45.85%


Epoch 7/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.41it/s, loss=1.95, acc=45.6]


Epoch 7: Train Loss: 1.9543, Train Acc: 45.59%, Val Loss: 1.7128, Val Acc: 51.08%


Epoch 8/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.28it/s, loss=1.79, acc=49.3]


Epoch 8: Train Loss: 1.7861, Train Acc: 49.29%, Val Loss: 1.5806, Val Acc: 56.82%


Epoch 9/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.42it/s, loss=1.66, acc=51.7]


Epoch 9: Train Loss: 1.6563, Train Acc: 51.75%, Val Loss: 1.4842, Val Acc: 60.62%


Epoch 10/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.47it/s, loss=1.51, acc=56.5] 


Epoch 10: Train Loss: 1.5084, Train Acc: 56.50%, Val Loss: 1.4002, Val Acc: 60.72%


Epoch 11/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.28it/s, loss=1.42, acc=59.3] 


Epoch 11: Train Loss: 1.4213, Train Acc: 59.29%, Val Loss: 1.3090, Val Acc: 65.33%


Epoch 12/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.60it/s, loss=1.36, acc=59.1] 


Epoch 12: Train Loss: 1.3572, Train Acc: 59.07%, Val Loss: 1.3033, Val Acc: 65.64%


Epoch 13/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.38it/s, loss=1.23, acc=64.7] 


Epoch 13: Train Loss: 1.2307, Train Acc: 64.72%, Val Loss: 1.2598, Val Acc: 66.46%


Epoch 14/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.39it/s, loss=1.21, acc=64.7] 


Epoch 14: Train Loss: 1.2137, Train Acc: 64.65%, Val Loss: 1.2783, Val Acc: 66.97%


Epoch 15/15 [baseline]: 100%|██████████| 143/143 [00:11<00:00, 12.65it/s, loss=1.08, acc=67.2] 


Epoch 15: Train Loss: 1.0790, Train Acc: 67.20%, Val Loss: 1.2167, Val Acc: 69.23%

Evaluating baseline model...


Evaluating: 100%|██████████| 31/31 [00:01<00:00, 15.71it/s]


                           precision    recall  f1-score   support

                   ADONIS     0.7143    0.7692    0.7407        13
AFRICAN GIANT SWALLOWTAIL     0.7500    0.5455    0.6316        11
           AMERICAN SNOOT     0.6154    0.7273    0.6667        11
                    AN 88     1.0000    1.0000    1.0000        13
                  APPOLLO     0.5000    0.6154    0.5517        13
                    ATALA     1.0000    0.8000    0.8889        15
 BANDED ORANGE HELICONIAN     0.8333    0.3333    0.4762        15
           BANDED PEACOCK     0.6429    0.6923    0.6667        13
            BECKERS WHITE     0.5000    0.5000    0.5000        12
         BLACK HAIRSTREAK     0.6364    0.5385    0.5833        13
              BLUE MORPHO     0.5000    0.5455    0.5217        11
        BLUE SPOTTED CROW     0.7778    0.5385    0.6364        13
           BROWN SIPROETA     0.7059    0.8000    0.7500        15
            CABBAGE WHITE     0.5625    0.6429    0.6000     

Epoch 1/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.55it/s, loss=4.27, acc=2]   


Epoch 1: Train Loss: 4.2671, Train Acc: 2.00%, Val Loss: 4.0483, Val Acc: 6.46%


Epoch 2/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.60it/s, loss=3.79, acc=7.98]


Epoch 2: Train Loss: 3.7871, Train Acc: 7.98%, Val Loss: 3.3679, Val Acc: 15.69%


Epoch 3/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.73it/s, loss=3.19, acc=19.9]


Epoch 3: Train Loss: 3.1903, Train Acc: 19.92%, Val Loss: 2.7397, Val Acc: 30.26%


Epoch 4/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.47it/s, loss=2.76, acc=27.9]


Epoch 4: Train Loss: 2.7575, Train Acc: 27.94%, Val Loss: 2.4028, Val Acc: 38.05%


Epoch 5/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.25it/s, loss=2.41, acc=35.9]


Epoch 5: Train Loss: 2.4125, Train Acc: 35.88%, Val Loss: 2.1768, Val Acc: 40.10%


Epoch 6/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.38it/s, loss=2.18, acc=39.3]


Epoch 6: Train Loss: 2.1844, Train Acc: 39.35%, Val Loss: 1.9779, Val Acc: 45.64%


Epoch 7/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.22it/s, loss=2.01, acc=44.9]


Epoch 7: Train Loss: 2.0076, Train Acc: 44.91%, Val Loss: 1.7467, Val Acc: 54.15%


Epoch 8/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.61it/s, loss=1.76, acc=50.1]


Epoch 8: Train Loss: 1.7576, Train Acc: 50.10%, Val Loss: 1.6552, Val Acc: 56.72%


Epoch 9/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.44it/s, loss=1.64, acc=53.5]


Epoch 9: Train Loss: 1.6412, Train Acc: 53.53%, Val Loss: 1.5993, Val Acc: 56.72%


Epoch 10/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.48it/s, loss=1.54, acc=56.3] 


Epoch 10: Train Loss: 1.5425, Train Acc: 56.30%, Val Loss: 1.4941, Val Acc: 57.95%


Epoch 11/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.48it/s, loss=1.46, acc=57.7]


Epoch 11: Train Loss: 1.4568, Train Acc: 57.73%, Val Loss: 1.4461, Val Acc: 61.54%


Epoch 12/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.37it/s, loss=1.36, acc=60.9] 


Epoch 12: Train Loss: 1.3596, Train Acc: 60.89%, Val Loss: 1.3791, Val Acc: 60.92%


Epoch 13/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.45it/s, loss=1.25, acc=63.9] 


Epoch 13: Train Loss: 1.2474, Train Acc: 63.93%, Val Loss: 1.3712, Val Acc: 63.08%


Epoch 14/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 12.33it/s, loss=1.19, acc=64.5] 


Epoch 14: Train Loss: 1.1938, Train Acc: 64.45%, Val Loss: 1.2718, Val Acc: 64.00%


Epoch 15/15 [fixed_weight]: 100%|██████████| 143/143 [00:11<00:00, 11.92it/s, loss=1.09, acc=67.2] 


Epoch 15: Train Loss: 1.0935, Train Acc: 67.18%, Val Loss: 1.2850, Val Acc: 65.23%

Evaluating fixed weight model...


Evaluating: 100%|██████████| 31/31 [00:01<00:00, 18.41it/s]


                           precision    recall  f1-score   support

                   ADONIS     0.9167    0.8462    0.8800        13
AFRICAN GIANT SWALLOWTAIL     0.8750    0.6364    0.7368        11
           AMERICAN SNOOT     0.6667    0.5455    0.6000        11
                    AN 88     0.9286    1.0000    0.9630        13
                  APPOLLO     0.5455    0.4615    0.5000        13
                    ATALA     0.6087    0.9333    0.7368        15
 BANDED ORANGE HELICONIAN     0.6667    0.8000    0.7273        15
           BANDED PEACOCK     0.8182    0.6923    0.7500        13
            BECKERS WHITE     0.2609    0.5000    0.3429        12
         BLACK HAIRSTREAK     0.8889    0.6154    0.7273        13
              BLUE MORPHO     0.6250    0.9091    0.7407        11
        BLUE SPOTTED CROW     0.5294    0.6923    0.6000        13
           BROWN SIPROETA     0.7000    0.4667    0.5600        15
            CABBAGE WHITE     0.6923    0.6429    0.6667     

Epoch 1/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.31it/s, loss=4.14, acc=3.65]


Epoch 1: Train Loss: 4.1391, Train Acc: 3.65%, Val Loss: 3.6789, Val Acc: 9.23%


Epoch 2/15 [hybrid]: 100%|██████████| 143/143 [00:12<00:00, 11.89it/s, loss=3.43, acc=13.6]


Epoch 2: Train Loss: 3.4296, Train Acc: 13.59%, Val Loss: 2.8313, Val Acc: 25.03%


Epoch 3/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.29it/s, loss=2.83, acc=25.3]


Epoch 3: Train Loss: 2.8317, Train Acc: 25.32%, Val Loss: 2.3664, Val Acc: 37.13%


Epoch 4/15 [hybrid]: 100%|██████████| 143/143 [00:12<00:00, 11.88it/s, loss=2.41, acc=34.1]


Epoch 4: Train Loss: 2.4059, Train Acc: 34.05%, Val Loss: 2.0845, Val Acc: 42.87%


Epoch 5/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.28it/s, loss=2.08, acc=42]  


Epoch 5: Train Loss: 2.0810, Train Acc: 42.01%, Val Loss: 1.9233, Val Acc: 47.49%


Epoch 6/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.21it/s, loss=1.88, acc=47]  


Epoch 6: Train Loss: 1.8815, Train Acc: 46.96%, Val Loss: 1.6376, Val Acc: 55.79%


Epoch 7/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.24it/s, loss=1.7, acc=51.9] 


Epoch 7: Train Loss: 1.7043, Train Acc: 51.95%, Val Loss: 1.5580, Val Acc: 56.72%


Epoch 8/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.44it/s, loss=1.57, acc=55]  


Epoch 8: Train Loss: 1.5674, Train Acc: 55.00%, Val Loss: 1.3976, Val Acc: 60.51%


Epoch 9/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.39it/s, loss=1.49, acc=56.9] 


Epoch 9: Train Loss: 1.4877, Train Acc: 56.91%, Val Loss: 1.3666, Val Acc: 61.44%


Epoch 10/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.44it/s, loss=1.38, acc=60.5] 


Epoch 10: Train Loss: 1.3769, Train Acc: 60.45%, Val Loss: 1.3472, Val Acc: 62.26%


Epoch 11/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.27it/s, loss=1.31, acc=61.6] 


Epoch 11: Train Loss: 1.3137, Train Acc: 61.57%, Val Loss: 1.2836, Val Acc: 66.15%


Epoch 12/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.30it/s, loss=1.22, acc=64]   


Epoch 12: Train Loss: 1.2248, Train Acc: 64.01%, Val Loss: 1.2325, Val Acc: 67.08%


Epoch 13/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.48it/s, loss=1.16, acc=65.7] 


Epoch 13: Train Loss: 1.1559, Train Acc: 65.71%, Val Loss: 1.1953, Val Acc: 66.87%


Epoch 14/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.25it/s, loss=1.09, acc=67.3] 


Epoch 14: Train Loss: 1.0905, Train Acc: 67.31%, Val Loss: 1.1376, Val Acc: 69.54%


Epoch 15/15 [hybrid]: 100%|██████████| 143/143 [00:11<00:00, 12.34it/s, loss=1.04, acc=68.8] 


Epoch 15: Train Loss: 1.0382, Train Acc: 68.83%, Val Loss: 1.1758, Val Acc: 67.79%

Evaluating hybrid model...


Evaluating: 100%|██████████| 31/31 [00:01<00:00, 19.04it/s]


                           precision    recall  f1-score   support

                   ADONIS     0.7500    0.9231    0.8276        13
AFRICAN GIANT SWALLOWTAIL     0.7778    0.6364    0.7000        11
           AMERICAN SNOOT     0.7500    0.8182    0.7826        11
                    AN 88     0.9286    1.0000    0.9630        13
                  APPOLLO     0.3438    0.8462    0.4889        13
                    ATALA     0.9286    0.8667    0.8966        15
 BANDED ORANGE HELICONIAN     0.6111    0.7333    0.6667        15
           BANDED PEACOCK     0.8333    0.7692    0.8000        13
            BECKERS WHITE     0.0000    0.0000    0.0000        12
         BLACK HAIRSTREAK     0.7692    0.7692    0.7692        13
              BLUE MORPHO     0.6667    0.5455    0.6000        11
        BLUE SPOTTED CROW     0.8571    0.4615    0.6000        13
           BROWN SIPROETA     0.8000    0.8000    0.8000        15
            CABBAGE WHITE     0.6111    0.7857    0.6875     