In [1]:
from __future__ import annotations

import copy
import logging
import os
import re
import textwrap
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager, nullcontext
from typing import Any, Optional, Union
import sys

import torch
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook
from accelerate.utils import named_module_tensors, offload_state_dict
from torch import nn
from transformers import PreTrainedModel
from transformers.pytorch_utils import Conv1D

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.wishart import Wishart
from torch.distributions.gamma import Gamma


import torch.nn.functional as F
import threading
import queue
import time

eps = 1e-3
class AsyncMonteCLoRASampler(threading.Thread):
    def __init__(self, model, buffer_size=10, device='cpu'):
        super().__init__(daemon=True)
        self.model = model
        self.device = device
        self.buffer_size = buffer_size
        self.queue = queue.Queue(maxsize=buffer_size)
        self.running = True

    def run(self):
        while self.running:
            if self.queue.qsize() < self.buffer_size:
                # print(self.queue.qsize())
                try:
                    with torch.no_grad():
                        z_mvn = torch.randn(
                            (self.model.num_experts, self.model.in_features, self.model.out_features),
                            device=self.device
                        )
                        wishart_sampler = Wishart(
                            df=self.model.out_features,
                            scale_tril=torch.eye(self.model.out_features, device=self.device)
                        )
                        z_wishart = wishart_sampler._bartlett_sampling(torch.Size())

                        z_dirichlet = torch.randn(self.model.num_experts, device=self.device)

                        sample = {
                            'z_mvn': z_mvn,
                            'z_wishart': z_wishart,
                            'z_dirichlet': z_dirichlet
                        }

                        self.queue.put_nowait(sample)
                except KeyboardInterrupt:
                    sys.exit(0)
                except Exception as e:
                    print(f"[AsyncMonteCLoRASampler] Error: {e}")
            else:
                time.sleep(0.01)

    def get(self):
        try:
            return self.queue.get_nowait()
        except queue.Empty:
            return None

    def stop(self):
        self.running = False
        self.join(timeout=1.0)

