In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from typing import Optional, Tuple, List, Dict, Union
import sys
import numpy as np
from tqdm import tqdm
import logging
import copy
import traceback
from dataclasses import dataclass
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import logging

# Set up device and logging
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

label = "flow_matching2"
logging.basicConfig(filename=f'Outputs/{label}.log',
                    level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    filemode='a')

class LoggerWriter(object):
    def __init__(self, level):
        self.level = level

    def write(self, message):
        if message.strip() != "":
            logging.log(self.level, message.strip())

    def flush(self):
        pass

sys.stdout = LoggerWriter(logging.INFO)
sys.stderr = LoggerWriter(logging.INFO)

# ======================== UTILITY CLASSES ========================
class Bunch:
    """Simple Bunch class for storing data"""
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def count_parameters(model):
    """Count trainable parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# ======================== WEIGHT SPACE OBJECTS ========================
class WeightSpaceObject:
    """Base class for weight space objects (MLPs)"""
    def __init__(self, weights, biases):
        self.weights = weights if isinstance(weights, tuple) else tuple(weights)
        self.biases = biases if isinstance(biases, tuple) else tuple(biases)
        
    def flatten(self, device=None):
        """Flatten weights and biases into a single vector"""
        flat = torch.cat([w.flatten() for w in self.weights] + 
                          [b.flatten() for b in self.biases])
        if device:
            flat = flat.to(device)
        return flat
    
    @classmethod
    def from_flat(cls, flat, layers, device):
        """Create WeightSpaceObject from flattened vector"""
        sizes = []
        # Calculate sizes for weight matrices
        for i in range(len(layers) - 1):
            sizes.append(layers[i] * layers[i+1])  # Weight matrix
        # Calculate sizes for bias vectors
        for i in range(1, len(layers)):
            sizes.append(layers[i])  # Bias vector
            
        # Split flat tensor into parts
        parts = []
        start = 0
        for size in sizes:
            parts.append(flat[start:start+size])
            start += size
            
        # Reshape into weight matrices and bias vectors
        weights = []
        biases = []
        for i in range(len(layers) - 1):
            weights.append(parts[i].reshape(layers[i+1], layers[i]))
            biases.append(parts[i + len(layers) - 1])
            
        return cls(weights, biases).to(device)
    
    def to(self, device):
        """Move weights and biases to specified device"""
        weights = tuple(w.to(device) for w in self.weights)
        biases = tuple(b.to(device) for b in self.biases)
        return WeightSpaceObject(weights, biases)

@dataclass
class AttentionWeights:
    """Container for multi-head attention weights"""
    qkv_weight: torch.Tensor  # Combined QKV projection
    qkv_bias: Optional[torch.Tensor]
    proj_weight: torch.Tensor  # Output projection
    proj_bias: Optional[torch.Tensor]
    num_heads: int
    
    def split_heads(self):
        """Split QKV weights by heads"""
        d_model = self.qkv_weight.shape[1]
        head_dim = d_model // self.num_heads
        
        # Reshape QKV: [3*d_model, d_model] -> [3, num_heads, head_dim, d_model]
        qkv = self.qkv_weight.reshape(3, self.num_heads, head_dim, d_model)
        q_weights = qkv[0]  # [num_heads, head_dim, d_model]
        k_weights = qkv[1]
        v_weights = qkv[2]
        
        return q_weights, k_weights, v_weights

@dataclass
class TransformerBlockWeights:
    """Container for transformer block weights"""
    attention: AttentionWeights
    norm1_weight: torch.Tensor
    norm1_bias: torch.Tensor
    mlp_weights: Tuple[torch.Tensor, ...]  # MLP layer weights
    mlp_biases: Tuple[torch.Tensor, ...]
    norm2_weight: torch.Tensor
    norm2_bias: torch.Tensor

class VisionTransformerWeightSpace:
    """Weight space object for Vision Transformers"""
    
    def __init__(self, 
                 patch_embed_weight: torch.Tensor,
                 patch_embed_bias: Optional[torch.Tensor],
                 cls_token: torch.Tensor,
                 pos_embed: torch.Tensor,
                 blocks: List[TransformerBlockWeights],
                 norm_weight: torch.Tensor,
                 norm_bias: torch.Tensor,
                 head_weight: torch.Tensor,
                 head_bias: torch.Tensor):
        
        self.patch_embed_weight = patch_embed_weight
        self.patch_embed_bias = patch_embed_bias
        self.cls_token = cls_token
        self.pos_embed = pos_embed
        self.blocks = blocks
        self.norm_weight = norm_weight
        self.norm_bias = norm_bias
        self.head_weight = head_weight
        self.head_bias = head_bias
        
    @classmethod
    def from_vit_model(cls, model: nn.Module):
        """Extract weights from a ViT model"""
        blocks = []
        
        # Extract transformer blocks
        for block in model.blocks:
            # Multi-head attention weights
            attn = block.attn
            attention_weights = AttentionWeights(
                qkv_weight=attn.qkv.weight.data.clone(),
                qkv_bias=attn.qkv.bias.data.clone() if attn.qkv.bias is not None else None,
                proj_weight=attn.proj.weight.data.clone(),
                proj_bias=attn.proj.bias.data.clone() if attn.proj.bias is not None else None,
                num_heads=attn.num_heads
            )
            
            # MLP weights - iterate through MLP's children modules
            mlp_weights = []
            mlp_biases = []
            for name, layer in block.mlp.named_children():
                if hasattr(layer, 'weight'):
                    mlp_weights.append(layer.weight.data.clone())
                    if hasattr(layer, 'bias') and layer.bias is not None:
                        mlp_biases.append(layer.bias.data.clone())
            mlp_weights = tuple(mlp_weights)
            mlp_biases = tuple(mlp_biases)
            
            # Layer norms
            block_weights = TransformerBlockWeights(
                attention=attention_weights,
                norm1_weight=block.norm1.weight.data.clone(),
                norm1_bias=block.norm1.bias.data.clone(),
                mlp_weights=mlp_weights,
                mlp_biases=mlp_biases,
                norm2_weight=block.norm2.weight.data.clone(),
                norm2_bias=block.norm2.bias.data.clone()
            )
            blocks.append(block_weights)
        
        # Create weight space object
        return cls(
            patch_embed_weight=model.patch_embed.proj.weight.data.clone(),
            patch_embed_bias=model.patch_embed.proj.bias.data.clone() 
                            if model.patch_embed.proj.bias is not None else None,
            cls_token=model.cls_token.data.clone(),
            pos_embed=model.pos_embed.data.clone(),
            blocks=blocks,
            norm_weight=model.norm.weight.data.clone(),
            norm_bias=model.norm.bias.data.clone(),
            head_weight=model.head.weight.data.clone(),
            head_bias=model.head.bias.data.clone()
        )
    
    def apply_to_model(self, model: nn.Module):
        """Apply weights to a ViT model"""
        with torch.no_grad():
            # Patch embedding
            model.patch_embed.proj.weight.data.copy_(self.patch_embed_weight)
            if self.patch_embed_bias is not None:
                model.patch_embed.proj.bias.data.copy_(self.patch_embed_bias)
            
            # Tokens and embeddings
            model.cls_token.data.copy_(self.cls_token)
            model.pos_embed.data.copy_(self.pos_embed)
            
            # Transformer blocks
            for block, block_weights in zip(model.blocks, self.blocks):
                # Attention
                attn = block.attn
                attn.qkv.weight.data.copy_(block_weights.attention.qkv_weight)
                if block_weights.attention.qkv_bias is not None:
                    attn.qkv.bias.data.copy_(block_weights.attention.qkv_bias)
                attn.proj.weight.data.copy_(block_weights.attention.proj_weight)
                if block_weights.attention.proj_bias is not None:
                    attn.proj.bias.data.copy_(block_weights.attention.proj_bias)
                
                # Layer norms
                block.norm1.weight.data.copy_(block_weights.norm1_weight)
                block.norm1.bias.data.copy_(block_weights.norm1_bias)
                block.norm2.weight.data.copy_(block_weights.norm2_weight)
                block.norm2.bias.data.copy_(block_weights.norm2_bias)
                
                # MLP - iterate through MLP's children modules
                mlp_layers = [layer for name, layer in block.mlp.named_children() 
                             if hasattr(layer, 'weight')]
                for layer, weight in zip(mlp_layers, block_weights.mlp_weights):
                    layer.weight.data.copy_(weight)
                    
                # Handle biases separately since not all layers may have them
                mlp_bias_idx = 0
                for name, layer in block.mlp.named_children():
                    if hasattr(layer, 'bias') and layer.bias is not None:
                        if mlp_bias_idx < len(block_weights.mlp_biases):
                            layer.bias.data.copy_(block_weights.mlp_biases[mlp_bias_idx])
                            mlp_bias_idx += 1
            
            # Final norm and head
            model.norm.weight.data.copy_(self.norm_weight)
            model.norm.bias.data.copy_(self.norm_bias)
            model.head.weight.data.copy_(self.head_weight)
            model.head.bias.data.copy_(self.head_bias)
    
    def flatten(self, device=None) -> torch.Tensor:
        """Flatten all weights into a single vector"""
        all_params = []
        
        # Patch embedding
        all_params.append(self.patch_embed_weight.flatten())
        if self.patch_embed_bias is not None:
            all_params.append(self.patch_embed_bias.flatten())
        
        # Tokens and embeddings
        all_params.append(self.cls_token.flatten())
        all_params.append(self.pos_embed.flatten())
        
        # Transformer blocks
        for block in self.blocks:
            # Attention
            all_params.append(block.attention.qkv_weight.flatten())
            if block.attention.qkv_bias is not None:
                all_params.append(block.attention.qkv_bias.flatten())
            all_params.append(block.attention.proj_weight.flatten())
            if block.attention.proj_bias is not None:
                all_params.append(block.attention.proj_bias.flatten())
            
            # Norms
            all_params.append(block.norm1_weight.flatten())
            all_params.append(block.norm1_bias.flatten())
            all_params.append(block.norm2_weight.flatten())
            all_params.append(block.norm2_bias.flatten())
            
            # MLP
            for w in block.mlp_weights:
                all_params.append(w.flatten())
            for b in block.mlp_biases:
                all_params.append(b.flatten())
        
        # Final norm and head
        all_params.append(self.norm_weight.flatten())
        all_params.append(self.norm_bias.flatten())
        all_params.append(self.head_weight.flatten())
        all_params.append(self.head_bias.flatten())
        
        flat = torch.cat(all_params)
        if device:
            flat = flat.to(device)
        return flat
    
    @classmethod
    def from_flat(cls, flat_tensor, reference_ws, device=None):
        """Reconstruct VisionTransformerWeightSpace from flattened weights"""
        if device is None:
            device = flat_tensor.device
            
        # Get all parameter shapes from reference
        param_shapes = []
        param_types = []
        
        # Patch embedding
        param_shapes.append(reference_ws.patch_embed_weight.shape)
        param_types.append('patch_embed_weight')
        
        if reference_ws.patch_embed_bias is not None:
            param_shapes.append(reference_ws.patch_embed_bias.shape)
            param_types.append('patch_embed_bias')
        
        # CLS token and pos embed
        param_shapes.append(reference_ws.cls_token.shape)
        param_types.append('cls_token')
        
        param_shapes.append(reference_ws.pos_embed.shape)
        param_types.append('pos_embed')
        
        # Transformer blocks
        for block_idx, block in enumerate(reference_ws.blocks):
            # Attention weights
            param_shapes.append(block.attention.qkv_weight.shape)
            param_types.append(f'block_{block_idx}_attn_qkv_weight')
            
            if block.attention.qkv_bias is not None:
                param_shapes.append(block.attention.qkv_bias.shape)
                param_types.append(f'block_{block_idx}_attn_qkv_bias')
            
            param_shapes.append(block.attention.proj_weight.shape)
            param_types.append(f'block_{block_idx}_attn_proj_weight')
            
            if block.attention.proj_bias is not None:
                param_shapes.append(block.attention.proj_bias.shape)
                param_types.append(f'block_{block_idx}_attn_proj_bias')
            
            # Layer norms
            param_shapes.append(block.norm1_weight.shape)
            param_types.append(f'block_{block_idx}_norm1_weight')
            
            param_shapes.append(block.norm1_bias.shape)
            param_types.append(f'block_{block_idx}_norm1_bias')
            
            param_shapes.append(block.norm2_weight.shape)
            param_types.append(f'block_{block_idx}_norm2_weight')
            
            param_shapes.append(block.norm2_bias.shape)
            param_types.append(f'block_{block_idx}_norm2_bias')
            
            # MLP weights
            for mlp_idx, mlp_weight in enumerate(block.mlp_weights):
                param_shapes.append(mlp_weight.shape)
                param_types.append(f'block_{block_idx}_mlp_weight_{mlp_idx}')
            
            # MLP biases
            for mlp_idx, mlp_bias in enumerate(block.mlp_biases):
                param_shapes.append(mlp_bias.shape)
                param_types.append(f'block_{block_idx}_mlp_bias_{mlp_idx}')
        
        # Final norm and head
        param_shapes.append(reference_ws.norm_weight.shape)
        param_types.append('norm_weight')
        
        param_shapes.append(reference_ws.norm_bias.shape)
        param_types.append('norm_bias')
        
        param_shapes.append(reference_ws.head_weight.shape)
        param_types.append('head_weight')
        
        param_shapes.append(reference_ws.head_bias.shape)
        param_types.append('head_bias')
        
        # Split flat tensor according to shapes
        sizes = [np.prod(shape) for shape in param_shapes]
        parts = []
        start = 0
        
        for size in sizes:
            parts.append(flat_tensor[start:start+size])
            start += size
        
        # Reconstruct parameters
        reconstructed_params = {}
        for i, (shape, param_type) in enumerate(zip(param_shapes, param_types)):
            reconstructed_params[param_type] = parts[i].reshape(shape).to(device)
        
        # Build the blocks
        reconstructed_blocks = []
        num_blocks = len(reference_ws.blocks)
        
        for block_idx in range(num_blocks):
            # Reconstruct attention weights
            qkv_weight = reconstructed_params[f'block_{block_idx}_attn_qkv_weight']
            qkv_bias = reconstructed_params.get(f'block_{block_idx}_attn_qkv_bias', None)
            proj_weight = reconstructed_params[f'block_{block_idx}_attn_proj_weight']  
            proj_bias = reconstructed_params.get(f'block_{block_idx}_attn_proj_bias', None)
            
            attention = AttentionWeights(
                qkv_weight=qkv_weight,
                qkv_bias=qkv_bias,
                proj_weight=proj_weight,
                proj_bias=proj_bias,
                num_heads=reference_ws.blocks[block_idx].attention.num_heads
            )
            
            # Reconstruct MLP weights
            mlp_weights = []
            mlp_biases = []
            
            mlp_weight_idx = 0
            while f'block_{block_idx}_mlp_weight_{mlp_weight_idx}' in reconstructed_params:
                mlp_weights.append(reconstructed_params[f'block_{block_idx}_mlp_weight_{mlp_weight_idx}'])
                mlp_weight_idx += 1
            
            mlp_bias_idx = 0
            while f'block_{block_idx}_mlp_bias_{mlp_bias_idx}' in reconstructed_params:
                mlp_biases.append(reconstructed_params[f'block_{block_idx}_mlp_bias_{mlp_bias_idx}'])
                mlp_bias_idx += 1
            
            # Create block
            block = TransformerBlockWeights(
                attention=attention,
                norm1_weight=reconstructed_params[f'block_{block_idx}_norm1_weight'],
                norm1_bias=reconstructed_params[f'block_{block_idx}_norm1_bias'],
                mlp_weights=tuple(mlp_weights),
                mlp_biases=tuple(mlp_biases),
                norm2_weight=reconstructed_params[f'block_{block_idx}_norm2_weight'],
                norm2_bias=reconstructed_params[f'block_{block_idx}_norm2_bias']
            )
            
            reconstructed_blocks.append(block)
        
        # Create the full weight space object
        return cls(
            patch_embed_weight=reconstructed_params['patch_embed_weight'],
            patch_embed_bias=reconstructed_params.get('patch_embed_bias', None),
            cls_token=reconstructed_params['cls_token'],
            pos_embed=reconstructed_params['pos_embed'],
            blocks=reconstructed_blocks,
            norm_weight=reconstructed_params['norm_weight'], 
            norm_bias=reconstructed_params['norm_bias'],
            head_weight=reconstructed_params['head_weight'],
            head_bias=reconstructed_params['head_bias']
        )

# ======================== REBASIN / TRANSFUSION ========================
class PermutationSpec:
    """Specification for permutations applied throughout the network"""
    
    def __init__(self, num_blocks: int):
        self.num_blocks = num_blocks
        self.block_perms = []
        for _ in range(num_blocks):
            self.block_perms.append({
                'attention_in': None,
                'attention_out': None,
                'mlp1': None,
                'mlp2': None,
            })
        
    def set_block_perm(self, block_idx: int, perm_type: str, perm: torch.Tensor):
        """Set a specific permutation for a block"""
        if block_idx < len(self.block_perms):
            self.block_perms[block_idx][perm_type] = perm

class TransFusionMatcher:
    """Weight matching using TransFusion approach"""
    
    def __init__(self, num_iterations: int = 3, epsilon: float = 1e-8):
        self.num_iterations = num_iterations
        self.epsilon = epsilon
        
    def compute_spectral_distance(self, weight1: torch.Tensor, weight2: torch.Tensor) -> float:
        """Compute permutation-invariant distance using singular values"""
        try:
            _, s1, _ = torch.svd(weight1.float())
            _, s2, _ = torch.svd(weight2.float())
        except:
            try:
                _, s1, _ = np.linalg.svd(weight1.cpu().numpy())
                _, s2, _ = np.linalg.svd(weight2.cpu().numpy())
                s1 = torch.tensor(s1, device=weight1.device)
                s2 = torch.tensor(s2, device=weight2.device)
            except:
                # Fallback to Frobenius norm
                return torch.norm(weight1 - weight2).item()
        
        # Pad to same length if necessary
        max_len = max(len(s1), len(s2))
        if len(s1) < max_len:
            s1 = torch.cat([s1, torch.zeros(max_len - len(s1), device=s1.device)])
        if len(s2) < max_len:
            s2 = torch.cat([s2, torch.zeros(max_len - len(s2), device=s2.device)])
        
        return torch.norm(s1 - s2).item()
    
    def match_attention_heads(self, attn1: AttentionWeights, attn2: AttentionWeights):
        """Match attention heads between two attention layers"""
        try:
            q1, k1, v1 = attn1.split_heads()
            q2, k2, v2 = attn2.split_heads()
            
            num_heads = attn1.num_heads
            d_model = attn1.qkv_weight.shape[1]
            head_dim = d_model // num_heads
            
            # Inter-head alignment using spectral distance
            distance_matrix = torch.zeros(num_heads, num_heads)
            
            for i in range(num_heads):
                for j in range(num_heads):
                    dist_q = self.compute_spectral_distance(q1[i], q2[j])
                    dist_k = self.compute_spectral_distance(k1[i], k2[j])
                    dist_v = self.compute_spectral_distance(v1[i], v2[j])
                    distance_matrix[i, j] = dist_q + dist_k + dist_v
            
            # Solve assignment problem
            row_ind, col_ind = linear_sum_assignment(distance_matrix.cpu().numpy())
            
            # Create permutation matrix
            perm = torch.eye(d_model, device=attn1.qkv_weight.device)
            
            # Apply head-wise permutation (simplified)
            for i, j in zip(row_ind, col_ind):
                if i != j:
                    # Swap head positions
                    start_i, end_i = i * head_dim, (i + 1) * head_dim
                    start_j, end_j = j * head_dim, (j + 1) * head_dim
                    # This is a simplified permutation - proper implementation would be more complex
            
            return None, None, perm
            
        except Exception as e:
            logging.warning(f"Error in attention matching: {e}")
            d_model = attn1.qkv_weight.shape[1]
            return None, None, torch.eye(d_model, device=attn1.qkv_weight.device)
    
    def match_mlp_layer(self, weight1: torch.Tensor, weight2: torch.Tensor, 
                       prev_perm: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Match MLP layers using Hungarian algorithm"""
        try:
            # Apply previous permutation if exists
            if prev_perm is not None and prev_perm.shape[0] == weight1.shape[1]:
                weight1_permuted = torch.mm(weight1, prev_perm.t())
            else:
                weight1_permuted = weight1
            
            # Compute cost matrix
            cost_matrix = -torch.mm(weight2, weight1_permuted.t())
            
            # Solve assignment problem
            row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().numpy())
            
            # Create permutation matrix
            n = weight1.shape[0]
            perm = torch.zeros(n, n, device=weight1.device)
            perm[row_ind, col_ind] = 1.0
            
            return perm
            
        except Exception as e:
            logging.warning(f"Error in MLP matching: {e}")
            return torch.eye(weight1.shape[0], device=weight1.device)
    
    def canonicalize_model(self, models: List[VisionTransformerWeightSpace], 
                          reference_idx: int = 0) -> List[VisionTransformerWeightSpace]:
        """Canonicalize multiple models using one as reference"""
        reference = models[reference_idx]
        canonicalized = []
        
        for i, model in enumerate(models):
            if i == reference_idx:
                canonicalized.append(reference)
            else:                
                # Simple canonicalization - in practice would be more sophisticated
                try:
                    current_model = copy.deepcopy(model)
                    
                    # Apply simple permutation matching for each block
                    for block_idx in range(len(current_model.blocks)):
                        current_block = current_model.blocks[block_idx]
                        reference_block = reference.blocks[block_idx]
                        
                        # Match attention (simplified)
                        _, _, attn_perm = self.match_attention_heads(
                            current_block.attention, reference_block.attention
                        )
                        
                        # Match MLP layers
                        if len(current_block.mlp_weights) >= 1:
                            mlp_perm = self.match_mlp_layer(
                                current_block.mlp_weights[0],
                                reference_block.mlp_weights[0]
                            )
                    
                    canonicalized.append(current_model)
                    
                except Exception as e:
                    canonicalized.append(model)  # Use original if canonicalization fails
        
        return canonicalized

