In [None]:
# https://github.com/ege-erdogan/weightflow/tree/master
# !pip install git+https://github.com/ege-erdogan/weightflow.git
# !git clone https://github.com/ege-erdogan/weightflow.git

In [None]:
# update scipy, 
# pip install einops
# !pip install rebasin
# !pip install graphviz
# !pip install -r ./weightflow/requirements.txt

# I can feel the worms in my head :(
# Anyways, this notebook contains Conditional multiclass flow matching for MNIST and Fashion MNIST MLPs
# * Git rebasin
# * Flow matching and generation testing
# * Hybrid embedding generation
# * Populating larger hidden dims from generated weight space dists (sorta works?)

In [None]:
import sys
sys.path.append('./weightflow')  # e.g., './repo_name'

import torch
import numpy as np
from collections import defaultdict
from typing import NamedTuple
from scipy.optimize import linear_sum_assignment

In [None]:
from nn.relational_transformer import RelationalTransformer
from nn.graph_constructor import GraphConstructor
from flow.flow_matching import CFM
from tqdm import tqdm
import copy
import logging
from utils.data import sample_gaussian_wsos

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')


# Ok, imports done. Now rebasin stuff. 

In [None]:
# PermutationSpec class similar to the JAX version but using PyTorch
class PermutationSpec(NamedTuple):
    perm_to_axes: dict
    axes_to_perm: dict

def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
    perm_to_axes = defaultdict(list)
    for wk, axis_perms in axes_to_perm.items():
        for axis, perm in enumerate(axis_perms):
            if perm is not None:
                perm_to_axes[perm].append((wk, axis))
    return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)

def mlp_permutation_spec_mlp() -> PermutationSpec:
    """Define permutation spec for MLP architecture"""
    return permutation_spec_from_axes_to_perm({
        "fc1.weight": (None, "P_0"),       # Input (None) to fc1 output (P_0)
        "fc1.bias": ("P_0",),              # Bias for fc1 output (P_0)
        "fc2.weight": ("P_0", "P_1"),      # fc1 output (P_0) to fc2 output (P_1)
        "fc2.bias": ("P_1",),              # Bias for fc2 output (P_1)
        "fc3.weight": ("P_1", None),       # fc2 output (P_1) to fc3 output (None)
        "fc3.bias": (None,),               # Bias for fc3 output (None)
    })

def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
    """Get parameter k from params, with permutations applied."""
    w = params[k]
    for axis, p in enumerate(ps.axes_to_perm[k]):
        # Skip the axis we're trying to permute
        if axis == except_axis:
            continue

        # None indicates no permutation for that axis
        if p is not None:
            w = torch.index_select(w, axis, torch.tensor(perm[p], device=w.device))

    return w

def apply_permutation(ps: PermutationSpec, perm, params):
    """Apply permutation to params"""
    return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}