class MonteCLoRASampler(nn.Module):
    def __init__(
        self, 
        in_features, 
        out_features, 
        num_experts,
        fan_in_fan_out=False,
        use_entropy=True,
        dirichlet_prior=1,
        sample_scaler=3e-4,
        kl_loss_weight=1e-5,
        mc_training=True,
        buffer_size=100,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_experts = num_experts
        self.fan_in_fan_out = fan_in_fan_out
        self.use_entropy = use_entropy
        self.device = device
        self.sample_scaler = sample_scaler
        self.kl_loss_weight = kl_loss_weight
        self.mc_training = mc_training
        self.dirichlet_prior = dirichlet_prior

        self.std_prior = nn.Parameter(torch.rand(out_features))
        # self.expert_weights_prior = nn.Parameter(torch.rand(num_experts))
        self.gaussian_var_prior = torch.eye(out_features).to(device)
        self.expert_weights = torch.ones(num_experts, device=device) / num_experts

        self.sampler = None
        if self.mc_training:
            self.sampler = AsyncMonteCLoRASampler(self, buffer_size, device)
            self.sampler.start()

    def wishart_reparameterization(self, std, z_wishart):

        updated_var = std @ z_wishart @ std.T
        updated_var = torch.diag(torch.clip(updated_var.diag(), min=eps))
        return updated_var

    def multivariate_reparameterization(self, z_mvn, cov_matrix):

        L = torch.linalg.cholesky(cov_matrix).to(cov_matrix.dtype)
        varsum = z_mvn @ L#torch.einsum('eio,op->eip', z_mvn, L) 
        varsum = torch.nan_to_num(varsum, nan=eps)
        return self.sample_scaler * varsum

    def dirichlet_reparameterization(self, alpha, z_dirichlet):
        mu = torch.log(alpha) - torch.log(alpha).mean()
        sigma = torch.diag(1/alpha * (1 - 2/self.num_experts) + 1/(self.num_experts**2) * (1/alpha).sum())
        L = torch.linalg.cholesky(sigma)
        return L @ z_dirichlet + mu

    def calculate_entropy(self, expert_weights):
        return (expert_weights ** 2).sum()

    def dirichlet_kl(self, alpha2):
        alpha1 = torch.tensor([self.dirichlet_prior]*self.num_experts, device=self.device)
        gamma = lambda v: torch.lgamma(v).exp()
        return torch.log(gamma(alpha2.sum())/gamma(alpha1.sum())) + \
               (torch.log(gamma(alpha2)/gamma(alpha1))).sum() + \
               ((alpha2 - alpha1)*(torch.digamma(alpha2) - torch.digamma(alpha2.sum()))).sum()

    def wishart_kl(self, std):
        var = std @ std.T
        var = torch.diag(var.diag())
        return 0.5 * (-torch.log(var).trace()*self.out_features + var.trace()*self.out_features - self.out_features**2)

    def multivariate_kl(self, var):
        var = torch.clamp(var, min=1e-6)
        return self.num_experts * 0.5 * (var.trace() - torch.log(var).trace() - self.out_features)

    def get_variational_loss(self):
        if self.mc_training and self.training:
            kl1 = 0 #self.dirichlet_kl(torch.exp(self.expert_weights_prior))
            kl2 = self.wishart_kl(torch.diag(torch.exp(self.std_prior)))
            kl3 = self.multivariate_kl(self.gaussian_var_prior)
            entropy = self.calculate_entropy(self.expert_weights) if self.use_entropy else 0
            return self.kl_loss_weight * (kl1 + kl2 + kl3), entropy
        return 0, 0

    def forward(self):
        if self.training and self.mc_training:
            # t = time.time()
            sample = self.sampler.get() if self.sampler else None

            if sample is not None:
                z_mvn = sample['z_mvn']
                z_wishart = sample['z_wishart']
                # z_dirichlet = sample['z_dirichlet']
            else:
                z_mvn = torch.randn((self.num_experts, self.in_features, self.out_features), device=self.device)
                wishart_sampler = Wishart(df=self.out_features, scale_tril=torch.eye(self.out_features, device=self.device))
                z_wishart = wishart_sampler._bartlett_sampling(torch.Size())
                # z_dirichlet = torch.randn(self.num_experts, device=self.device)
            # temp = time.time() - t
            # print(temp)

            # t = time.time()
            std = torch.diag(torch.exp(self.std_prior))
            gaussian_var = self.wishart_reparameterization(std, z_wishart)
            self.gaussian_var_prior = gaussian_var

            var = self.multivariate_reparameterization(z_mvn, gaussian_var)

            # expert_weights = self.dirichlet_reparameterization(torch.exp(self.expert_weights_prior), z_dirichlet)
            # expert_weights = torch.ones(self.num_experts, device=self.device) / self.num_experts
            # expert_weights = torch.sigmoid(expert_weights)
            # expert_weights = expert_weights / expert_weights.sum()
            # self.expert_weights = expert_weights
            # temp2 = time.time()  - t
            # print(temp2)
            # print(temp2/temp)
            return var, self.expert_weights
        else:
            return -1, -1

    def eval(self):
        if self.sampler:
            self.sampler.stop()
            self.sampler = None
        super().eval()

    def __del__(self):
        if hasattr(self, 'sampler') and self.sampler:
            self.sampler.stop()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import time

# Assume sampler_list is defined and on the correct CUDA device
device = torch.device("cuda:3") # Or cuda:local_rank in DDP
sampler_list = {f"sampler_{i}": MonteCLoRASampler(10,10,10, device=device) for i in range(10)}
for sampler in sampler_list.values():
    sampler.to(device)
    sampler.train()

In [5]:


if torch.cuda.is_available():
    num_samplers = len(sampler_list)
    # Create a stream for each sampler call
    streams = [torch.cuda.Stream(device=device) for _ in range(num_samplers)]

    results = {}
    outputs = [None] * num_samplers # Pre-allocate list for outputs

    start_time = time.time()

    # Launch forward calls on different streams
    sampler_items = list(sampler_list.items()) # Get (name, sampler) pairs
    for i in range(num_samplers):
        name, sampler = sampler_items[i]
        stream = streams[i]
        # --- IMPORTANT ---
        # The forward call itself needs to happen within the stream context
        # If forward() launches kernels, they go to 'stream'
        with torch.cuda.stream(stream):
             # Run the forward pass - operations inside will be queued on 'stream'
             # We store the output directly, synchronization happens later
             outputs[i] = sampler.forward() # Output is (var, weights) tuple

    # Synchronize all streams to ensure computations are finished
    # Option 1: Synchronize default stream with all others (simpler)
    # torch.cuda.synchronize(device=device) # Waits for all kernels on the device

    # Option 2: Synchronize each stream individually (more granular)
    for stream in streams:
        stream.synchronize()

    end_time = time.time()
    print(f"CUDA Streams took: {end_time - start_time:.4f} seconds")

    # Populate results dictionary after synchronization
    for i in range(num_samplers):
        name, _ = sampler_items[i]
        results[name] = outputs[i]


    # Example access
    # var_0, weights_0 = results['sampler_0']
    # if var_0 != -1: # Check for non-training output
    #      print(f"Sampler 0 var shape: {var_0.shape}")
    #      print(f"Sampler 0 weights: {weights_0}")

else:
    print("CUDA not available, skipping CUDA Streams example.")
    # Fallback to sequential execution or ThreadPoolExecutor if desired

# --- Don't forget cleanup ---
# for sampler in sampler_list.values():
#      sampler.eval()

CUDA Streams took: 0.0433 seconds


In [14]:
# Keep all previous imports and the definitions of:
# eps, AsyncMonteCLoRASampler, MonteCLoRASampler

from __future__ import annotations

import copy
import logging
import os
import re
import textwrap
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager, nullcontext
from typing import Any, Optional, Union, Dict, Tuple, List
import sys
import itertools
import concurrent.futures

import torch
# Import cuda specifics if available
if torch.cuda.is_available():
    import torch.cuda
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook
from accelerate.utils import named_module_tensors, offload_state_dict
from torch import nn
from transformers import PreTrainedModel
from transformers.pytorch_utils import Conv1D

from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.dirichlet import Dirichlet
from torch.distributions.wishart import Wishart
from torch.distributions.gamma import Gamma


import torch.nn.functional as F
import threading
import queue
import time

# ============================================================================
# Assume eps, AsyncMonteCLoRASampler, MonteCLoRASampler are defined as before
# Paste their full definitions here if running stand-alone
# ============================================================================
eps = torch.finfo(torch.float32).eps

# --- AsyncMonteCLoRASampler Definition ---
class AsyncMonteCLoRASampler(threading.Thread):
    def __init__(self, model, buffer_size=10, device='cpu'):
        super().__init__(daemon=True)
        self.model = model
        self.device = device
        self.buffer_size = buffer_size
        self.queue = queue.Queue(maxsize=buffer_size)
        self.running = True
        self._fill_exception = None

    def run(self):
        while self.running:
            if self._fill_exception:
                 time.sleep(0.1)
                 continue
            if self.queue.qsize() < self.buffer_size:
                try:
                    with torch.no_grad():
                        z_mvn = torch.randn(
                            (self.model.num_experts, self.model.in_features, self.model.out_features),
                            device=self.device
                        )
                        # Ensure df > p-1 for Wishart. Handle p=1 case.
                        df_tensor = torch.tensor(float(max(self.model.out_features, 1)), device=self.device)
                        if self.model.out_features == 0: # Avoid eye(0)
                           z_wishart = torch.empty((0,0), device=self.device) # Or handle differently
                        else:
                           scale_tril = torch.eye(self.model.out_features, device=self.device)
                           wishart_sampler = Wishart(df=df_tensor, scale_tril=scale_tril)
                           z_wishart = wishart_sampler.sample() # Sample directly

                        z_dirichlet = torch.randn(self.model.num_experts, device=self.device)

                        sample = {
                            'z_mvn': z_mvn,
                            'z_wishart': z_wishart,
                            'z_dirichlet': z_dirichlet
                        }
                        self.queue.put_nowait(sample)
                except queue.Full:
                     time.sleep(0.005)
                except RuntimeError as e:
                     print(f"[AsyncMonteCLoRASampler] Runtime Error during sampling: {e}")
                     self._fill_exception = e
                     time.sleep(0.1)
                except Exception as e:
                    if isinstance(e, KeyboardInterrupt):
                        print("[AsyncMonteCLoRASampler] KeyboardInterrupt received, stopping.")
                        self.running = False
                    else:
                        print(f"[AsyncMonteCLoRASampler] Error during sampling: {e}")
                        self._fill_exception = e
                    time.sleep(0.1)
            else:
                time.sleep(0.01)

    def get(self):
        if self._fill_exception:
            raise RuntimeError(f"Async sampler encountered an error: {self._fill_exception}") from self._fill_exception
        try:
            return self.queue.get_nowait()
        except queue.Empty:
            return None

    def stop(self):
        self.running = False
        while not self.queue.empty():
            try:
                self.queue.get_nowait()
            except queue.Empty:
                break
        self.join(timeout=2.0) # Increased timeout slightly

    def __del__(self):
        if self.is_alive():
            self.stop()

# --- MonteCLoRASampler Definition ---
class MonteCLoRASampler(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        num_experts,
        fan_in_fan_out=False,
        use_entropy=True,
        dirichlet_prior=1.0,
        sample_scaler=3e-4,
        kl_loss_weight=1e-5,
        mc_training=True,
        buffer_size=10,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_experts = num_experts
        self.use_entropy = use_entropy
        self.device = device
        self.sample_scaler = sample_scaler
        self.kl_loss_weight = kl_loss_weight
        self.mc_training = mc_training
        self.dirichlet_prior = float(dirichlet_prior)
        self.buffer_size = buffer_size

        self.log_std_prior = nn.Parameter(torch.randn(out_features, device=device))
        self.register_buffer('last_gaussian_var', torch.eye(out_features, device=device), persistent=False)
        self.register_buffer('last_expert_weights', torch.ones(num_experts, device=device) / num_experts, persistent=False)

        self.sampler = None
        self._start_sampler()
        self.to(device)


    def _start_sampler(self):
        if self.mc_training and self.training and (self.sampler is None or not self.sampler.is_alive()):
            self.sampler = AsyncMonteCLoRASampler(self, self.buffer_size, self.device)
            self.sampler.start()

    def _stop_sampler(self):
        if hasattr(self, 'sampler') and self.sampler and self.sampler.is_alive():
            self.sampler.stop()
        self.sampler = None

    def train(self, mode: bool = True):
        # Clear potential error from previous runs when changing mode
        if hasattr(self, 'sampler') and self.sampler:
             self.sampler._fill_exception = None
        if mode:
            self._start_sampler()
        else:
            self._stop_sampler()
        return super().train(mode)

    def eval(self):
        return self.train(False)

    def wishart_reparameterization(self, log_std, z_wishart):
        if self.out_features == 0:
            return torch.empty((0,0), device=self.device)
        std = torch.diag(torch.exp(torch.clamp(log_std, min=-10, max=10)))
        updated_var = std @ z_wishart @ std
        updated_var_diag = torch.diag(torch.clamp(updated_var.diag(), min=eps))
        self.last_gaussian_var = updated_var_diag
        return updated_var_diag

    def multivariate_reparameterization(self, z_mvn, cov_matrix_diag):
        if self.out_features == 0:
            return torch.zeros_like(z_mvn) # Return zeros if no output features
        std_dev = torch.sqrt(torch.diag(cov_matrix_diag)) # Get std dev vector
        varsum = z_mvn * std_dev.unsqueeze(0).unsqueeze(1) # Broadcast std_dev
        varsum = torch.nan_to_num(varsum, nan=eps)
        return self.sample_scaler * varsum

    def dirichlet_reparameterization(self, log_alpha, z_dirichlet):
        # Placeholder: return fixed weights for now
        fixed_weights = torch.ones(self.num_experts, device=self.device) / self.num_experts
        self.last_expert_weights = fixed_weights
        return fixed_weights

    def calculate_entropy(self, expert_weights):
        return (expert_weights ** 2).sum()

    def dirichlet_kl(self, log_alpha2):
        # Not currently used
        return torch.tensor(0.0, device=self.device)

    def wishart_kl(self, log_std):
        if self.out_features == 0: return torch.tensor(0.0, device=self.device)
        var = torch.exp(2 * log_std) # Variance vector = std^2
        log_det_sigma = torch.sum(torch.log(var + eps))
        trace_sigma = torch.sum(var)
        p = self.out_features
        kl_normal = 0.5 * (trace_sigma - log_det_sigma - p)
        return kl_normal

    def multivariate_kl(self, cov_matrix_diag):
        if self.out_features == 0 or cov_matrix_diag.numel() == 0:
             return torch.tensor(0.0, device=self.device)
        var = torch.diag(cov_matrix_diag)
        if var.numel() == 0: # Handle case where diag is empty
             return torch.tensor(0.0, device=self.device)
        log_det_sigma = torch.sum(torch.log(var + eps))
        trace_sigma = torch.sum(var)
        p = self.out_features
        kl_normal = 0.5 * (trace_sigma - log_det_sigma - p)
        return self.num_experts * kl_normal

    def get_variational_loss(self) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.mc_training and self.training:
            kl1 = torch.tensor(0.0, device=self.device) # dirichlet_kl
            kl2 = self.wishart_kl(self.log_std_prior)
            kl3 = self.multivariate_kl(self.last_gaussian_var)
            total_kl = kl1 + kl2 + kl3
            entropy = self.calculate_entropy(self.last_expert_weights) if self.use_entropy else 0.0
            return self.kl_loss_weight * total_kl, entropy
        else:
            return torch.tensor(0.0, device=self.device), torch.tensor(0.0, device=self.device)

    def forward(self) -> Tuple[Union[torch.Tensor, int], Union[torch.Tensor, int]]:
        if self.training and self.mc_training:
            if self.sampler is None or not self.sampler.is_alive():
                 # Try to recover if sampler died unexpectedly
                 print(f"Warning: Sampler for {id(self)} not running in forward. Attempting restart.")
                 self._start_sampler()
                 if self.sampler is None or not self.sampler.is_alive():
                      print(f"Error: Could not restart sampler for {id(self)}.")
                      # Decide fallback: error or synchronous? Let's return error indicator
                      return -1, -1 # Indicate failure
                 else:
                    time.sleep(0.01) # Give thread time to start and maybe populate queue
                    sample = self.sampler.get()
            else:
                try:
                    sample = self.sampler.get()
                except RuntimeError as e: # Catch error propagated from async sampler
                     print(f"Error retrieving sample from async sampler {id(self)}: {e}")
                     return -1, -1 # Indicate failure

            if sample is None:
                # print("Async queue empty, generating sample synchronously.")
                with torch.no_grad():
                    z_mvn = torch.randn((self.num_experts, self.in_features, self.out_features), device=self.device)
                    if self.out_features > 0:
                        df_tensor = torch.tensor(float(self.out_features), device=self.device)
                        scale_tril = torch.eye(self.out_features, device=self.device)
                        wishart_sampler = Wishart(df=df_tensor, scale_tril=scale_tril)
                        z_wishart = wishart_sampler.sample()
                    else:
                        z_wishart = torch.empty((0,0), device=self.device)
            else:
                z_mvn = sample['z_mvn']
                z_wishart = sample['z_wishart']

            gaussian_var_diag = self.wishart_reparameterization(self.log_std_prior, z_wishart)
            var = self.multivariate_reparameterization(z_mvn, gaussian_var_diag)
            # Use fixed expert weights for now
            expert_weights = torch.ones(self.num_experts, device=self.device) / self.num_experts
            self.last_expert_weights = expert_weights

            return var, expert_weights
        else:
            return -1, -1

    def __del__(self):
        self._stop_sampler()

# --- Updated MultiMonteCLoRASampler ---

class MultiMonteCLoRASampler(nn.Module):
    def __init__(
        self,
        num_outer: int,
        in_features: int,
        out_features: int,
        num_experts: int,
        fan_in_fan_out: bool = False,
        use_entropy: bool = True,
        dirichlet_prior: float = 1.0,
        sample_scaler: float = 3e-4,
        kl_loss_weight: float = 1e-5,
        cosine_similarity_weight: float = 1e-5,
        mc_training: bool = True,
        buffer_size: int = 10,
        device: str | torch.device = 'cuda' if torch.cuda.is_available() else 'cpu',
        use_cuda_streams: bool = True # New flag to control stream usage
    ):
        super().__init__()
        self.num_outer = num_outer
        self.mc_training = mc_training
        self.cosine_similarity_weight = cosine_similarity_weight
        self.kl_loss_weight = kl_loss_weight # Needed for individual samplers
        self.use_cuda_streams = use_cuda_streams

        if num_outer < 1:
            raise ValueError("num_outer must be at least 1")

        # Ensure device is a torch.device object
        self.device = torch.device(device)

        self.samplers = nn.ModuleList()
        for i in range(num_outer):
            sampler = MonteCLoRASampler(
                in_features=in_features,
                out_features=out_features,
                num_experts=num_experts,
                fan_in_fan_out=fan_in_fan_out,
                use_entropy=use_entropy,
                dirichlet_prior=dirichlet_prior,
                sample_scaler=sample_scaler,
                kl_loss_weight=kl_loss_weight, # Pass down for internal calculation
                mc_training=mc_training,
                buffer_size=buffer_size,
                device=self.device, # Use the unified device
            )
            self.samplers.append(sampler)

        # Initialize CUDA streams if applicable and requested
        self.streams = None
        if self.device.type == 'cuda' and self.use_cuda_streams and torch.cuda.is_available():
            # print("Initializing CUDA streams for MultiMonteCLoRASampler") # Debug
            self.streams = [torch.cuda.Stream(device=self.device) for _ in range(self.num_outer)]
        elif self.device.type == 'cuda' and not torch.cuda.is_available():
             print("Warning: Device specified as CUDA, but CUDA is not available. Falling back.")
             self.device = torch.device('cpu') # Fallback device

        self.to(self.device) # Ensure wrapper module itself is on the correct device


    def forward(self) -> Tuple[Union[torch.Tensor, int], Union[torch.Tensor, int]]:
        """
        Calls the forward method of all internal samplers (concurrently if possible)
        and returns the *mean* of their valid outputs.

        Returns:
            Tuple[Union[torch.Tensor, int], Union[torch.Tensor, int]]:
            A tuple containing:
                - mean_var (torch.Tensor): Mean of LoRA weight matrices delta_W across samplers.
                                           Shape: (num_experts, in_features, out_features).
                - mean_expert_weights (torch.Tensor): Mean of expert weights across samplers.
                                                      Shape: (num_experts,).
            Returns (-1, -1) if not in training mode, mc_training is False, or no
            samplers produce valid output.
        """
        if not self.training or not self.mc_training:
            return -1, -1

        valid_vars: List[torch.Tensor] = []
        valid_weights: List[torch.Tensor] = []
        outputs: List[Tuple[Union[torch.Tensor, int], Union[torch.Tensor, int]]] = [(None, None)] * self.num_outer

        try:
            # --- Execute forward passes concurrently ---
            if self.streams: # Use CUDA Streams
                # print("Using CUDA streams for forward pass...") # Debug
                for i, sampler in enumerate(self.samplers):
                    with torch.cuda.stream(self.streams[i]):
                        # Launch the forward pass onto the stream
                        outputs[i] = sampler.forward()
                        # Note: Execution is asynchronous here. Result tensors are not ready yet.

                # Synchronize all streams to wait for completion
                # torch.cuda.synchronize(self.device) # Waits for all kernels on device
                # Alternative: synchronize streams individually (might be slightly more efficient if streams finish at different times)
                for stream in self.streams:
                     stream.synchronize()
                # print("CUDA streams synchronized.") # Debug

                # Now collect results from the 'outputs' list
                for i in range(self.num_outer):
                    var, expert_weights = outputs[i]
                    if isinstance(var, torch.Tensor) and isinstance(expert_weights, torch.Tensor):
                        valid_vars.append(var)
                        valid_weights.append(expert_weights)
                    # else: print(f"Sampler {i} returned invalid output: {outputs[i]}") # Debug

            else: # Use ThreadPoolExecutor (CPU or no streams requested)
                # print("Using ThreadPoolExecutor for forward pass...") # Debug
                futures = {}
                # Limit workers? Or use num_outer? Let's use num_outer for max potential parallelism.
                with concurrent.futures.ThreadPoolExecutor(max_workers=self.num_outer) as executor:
                    for i, sampler in enumerate(self.samplers):
                        futures[i] = executor.submit(sampler.forward)

                    for i in range(self.num_outer):
                        try:
                            var, expert_weights = futures[i].result()
                            if isinstance(var, torch.Tensor) and isinstance(expert_weights, torch.Tensor):
                                valid_vars.append(var)
                                valid_weights.append(expert_weights)
                            # else: print(f"Sampler {i} returned invalid output: {(var, expert_weights)}") # Debug
                        except Exception as e:
                            print(f"Error collecting result from sampler {i} thread: {e}")
                            # Decide how to handle thread errors, e.g., continue without its result

        except RuntimeError as e:
             # Catch potential CUDA errors during stream operations or synchronization
             print(f"Runtime Error during multi-sampler forward pass: {e}")
             return -1, -1 # Indicate failure
        except Exception as e:
             print(f"Unexpected Error during multi-sampler forward pass: {e}")
             return -1, -1 # Indicate failure

        # --- Average the valid results ---
        if not valid_vars or not valid_weights:
            # print("No valid outputs received from samplers.") # Debug
            return -1, -1 # Return error indicator if no sampler succeeded

        # Stack tensors along a new dimension (dim=0) and compute the mean
        # Ensure all tensors are on the correct device before stacking (should be already)
        try:
            mean_var = torch.stack(valid_vars).mean(dim=0)
            mean_expert_weights = torch.stack(valid_weights).mean(dim=0)
        except RuntimeError as e:
             # Catch potential shape mismatches if samplers somehow returned different shapes
             print(f"Error averaging sampler outputs (check shapes?): {e}")
             # Optionally inspect shapes:
             # for i, v in enumerate(valid_vars): print(f"Var {i} shape: {v.shape}")
             # for i, w in enumerate(valid_weights): print(f"Weights {i} shape: {w.shape}")
             return -1, -1

        return mean_var, mean_expert_weights

    def get_variational_loss(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculates the combined variational loss across all samplers.
        (Loss logic remains the same as before - operates on individual sampler states)
        """
        if not self.training or not self.mc_training:
            return torch.tensor(0.0, device=self.device), torch.tensor(0.0, device=self.device)

        total_kl_loss = torch.tensor(0.0, device=self.device)
        total_entropy = torch.tensor(0.0, device=self.device)
        sampler_states = [] # Store states needed for cosine similarity

        # --- 1. Aggregate KL and Entropy ---
        for sampler in self.samplers:
            try:
                kl_loss, entropy = sampler.get_variational_loss()
                total_kl_loss += kl_loss
                total_entropy += entropy
                # Store the necessary state for cosine similarity calculation
                # We need the diagonal variance computed in the *last* forward pass
                sampler_states.append(sampler.last_gaussian_var) # Store the diagonal matrix
            except Exception as e:
                 print(f"Error getting loss from sampler {id(sampler)}: {e}")
                 # Skip this sampler for loss calculation? Or handle differently?
                 # Let's skip for now to avoid crashing.
                 continue

        if len(sampler_states) == 0: # If all samplers failed in get_variational_loss
             return torch.tensor(0.0, device=self.device), torch.tensor(0.0, device=self.device)

        mean_kl_loss = total_kl_loss / len(sampler_states)
        mean_entropy = total_entropy / len(sampler_states)

        # --- 2. Calculate Cosine Similarity Penalty ---
        total_cosine_similarity = torch.tensor(0.0, device=self.device)
        num_pairs = 0

        if len(sampler_states) >= 2 and self.cosine_similarity_weight > 0:
            # Extract the diagonal vectors
            prior_diagonals = [state.diag() for state in sampler_states if state.numel() > 0] # Ensure not empty

            # Check if we still have enough vectors after filtering empty ones
            if len(prior_diagonals) >= 2:
                for i, j in itertools.combinations(range(len(prior_diagonals)), 2):
                    vec_i = prior_diagonals[i].float()
                    vec_j = prior_diagonals[j].float()

                    # Check for non-zero vectors before similarity (cosine_similarity handles zero vectors with eps)
                    # if torch.count_nonzero(vec_i) > 0 and torch.count_nonzero(vec_j) > 0:
                    similarity = F.cosine_similarity(vec_i, vec_j, dim=0, eps=eps)
                    similarity_penalty = torch.clamp(similarity, min=0.0)
                    total_cosine_similarity += similarity_penalty
                    num_pairs += 1
            # else: print("Not enough valid prior diagonals for cosine similarity.") # Debug


        mean_cosine_similarity = (total_cosine_similarity / num_pairs) if num_pairs > 0 else torch.tensor(0.0, device=self.device)
        cosine_similarity_loss = self.cosine_similarity_weight * mean_cosine_similarity

        # --- 3. Combine Losses ---
        total_weighted_loss = mean_kl_loss + cosine_similarity_loss

        return total_weighted_loss, mean_entropy

    def train(self, mode: bool = True):
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        super().train(mode)
        for sampler in self.samplers:
            sampler.train(mode) # Propagate mode to children
        return self

    def eval(self):
        return self.train(False)

    def __del__(self):
        # print("Deleting MultiMonteCLoRASampler...") # Debug
        # Ensure streams are cleaned up? Streams are usually managed by context or device reset.
        # Explicitly stopping sampler threads is important.
        for i, sampler in enumerate(self.samplers):
            try:
                 sampler.eval() # Calls _stop_sampler internally
            except Exception as e:
                 # Mute errors during deletion as interpreter might be shutting down
                 pass
                 # print(f"Muted error cleaning up sampler {i} in MultiMonteCLoRASampler.__del__: {e}")


# --- Example Usage (Updated) ---

N_OUTER = 2
IN_DIM = 768
OUT_DIM = 4
N_EXPERTS = 4
USE_GPU = True # Set to False to test CPU path

if USE_GPU and torch.cuda.is_available():
    DEVICE = 'cuda:1'
else:
    DEVICE = 'cpu'
print(f"Using device: {DEVICE}")

# --- Test CUDA Stream Path ---

print("\n--- Testing CUDA Stream Path ---")
multi_sampler_stream = MultiMonteCLoRASampler(
    num_outer=N_OUTER,
    in_features=IN_DIM,
    out_features=OUT_DIM,
    num_experts=N_EXPERTS,
    kl_loss_weight=1e-5,
    cosine_similarity_weight=5e-6,
    buffer_size=20,
    device=DEVICE,
    use_cuda_streams=False # Explicitly request streams
)
multi_sampler_stream.train()
print("Stream Sampler: Train mode set.")

print("Stream Sampler: Simulating forward pass...")
start_time = time.time()
mean_var, mean_weights = multi_sampler_stream.forward()
end_time = time.time()
print(f"Stream Sampler: Forward pass took: {end_time - start_time:.4f} seconds")

if isinstance(mean_var, torch.Tensor):
    print(f"Stream Sampler: Mean var sample shape: {mean_var.shape}")
    print(f"Stream Sampler: Mean expert weights shape: {mean_weights.shape}")

    print("Stream Sampler: Calculating variational loss...")
    start_time = time.time()
    combined_loss, mean_entropy_term = multi_sampler_stream.get_variational_loss()
    end_time = time.time()
    print(f"Stream Sampler: Loss calculation took: {end_time - start_time:.4f} seconds")
    print(f"Stream Sampler: Combined Loss: {combined_loss.item():.6f}, Mean Entropy: {mean_entropy_term.item():.6f}")

    if combined_loss.requires_grad:
            print("Stream Sampler: Loss requires grad.")
    else:
            print("Stream Sampler: Loss does not require grad.")
else:
    print("Stream Sampler: Forward pass failed or returned non-tensor.")

print("Stream Sampler: Setting to eval mode...")
multi_sampler_stream.eval()
# del multi_sampler_stream 
print("Stream Sampler: Deleted.")
time.sleep(0.5) 

print(f"\n--- Testing {'CPU' if DEVICE=='cpu' else 'ThreadPool'} Path ---")
multi_sampler_thread = MultiMonteCLoRASampler(
    num_outer=N_OUTER,
    in_features=IN_DIM,
    out_features=OUT_DIM,
    num_experts=N_EXPERTS,
    kl_loss_weight=1e-5,
    cosine_similarity_weight=5e-6,
    buffer_size=20,
    device=DEVICE,
    use_cuda_streams=False
)
multi_sampler_thread.train()
print("Thread Sampler: Train mode set.")

print("Thread Sampler: Simulating forward pass...")
start_time = time.time()
mean_var_t, mean_weights_t = multi_sampler_thread.forward()
end_time = time.time()
print(f"Thread Sampler: Forward pass took: {end_time - start_time:.4f} seconds")

if isinstance(mean_var_t, torch.Tensor):
    print(f"Thread Sampler: Mean var sample shape: {mean_var_t.shape}")
    print(f"Thread Sampler: Mean expert weights shape: {mean_weights_t.shape}")

    print("Thread Sampler: Calculating variational loss...")
    start_time = time.time()
    combined_loss_t, mean_entropy_term_t = multi_sampler_thread.get_variational_loss()
    end_time = time.time()
    print(f"Thread Sampler: Loss calculation took: {end_time - start_time:.4f} seconds")
    print(f"Thread Sampler: Combined Loss: {combined_loss_t.item():.6f}, Mean Entropy: {mean_entropy_term_t.item():.6f}")
else:
    print("Thread Sampler: Forward pass failed or returned non-tensor.")

print("Thread Sampler: Setting to eval mode...")
multi_sampler_thread.eval()
# del multi_sampler_thread # Cleanup
print("Thread Sampler: Deleted.")
time.sleep(0.5)

print("\nDone.")

Using device: cuda:1

--- Testing CUDA Stream Path ---
Stream Sampler: Train mode set.
Stream Sampler: Simulating forward pass...
Stream Sampler: Forward pass took: 0.0109 seconds
Stream Sampler: Mean var sample shape: torch.Size([4, 768, 4])
Stream Sampler: Mean expert weights shape: torch.Size([4])
Stream Sampler: Calculating variational loss...
Stream Sampler: Loss calculation took: 0.0038 seconds
Stream Sampler: Combined Loss: 0.001938, Mean Entropy: 0.250000
Stream Sampler: Loss requires grad.
Stream Sampler: Setting to eval mode...
Stream Sampler: Deleted.

--- Testing ThreadPool Path ---
Thread Sampler: Train mode set.
Thread Sampler: Simulating forward pass...
Thread Sampler: Forward pass took: 0.0265 seconds
Thread Sampler: Mean var sample shape: torch.Size([4, 768, 4])
Thread Sampler: Mean expert weights shape: torch.Size([4])
Thread Sampler: Calculating variational loss...
Thread Sampler: Loss calculation took: 0.0032 seconds
Thread Sampler: Combined Loss: 0.002235, Mean Ent

In [21]:
print("Thread Sampler: Simulating forward pass...")
multi_sampler_thread.train()
start_time = time.time()
mean_var_t, mean_weights_t = multi_sampler_thread.forward()
end_time = time.time()
print(f"Thread Sampler: Forward pass took: {end_time - start_time:.4f} seconds")

Thread Sampler: Simulating forward pass...
Thread Sampler: Forward pass took: 0.0120 seconds


In [24]:
mean_var_t.shape

torch.Size([4, 768, 4])

In [23]:
if isinstance(mean_var_t, torch.Tensor):
    print(f"Thread Sampler: Mean var sample shape: {mean_var_t.shape}")
    print(f"Thread Sampler: Mean expert weights shape: {mean_weights_t.shape}")

    print("Thread Sampler: Calculating variational loss...")
    start_time = time.time()
    combined_loss_t, mean_entropy_term_t = multi_sampler_thread.get_variational_loss()
    end_time = time.time()
    print(f"Thread Sampler: Loss calculation took: {end_time - start_time:.4f} seconds")
    print(f"Thread Sampler: Combined Loss: {combined_loss_t.item():.6f}, Mean Entropy: {mean_entropy_term_t.item():.6f}")
else:
    print("Thread Sampler: Forward pass failed or returned non-tensor.")

Thread Sampler: Mean var sample shape: torch.Size([4, 768, 4])
Thread Sampler: Mean expert weights shape: torch.Size([4])
Thread Sampler: Calculating variational loss...
Thread Sampler: Loss calculation took: 0.0034 seconds
Thread Sampler: Combined Loss: 0.000639, Mean Entropy: 0.250000


In [18]:
multi_sampler_stream.train()
print("Stream Sampler: Train mode set.")

print("Stream Sampler: Simulating forward pass...")
start_time = time.time()
mean_var, mean_weights = multi_sampler_stream.forward()
end_time = time.time()
print(f"Stream Sampler: Forward pass took: {end_time - start_time:.4f} seconds")

if isinstance(mean_var, torch.Tensor):
    print(f"Stream Sampler: Mean var sample shape: {mean_var.shape}")
    print(f"Stream Sampler: Mean expert weights shape: {mean_weights.shape}")

    print("Stream Sampler: Calculating variational loss...")
    start_time = time.time()
    combined_loss, mean_entropy_term = multi_sampler_stream.get_variational_loss()
    end_time = time.time()
    print(f"Stream Sampler: Loss calculation took: {end_time - start_time:.4f} seconds")
    print(f"Stream Sampler: Combined Loss: {combined_loss.item():.6f}, Mean Entropy: {mean_entropy_term.item():.6f}")

    if combined_loss.requires_grad:
            print("Stream Sampler: Loss requires grad.")
    else:
            print("Stream Sampler: Loss does not require grad.")
else:
    print("Stream Sampler: Forward pass failed or returned non-tensor.")

print("Stream Sampler: Setting to eval mode...")
multi_sampler_stream.eval()
# del multi_sampler_stream 
print("Stream Sampler: Deleted.")
time.sleep(0.5) 

Stream Sampler: Train mode set.
Stream Sampler: Simulating forward pass...
Stream Sampler: Forward pass took: 0.0116 seconds
Stream Sampler: Mean var sample shape: torch.Size([4, 768, 4])
Stream Sampler: Mean expert weights shape: torch.Size([4])
Stream Sampler: Calculating variational loss...
Stream Sampler: Loss calculation took: 0.0046 seconds
Stream Sampler: Combined Loss: 0.002058, Mean Entropy: 0.250000
Stream Sampler: Loss requires grad.
Stream Sampler: Setting to eval mode...
Stream Sampler: Deleted.
