In [None]:
!pip install onnxruntime

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import numpy as np
import pickle
import json
import requests
import base64
import io
from datetime import datetime
import onnxruntime as ort
from typing import Tuple
import math

# THIS SAFEMODELSTEALINGATTACK PART DOES NOT BELONG TO THIS NOTEBOOK

This was used to get the data, and save it as a pickel file for later use. Hence it is shown here only as a standalone class. 

In [1]:
class SafeModelStealingAttack:
    def __init__(self, token, max_queries=100000):
        self.token = token
        self.max_queries = max_queries
        self.queries_used = 0
        self.seed = None
        self.port = None
        self.stolen_data = []
        self.session_file = "attack_session.json"
        
    def save_session(self):
        """Save current session state"""
        session_data = {
            'seed': self.seed,
            'port': self.port,
            'queries_used': self.queries_used,
            'timestamp': datetime.now().isoformat()
        }
        with open(self.session_file, 'w') as f:
            json.dump(session_data, f, indent=2)
        print(f"Session saved: {session_data}")
    
    def load_session(self):
        """Load previous session state"""
        try:
            with open(self.session_file, 'r') as f:
                session_data = json.load(f)
            self.seed = session_data.get('seed')
            self.port = session_data.get('port')
            self.queries_used = session_data.get('queries_used', 0)
            print(f"Session loaded: {session_data}")
            return True
        except FileNotFoundError:
            print("No previous session found")
            return False
    
    def request_api(self):
        """Request new API access"""
        print("Requesting new API...")
        try:
            response = requests.get("http://34.122.51.94:9090/stealing_launch", 
                                  headers={"token": self.token}, timeout=10)
            answer = response.json()
            print(f"API Response: {answer}")
            
            if 'detail' in answer:
                print(f"Error in API request: {answer['detail']}")
                return False
                
            self.seed = str(answer['seed'])
            self.port = str(answer['port'])
            self.queries_used = 0  # Reset query counter for new API
            print(f"Obtained seed: {self.seed}, port: {self.port}")
            
            # Save session immediately
            self.save_session()
            return True
            
        except requests.exceptions.RequestException as e:
            print(f"Failed to request API: {e}")
            return False
    
    def test_api_connection(self):
        """Test if API is still accessible"""
        if not self.port:
            return False
            
        try:
            # Try a small test (this won't work with the actual API, but will test connection)
            test_url = f"http://34.122.51.94:{self.port}/query"
            response = requests.get(test_url, timeout=5)
            return response.status_code != 404  # 404 means port doesn't exist
        except:
            return False
    
    def query_api_safe(self, images, max_retries=3):
        """Safe API querying with retries and error handling"""
        if self.queries_used >= self.max_queries:
            raise Exception("Maximum queries exceeded!")
        
        if not self.port:
            raise Exception("No active API session! Need to request API first.")
            
        for attempt in range(max_retries):
            try:
                endpoint = "/query"
                url = f"http://34.122.51.94:{self.port}" + endpoint
                
                image_data = []
                for img in images:
                    img_byte_arr = io.BytesIO()
                    img.save(img_byte_arr, format='PNG')
                    img_byte_arr.seek(0)
                    img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
                    image_data.append(img_base64)
                
                payload = json.dumps(image_data)
                
                print(f"Attempting query {self.queries_used + 1} (attempt {attempt + 1}/{max_retries})")
                response = requests.get(url, files={"file": payload}, 
                                     headers={"token": self.token}, timeout=30)
                
                if response.status_code == 200:
                    representations = response.json()["representations"]
                    self.queries_used += 1
                    print(f"✓ Query {self.queries_used} successful. Got {len(representations)} representations.")
                    
                    # Save session after each successful query
                    self.save_session()
                    return representations
                    
                elif response.status_code == 429:  # Too many requests
                    print("Rate limited. Waiting 60 seconds...")
                    time.sleep(60)
                    continue
                    
                else:
                    print(f"Query failed. Status: {response.status_code}, Response: {response.text}")
                    if attempt == max_retries - 1:
                        raise Exception(f"Query failed after {max_retries} attempts")
                    time.sleep(5)  # Wait before retry
                    
            except requests.exceptions.ConnectionError as e:
                print(f"Connection error on attempt {attempt + 1}: {e}")
                if attempt == max_retries - 1:
                    raise Exception(f"Connection failed after {max_retries} attempts. API may have expired.")
                time.sleep(10)  # Wait longer for connection issues
                
            except Exception as e:
                print(f"Unexpected error on attempt {attempt + 1}: {e}")
                if attempt == max_retries - 1:
                    raise
                time.sleep(5)
    
    def collect_training_data_safe(self, dataset, num_queries=90, batch_size=1000):
        """Safely collect training data with checkpointing"""
        print(f"Collecting training data with {num_queries} queries...")
        
        # Try to load existing partial data
        try:
            with open('partial_stolen_data.pickle', 'rb') as f:
                partial_data = pickle.load(f)
            all_images = partial_data['images']
            all_representations = partial_data['representations']
            start_query = len(all_representations) // batch_size
            print(f"Resuming from query {start_query + 1}")
        except FileNotFoundError:
            all_images = []
            all_representations = []
            start_query = 0
            print("Starting fresh data collection")
        
        dataset_size = len(dataset)
        
        for query_idx in range(start_query, num_queries):
            try:
                print(f"\n=== Executing query {query_idx + 1}/{num_queries} ===")
                
                # Mix of random and systematic sampling
                if query_idx < num_queries // 2:
                    # Random sampling
                    indices = np.random.choice(dataset_size, batch_size, replace=False)
                else:
                    # Systematic sampling
                    start_idx = (query_idx - num_queries // 2) * batch_size
                    indices = np.arange(start_idx, min(start_idx + batch_size, dataset_size))
                    if len(indices) < batch_size:
                        remaining = batch_size - len(indices)
                        indices = np.concatenate([indices, np.arange(remaining)])
                
                query_images = [dataset.imgs[idx] for idx in indices]
                representations = self.query_api_safe(query_images)
                
                all_images.extend(query_images)
                all_representations.extend(representations)
                
                # Save partial results every 5 queries
                if (query_idx + 1) % 5 == 0:
                    self.save_partial_data(all_images, all_representations, query_idx + 1)
                
                # Small delay to be nice to the server
                time.sleep(1)
                
            except Exception as e:
                print(f"Error during query {query_idx + 1}: {e}")
                print("Saving current progress and stopping...")
                self.save_partial_data(all_images, all_representations, query_idx)
                raise
        
        # Save final results
        self.save_stolen_data(all_images, all_representations, "final")
        
        print(f"\n✓ Successfully collected {len(all_representations)} image-representation pairs")
        return all_images, all_representations
    
    def save_partial_data(self, images, representations, query_num):
        """Save partial data during collection"""
        filename = f"partial_stolen_data.pickle"
        with open(filename, 'wb') as f:
            pickle.dump({
                'images': images, 
                'representations': representations,
                'query_num': query_num,
                'timestamp': datetime.now().isoformat()
            }, f)
        print(f"Saved partial data: {len(representations)} samples after query {query_num}")
    
    def save_stolen_data(self, images, representations, suffix=""):
        """Save stolen data to disk"""
        filename = f"stolen_data_{suffix}.pickle"
        with open(filename, 'wb') as f:
            pickle.dump({
                'images': images, 
                'representations': representations,
                'timestamp': datetime.now().isoformat()
            }, f)
        print(f"Saved stolen data to {filename}")
    
    def load_stolen_data(self, filename="stolen_data_final.pickle"):
        """Load stolen data from disk"""
        with open(filename, 'rb') as f:
            data = pickle.load(f)
        return data['images'], data['representations']

In [None]:
class TaskDataset(Dataset):
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform

    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label

    def __len__(self):
        return len(self.ids)

class ImprovedStolenEncoder(nn.Module):
    def __init__(self, input_dim=3, output_dim=1024):
        super(ImprovedStolenEncoder, self).__init__()
        
        # More sophisticated feature extraction with residual connections
        self.initial_conv = nn.Sequential(
            nn.Conv2d(input_dim, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Residual blocks with different channel sizes
        self.res_block1 = self._make_residual_block(64, 128, stride=2)
        self.res_block2 = self._make_residual_block(128, 256, stride=2)
        self.res_block3 = self._make_residual_block(256, 512, stride=2)
        self.res_block4 = self._make_residual_block(512, 512, stride=2)
        
        # Replace AdaptiveAvgPool2d with regular pooling for ONNX compatibility
        # After 4 stride-2 operations: 32->16->8->4->2, so final size is 2x2
        self.global_pool = nn.AvgPool2d(kernel_size=2, stride=1)  # 2x2 -> 1x1
        
        # More sophisticated projection head with dropout and residual connection
        self.projection_head = nn.Sequential(
            nn.Linear(512, 2048, bias=False),
            nn.BatchNorm1d(2048),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(1024, output_dim)
        )
        
        # Initialize weights properly
        self._initialize_weights()
        
    def _make_residual_block(self, in_channels, out_channels, stride=1):
        """Create a residual block with proper skip connections"""
        layers = []
        
        # Main path
        layers.extend([
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        ])
        
        main_path = nn.Sequential(*layers)
        
        # Skip connection
        if stride != 1 or in_channels != out_channels:
            skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            skip_connection = nn.Identity()
        
        return ResidualBlock(main_path, skip_connection)
    
    def _initialize_weights(self):
        """Initialize weights using He initialization"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Initial convolution
        x = self.initial_conv(x)
        
        # Residual blocks
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = self.res_block3(x)
        x = self.res_block4(x)
        
        # Global pooling - ONNX compatible
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)  # Flatten
        
        # Projection head
        x = self.projection_head(x)
        
        return x

class ResidualBlock(nn.Module):
    """Custom residual block for better ONNX compatibility"""
    def __init__(self, main_path, skip_connection):
        super(ResidualBlock, self).__init__()
        self.main_path = main_path
        self.skip_connection = skip_connection
        
    def forward(self, x):
        identity = self.skip_connection(x)
        out = self.main_path(x)
        out = out + identity
        return F.relu(out, inplace=True)

class B4BRobustTrainer:
    def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model.to(device)
        self.device = device
        print(f"Using device: {device}")
        
    def create_dataloader(self, images, representations, batch_size=64, test_split=0.2):
        """Create dataloaders with improved data preprocessing"""
        print("Creating dataloaders with enhanced preprocessing...")
        
        # Determine image type
        sample_img = images[0]
        if hasattr(sample_img, 'mode'):
            is_grayscale = sample_img.mode == 'L'
        else:
            if isinstance(sample_img, torch.Tensor):
                is_grayscale = sample_img.shape[0] == 1
            else:
                from PIL import Image
                if isinstance(sample_img, Image.Image):
                    is_grayscale = sample_img.mode == 'L'
                else:
                    is_grayscale = len(sample_img.shape) == 2 or sample_img.shape[-1] == 1
        
        print(f"Images are {'grayscale' if is_grayscale else 'RGB'}")
        
        # Simple approach - always ensure 3 channels
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
            transforms.ToTensor(),
        ])
        
        val_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        indices = list(range(len(images)))
        train_indices, val_indices = train_test_split(indices, test_size=test_split, random_state=42)

        train_image_tensors = []
        val_image_tensors = []
        
        # Process training images
        for i in train_indices:
            img = images[i]
            # Convert tensor to PIL if necessary
            if isinstance(img, torch.Tensor):
                if img.dim() == 3 and img.shape[0] in [1, 3]:
                    img = transforms.ToPILImage()(img)
                elif img.dim() == 2:
                    img = transforms.ToPILImage()(img.unsqueeze(0))
            
            img_tensor = train_transform(img)
            # Ensure 3 channels
            if img_tensor.shape[0] == 1:
                img_tensor = img_tensor.repeat(3, 1, 1)
            # Apply normalization after ensuring 3 channels
            img_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_tensor)
            train_image_tensors.append(img_tensor)
            
        # Process validation images
        for i in val_indices:
            img = images[i]
            # Convert tensor to PIL if necessary
            if isinstance(img, torch.Tensor):
                if img.dim() == 3 and img.shape[0] in [1, 3]:
                    img = transforms.ToPILImage()(img)
                elif img.dim() == 2:
                    img = transforms.ToPILImage()(img.unsqueeze(0))
            
            img_tensor = val_transform(img)
            # Ensure 3 channels
            if img_tensor.shape[0] == 1:
                img_tensor = img_tensor.repeat(3, 1, 1)
            # Apply normalization after ensuring 3 channels
            img_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_tensor)
            val_image_tensors.append(img_tensor)
        
        # Stack tensors and split representations
        X_train = torch.stack(train_image_tensors)
        X_val = torch.stack(val_image_tensors)
        y_train = torch.tensor([representations[i] for i in train_indices], dtype=torch.float32)
        y_val = torch.tensor([representations[i] for i in val_indices], dtype=torch.float32)
        
        print(f"Data shapes: Train Images {X_train.shape}, Train Representations {y_train.shape}")
        print(f"             Val Images {X_val.shape}, Val Representations {y_val.shape}")
        
        # Create datasets
        train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
        val_dataset = torch.utils.data.TensorDataset(X_val, y_val)
        
        # Create dataloaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                                num_workers=4, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                              num_workers=4, pin_memory=True)
        
        print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")
        return train_loader, val_loader
    
    def advanced_loss_function(self, predictions, targets, epoch=0):
        """Advanced loss function specifically designed for B4B defense"""
        
        # Standard MSE loss
        mse_loss = F.mse_loss(predictions, targets)
        
        # Cosine similarity loss (robust to scale differences)
        cosine_loss = 1 - F.cosine_similarity(predictions, targets).mean()
        
        # Huber loss for robustness to outliers (B4B noise)
        huber_loss = F.smooth_l1_loss(predictions, targets, beta=0.1)
        
        # L1 loss for sparsity
        l1_loss = F.l1_loss(predictions, targets)
        
        # Correlation loss to maintain relationship structure
        pred_centered = predictions - predictions.mean(dim=1, keepdim=True)
        target_centered = targets - targets.mean(dim=1, keepdim=True)
        correlation = (pred_centered * target_centered).sum(dim=1) / (
            torch.sqrt((pred_centered ** 2).sum(dim=1)) * torch.sqrt((target_centered ** 2).sum(dim=1)) + 1e-8
        )
        correlation_loss = 1 - correlation.mean()
        
        # Adaptive weighting based on training progress
        epoch_weight = min(epoch / 50.0, 1.0)  # Gradually increase correlation weight
        
        total_loss = (0.4 * mse_loss + 
                     0.2 * cosine_loss + 
                     0.2 * huber_loss + 
                     0.1 * l1_loss + 
                     0.1 * epoch_weight * correlation_loss)
        
        return total_loss, {
            'mse': mse_loss.item(),
            'cosine': cosine_loss.item(),
            'huber': huber_loss.item(),
            'l1': l1_loss.item(),
            'correlation': correlation_loss.item()
        }
    
    def train(self, train_loader, val_loader, epochs=150, lr=0.001):
        """Enhanced training with curriculum learning and advanced optimization"""
        
        # Advanced optimizer with better hyperparameters
        optimizer = optim.AdamW(
            self.model.parameters(), 
            lr=lr, 
            weight_decay=1e-4,
            betas=(0.9, 0.999),
            eps=1e-8
        )
        
        # More sophisticated learning rate scheduling
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=20, T_mult=2, eta_min=1e-6
        )
        
        # Enhanced early stopping
        best_val_loss = float('inf')
        best_cosine_sim = -1.0
        patience_counter = 0
        max_patience = 10
        min_improvement = 1e-5
        
        # Tracking
        train_losses = []
        val_losses = []
        cosine_similarities = []
        loss_components = {'mse': [], 'cosine': [], 'huber': [], 'l1': [], 'correlation': []}
        
        print("Starting enhanced training...")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        for epoch in range(epochs):
            # Training phase with curriculum learning
            self.model.train()
            train_loss = 0.0
            epoch_loss_components = {k: 0.0 for k in loss_components.keys()}
            
            for batch_idx, (images, targets) in enumerate(train_loader):
                images = images.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)
                
                optimizer.zero_grad()
                outputs = self.model(images)
                
                # Advanced loss with epoch-dependent weighting
                loss, components = self.advanced_loss_function(outputs, targets, epoch)
                loss.backward()
                
                # Gradient clipping with adaptive norm
                max_norm = 1.0 if epoch < 30 else 0.5
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=max_norm)
                
                optimizer.step()
                train_loss += loss.item()
                
                # Track loss components
                for k, v in components.items():
                    epoch_loss_components[k] += v
                
                if batch_idx % 50 == 0:
                    print(f"Epoch {epoch+1}, Batch {batch_idx}: Loss = {loss.item():.4f}")
            
            # Update learning rate
            scheduler.step()
            
            # Validation phase with comprehensive metrics
            self.model.eval()
            val_loss = 0.0
            total_cosine_sim = 0.0
            total_l2_dist = 0.0
            num_batches = 0
            
            with torch.no_grad():
                for images, targets in val_loader:
                    images = images.to(self.device, non_blocking=True)
                    targets = targets.to(self.device, non_blocking=True)
                    
                    outputs = self.model(images)
                    loss, _ = self.advanced_loss_function(outputs, targets, epoch)
                    val_loss += loss.item()
                    
                    # Multiple similarity metrics
                    cosine_sim = F.cosine_similarity(outputs, targets).mean()
                    l2_dist = torch.norm(outputs - targets, dim=1).mean()
                    
                    total_cosine_sim += cosine_sim.item()
                    total_l2_dist += l2_dist.item()
                    num_batches += 1
            
            # Average metrics
            train_loss /= len(train_loader)
            val_loss /= len(val_loader)
            avg_cosine_sim = total_cosine_sim / num_batches
            avg_l2_dist = total_l2_dist / num_batches
            
            # Average loss components
            for k in epoch_loss_components:
                epoch_loss_components[k] /= len(train_loader)
                loss_components[k].append(epoch_loss_components[k])
            
            # Store history
            train_losses.append(train_loss)
            val_losses.append(val_loss)
            cosine_similarities.append(avg_cosine_sim)
            
            print(f"Epoch {epoch+1}/{epochs}:")
            print(f"  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            print(f"  Cosine Similarity: {avg_cosine_sim:.4f}, L2 Distance: {avg_l2_dist:.4f}")
            print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
            print(f"  Loss Components - MSE: {epoch_loss_components['mse']:.4f}, "
                  f"Cosine: {epoch_loss_components['cosine']:.4f}, "
                  f"Huber: {epoch_loss_components['huber']:.4f}")
            
            # Enhanced early stopping with multiple criteria
            improvement = best_val_loss - val_loss
            cosine_improvement = avg_cosine_sim > best_cosine_sim
            
            if (val_loss < best_val_loss and improvement > min_improvement) or cosine_improvement:
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                if avg_cosine_sim > best_cosine_sim:
                    best_cosine_sim = avg_cosine_sim
                    
                patience_counter = 0
                
                # Save best model with comprehensive state
                torch.save({
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'epoch': epoch,
                    'val_loss': val_loss,
                    'cosine_sim': avg_cosine_sim,
                    'l2_dist': avg_l2_dist,
                    'train_history': {
                        'train_losses': train_losses,
                        'val_losses': val_losses,
                        'cosine_similarities': cosine_similarities,
                        'loss_components': loss_components
                    }
                }, 'best_stolen_model.pth')
                
                print(f"  ✓ New best model saved (val_loss: {val_loss:.4f}, cosine: {avg_cosine_sim:.4f})")
                
            else:
                patience_counter += 1
                print(f"  → No significant improvement (patience: {patience_counter}/{max_patience})")
                
                if patience_counter >= max_patience:
                    print(f"\n🛑 Early stopping triggered at epoch {epoch+1}")
                    break
                    
                # Reduce learning rate if stuck
                if patience_counter % 10 == 0:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] *= 0.5
                    print(f"  📉 Reduced learning rate to {optimizer.param_groups[0]['lr']:.2e}")
            
            print("-" * 60)
        
        # Load best model
        checkpoint = torch.load('best_stolen_model.pth')
        self.model.load_state_dict(checkpoint['model_state_dict'])
        
        print("\n✅ Enhanced training completed!")
        print(f"📊 Final Results:")
        print(f"   Best validation loss: {best_val_loss:.4f}")
        print(f"   Best cosine similarity: {best_cosine_sim:.4f}")
        print(f"   Total epochs: {len(train_losses)}")
        
        return checkpoint['train_history']

In [None]:
def load_stolen_data(filename="/kaggle/input/tml-assignment2-data/stolen_data_final.pickle"):
    """Load stolen data from disk"""
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data['images'], data['representations']

def load_session():
    """Load session data from JSON file"""
    try:
        with open('attack_session.json', 'r') as f:
            session_data = json.load(f)
        return session_data
    except FileNotFoundError:
        print("No session file found")
        return None

def export_model_to_onnx():
    """Export the trained model to ONNX format with ONNX-compatible architecture"""
    print("=== Exporting Enhanced Model to ONNX ===")
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    try:
        checkpoint = torch.load('best_stolen_model.pth', map_location=device, weights_only=False)
        print("✓ Loaded model checkpoint")
    except FileNotFoundError:
        print("❌ Error: 'best_stolen_model.pth' not found!")
        return False
    
    # Create model and load weights
    stolen_model = ImprovedStolenEncoder(input_dim=3, output_dim=1024)
    stolen_model.load_state_dict(checkpoint['model_state_dict'])
    stolen_model.eval()
    stolen_model = stolen_model.cpu()
    
    print("✓ Model loaded and moved to CPU for ONNX export")
    
    # Create dummy input
    dummy_input = torch.randn(1, 3, 32, 32)
    print(f"✓ Created dummy input with shape: {dummy_input.shape}")
    
    # Export to ONNX with optimized settings
    onnx_path = 'stolen_model.onnx'
    
    try:
        torch.onnx.export(
            stolen_model,
            dummy_input,
            onnx_path,
            export_params=True,
            opset_version=11,  # Stable version with good compatibility
            input_names=["x"],
            output_names=["output"],
            dynamic_axes={'x': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
            verbose=False,
            do_constant_folding=True,  # Optimize the model
            training=torch.onnx.TrainingMode.EVAL
        )
        print("✓ ONNX export successful!")
        
    except Exception as e:
        print(f"❌ ONNX export failed: {e}")
        return False
    
    # Validate ONNX model
    print("Validating ONNX model...")
    try:
        ort_session = ort.InferenceSession(onnx_path)
        
        # Test with dummy input
        test_input = np.random.randn(1, 3, 32, 32).astype(np.float32)
        ort_inputs = {ort_session.get_inputs()[0].name: test_input}
        ort_outputs = ort_session.run(None, ort_inputs)
        output = ort_outputs[0][0]
        
        expected_shape = (1024,)
        if output.shape != expected_shape:
            print(f"❌ Invalid output shape: {output.shape}, expected: {expected_shape}")
            return False
            
        print(f"✓ ONNX model validation successful!")
        print(f"  Input shape: {test_input.shape}")
        print(f"  Output shape: {output.shape}")
        print(f"  Output range: [{output.min():.4f}, {output.max():.4f}]")
        
        return True
        
    except Exception as e:
        print(f"❌ ONNX model validation failed: {e}")
        return False

def submit_for_evaluation():
    """Submit the ONNX model for evaluation"""
    print("\n=== Submitting Enhanced Model for Evaluation ===")
    
    TOKEN = "96005201"  # Replace with your actual token
    onnx_path = 'stolen_model.onnx'
    
    # Load session info
    try:
        with open('/kaggle/input/tml-assignment2-data/attack_session.json', 'r') as f:
            session_data = json.load(f)
        seed = session_data.get('seed')
        
        if not seed:
            print("❌ No seed found in session data!")
            return False
            
        print(f"✓ Using seed: {seed}")
        
    except FileNotFoundError:
        print("❌ No session file found!")
        return False
    
    # Submit to evaluation server
    try:
        print("Submitting enhanced model to evaluation server...")
        with open(onnx_path, 'rb') as f:
            response = requests.post(
                "http://34.122.51.94:9090/stealing",
                files={"file": f},
                headers={"token": TOKEN, "seed": seed},
                timeout=90
            )
        
        print(f"Response Status Code: {response.status_code}")
        
        if response.status_code == 200:
            result = response.json()
            print("✅ Submission successful!")
            print(f"📊 Evaluation Result: {result}")
            
            # Save submission result
            submission_info = {
                'submission_time': datetime.now().isoformat(),
                'seed': seed,
                'model_path': onnx_path,
                'response': result,
                'status': 'success',
                'model_type': 'ImprovedStolenEncoder'
            }
            
            with open('submission_result.json', 'w') as f:
                json.dump(submission_info, f, indent=2)
            
            return True
            
        else:
            print(f"❌ Submission failed with status {response.status_code}")
            print(f"Response: {response.text}")
            return False
            
    except Exception as e:
        print(f"❌ Submission failed: {e}")
        return False

In [None]:
def main():
    """Main function with enhanced pipeline"""
    print("=== Enhanced Model Stealing Pipeline ===")
    
    # Step 1: Export enhanced model to ONNX
    if not export_model_to_onnx():
        print("❌ Enhanced model export failed.")
        return
    
    # Step 2: Submit for evaluation
    if not submit_for_evaluation():
        print("❌ Submission failed.")
        return
    
    print("\n🎉 Enhanced pipeline completed successfully!")
    print("📁 Files created:")
    print("  - stolen_model.onnx (enhanced model)")
    print("  - submission_result.json (evaluation results)")

def train_enhanced_model():
    """Training function for the enhanced model"""
    print("=== Training Enhanced Model ===")
    
    # Load data
    images, representations = load_stolen_data()
    print(f"Loaded {len(images)} images and {len(representations)} representations")
    
    # Create enhanced model
    model = ImprovedStolenEncoder(input_dim=3, output_dim=1024)
    
    # Create trainer
    trainer = B4BRobustTrainer(model)
    
    # Create dataloaders with enhanced preprocessing
    train_loader, val_loader = trainer.create_dataloader(
        images, representations, batch_size=64, test_split=0.2
    )
    
    # Train with enhanced strategy
    history = trainer.train(
        train_loader, val_loader, epochs=150, lr=0.001
    )
    
    print("✅ Enhanced training completed!")
    return history

In [None]:
if __name__ == "__main__":
    # First train the enhanced model
    print("Starting enhanced model training...")
    train_enhanced_model()
    
    # Then export and submit
    main()