# ======================== VIT MODEL DEFINITION ========================
class MultiHeadAttention(nn.Module):
    """Multi-head attention module"""
    
    def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        
        return x

class MLP(nn.Module):
    """MLP module"""
    
    def __init__(self, in_features: int, hidden_features: Optional[int] = None, 
                 out_features: Optional[int] = None, dropout: float = 0.1):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer block with attention and MLP"""
    
    def __init__(self, dim: int, num_heads: int = 8, mlp_ratio: float = 4.0, 
                 dropout: float = 0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dropout=dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class PatchEmbed(nn.Module):
    """Image to patch embedding"""
    
    def __init__(self, img_size: int = 32, patch_size: int = 4, 
                 in_chans: int = 3, embed_dim: int = 192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(in_chans, embed_dim, 
                             kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class VisionTransformer(nn.Module):
    """Simple Vision Transformer for CIFAR-10"""
    
    def __init__(self, 
                 img_size: int = 32,
                 patch_size: int = 4,
                 in_chans: int = 3,
                 num_classes: int = 10,
                 embed_dim: int = 512,
                 depth: int = 8,
                 num_heads: int = 8,
                 mlp_ratio: float = 4.0,
                 dropout: float = 0.1):
        super().__init__()
        
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        self._init_weights()
        
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
                
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.norm(x)
        x = x[:, 0]
        x = self.head(x)
        
        return x

def create_vit_small(num_classes: int = 10, **kwargs) -> VisionTransformer:
    """Create a small ViT suitable for CIFAR-10"""
    # Set default values, but allow overrides from kwargs
    defaults = {
        'img_size': 32,
        'patch_size': 4,
        'embed_dim': 256,
        'depth': 4,
        'num_heads': 4,
        'mlp_ratio': 4.0,
        'num_classes': num_classes
    }
    
    # Update defaults with any provided kwargs
    defaults.update(kwargs)
    
    return VisionTransformer(**defaults)

# ======================== DATA LOADING ========================
def load_cifar10(batch_size=128):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, test_loader

def evaluate(model, test_loader, device):
    """Evaluate model on test set"""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

class SimpleCFM:
    """Base Conditional Flow Matching class"""
    
    def __init__(self, sourceloader, targetloader, model, mode="velocity", 
                 t_dist="uniform", device=None):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.sourceloader = sourceloader
        self.targetloader = targetloader
        self.model = model
        self.mode = mode
        self.t_dist = t_dist
        self.sigma = 0.001
        self.best_loss = float('inf')
        self.best_model_state = None
        self.input_dim = getattr(model, 'input_dim', model.net[0].in_features)

    def sample_from_loader(self, loader):
        """Sample from a dataloader"""
        try:
            if not hasattr(loader, '_iterator') or loader._iterator is None:
                loader._iterator = iter(loader)
            try:
                batch = next(loader._iterator)
            except StopIteration:
                loader._iterator = iter(loader)
                batch = next(loader._iterator)
            return batch[0]
        except Exception as e:
            logging.warning(f"Error sampling from loader: {e}")
            # Return zero tensor as fallback
            return torch.zeros(1, self.input_dim, device=self.device)
    
    def sample_time_and_flow(self):
        """Sample time, start and end points, and intermediate x_t"""
        x0 = self.sample_from_loader(self.sourceloader)
        x1 = self.sample_from_loader(self.targetloader)
        
        batch_size = min(x0.size(0), x1.size(0))
        x0 = x0[:batch_size].to(self.device)
        x1 = x1[:batch_size].to(self.device)
        
        if self.t_dist == "uniform":
            t = torch.rand(batch_size, device=self.device)
        elif self.t_dist == "beta":
            t = torch.distributions.Beta(2.0, 5.0).sample((batch_size,)).to(self.device)
        
        t_pad = t.reshape(-1, *([1] * (x0.dim() - 1)))
        
        mu_t = (1 - t_pad) * x0 + t_pad * x1
        sigma_pad = torch.tensor(self.sigma, device=self.device)
        xt = mu_t + sigma_pad * torch.randn_like(x0)
        ut = x1 - x0
        
        return Bunch(t=t.unsqueeze(-1), x0=x0, xt=xt, x1=x1, ut=ut, batch_size=batch_size)

    def forward(self, flow):
        """Forward pass through the model"""
        try:
            flow_pred = self.model(flow.xt, flow.t)
            return None, flow_pred
        except Exception as e:
            logging.warning(f"Error in forward pass: {e}")
            return None, torch.zeros_like(flow.ut)
        
    def loss_fn(self, flow_pred, flow):
        """Compute loss between predicted and true flows"""
        if self.mode == "velocity":
            l_flow = torch.mean((flow_pred.squeeze() - flow.ut) ** 2)
        else:
            l_flow = torch.mean((flow_pred.squeeze() - flow.x1) ** 2)
        return None, l_flow

    def map(self, x0, n_steps=100, return_traj=False, method="rk4", adaptive=False):
        """Enhanced mapping function optimized for weight spaces"""
        
        # Use best model state if available
        if self.best_model_state is not None:
            current_state = {k: v.clone() for k, v in self.model.state_dict().items()}
            self.model.load_state_dict(self.best_model_state)

        self.model.eval()
        batch_size, flat_dim = x0.size()
        
        # Initialize trajectory tracking
        traj = [x0.detach().clone()] if return_traj else None
        xt = x0.clone()

        if method == "euler":
            # Your current Euler method (works fine for weight spaces)
            times = torch.linspace(0, 1, n_steps).to(self.device)
            dt = times[1] - times[0]

            for i, t in enumerate(times[:-1]):
                with torch.no_grad():
                    t_tensor = torch.ones(batch_size, 1).to(self.device) * t
                    
                    try:
                        pred = self.model(xt, t_tensor)
                        if pred.dim() > 2:
                            pred = pred.squeeze(-1)

                        # Get velocity
                        if self.mode == "velocity":
                            vt = pred
                        else:
                            vt = pred - xt

                        # Euler step
                        xt = xt + vt * dt
                        
                        # For weight spaces, remove the late-stage noise addition
                        # Weight distributions are stable enough without it
                        
                    except Exception as e:
                        logging.warning(f"Error at step {i}: {e}")
                        break
                
                if return_traj:
                    traj.append(xt.detach().clone())
                    
        elif method == "rk4":
            # Runge-Kutta 4th order (more accurate for smooth weight spaces)
            times = torch.linspace(0, 1, n_steps).to(self.device)
            dt = times[1] - times[0]
            
            for i, t in enumerate(times[:-1]):
                with torch.no_grad():
                    # RK4 integration
                    t_tensor = torch.ones(batch_size, 1).to(self.device) * t
                    
                    try:
                        # k1
                        pred1 = self.model(xt, t_tensor)
                        if pred1.dim() > 2:
                            pred1 = pred1.squeeze(-1)
                        k1 = pred1 if self.mode == "velocity" else pred1 - xt
                        
                        # k2  
                        xt_k2 = xt + 0.5 * dt * k1
                        t_k2 = t_tensor + 0.5 * dt
                        pred2 = self.model(xt_k2, t_k2)
                        if pred2.dim() > 2:
                            pred2 = pred2.squeeze(-1)
                        k2 = pred2 if self.mode == "velocity" else pred2 - xt_k2
                        
                        # k3
                        xt_k3 = xt + 0.5 * dt * k2  
                        pred3 = self.model(xt_k3, t_k2)
                        if pred3.dim() > 2:
                            pred3 = pred3.squeeze(-1)
                        k3 = pred3 if self.mode == "velocity" else pred3 - xt_k3
                        
                        # k4
                        xt_k4 = xt + dt * k3
                        t_k4 = t_tensor + dt
                        pred4 = self.model(xt_k4, t_k4)
                        if pred4.dim() > 2:
                            pred4 = pred4.squeeze(-1)
                        k4 = pred4 if self.mode == "velocity" else pred4 - xt_k4
                        
                        # Final RK4 step
                        xt = xt + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
                        
                    except Exception as e:
                        logging.warning(f"RK4 error at step {i}: {e}")
                        # Fallback to Euler step
                        pred = self.model(xt, t_tensor)
                        if pred.dim() > 2:
                            pred = pred.squeeze(-1)
                        vt = pred if self.mode == "velocity" else pred - xt
                        xt = xt + vt * dt
                
                if return_traj:
                    traj.append(xt.detach().clone())
                    
        elif method == "adaptive":
            # Adaptive step size (overkill for most weight spaces, but available)
            from scipy.integrate import solve_ivp
            
            def vector_field_numpy(t, x_flat):
                x_tensor = torch.from_numpy(x_flat.reshape(batch_size, -1)).float().to(self.device)
                t_tensor = torch.ones(batch_size, 1).to(self.device) * t
                
                with torch.no_grad():
                    pred = self.model(x_tensor, t_tensor)
                    if pred.dim() > 2:
                        pred = pred.squeeze(-1)
                    vt = pred if self.mode == "velocity" else pred - x_tensor
                
                return vt.cpu().numpy().flatten()
            
            # Solve ODE with adaptive step size
            sol = solve_ivp(
                vector_field_numpy, 
                [0, 1], 
                xt.cpu().numpy().flatten(),
                dense_output=True,
                rtol=1e-6
            )
            
            if return_traj:
                # Evaluate at regular intervals for trajectory
                t_eval = np.linspace(0, 1, n_steps)
                traj_points = sol.sol(t_eval)
                traj = [torch.from_numpy(tp.reshape(batch_size, -1)).to(self.device) 
                       for tp in traj_points.T]
            
            xt = torch.from_numpy(sol.y[:, -1].reshape(batch_size, -1)).to(self.device)

        # Restore model state and mode
        if self.best_model_state is not None:
            self.model.load_state_dict(current_state)
        self.model.train()

        return traj if return_traj else xt
    
    # def map(self, x0, n_steps=50, method="euler"):
    #     """Map points using the flow model"""
    #     if self.best_model_state is not None:
    #         current_state = {k: v.clone() for k, v in self.model.state_dict().items()}
    #         self.model.load_state_dict(self.best_model_state)

    #     self.model.eval()
    #     batch_size, flat_dim = x0.size()
        
    #     times = torch.linspace(0, 1, n_steps, device=self.device)
    #     dt = times[1] - times[0]
    #     xt = x0.clone()

    #     for t in times[:-1]:
    #         with torch.no_grad():
    #             t_tensor = torch.ones(batch_size, 1, device=self.device) * t
                
    #             try:
    #                 pred = self.model(xt, t_tensor)
    #                 if pred.dim() > 2:
    #                     pred = pred.squeeze(-1)

    #                 if self.mode == "velocity":
    #                     vt = pred
    #                 else:
    #                     vt = pred - xt

    #                 xt = xt + vt * dt
                        
    #             except Exception as e:
    #                 logging.warning(f"Error during mapping at t={t}: {e}")
    #                 break

    #     if self.best_model_state is not None:
    #         self.model.load_state_dict(current_state)
    #     self.model.train()

    #     return xt
    
    def train(self, n_iters=10000, optimizer=None, sigma=0.001, log_freq=100):
        """Train the flow model"""
        self.sigma = sigma
        
        pbar = tqdm(range(n_iters), desc="Training CFM")
        for i in pbar:
            try:
                optimizer.zero_grad()
                
                flow = self.sample_time_and_flow()
                _, flow_pred = self.forward(flow)
                _, loss = self.loss_fn(flow_pred, flow)
                
                if torch.isfinite(loss):
                    loss.backward()
                    optimizer.step()
                
                    if loss.item() < self.best_loss:
                        self.best_loss = loss.item()
                        self.best_model_state = {k: v.clone() for k, v in self.model.state_dict().items()}
                else:
                    logging.warning(f"Invalid loss at step {i}: {loss.item()}")
                    continue
                
                if i % log_freq == 0:
                    pbar.set_description(f"CFM Training [loss: {loss.item():.6f}]")
            
            except Exception as e:
                logging.error(f"Error during training iteration {i}: {e}")
                continue

class EnhancedWeightSpaceCFM(SimpleCFM):
    """Enhanced CFM for weight spaces with better stability"""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sigma = 0.001

class VisionTransformerFlowModel(nn.Module):
    """Flow model for ViT weight spaces"""
    
    def __init__(self, input_dim, time_embed_dim=64):
        super().__init__()
        self.input_dim = input_dim
        self.time_embed_dim = time_embed_dim
        
        self.time_embed = nn.Sequential(
            nn.Linear(1, time_embed_dim),
            nn.GELU(),
            nn.Linear(time_embed_dim, time_embed_dim)
        )
        
        # Main network with skip connections
        # Smaller network since your weight space seems well-behaved
        hidden_dim = min(256, input_dim // 4)
        
        self.net = nn.Sequential(
            nn.Linear(input_dim + time_embed_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.LayerNorm(hidden_dim//2), 
            nn.GELU(),
            nn.Dropout(0.1),
            
            nn.Linear(hidden_dim//2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            
            nn.Linear(hidden_dim, input_dim)
        )
        
        # Initialize output to zero
        nn.init.zeros_(self.net[-1].weight)
        nn.init.zeros_(self.net[-1].bias)
    
    def forward(self, x, t):
        t_embed = self.time_embed(t)
        combined = torch.cat([x, t_embed], dim=-1)
        return self.net(combined)

def get_permuted_models_data(ref_point=0, model_dir="imagenet_vit_models", 
                           num_models=50, device=None):
    """Load and align ViT models using rebasin"""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    ref_model = create_vit_small()
    ref_model_path = f"{model_dir}/vit_weights_{ref_point}.pt"
    
    try:
        ref_model.load_state_dict(torch.load(ref_model_path, map_location=device))
        ref_model = ref_model.to(device)
        logging.info(f"Loaded reference model from {ref_model_path}")
    except Exception as e:
        logging.error(f"Failed to load reference model: {e}")
        raise e
    
    ref_ws = VisionTransformerWeightSpace.from_vit_model(ref_model)
    weight_space_objects = [ref_ws]
    
    matcher = TransFusionMatcher(num_iterations=2)
    
    for i in tqdm(range(num_models), desc="Processing ViT models"):
        if i == ref_point:
            continue
        
        model_path = f"{model_dir}/vit_weights_{i}.pt"
        if not os.path.exists(model_path):
            logging.warning(f"Skipping model {i} - file not found")
            continue
        
        try:
            model = create_vit_small()
            model.load_state_dict(torch.load(model_path, map_location=device))
            model = model.to(device)
            
            ws = VisionTransformerWeightSpace.from_vit_model(model)
            
            canonicalized_list = matcher.canonicalize_model([ref_ws, ws], reference_idx=0)
            aligned_ws = canonicalized_list[1]
            
            weight_space_objects.append(aligned_ws)
        
        except Exception as e:
            logging.error(f"Error processing model {i}: {e}")
            continue
        
        torch.cuda.empty_cache()
    
    logging.info(f"Successfully processed {len(weight_space_objects)} ViT models")
    return ref_model, weight_space_objects

def train_vit_flow_matching(vit_config=None, model_dir="imagenet_vit_models", 
                           num_models=50):
    """Complete pipeline for training flow matching on ViT weights"""
    if vit_config is None:
        vit_config = {
            'num_classes': 10,
            'embed_dim': 256,
            'depth': 4,
            'num_heads': 4,
            'mlp_ratio': 4.0,
            'dropout': 0.1
        }
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    ref_model, weight_space_objects = get_permuted_models_data(
        model_dir=model_dir,
        num_models=num_models,
        device=device
    )
    
    logging.info("Converting ViT weights to flat tensors...")
    flat_weights = []
    for ws in tqdm(weight_space_objects):
        flat = ws.flatten(device=device)
        flat_weights.append(flat)
    
    flat_target_weights = torch.stack(flat_weights)
    flat_dim = flat_target_weights.shape[1]
    
    logging.info(f"ViT weight space dimension: {flat_dim:,}")

    target_std = torch.std(flat_target_weights).item()
    logging.info(f"Target std: {target_std:.6f}")
    num_samples = len(flat_target_weights)
    flat_source_weights = torch.randn(num_samples, flat_dim, device=device) * target_std
    
    # # Create source distribution
    # flat_source_weights = torch.randn(num_samples, flat_dim, device=device) * 0.1
    
    # Create dataloaders
    source_dataset = TensorDataset(flat_source_weights)
    target_dataset = TensorDataset(flat_target_weights)
    
    sourceloader = DataLoader(source_dataset, batch_size=1, shuffle=True, drop_last=True)
    targetloader = DataLoader(target_dataset, batch_size=1, shuffle=True, drop_last=True)

    
    # Create and train flow model
    flow_model = VisionTransformerFlowModel(flat_dim).to(device)
    logging.info(f"Flow model parameters: {count_parameters(flow_model):,}")
    
    cfm = EnhancedWeightSpaceCFM(
        sourceloader=sourceloader,
        targetloader=targetloader,
        model=flow_model,
        mode="velocity",
        device=device
    )
    
    optimizer = torch.optim.AdamW(
        flow_model.parameters(),
        lr=1e-4,
        weight_decay=1e-5,
        betas=(0.9, 0.95)
    )
    
    cfm.train(
        n_iters=30000,
        optimizer=optimizer,
        sigma=0.001,
        log_freq=50
    )
    
    return cfm, weight_space_objects[0], vit_config, target_std

def generate_new_vit_models(cfm, reference_ws, vit_config, target_std, n_samples=5):
    """Generate new ViT models using trained flow matching"""
    device = cfm.device
    flat_dim = cfm.input_dim
    
    logging.info(f"Generating {n_samples} new ViT models...")
    # source_samples = torch.randn(n_samples, flat_dim, device=device) * target_std
    source_samples = torch.randn(n_samples, flat_dim, device=device) * 0.01
    
    generated_flat = cfm.map(source_samples, n_steps=100)
    
    generated_models = []
    test_loader = load_cifar10(batch_size=128)[1]
    
    for i in range(n_samples):
        try:
            generated_ws = VisionTransformerWeightSpace.from_flat(
                generated_flat[i], reference_ws, device
            )
            
            new_model = create_vit_small(**vit_config).to(device)
            generated_ws.apply_to_model(new_model)
            
            accuracy = evaluate(new_model, test_loader, device)
            logging.info(f"Generated ViT model {i}: {accuracy*100:.2f}% accuracy")
            
            generated_models.append(new_model)
            
        except Exception as e:
            logging.error(f"Error generating model {i}: {e}")
            logging.error(f"vit_config contents: {vit_config}")
            traceback.print_exc()
            continue
    
    return generated_models

def main():
    """Main function demonstrating ViT flow matching"""
    logging.info("Starting ViT Flow Matching Pipeline...")
    
    # Configuration
    vit_config = {
        'num_classes': 10,
        'embed_dim': 256,
        'depth': 4,
        'num_heads': 4,
        'mlp_ratio': 4.0,
        'dropout': 0.1
    }
    
    logging.info(f"Using ViT config: {vit_config}")
    
    try:
        logging.info("Testing model creation...")
        test_model = create_vit_small(**vit_config)
        
        cfm, reference_ws, config, target_std = train_vit_flow_matching(
            vit_config=vit_config,
            model_dir="/scratch/sgupta/data/imagenet_vit_models",
            num_models=100
        )
        
        generated_models = generate_new_vit_models(
            cfm, reference_ws, config, target_std, n_samples=100
        )
        
        logging.info(f"Successfully generated {len(generated_models)} new ViT models!")
        
    except Exception as e:
        logging.error(f"Error in main pipeline: {e}")
        traceback.print_exc()

if __name__ == "__main__":
    main()