def weight_matching(ps: PermutationSpec, params_a, params_b, max_iter=100, init_perm=None, silent=True, device=None):
    """Find permutation of params_b to make them match params_a."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Move all tensors to the correct device
    params_a = {k: v.to(device) for k, v in params_a.items()}
    params_b = {k: v.to(device) for k, v in params_b.items()}

    # Get permutation sizes from the first parameter with each permutation
    perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] 
                  for p, axes in ps.perm_to_axes.items()}
    
    # Initialize permutations to identity if none provided
    if init_perm is None:
        perm = {p: torch.arange(n, device=device) for p, n in perm_sizes.items()}
    else:
        perm = {p: v.to(device) for p, v in init_perm.items()}
        
    perm_names = list(perm.keys())
    
    # Use a random number generator with a fixed seed for reproducibility
    rng = np.random.RandomState(42)

    for iteration in range(max_iter):
        progress = False
        
        # Shuffle the order of permutations to update
        for p_ix in rng.permutation(len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]
            
            # Initialize cost matrix
            A = torch.zeros((n, n), device=device)
            
            # Fill in cost matrix based on all parameters affected by this permutation
            for wk, axis in ps.perm_to_axes[p]:
                w_a = params_a[wk]
                w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)

                w_a = w_a.moveaxis(axis, 0).reshape((n, -1))
                w_b = w_b.moveaxis(axis, 0).reshape((n, -1))

                A += w_a @ w_b.T

            # Solve the linear assignment problem
            ri, ci = linear_sum_assignment(A.detach().cpu().numpy(), maximize=True)
            assert (ri == np.arange(len(ri))).all()

            # Calculate improvement
            eye_old = torch.eye(n, device=device)[perm[p]]
            eye_new = torch.eye(n, device=device)[ci]

            oldL = torch.tensordot(A, eye_old, dims=([0, 1], [0, 1]))
            newL = torch.tensordot(A, eye_new, dims=([0, 1], [0, 1]))

            if not silent and newL > oldL + 1e-12:
                logging.info(f"{iteration}/{p}: {newL.item() - oldL.item()}")

            progress = progress or newL > oldL + 1e-12

            perm[p] = torch.tensor(ci, device=device)

        if not progress:
            break

    return perm


def update_model_weights(model, aligned_params):
    """Update model weights with aligned parameters"""
    # Convert numpy arrays to torch tensors if needed
    model.fc1.weight.data = aligned_params["fc1.weight"].T
    model.fc1.bias.data = aligned_params["fc1.bias"]
    model.fc2.weight.data = aligned_params["fc2.weight"].T
    model.fc2.bias.data = aligned_params["fc2.bias"]
    model.fc3.weight.data = aligned_params["fc3.weight"].T
    model.fc3.bias.data = aligned_params["fc3.bias"]
    
def load_model_weights(model, model_path):
    """Load model weights from file"""
    weights, biases = torch.load(model_path, map_location=device)
    model.fc1.weight.data = weights[0]
    model.fc1.bias.data = biases[0]
    model.fc2.weight.data = weights[1]
    model.fc2.bias.data = biases[1]
    model.fc3.weight.data = weights[2]
    model.fc3.bias.data = biases[2]
    return model.to(device)

def get_permuted_models_data(ref_point=0, model_dir="models", num_models=200, model_type = f'MNIST'):
    """Apply weight matching to align models with a reference model"""
    # Create reference model
    ref_model = MLP()  # Assumes MLP class is defined
    ref_model_path = f"{model_dir}/{model_type}_mixed_mlp_weights_{ref_point}.pt"
    ref_model = load_model_weights(ref_model, ref_model_path).to(device)
    
    ps = mlp_permutation_spec_mlp()
    
    # Convert reference model weights to dictionary format
    params_a = {
        "fc1.weight": ref_model.fc1.weight.T.to(device),
        "fc1.bias": ref_model.fc1.bias.to(device),
        "fc2.weight": ref_model.fc2.weight.T.to(device),
        "fc2.bias": ref_model.fc2.bias.to(device),
        "fc3.weight": ref_model.fc3.weight.T.to(device),
        "fc3.bias": ref_model.fc3.bias.to(device),
    }
    
    org_models = []
    permuted_models = []

    for i in range(0, num_models):
        if i == ref_point:
            continue
            
        model_path = f"{model_dir}/{model_type}_mixed_mlp_weights_{i}.pt"

        model = MLP()  # Assumes MLP class is defined
        model = load_model_weights(model, model_path).to(device)
        org_models.append(model)
        
        # Convert model weights to dictionary format
        params_b = {
                "fc1.weight": model.fc1.weight.T.to(device),
                "fc1.bias": model.fc1.bias.to(device),
                "fc2.weight": model.fc2.weight.T.to(device),
                "fc2.bias": model.fc2.bias.to(device),
                "fc3.weight": model.fc3.weight.T.to(device),
                "fc3.bias": model.fc3.bias.to(device),
        }

        # Find permutation to align with reference model
        perm = weight_matching(ps, params_a, params_b)
        
        # Apply permutation to model_b
        aligned_params_b = apply_permutation(ps, perm, params_b)
        
        # Create a new model with permuted weights
        reconstructed_model = copy.deepcopy(model)
        update_model_weights(reconstructed_model, aligned_params_b)
        
        permuted_models.append(reconstructed_model.to(device))

            
    return ref_model, org_models, permuted_models


# Not sure whats next but ok, 
* Review WSOs 

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
import copy
from collections import defaultdict
import os
import traceback

In [None]:
# Define MNIST classifier MLP class for dataset
class MLP(nn.Module):
    """
    A fully connected 3 layer neural network for classification tasks where hidden layers have 32 neurons (default) and ReLU activations
    input: torch.tensor( [batch_size, 196] )
    output: torch.tensor( [batch_size, 10] )
    For classifying inputs of 196 into 10 classes.
    """
    def __init__(self, hidden_dim = 32, init_type='xavier', seed=None, type = 'MNIST'):
        super(MLP, self).__init__()
        
        self.fc1 = nn.Linear(196, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 10)
        
        if seed is not None:
            torch.manual_seed(seed)  # Set a unique seed for reproducibility

        self.init_weights(init_type)
        self.type = type

    def init_weights(self, init_type):
        if init_type == 'xavier':
            nn.init.xavier_uniform_(self.fc1.weight)
            nn.init.xavier_uniform_(self.fc2.weight)
            nn.init.xavier_uniform_(self.fc3.weight)
        elif init_type == 'he':
            nn.init.kaiming_uniform_(self.fc1.weight, nonlinearity='relu')
            nn.init.kaiming_uniform_(self.fc2.weight, nonlinearity='relu')
            nn.init.kaiming_uniform_(self.fc3.weight, nonlinearity='relu')
        else:
            nn.init.normal_(self.fc1.weight)
            nn.init.normal_(self.fc2.weight)
            nn.init.normal_(self.fc3.weight)
        
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
        nn.init.zeros_(self.fc3.bias)

    def forward(self, x):
        x = x.view(-1, 196)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x



def test_mlp(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    model = model.to(device)
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return 100 * correct / total

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def zero_like_wso(wso):
    zero_weights = tuple(torch.zeros_like(w) for w in wso.weights)
    zero_biases = tuple(torch.zeros_like(b) for b in wso.biases)
    return WeightSpaceObject(zero_weights, zero_biases)

# WeightSpaceObject class for handling MLP weights
class WeightSpaceObject:
    def __init__(self, weights, biases):
        self.weights = weights if isinstance(weights, tuple) else tuple(weights)
        self.biases = biases if isinstance(biases, tuple) else tuple(biases)
        
    def flatten(self, device=None):
        """Flatten weights and biases into a single vector"""
        flat = torch.cat([w.flatten() for w in self.weights] + 
                          [b.flatten() for b in self.biases])
        if device:
            flat = flat.to(device)
        return flat
    
    @classmethod
    def from_flat(cls, flat, layers, device):
        """Create WeightSpaceObject from flattened vector"""
        sizes = []
        # Calculate sizes for weight matrices
        for i in range(len(layers) - 1):
            sizes.append(layers[i] * layers[i+1])  # Weight matrix
        # Calculate sizes for bias vectors
        for i in range(1, len(layers)):
            sizes.append(layers[i])  # Bias vector
            
        # Split flat tensor into parts
        parts = []
        start = 0
        for size in sizes:
            parts.append(flat[start:start+size])
            start += size
            
        # Reshape into weight matrices and bias vectors
        weights = []
        biases = []
        for i in range(len(layers) - 1):
            w_size = layers[i] * layers[i+1]
            weights.append(parts[i].reshape(layers[i+1], layers[i]))
            biases.append(parts[i + len(layers) - 1])
            
        return cls(weights, biases).to(device)
    
    def to(self, device):
        """Move weights and biases to specified device"""
        weights = tuple(w.to(device) for w in self.weights)
        biases = tuple(b.to(device) for b in self.biases)
        return WeightSpaceObject(weights, biases)
        
    def map(self, fn):
        new_weights = tuple(fn(w) for w in self.weights)
        new_biases = tuple(fn(b) for b in self.biases)
        return WeightSpaceObject(new_weights, new_biases)

# Simple Bunch class for storing data
class Bunch:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

# Safe deflatten function that checks bounds before accessing tensors
def safe_deflatten(flat, batch_size, starts, ends):
    """Safely deflatten a tensor without index errors"""
    parts = []
    actual_batch_size = flat.size(0)
    
    # Ensure we don't exceed the actual batch size
    safe_batch_size = min(actual_batch_size, batch_size)
    
    for i in range(safe_batch_size):
        batch_parts = []
        for si, ei in zip(starts, ends):
            if si < ei:  # Only process valid ranges
                batch_parts.append(flat[i][si:ei])
        parts.append(batch_parts)
    
    return parts


# Flow matching

In [None]:
# model_type = MNIST or Fashion MNIST
ref_model, original_models, mnist_permuted_models = get_permuted_models_data(ref_point=0, model_type = "MNIST")
ref_model, original_models, fmnist_permuted_models = get_permuted_models_data(ref_point=0, model_type = "Fashion MNIST")

layer_layout = [196, 32, 32, 10]  # MLP architecture for MNIST

# Create WSO objects from permuted models
logging.info("Converting MNIST models to WeightSpaceObjects...")
mnist_weights_list = []
for model in tqdm(mnist_permuted_models):
    weights = (
        model.fc1.weight.data.clone(),
        model.fc2.weight.data.clone(),
        model.fc3.weight.data.clone()
    )
    
    biases = (
        model.fc1.bias.data.clone(),
        model.fc2.bias.data.clone(), 
        model.fc3.bias.data.clone()
    )
    
    wso = WeightSpaceObject(weights, biases)
    mnist_weights_list.append(wso)

logging.info("Converting Fashion MNIST models to WeightSpaceObjects...")
fmnist_weights_list = []
for model in tqdm(fmnist_permuted_models):
    weights = (
        model.fc1.weight.data.clone(),
        model.fc2.weight.data.clone(),
        model.fc3.weight.data.clone()
    )
    
    biases = (
        model.fc1.bias.data.clone(),
        model.fc2.bias.data.clone(), 
        model.fc3.bias.data.clone()
    )
    
    wso = WeightSpaceObject(weights, biases)
    fmnist_weights_list.append(wso)



In [None]:
# Create flat vectors
# logging.info("Converting to flat tensors...")

flat_mnist_target_weights = torch.stack([wso.flatten(device) for wso in mnist_weights_list])
mnist_labels = torch.full([flat_mnist_target_weights.shape[0], 1], 0)

flat_fmnist_target_weights = torch.stack([wso.flatten(device) for wso in fmnist_weights_list])
fmnist_labels = torch.full([flat_fmnist_target_weights.shape[0], 1], 1)

mnist_target_dataset = TensorDataset(flat_mnist_target_weights, mnist_labels)
fmnist_target_dataset = TensorDataset(flat_fmnist_target_weights, fmnist_labels)

from torch.utils.data import ConcatDataset
def collate_fn(batch):
    # batch is a list of (flat, label)
    # We’ll sample pairs _within_ the same label to do conditional flow matching:
    flats, labs = zip(*batch)
    flats = torch.stack(flats)  # [B, flat_dim]
    labs  = torch.stack(labs)   # [B]
    return flats, labs

batch_size = 32
targetloader = DataLoader(
    ConcatDataset([mnist_target_dataset, fmnist_target_dataset]), 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=collate_fn
)

# Multiclass CFM time?

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from types import SimpleNamespace


class ResidualBlock(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.norm = nn.LayerNorm(out_dim)
        self.activation = nn.ReLU()
        self.residual = (in_dim == out_dim)

    def forward(self, x):
        out = self.norm( self.activation(self.linear(x)) )
        if self.residual:
            return out + x
        return out

class TimeConditionedMLP(nn.Module):
    """
    A conditional flow network that takes:
      - x ∈ ℝ^{flat_dim}
      - t ∈ ℝ (scalar in [0,1], shaped as [B,1])
      - c ∈ {0,1,…,num_classes−1} (class labels)

    and returns a predicted velocity in ℝ^{flat_dim}. Internally:
      • We embed t via a small MLP (time_embed).
      • We embed c via nn.Embedding.
      • We concatenate (x, t_embed, c_embed) → pass through two hidden layers → output dimension flat_dim.
    """

    def __init__(
        self,
        flat_dim: int,
        t_embed_dim: int,
        class_embed_dim: int,
        num_classes: int,
        hidden_dim: int = 512,
    ):
        super().__init__()
        self.flat_dim = flat_dim
        self.t_embed_dim = t_embed_dim
        self.class_embed_dim = class_embed_dim
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim

        # 1) Time embedding network: ℝ → ℝ^{t_embed_dim}
        self.time_embed = nn.Sequential(
            nn.Linear(1, t_embed_dim),
        )

        # 2) Class embedding: maps {0,…,num_classes−1} → ℝ^{class_embed_dim}
        self.class_emb = nn.Embedding(num_classes, class_embed_dim)

        # 3) Main MLP: input dim = flat_dim + t_embed_dim + class_embed_dim
        in_dim = flat_dim + t_embed_dim + class_embed_dim

        layers = []
        prev_dim = in_dim
        for hidden_dim in hidden_dims:
            layers.append(ResidualBlock(prev_dim, hidden_dim))
            prev_dim = hidden_dim
            
        self.hidden_layers = nn.Sequential(*layers)
        self.output_layer = nn.Linear(prev_dim, flat_dim)

    def forward(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """
        Inputs:
          x : Tensor of shape [B, flat_dim]
          t : Tensor of shape [B, 1]  (scalar time per sample)
          c : LongTensor of shape [B] (class labels in {0,...,num_classes-1})

        Returns:
          velocity_pred : Tensor of shape [B, flat_dim]
        """
        # 1) Compute time embeddings
        #    t is [B, 1], pass through time_embed → [B, t_embed_dim]
        t_emb = self.time_embed(t)

        # 2) Compute class embeddings
        #    c is [B], pass through nn.Embedding → [B, class_embed_dim]
        c_emb = self.class_emb(c)

        # 3) Concatenate along feature dimension → [B, flat_dim + t_embed_dim + class_embed_dim]
        xcat = torch.cat([x, t_emb, c_emb], dim=-1)

        # 4) Forward through two hidden layers
        x = self.hidden_layers(xcat)
        return self.output_layer(x)


class SimpleCFM:
    def __init__(
        self,
        targetloader: torch.utils.data.DataLoader,
        model: TimeConditionedMLP,
        layer_layout,
        fm_type: str = "vanilla",
        mode: str = "velocity",
        t_dist: str = "uniform",
        device: torch.device = None,
        normalize_pred: bool = False,
        geometric: bool = False,
        source_std: float = 0.001,   # standard deviation for on-the-fly noise
    ):
        """
        Args:
          targetloader   : DataLoader yielding (flat_tensor, label) for the “target” weight-space objects.
                           We assume `label ∈ {0,1}`.  On each iteration we draw exactly one batch from here.
          model          : A TimeConditionedMLP instance (flat_dim + time + class → velocity).
          layer_layout   : np.array describing layer sizes—(kept for legacy, not used directly here).
          fm_type        : “vanilla” or “ot”—ignored for now.
          mode           : “velocity” or “ot”—we implement only velocity-mode here.
          t_dist         : “uniform” or “beta”—we sample t ~ Uniform(0,1) for now.
          device         : torch.device (“cuda” or “cpu”).
          normalize_pred : If True, normalize predicted velocity to unit norm (optional).
          geometric      : If True, use geometric interpolation (not implemented; we use linear).
          source_std     : float ≥0, the σ used for on-the-fly Gaussian noise in ℝ^{flat_dim}.
        """
        self.targetloader = targetloader
        self.model        = model.to(device)
        self.layer_layout = layer_layout
        self.fm_type      = fm_type
        self.mode         = mode
        self.t_dist       = t_dist
        self.device       = device if device is not None else torch.device("cpu")
        self.normalize_pred = normalize_pred 
        self.geometric    = geometric
        self.source_std   = source_std

        self.metrics      = {"train_loss": [], "time": []}
        self.best_loss    = float("inf")
        self.best_model_state = None

        # We will need an iterator over targetloader:
        self._target_iter = iter(self.targetloader)

    def sample_from_loader(self):
        """
        Sample one minibatch from targetloader, then generate a matching 'source' noise batch
        with exactly the same labels.  Returns a ‘flow’ namespace with attributes:

          flow.xt          : [B, flat_dim]   = x₀ + t*(x₁ − x₀)
          flow.t           : [B, 1]          = the sampled t ∼ Uniform(0,1)
          flow.true_flow   : [B, flat_dim]   = (x₁ − x₀)
          flow.class_label : [B] (long)      = class label (0 or 1) for each pair
        """
        try:
            flats1, labs1 = next(self._target_iter)
        except StopIteration:
            # re‐create the iterator if we hit the end
            self._target_iter = iter(self.targetloader)
            flats1, labs1 = next(self._target_iter)

        flats1 = flats1.to(self.device)  # [B, flat_dim]
        labs1  = labs1.squeeze(-1).to(self.device).long()  # [B], ensure type=long

        B, flat_dim = flats1.shape

        # 1) On‐the‐fly: generate x₀ (noise) for each sample in this batch
        #    We simply sample from N(0, source_std²) in ℝ^{flat_dim}:
        #    shape = [B, flat_dim]
        #    If you wanted to make the noise distribution depend on 'labs1', you could
        #    do two separate Gaussian draws.  But we assume just diagonal Gaussian noise
        #    works for both classes.
        x0 = torch.randn(B, flat_dim, device=self.device) * self.source_std

        # 2) Sample t ∼ Uniform(0,1) for each sample
        t = torch.rand(B, 1, device=self.device)  # [B,1]

        # 3) Compute linear interpolation: xₜ = x₀ + t*(x₁ − x₀)
        xt = x0 + t * (flats1 - x0)  # [B, flat_dim]

        # 4) The ‘true_flow’ is (x₁ - x₀)
        true_flow = (flats1 - x0)    # [B, flat_dim]

        # 5) Build the SimpleNamespace
        flow = SimpleNamespace()
        flow.xt          = xt             # [B, flat_dim]
        flow.t           = t              # [B, 1]
        flow.true_flow   = true_flow      # [B, flat_dim]
        flow.class_label = labs1          # [B] (long tensor, values ∈ {0,1})

        return flow

    def forward(self, flow: SimpleNamespace) -> torch.Tensor:
        """
        Given a flow struct (xt, t, class_label), returns the predicted velocity
        => model(xt, t, class_label) ∈ ℝ^{flat_dim}.
        """
        x_t = flow.xt           # [B, flat_dim]
        t   = flow.t            # [B, 1]
        c   = flow.class_label  # [B]

        pred = self.model(x_t, t, c)  # [B, flat_dim]

        if self.normalize_pred:
            norm = pred.norm(dim=-1, keepdim=True).clamp(min=1e-6)
            pred = pred / norm

        return pred
        
    # def loss_fn(self, flow_pred, flow):
    #     """Compute loss between predicted and true flows"""
    #     if self.mode == "target":
    #         l_flow = torch.mean((flow_pred.squeeze() - flow.x1) ** 2)
    #     elif self.mode == "velocity":
    #         l_flow = torch.mean((flow_pred.squeeze() - flow.ut) ** 2)
    #     elif self.fm_type == "ot":
    #         l_flow = torch.mean((flow_pred.squeeze() - flow.ut) ** 2)
    #     else:
    #         # Fallback to velocity mode if unknown
    #         l_flow = torch.mean((flow_pred.squeeze() - flow.ut) ** 2)
    #     return None, l_flow
        
    def train(self, n_iters: int, optimizer: torch.optim.Optimizer, sigma: float = 0.0):
        """
        Run n_iters of conditional flow‐matching training. At each iteration:
          1) Sample a minibatch → flow = sample_from_loader()
          2) pred_flow = forward(flow)
          3) loss = ||pred_flow − true_flow||²₂
          4) backward & step
        """
        self.model.train()
        for it in range(1, n_iters + 1):
            flow      = self.sample_from_loader()
            pred_flow = self.forward(flow)             # [B, flat_dim]
            true_flow = flow.true_flow                 # [B, flat_dim]

            # update loss function to a transport map, target, velocity, ot, etc
            loss = F.mse_loss(pred_flow, true_flow)    # scalar - replaces velocity setting
            # _, loss = self.loss_fn(pred_flow, true_flow)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            self.metrics["train_loss"].append(loss.item())
            if it % 100 == 0 or it == n_iters:
                print(f"[Iter {it:5d}/{n_iters:5d}]  Loss = {loss.item():.6f}")
                
                checkpoint_dir = 'checkpoints'
                os.makedirs(checkpoint_dir, exist_ok=True)
                ckpt_path = os.path.join(checkpoint_dir, f'linear_cfm_{it}.pth') # changed!
                torch.save(self.model.state_dict(), ckpt_path)

            if loss.item() < self.best_loss:
                self.best_loss = loss.item()
                self.best_model_state = {
                    "model":     self.model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "iteration": it,
                }

    @torch.no_grad()
    def map(
        self,
        random_flat: torch.Tensor,
        class_label: torch.Tensor,
        n_steps: int = 100,
        noise_scale: float = 0.0,
    ) -> torch.Tensor:
        """
        Generate new weight-space samples conditioned on a given class label by integrating
        from t=0 → t=1 using the learned velocity network.

        Args:
          random_flat  : [B, flat_dim] initial “noise” at t=0
          class_label  : [B]        long tensor ∈ {0,1}
          n_steps      : int, # of Euler steps
          noise_scale  : float, added Gaussian noise at each step (optional)

        Returns:
          x : [B, flat_dim] final mapped weights
        """
        self.model.eval()
        x = random_flat.to(self.device)          # [B, flat_dim]
        # c = self.model.class_emb(class_label).to(self.device).long()
        c = class_label.to(self.device).long()   # [B]

        dt = 1.0 / float(n_steps)
        for i in range(n_steps):
            # Compute tᵢ = i / n_steps
            t_i = torch.full((x.shape[0], 1), float(i) / float(n_steps), device=self.device)

            # Predict velocity vᵢ = model(xᵢ, tᵢ, c)
            v = self.model(x, t_i, c)          # [B, flat_dim]

            # Euler update: x ← x + v * dt
            x = x + v * dt

            if noise_scale > 0.0:
                x = x + noise_scale * torch.randn_like(x)

        return x  # [B, flat_dim]


In [None]:
flat_dim        = flat_mnist_target_weights.shape[1]
t_embed_dim     = 16         # e.g. choose 16
class_embed_dim = 64         # size of class embedding (tunable)
num_classes     = 2          # MNIST vs. FashionMNIST
hidden_dims     = [4096, 2048, 2048, 4096]

flow_net = TimeConditionedMLP(
    flat_dim=flat_dim,
    t_embed_dim=t_embed_dim,
    class_embed_dim=class_embed_dim,
    num_classes=num_classes,
    hidden_dim=hidden_dims
).to(device)

# # Set to training mode
flow_net.train()

# Count parameters
n_params_base = sum(p.numel() for p in MLP().parameters())
n_params_flow = count_parameters(flow_net)
logging.info(f"MLP params:{n_params_base}")
logging.info(f"Flow model params:{n_params_flow}")

In [None]:
# 3) Instantiate SimpleCFM with only targetloader
cfm = SimpleCFM(
    targetloader   = targetloader,
    model          = flow_net,
    layer_layout   = layer_layout, # just to keep the old API happy
    fm_type        = "vanilla",
    mode           = "velocity",
    t_dist         = "uniform",
    device         = device,
    normalize_pred = True, # consider setting to true?
    geometric      = False, # doesn't do anything right now
    source_std     = 0.001, # match whatever you used beforehyperparameter tuning, 0.001 worked for single class CFM
)

optimizer = torch.optim.Adam(flow_net.parameters(), lr=1e-4) # Adam vs AdamW
# cfm.train(n_iters=10000, optimizer=optimizer, sigma=0.01)
# cfm.plot_metrics()

In [None]:
checkpoint_dir = 'checkpoints'
load_it = 4500
last_checkpoint = os.path.join(checkpoint_dir, f'linear_cfm_{load_it}.pth')


# Re-instantiate model with same architecture

flat_dim        = flat_mnist_target_weights.shape[1]
t_embed_dim     = 16         # e.g. choose 16
class_embed_dim = 64         # size of class embedding (tunable)
num_classes     = 2          # MNIST vs. FashionMNIST
hidden_dims     = [4096, 2048, 2048, 4096]

flow_net = TimeConditionedMLP(
    flat_dim=flat_dim,
    t_embed_dim=t_embed_dim,
    class_embed_dim=class_embed_dim,
    num_classes=num_classes,
    hidden_dim=hidden_dims
).to(device)

flow_net.load_state_dict(torch.load(last_checkpoint))
flow_net.eval()
print(f"Loaded checkpoint from '{last_checkpoint}'")

cfm = SimpleCFM(
    targetloader   = targetloader,
    model          = flow_net,
    layer_layout   = layer_layout, # just to keep the old API happy
    fm_type        = "vanilla",
    mode           = "velocity",
    t_dist         = "uniform",
    device         = device,
    normalize_pred = True, # consider setting to true?
    geometric      = False, # doesn't do anything right now
    source_std     = 0.001, # match whatever you used beforehyperparameter tuning, 0.001 worked for single class CFM
)

In [None]:
from torchvision import datasets, transforms

# need to handle test_loader selection correctly. 
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((14, 14)),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset with the downsampling transform
mnist_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)
mnist_test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=mnist_transform)

# Create data loaders
mnist_train_loader = torch.utils.data.DataLoader(mnist_train_dataset, batch_size=64, shuffle=True)
mnist_test_loader = torch.utils.data.DataLoader(mnist_test_dataset, batch_size=64, shuffle=False)

# Resize (14x14) and Normalize Fashion MNIST images from torch.vision then create dataset and dataloader
fashion_mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((14, 14)),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset with the downsampling transform
fashion_mnist_train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=fashion_mnist_transform)
fashion_mnist_test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=fashion_mnist_transform)

# Create data loaders
fashion_mnist_train_loader = torch.utils.data.DataLoader(fashion_mnist_train_dataset, batch_size=64, shuffle=True)
fashion_mnist_test_loader = torch.utils.data.DataLoader(fashion_mnist_test_dataset, batch_size=64, shuffle=False)

In [None]:
# Generate new MLP weights
logging.info("Generating new MLP weights...")
n_samples = 10

generate_type = 1 # 0 MNIST, 1 Fashion MNIST
source_std = cfm.source_std
random_flat = torch.randn(n_samples, flat_dim, device=device) * source_std
class_labels = torch.full([n_samples], generate_type, dtype=torch.long, device=device)#.unsqueeze(0) # 0 for MNIST and 1 for Fashion MNIST

new_weights_flat = cfm.map(
    random_flat, 
    class_labels,
    n_steps=100,
    noise_scale=0.001
)


if generate_type == 0:
    test_loader = mnist_test_loader
else: 
    test_loader = fashion_mnist_test_loader
    
# Convert to MLP weights and save
accuracies = []
for i in range(n_samples):
    new_wso = WeightSpaceObject.from_flat(
        new_weights_flat[i], 
        layers=np.array(layer_layout), 
        device=device
    )

    expected_weight_shapes = [(32, 196), (32, 32), (10, 32)]
    expected_bias_shapes = [(32,), (32,), (10,)]

    assert len(new_wso.weights) == 3, f"Expected 3 weight matrices, got {len(new_wso.weights)}"
    assert len(new_wso.biases) == 3, f"Expected 3 bias vectors, got {len(new_wso.biases)}"
    
    # Check each weight and bias shape
    for j, (w, expected_shape) in enumerate(zip(new_wso.weights, expected_weight_shapes)):
        assert w.shape == expected_shape, f"Weight {j} has shape {w.shape}, expected {expected_shape}"
    
    for j, (b, expected_shape) in enumerate(zip(new_wso.biases, expected_bias_shapes)):
        assert b.shape == expected_shape, f"Bias {j} has shape {b.shape}, expected {expected_shape}"

    # Save the generated weights
    # torch.save(
    #     (new_wso.weights, new_wso.biases),
    #     f"generated_mlp_weights_{i}.pt"
    # )

    # Create and test model
    model = MLP()
    model.fc1.weight.data = new_wso.weights[0].clone()
    model.fc1.bias.data = new_wso.biases[0].clone()
    model.fc2.weight.data = new_wso.weights[1].clone()
    model.fc2.bias.data = new_wso.biases[1].clone()
    model.fc3.weight.data = new_wso.weights[2].clone()
    model.fc3.bias.data = new_wso.biases[2].clone()

    acc = test_mlp(model, test_loader)
    accuracies.append(acc)
    title_type = 0
    if generate_type == 0:  title_type = f"MNIST" 
    else: title_type = f"Fashion MNIST"
        
    logging.info(f"Generated { title_type } MLP {i} accuracy: {acc:.2f}%")


In [None]:
plt.figure(figsize=(12, 8), dpi=80)
plt.hist(accuracies, alpha = 0.8)
plt.axvline(np.mean(accuracies), color = 'k', label = f"Mean generated {n_samples} model(s) performance")
plt.axvline(np.mean(accuracies) + np.std(accuracies), color = 'k', linestyle = '--', label = f"+1 Deviation performance")
plt.axvline(np.mean(accuracies) - np.std(accuracies), color = 'k', linestyle = '--', label = f"-1 Deviation performance")
plt.legend()
plt.show()

# Hybrid Embeddings: 

In [None]:
from torchvision import datasets, transforms

# Generate new MLP weights
logging.info("Generating new MLP weights...")
n_samples = 5
n_steps = 100
generate_types = [0.98, 0.9825, 0.985, 0.9875, 0.99, 0.9925, 0.995, 0.9975, 1] # 0 MNIST, 1 Fashion MNIST

for generate_type in generate_types: 
    
    source_std = 0.001
    random_flat = torch.randn(n_samples, flat_dim, device=device) * source_std
    class_labels = torch.full([n_samples], generate_type, dtype=torch.long, device=device)#.unsqueeze(0) # 0 for MNIST and 1 for Fashion MNIST
    
    new_weights_flat = cfm.map(
        random_flat, 
        class_labels,
        n_steps=n_steps,
        noise_scale=0.001
    )
    # Convert to MLP weights and save
    mnist_accuracies = []
    fashion_mnist_accuracies = []
    
    for i in range(n_samples):
        new_wso = WeightSpaceObject.from_flat(
            new_weights_flat[i], 
            layers=np.array(layer_layout), 
            device=device
        )
    
        expected_weight_shapes = [(32, 196), (32, 32), (10, 32)]
        expected_bias_shapes = [(32,), (32,), (10,)]
    
        assert len(new_wso.weights) == 3, f"Expected 3 weight matrices, got {len(new_wso.weights)}"
        assert len(new_wso.biases) == 3, f"Expected 3 bias vectors, got {len(new_wso.biases)}"
        
        # Check each weight and bias shape
        for j, (w, expected_shape) in enumerate(zip(new_wso.weights, expected_weight_shapes)):
            assert w.shape == expected_shape, f"Weight {j} has shape {w.shape}, expected {expected_shape}"
        
        for j, (b, expected_shape) in enumerate(zip(new_wso.biases, expected_bias_shapes)):
            assert b.shape == expected_shape, f"Bias {j} has shape {b.shape}, expected {expected_shape}"
    
        # Create and test model
        model = MLP()
        model.fc1.weight.data = new_wso.weights[0].clone()
        model.fc1.bias.data = new_wso.biases[0].clone()
        model.fc2.weight.data = new_wso.weights[1].clone()
        model.fc2.bias.data = new_wso.biases[1].clone()
        model.fc3.weight.data = new_wso.weights[2].clone()
        model.fc3.bias.data = new_wso.biases[2].clone()
    
        mnist_acc = test_mlp(model, mnist_test_loader)
        mnist_accuracies.append(mnist_acc)

        fmnist_acc = test_mlp(model, fashion_mnist_test_loader)
        fashion_mnist_accuracies.append(fmnist_acc)
            
        logging.info(f"Generated { generate_type } labeled MLP {i} accuracies: MNIST: {mnist_acc:.2f}%  Fashion MNIST: {fmnist_acc:.2f}%")
    print(f"--- Generated type {generate_type} stats: avg MNIST: {np.mean(mnist_accuracies):.2f}%, avg Fashion MNIST: {np.mean(fashion_mnist_accuracies):.2f}% ---")
    

# Width projection:

In [None]:
import torch
import numpy as np
from sklearn.decomposition import PCA

# gc = torch.cuda.empty_cache if torch.cuda.is_available() else lambda: None

def collect_all_neurons(
    diffusion_model, #
    scheduler, #
    gae, #
    template, #
    class_label,
    latent_dim, #
    N_nodes, #
    num_samples: int,
    n_in: int,
    n_hid: int,
    n_out: int,
    device: torch.device
) -> (np.ndarray, np.ndarray):
    """
    Sample `num_samples` small MLPs via the latent diffusion model, decode them,
    and extract joint parameter vectors for each hidden neuron in both hidden layers.

    Each hidden neuron is represented by:
      - Layer1 neurons: [W1_row (n_in), b1 (1), W2_column (n_hid)]
      - Layer2 neurons: [W2_row (n_hid), b2 (1), W3_column (n_out)]

    Returns:
      L1: np.ndarray of shape [num_samples * n_hid, n_in + 1 + n_hid]
      L2: np.ndarray of shape [num_samples * n_hid, n_hid + 1 + n_out]
    """
    layer1_neurons = []
    layer2_neurons = []

    # Ensure models are in eval mode for deterministic behavior
    diffusion_model.eval()
    gae.eval()

    for _ in range(num_samples):
        # 1) Sample a single latent graph
        z = sample_latents_2d_cond(
            diffusion_model,
            scheduler,
            num_graphs=1,
            N_nodes=N_nodes,
            latent_dim=latent_dim,  # placeholder
            device=device,
            class_label=class_label
        )[0].to(device)

        # 2) Decode latent to graph predictions, then to an MLP
        edge_pred, node_pred = gae.decoder(
            z,
            template.edge_index,
            template.edge_attr.view(-1).to(device)
        )
        mlp = vgae_to_mlp(Data(
            x=node_pred.unsqueeze(1),
            edge_index=template.edge_index,
            edge_attr=edge_pred.unsqueeze(1),
        ))

        # 3) Extract weight and bias arrays
        W1 = mlp.fc1.weight.data.cpu().numpy()  # shape [n_hid, n_in]
        b1 = mlp.fc1.bias.data.cpu().numpy()    # shape [n_hid]
        W2 = mlp.fc2.weight.data.cpu().numpy()  # shape [n_hid, n_hid]
        b2 = mlp.fc2.bias.data.cpu().numpy()    # shape [n_hid]
        W3 = mlp.fc3.weight.data.cpu().numpy()  # shape [n_out, n_hid]

        # 4) Collect neuron vectors for both layers
        for i in range(n_hid):
            # incoming to layer1, bias1, outgoing from layer1 to layer2
            v1 = np.concatenate([W1[i], [b1[i]], W2[:, i]])
            layer1_neurons.append(v1)
        for j in range(n_hid):
            # incoming to layer2, bias2, outgoing from layer2 to output
            v2 = np.concatenate([W2[j, :], [b2[j]], W3[:, j]])
            layer2_neurons.append(v2)

        # optional cache cleanup
        # gc.collect()
        # if torch.cuda.is_available():
        #     gc()

    return np.stack(layer1_neurons, axis=0), np.stack(layer2_neurons, axis=0)


In [None]:
# Generate new MLP weights
logging.info("Generating new MLP weights...")
n_samples = 10

generate_type = 1 # 0 MNIST, 1 Fashion MNIST
source_std = cfm.source_std
random_flat = torch.randn(n_samples, flat_dim, device=device) * source_std
class_labels = torch.full([n_samples], generate_type, dtype=torch.long, device=device)#.unsqueeze(0) # 0 for MNIST and 1 for Fashion MNIST

new_weights_flat = cfm.map(
    random_flat, 
    class_labels,
    n_steps=100,
    noise_scale=0.001
)
new_weights_flat.shape

In [None]:
# flat_mnist_target_weights.shape[1]
# device

In [None]:
import torch
import numpy as np
from sklearn.decomposition import PCA

def build_distribution(cfm, n_samples = 1, generate_type = 0, n_steps = 100, flat_dim = 7690, device = device):
    
    # model shape hyperparams
    n_in=196
    n_hid=32
    n_out=10

    # cfm
    source_std = cfm.source_std
    random_flat = torch.randn(n_samples, flat_dim, device=device) * source_std
    class_labels = torch.full([n_samples], generate_type, dtype=torch.long, device=device)#.unsqueeze(0) # 0 for MNIST and 1 for Fashion MNIST
    
    new_weights_flat = cfm.map(
        random_flat, 
        class_labels,
        n_steps=n_steps,
        noise_scale=source_std
    )
    
    # Convert to MLP weights and save
    layer1_neurons = []
    layer2_neurons = []
    for i in range(n_samples):
        
        new_wso = WeightSpaceObject.from_flat(
            new_weights_flat[i], 
            layers=np.array(layer_layout), 
            device=device
        )
    
        expected_weight_shapes = [(32, 196), (32, 32), (10, 32)]
        expected_bias_shapes = [(32,), (32,), (10,)]
    
        assert len(new_wso.weights) == 3, f"Expected 3 weight matrices, got {len(new_wso.weights)}"
        assert len(new_wso.biases) == 3, f"Expected 3 bias vectors, got {len(new_wso.biases)}"
        
        # Check each weight and bias shape
        for j, (w, expected_shape) in enumerate(zip(new_wso.weights, expected_weight_shapes)):
            assert w.shape == expected_shape, f"Weight {j} has shape {w.shape}, expected {expected_shape}"
        
        for j, (b, expected_shape) in enumerate(zip(new_wso.biases, expected_bias_shapes)):
            assert b.shape == expected_shape, f"Bias {j} has shape {b.shape}, expected {expected_shape}"
    
        # Create and test model
        model = MLP()
        model.fc1.weight.data = new_wso.weights[0].clone()
        model.fc1.bias.data = new_wso.biases[0].clone()
        model.fc2.weight.data = new_wso.weights[1].clone()
        model.fc2.bias.data = new_wso.biases[1].clone()
        model.fc3.weight.data = new_wso.weights[2].clone()
        model.fc3.bias.data = new_wso.biases[2].clone()
    
        # 3) Extract weight and bias arrays
        W1 = model.fc1.weight.data.cpu().numpy()  # shape [n_hid, n_in]
        b1 = model.fc1.bias.data.cpu().numpy()    # shape [n_hid]
        W2 = model.fc2.weight.data.cpu().numpy()  # shape [n_hid, n_hid]
        b2 = model.fc2.bias.data.cpu().numpy()    # shape [n_hid]
        W3 = model.fc3.weight.data.cpu().numpy()  # shape [n_out, n_hid]
    
        # 4) Collect neuron vectors for both layers
        for i in range(n_hid):
            # incoming to layer1, bias1, outgoing from layer1 to layer2
            v1 = np.concatenate([W1[i], [b1[i]], W2[:, i]])
            layer1_neurons.append(v1)
        for j in range(n_hid):
            # incoming to layer2, bias2, outgoing from layer2 to output
            v2 = np.concatenate([W2[j, :], [b2[j]], W3[:, j]])
            layer2_neurons.append(v2)
    
        return np.stack(layer1_neurons, axis=0), np.stack(layer2_neurons, axis=0)


def build_mixeduse_from_two_pcas(
    pca1,
    mu1,
    cov1,
    pca2,
    mu2,
    cov2,
    M: int,
    n_in: int,
    orig_hid: int,
    n_out: int,
    init_type: str = 'xavier',
    seed: int = None,
    device: torch.device = torch.device('cpu')
) -> nn.Module:
    """
    Construct a new MixedUseMLP with hidden dimension M by sampling neurons
    from two PCA+Gaussian models:
      - PCA1 models layer1 neuron joint distribution: size n_in + 1 + orig_hid
      - PCA2 models layer2 neuron joint distribution: size orig_hid + 1 + n_out

    Parameters:
      pca1, mu1, cov1: PCA, mean, cov for layer1
      pca2, mu2, cov2: PCA, mean, cov for layer2
      M: new hidden layer size
      n_in: input dimension
      orig_hid: original hidden size used in PCA
      n_out: output dimension
      init_type: weight init for layer2 fallback
      seed: random seed for reproducibility
      device: torch device

    Returns:
      A MixedUseMLP with parameters set by sampled neurons.
    """
    # Seed for reproducibility
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)

    # Sample PCA latents and invert to full vectors
    Z1_new = np.random.multivariate_normal(mu1, cov1, size=M)
    neuron1 = pca1.inverse_transform(Z1_new)
    Z2_new = np.random.multivariate_normal(mu2, cov2, size=M)
    neuron2 = pca2.inverse_transform(Z2_new)

    # Split neuron parameter vectors
    W1_new = neuron1[:, :n_in]
    b1_new = neuron1[:, n_in]
    W2_out = neuron1[:, n_in+1:]

    W2_in = neuron2[:, :orig_hid]
    b2_new = neuron2[:, orig_hid]
    W3_out = neuron2[:, orig_hid+1:]

    # Build target network
    from torch import nn
    mlp = MLP(hidden_dim=M, init_type=init_type, seed=seed).to(device)
    mlp.eval()

    with torch.no_grad():
# np.sqrt(2. / in_dim) 
        
        # Layer1 parameters
        # with Kaiming-He dev trick
        mlp.fc1.weight.copy_( torch.from_numpy(W1_new).float().to(device) * np.sqrt(2./ n_in) )
        mlp.fc1.bias.copy_( torch.from_numpy(b1_new).float().to(device) )

        # Layer2 (fc2): build a linking matrix that aligns new outgoing and incoming neuron projections
        # We have two sets of vectors:
        # 1) W2_out: outgoing projections from new layer1 neurons to old layer2 space ([M, orig_hid])
        # 2) W2_in: incoming projections to new layer2 neurons from old layer1 space ([M, orig_hid])
        # To construct a full [M, M] weight matrix, we compute an orthonormal basis Q of shape [M, M] via QR,
        # then align its first `orig_hid` columns to approximate the W2_out and W2_in jointly.
        # Simplest implementation: use random orthonormal Q as a prior that mixes both channels.
        
        A = torch.randn(M, M, device=device)
        Q, Rmat = torch.qr(A)
        # ensure proper orientation
        if torch.det(Q) < 0:
            Q[:, 0] *= -1
            
        # mlp.fc2.weight.copy_(Q * np.sqrt(2. / M))
        # Replace the first `orig_hid` columns of Q with a projection of W2_out and W2_in
        
        # Project W2_out (M x orig_hid) into M-dimensional basis
        proj_out = torch.from_numpy(W2_out).float().to(device) @ Q[:orig_hid, :]  
        proj_in  = torch.from_numpy(W2_in).float().to(device) @ Q[:orig_hid, :]
        
        # Average the two projections for alignment
        aligned = 0.5 * (proj_out + proj_in) * np.sqrt(2. / M)
        mlp.fc2.weight[:aligned.size(0), :aligned.size(1)].copy_(aligned)
        
        # Bias for fc2
        mlp.fc2.bias.copy_(torch.from_numpy(b2_new).float().to(device))

        # Output layer (fc3)
        mlp.fc3.weight.copy_(torch.from_numpy(W3_out.T * np.sqrt(2./ M)).float().to(device))
        mlp.fc3.bias.zero_()

    return mlp

def first_epoch_training(model, train_loader, epochs=1):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    hist = []
    
    model.train()
    for epoch in range(epochs):
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            hist.append(loss.item())
            loss.backward()
            optimizer.step()
    
    return model, hist

In [None]:
n_samples = 20
class_label = 0 # MNIST

# 2a) collect
L1, L2 = build_distribution(
    cfm, n_samples = n_samples, generate_type = class_label
)

var_accounted = 0.99
# 2b) PCA on layer1
pca1 = PCA(n_components=var_accounted, whiten=True)
Z1   = pca1.fit_transform(L1)
mu1  = Z1.mean(axis=0)
cov1 = np.cov(Z1, rowvar=False)

# 2c) PCA on layer2
pca2 = PCA(n_components=var_accounted, whiten=True)
Z2   = pca2.fit_transform(L2)
mu2  = Z2.mean(axis=0)
cov2 = np.cov(Z2, rowvar=False)
print(f"Done!")

In [None]:
Z2.shape

In [None]:
# Need to do a loop to sample many cases to test if this is robustly better than Xavier init 

weird_loader = mnist_test_loader

My_M = 64 # random number, more than the original (finicky computer)
longshot = build_mixeduse_from_two_pcas(
    pca1, mu1, cov1,
    pca2, mu2, cov2,
    M = My_M,   # desired new hidden size
    n_in = 196, n_out = 10, orig_hid = 32,
    init_type='xavier', seed=None, device='cpu'
)

initial_acc = test_mlp(longshot, weird_loader)
trained_dumb_recon, recon_hist = first_epoch_training(longshot, weird_loader)
print(f"Populated Model initial performance: {initial_acc:.2f}%")
print(f"Populated Model After 1 epoch: {test_mlp(trained_dumb_recon, weird_loader):.2f}%")

# new_model = sample_first_mlp_from_batch(dumb)
old_model = MLP(hidden_dim = My_M, init_type='xavier', seed=np.random.randint(10))
random_acc = test_mlp(old_model, weird_loader)
trained_dumb, hist = first_epoch_training(old_model, weird_loader)
print(f"Fresh Initialization initial performance: {random_acc:.2f}%")
print(f"Fresh Initialization After 1 epoch: {test_mlp(trained_dumb, weird_loader):.2f}%")

In [None]:
x = np.linspace( 0, len(hist), len(hist))
plt.figure(figsize=(8, 6), dpi=80)
plt.scatter(x, hist, label = f"Random Initialization")
plt.scatter(x, recon_hist, label = f"Sampled Initialization")
plt.xlabel(f"Optimizer Step no. in epoch 1")
plt.ylabel(f"Loss per step")
plt.legend()
# plt.savefig(f"MLP_Diffusion_figure.png", format="png")
plt.show()


In [None]:
lil_hist = hist[:50]
lil_recon_hist = recon_hist[:50]
x = np.linspace( 0, len(lil_hist), len(lil_hist))
plt.figure(figsize=(8, 6), dpi=80)
plt.scatter(x, lil_hist, label = f"Random Initialization")
plt.scatter(x, lil_recon_hist, label = f"Sampled Initialization")
plt.xlabel(f"Optimizer Step no. in epoch 1")
plt.ylabel(f"Loss per step")
plt.legend()
# plt.savefig(f"MLP_Diffusion_figure.png", format="png")
plt.show()