In [21]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

import gym
from gym import spaces
# Assuming you have PIL (Python Imaging Library) installed for image processing
from PIL import Image
import torchvision.transforms as transforms


#VIT
class FlexiblePatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_size=768, temporal_patch_size=1, is_3d=False):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_size = embed_size
        self.is_3d = is_3d

        if is_3d:
            self.num_patches = int(((img_size // patch_size) ** 2) * temporal_patch_size)
            self.projection = nn.Conv3d(in_channels, embed_size, kernel_size=(temporal_patch_size, patch_size, patch_size), 

                                        stride=(temporal_patch_size, patch_size, patch_size))

        else:
            self.num_patches = (img_size // patch_size) ** 2
            self.projection = nn.Conv2d(in_channels, embed_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # [B, E, T/P, H/P, W/P] or [B, E, H/P, W/P]
        x = x.flatten(2)  # Flatten spatial and temporal dimensions
        x = x.transpose(1, 2)  # [B, N, E]
        return x

class PositionalEmbedding(nn.Module):
    def __init__(self, num_patches, embed_size):
        super().__init__()

        self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_size))

    def forward(self, x):
        batch_size = x.shape[0]
        cls_token = torch.zeros(batch_size, 1, x.shape[-1], device=x.device)
        x = torch.cat([cls_token, x], dim=1)  # [B, 1+N, E]
        x += self.positional_embedding
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        # Feedforward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = nn.ReLU()

    def forward(self, src):
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
        self.num_layers = num_layers

    def forward(self, src):
        for layer in self.layers:
            src = layer(src)

        return src


class FlexibleVisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_size=768, num_heads=12, num_layers=12, num_classes=1000, temporal_patch_size=1, is_3d=False):
        super().__init__()
        self.patch_embedding = FlexiblePatchEmbedding(img_size, patch_size, in_channels, embed_size, temporal_patch_size, is_3d)
        self.positional_embedding = PositionalEmbedding(self.patch_embedding.num_patches, embed_size)

        encoder_layer = TransformerEncoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, num_classes)
        )

    def forward(self, x, return_embeddings=False):
        x = self.patch_embedding(x)
        x = self.positional_embedding(x)
        x = self.transformer_encoder(x)
        if return_embeddings:
            return x  # Return the sequence of embeddings directly
        cls_token = x[:, 0]
        x = self.mlp_head(cls_token)
        return x


# DualStateMamba
class MambaModelWithDualState(nn.Module):
    def __init__(self, d_model, merge_interval=10):
        super().__init__()
        self.merge_interval = merge_interval
        self.d_model = d_model

        # Simulated components for updating and merging states
        self.state_updater = nn.Linear(d_model, d_model)
        self.state_merger = nn.Linear(2 * d_model, d_model)  # Merges two states into one
        self.state_projector = nn.Linear(d_model, d_model)  # Projects merged state to next state

        # Initial states
        self.reset_states()

    def reset_states(self):
        # Resets/initializes the states; assumes states are kept between forward passes
        self.vit_state = torch.zeros(1, self.d_model)
        self.q_network_state = torch.zeros(1, self.d_model)

    def update_state(self, current_state, new_input):
        if new_input is None:
            # If new_input is None, return the current state without change
            return current_state
        else:
            # Simple state update mechanism; in practice, this could involve more complex temporal processing
            return F.relu(self.state_updater(new_input + current_state))

    def merge_states(self, vit_state, q_network_state):
        # Merges two states; this could be a complex learned fusion operation
        merged = torch.cat([vit_state, q_network_state], dim=1)
        return F.relu(self.state_merger(merged))

    def project_to_next_state(self, merged_state):
        # Projects the merged state to the next state; could involve temporal dynamics modeling
        return F.relu(self.state_projector(merged_state))

    def forward(self, vit_features, q_network_action, timestep):
        # Update states based on new inputs
        self.vit_state = self.update_state(self.vit_state, vit_features)
        self.q_network_state = self.update_state(self.q_network_state, q_network_action)

        if timestep % self.merge_interval == 0:
            # Periodic merge operation
            merged_state = self.merge_states(self.vit_state, self.q_network_state)
            next_state = self.project_to_next_state(merged_state)

            # Update both states with the merged and projected state
            self.vit_state, self.q_network_state = next_state.clone(), next_state.clone()

        # Optionally, return current state for integration with other components

        return self.vit_state, self.q_network_state



# Example usage
d_model = 512  # Feature dimension
mamba_model = MambaModelWithDualState(d_model=d_model)

# Simulated inputs for demonstration
vit_features = torch.randn(1, d_model)  # Output from ViT for current frame
q_network_action = torch.randn(1, d_model)  # Representation of action selected by Q-Network
timestep = 1  # Current timestep

# Forward pass through the model
vit_state, q_network_state = mamba_model(vit_features, q_network_action, timestep)

# The returned states (vit_state, q_network_state) could be used for further processing, decision making, etc.
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        mean_square = torch.mean(x**2, dim=-1, keepdim=True)
        rms = torch.sqrt(mean_square + self.eps)
        return x / rms * self.scale

class MambaBlock(nn.Module):

    def __init__(self, d_model, expand_factor=2, d_state=512):
        super().__init__()
        self.d_inner = expand_factor * d_model
        self.conv1d = nn.Conv1d(in_channels=d_model, out_channels=self.d_inner, kernel_size=3, padding=1)
        self.x_proj = nn.Linear(self.d_inner, d_state)
        self.out_proj = nn.Linear(d_state, d_model)
        self.silu = nn.SiLU()

    def forward(self, x):
        x = x.transpose(1, 2)  # Assuming [batch, feature, seq_len] for Conv1D
        x, z = torch.chunk(self.conv1d(x), 2, dim=1)
        x = x.transpose(1, 2)  # Switch back to [batch, seq_len, feature]
        ssm_params = self.x_proj(self.silu(x))
        z = self.silu(z.transpose(1, 2))

        # Assuming a selective_scan or equivalent function exists and operates on ssm_params
        # This is a placeholder to represent some form of state-space computation
        x = ssm_params * z  # Simplified representation of a state-space operation
        out = self.out_proj(x)
        return out


class ResidualBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.norm = RMSNorm(d_model)
        self.mamba_block = MambaBlock(d_model)

    def forward(self, x):
        out = self.mamba_block(self.norm(x))
        return x + out



class MambaModel(nn.Module):
    def __init__(self, d_model, n_layers):
        super().__init__()
        self.layers = nn.ModuleList([ResidualBlock(d_model) for _ in range(n_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

def selective_scan(delta, B, x):
    """Implements the selective scan operation.
    Args:
        delta (Tensor): Delta vector (shape: [batch_size, d_state])
        B (Tensor): B matrix (shape: [batch_size, d_state, d_state])
        x (Tensor): Input (shape: [batch_size, seq_len, d_state + d_input])

    Returns:
        Tensor: Updated state representation (shape: [batch_size, seq_len, d_state])
    """
    batch_size, seq_len, _ = x.shape
    state = torch.zeros(batch_size, d_state).to(x.device) 
    output = torch.zeros(batch_size, seq_len, d_state).to(x.device)

    for t in range(seq_len):
        xt = x[:, t, :]  # Current input with state concatenated 
        state = B[:, t] * state  + xt * delta[:, t]  
        output[:, t, :] = state

    return output


# Q Network
class QNetwork(nn.Module):
    def __init__(self, mamba_output_dim, num_actions):
        super().__init__()
        self.linear1 = nn.Linear(mamba_output_dim, 128)  
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(128, num_actions)  

    def forward(self, state_representation):
        print(f'Inside QNetwork, input shape: {state_representation.shape}, type: {type(state_representation)}')

        x = self.linear1(state_representation)
        x = self.relu(x)
        q_values = self.linear2(x)
        return q_values



class ReplayBuffer:
    def __init__(self, capacity, n_step, gamma):
        self.capacity = capacity
        self.buffer = []  # List to store experiences
        self.priorities = SumTree(capacity)  # Use a SumTree for efficient priority sampling
        self.n_step = n_step
        self.gamma = gamma 

    def _calculate_n_step_return(self, start_idx):
        """Calculates the n-step return for an experience."""
        reward = 0

        for i in range(self.n_step):
            idx = start_idx + i
            if idx >= len(self.buffer):
                break
            reward += self.buffer[idx][2] * self.gamma**i  # Discounted reward
            if self.buffer[idx][4]:  # Done flag
                break
        return reward 

    def add(self, state, action, reward, next_state, done):
        """Adds a new experience and updates priorities."""
        experience = (state, action, reward, next_state, done)
        max_priority = self.priorities.max() or 1.0  # Avoid 0 priority
        self.buffer.append(experience)
        self.priorities.add(max_priority, len(self.buffer) - 1)


    def sample(self, batch_size, beta=0.4):
        """Samples a batch of experiences with importance-sampling weights."""
        indices = self.priorities.sample(batch_size)
        weights = np.zeros(batch_size, dtype=np.float32)
        batch = [self.buffer[idx] for idx in indices]
        
        total_priority = self.priorities.total_priority()
        min_priority = self.priorities.min() / total_priority
        max_weight = (min_priority * batch_size) ** (-beta)
        
        for i, idx in enumerate(indices):
            priority = self.priorities.get(idx)[0] / total_priority
            weight = (priority * batch_size) ** (-beta) / max_weight
            weights[i] = weight
        
        states, actions, rewards, next_states, dones = zip(*batch)
        
        # Convert lists to tensors. Note: Assuming your states and next_states are images and already in the correct shape.
        states = torch.stack([torch.Tensor(state) for state in states])
        next_states = torch.stack([torch.Tensor(state) for state in next_states])
        actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.float32)
        print(f'Sampled states shape: {states.shape}, type: {type(states)}')

        return states, actions, rewards, next_states, dones, indices, torch.tensor(weights, dtype=torch.float32)


    def update_priorities(self, indices, td_errors):
        """Updates priorities based on TD errors."""
        for idx, td_error in zip(indices, td_errors):
            priority = abs(td_error) + 1e-5  # Avoid zero priority
            self.priorities.update(idx, priority)



def update_q_network(model, q_network, target_q_network, buffer, optimizer, batch_size, gamma, n_step):
    # Sample a batch of experiences
    states, actions, rewards, next_states, dones, indices, weights = buffer.sample(batch_size)
    print(f"states types: {states.type}, states shape: {states.shape}")
    print(f"next_states types: {next_states.type}, next_states shape: {next_states.shape}")
    print(f"actions types: {actions.type}, actions shape: {actions.shape}")
    print(f"rewards types: {rewards.type}, rewards shape: {rewards.shape}")
    print(f"dones types: {dones.type}, dones shape: {dones.shape}")
    #print(f"indices types: {indices.type}")
    print(f"weights types: {weights.type}, weights shape: {weights.shape}")

    actions = actions.unsqueeze(1) if actions.dim() == 1 else actions
    rewards = rewards if rewards.dim() > 0 else torch.tensor(rewards, dtype=torch.float32)
    dones = dones if dones.dim() > 0 else torch.tensor(dones, dtype=torch.float32)
    weights = weights if weights.dim() > 0 else torch.tensor(weights, dtype=torch.float32)

    # Calculate the target Q-values
    with torch.no_grad():
        next_q_values = target_q_network(next_states)
        max_next_q_values = next_q_values.max(dim=1)[0]
        q_targets = rewards + (gamma ** n_step) * max_next_q_values * (1 - dones)

    # Calculate the current Q-values
    current_q_values = q_network(states).gather(1, actions).squeeze()

    # Compute loss using the weights from the prioritized replay buffer
    loss = (weights * F.mse_loss(current_q_values, q_targets, reduction='none')).mean()

    # Perform an optimization step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Update the priorities in the buffer
    td_errors = torch.abs(current_q_values - q_targets).detach().numpy()
    buffer.update_priorities(indices, td_errors)



class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)  # Binary tree, twice the capacity for leaves and parents
        self.data_pointer = 0  # Pointer to the next free data index

    def add(self, priority, data_index):
        tree_index = self.data_pointer + self.capacity - 1
        self.update(tree_index, priority)  # Update the new leaf's value
        self.data_pointer += 1

        if self.data_pointer >= self.capacity: 
            self.data_pointer = 0  # Wrap around if the buffer is full

    def update(self, tree_index, priority):
        change = priority - self.tree[tree_index]
        self.tree[tree_index] = priority
        self._propagate_changes(tree_index, change)  # Update values up the tree

    def min(self):
        """Return the smallest non-zero priority."""
        min_val = np.min(self.tree[-self.capacity:][self.tree[-self.capacity:] > 0])
        if min_val == 0:
            return np.min(self.tree[-self.capacity:])  # Fallback to min, including zeros
        return min_val

    def _propagate_changes(self, tree_index, change):
        parent_index = (tree_index - 1) // 2
        self.tree[parent_index] += change

        if parent_index != 0:
            self._propagate_changes(parent_index, change)

    def get(self, sample):
        parent_index = 0
        while True:
            left_child_index = 2 * parent_index + 1
            right_child_index = left_child_index + 1

            if left_child_index >= len(self.tree):  # Leaf node
                leaf_index = parent_index
                break

            else:  # Continue searching
                if sample <= self.tree[left_child_index]:
                    parent_index = left_child_index
                else:
                    sample -= self.tree[left_child_index]
                    parent_index = right_child_index

        data_index = leaf_index - self.capacity + 1
        return leaf_index, self.tree[leaf_index], data_index

    def total_priority(self):
        return self.tree[0]  # Root node holds the sum of all priorities

    def sample(self, batch_size):
        batch_indices = []
        p_segments = self.total_priority() / batch_size

        for i in range(batch_size):
            start = i * p_segments
            end = (i + 1) * p_segments
            sample = random.uniform(start, end)
            _, _, idx = self.get(sample)
            batch_indices.append(idx)

        return batch_indices

    def max(self):
        return np.max(self.tree[-self.capacity:])  # Max priority among leave
    



class VisualEnvironment(gym.Env):
    def __init__(self):
        super(VisualEnvironment, self).__init__()
        self.action_space = spaces.Discrete(2)  # Example: left or right
        self.observation_space = spaces.Box(low=0, high=255, shape=(224, 224, 3), dtype=np.uint8)
        self.state = None

    def step(self, action):
        reward = 0
        done = False
        info = {}
        # Generate a random image for the next state
        self.state = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
        return self.state, reward, done, info

    def reset(self):
        # Generate an initial random image
        self.state = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
        return self.state



def preprocess_state(image, target_size=(224, 224)):
    """
    Preprocess the state image for the Vision Transformer model.
    Args:
        image (np.ndarray): The state image as a NumPy array.
        target_size (tuple): The target image size (width, height).
    Returns:
        torch.Tensor: The preprocessed image tensor.
    """
    transform = transforms.Compose([
        transforms.ToPILImage(),  # Convert np.ndarray to PIL.Image
        transforms.Resize(target_size),  # Resize to target size
        transforms.ToTensor(),  # Convert PIL.Image to torch.Tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Standard normalization for ImageNet
                             std=[0.229, 0.224, 0.225])
    ])
    image_tensor = transform(image)
    # Add a batch dimension (B, C, H, W) where B=1 since you're processing single images
    image_tensor = image_tensor.unsqueeze(0)
    return image_tensor

# Assuming an environment class Env that supports env.reset() and env.step(action)
env = VisualEnvironment()
state = env.reset()  # Get initial state as a raw image
preprocessed_state = preprocess_state(state)  # Preprocess the initial state
done = False

# Instantiate the model, Q-network, and other components
vit_model = FlexibleVisionTransformer(
    img_size=224,
    patch_size=16,
    in_channels=3,
    embed_size=768,
    num_heads=12,
    num_layers=12,
    num_classes=1000,
    temporal_patch_size=1,
    is_3d=False)

vit_features = vit_model(preprocessed_state)
mamba_model = MambaModelWithDualState(d_model=768)
q_network = QNetwork(mamba_output_dim=768, num_actions=env.action_space.n)
target_q_network = copy.deepcopy(q_network)  # For stable Q-value targets
replay_buffer = ReplayBuffer(capacity=10000, n_step=3, gamma=0.99)
optimizer = torch.optim.Adam(q_network.parameters(), lr=1e-4)
total_episodes = 100
batch_size = 8
# Example training loop
for episode in range(total_episodes):
    state = env.reset()  # Get initial state as a raw image
    preprocessed_state = preprocess_state(state)  # Preprocess the initial state
    print(f'Preprocessed state shape: {preprocessed_state.shape}, type: {type(preprocessed_state)}')

    done = False
    
    while not done:
        # vit_model and mamba_model accept the preprocessed state directly
        vit_output = vit_model(preprocessed_state, return_embeddings=True)
        print(f'Vision Transformer output shape: {vit_output.shape}, type: {type(vit_output)}')

        cls_token = vit_output[:, 0, :]  # Extract the cls_token
        vit_state, q_network_state = mamba_model(cls_token, None, timestep)        
        print(f'Input to QNetwork shape: {vit_state.shape}, type: {type(vit_state)}')


        action = q_network(vit_state).argmax().item()  # Decide action
        next_raw_state, reward, done, _ = env.step(action)  # Execute action in env
        next_preprocessed_state = preprocess_state(next_raw_state)  # Preprocess the new state

        print(f'Buffer size before sampling: {len(replay_buffer.buffer)}')

        # Store the transition in replay buffer
        replay_buffer.add(preprocessed_state.squeeze(0).cpu().numpy(), action, reward, next_preprocessed_state.squeeze(0).cpu().numpy(), done)
        
        # Update Q-network periodically
        if len(replay_buffer.buffer) >= batch_size:
            update_q_network(mamba_model, q_network, target_q_network, replay_buffer, optimizer, batch_size, 0.99, 3)  # Assume n_step=3 and gamma=0.99
        
        preprocessed_state = next_preprocessed_state
    state = env.reset()
    done = False

Preprocessed state shape: torch.Size([1, 3, 224, 224]), type: <class 'torch.Tensor'>
Vision Transformer output shape: torch.Size([1, 197, 768]), type: <class 'torch.Tensor'>
Input to QNetwork shape: torch.Size([1, 768]), type: <class 'torch.Tensor'>
Inside QNetwork, input shape: torch.Size([1, 768]), type: <class 'torch.Tensor'>
Buffer size before sampling: 0
Vision Transformer output shape: torch.Size([1, 197, 768]), type: <class 'torch.Tensor'>
Input to QNetwork shape: torch.Size([1, 768]), type: <class 'torch.Tensor'>
Inside QNetwork, input shape: torch.Size([1, 768]), type: <class 'torch.Tensor'>
Buffer size before sampling: 1
Vision Transformer output shape: torch.Size([1, 197, 768]), type: <class 'torch.Tensor'>
Input to QNetwork shape: torch.Size([1, 768]), type: <class 'torch.Tensor'>
Inside QNetwork, input shape: torch.Size([1, 768]), type: <class 'torch.Tensor'>
Buffer size before sampling: 2
Vision Transformer output shape: torch.Size([1, 197, 768]), type: <class 'torch.Tens

RuntimeError: mat1 and mat2 shapes cannot be multiplied (5376x224 and 768x128)

# Training
1. **Initialization:**
   - Instantiate the environment, models (VIT, MambaModel, QNetwork, and Target QNetwork), replay buffer, and optimizer.
   - Set hyperparameters (learning rate, batch size, discount factor gamma, update frequency for the target network).

2. **Episode Loop:**
   - Reset the environment to get the initial state.

3. **Timestep Loop within an Episode:**
   - Process the current state through VIT and optionally MambaModel to get a state representation.
   - Select an action based on the current state representation.
   - Execute the action in the environment to obtain the next state, reward, and done flag.
   - Preprocess the next state and add the transition to the replay buffer.
   - Sample a batch from the replay buffer and perform a gradient descent step on the Q-network.
   - Every fixed number of steps, update the target Q-network's weights to match the current Q-network's weights.
   - Check for the episode termination condition. If true, break the loop.

4. **Replay Buffer:**
   - Ensure transitions stored and sampled include preprocessed states or appropriate state representations, not raw pixel data.

5. **Q-Network Optimization:**
   - Calculate expected Q values from the Q-network using sampled states.
   - Calculate target Q values using the next state's max Q value from the target Q-network.
   - Compute the loss between expected and target Q values.
   - Perform a backpropagation and optimizer step to update the Q-network's weights.

6. **Target Network Update:**
   - Periodically sync the target Q-network's weights with the current Q-network's weights.

In [73]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

import gym
from gym import spaces
# Assuming you have PIL (Python Imaging Library) installed for image processing
from PIL import Image
import torchvision.transforms as transforms
from collections import namedtuple



#VIT
class FlexiblePatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_size=768, temporal_patch_size=1, is_3d=False):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_size = embed_size
        self.is_3d = is_3d

        if is_3d:
            self.num_patches = int(((img_size // patch_size) ** 2) * temporal_patch_size)
            self.projection = nn.Conv3d(in_channels, embed_size, kernel_size=(temporal_patch_size, patch_size, patch_size), 

                                        stride=(temporal_patch_size, patch_size, patch_size))

        else:
            (img_size // patch_size) ** 2 * embed_size            

            self.projection = nn.Conv2d(in_channels, embed_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        print(f"Input shape to FlexiblePatchEmbedding: {x.shape}")

        x = self.projection(x)  # [B, E, T/P, H/P, W/P] or [B, E, H/P, W/P]
        print(f"Shape after projection: {x.shape}")

        x = x.transpose(1, 2)  # [B, N, E]  – Keep spatial information 
        x = x.flatten(2)
        return x 

class PositionalEmbedding(nn.Module):
    def __init__(self, num_patches, embed_size):
        super().__init__()

        # Initialize positional embeddings with an extra position for the cls_token
        # *** Key Change: Create the positional embedding directly ***
        self.positional_embedding = nn.Parameter(torch.zeros(1, embed_size + 1, embed_size)) 

    def forward(self, x):
        print(f"Number of patches received for positional embeddings: {x.shape[1]}")
        print(f"Shape of x before CLS token: {x.shape}") 
        batch_size = x.shape[0]
        cls_token = torch.zeros(batch_size, 1, x.shape[-1], device=x.device)
        x = torch.cat([cls_token, x], dim=1)  # [B, 1+N, E]

        # Debugging: Print shapes to ensure compatibility
        print(f"x shape: {x.shape}")
        print(f"Positional embedding shape: {self.positional_embedding.shape}")

        # Ensure that the positional_embedding is correctly sized to match `x`
        #assert x.shape[1] == self.positional_embedding.shape[1], "Mismatch in positional embedding size"
        
        x += self.positional_embedding
        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        # Feedforward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = nn.ReLU()

    def forward(self, src):
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
        self.num_layers = num_layers

    def forward(self, src):
        for layer in self.layers:
            src = layer(src)

        return src


class FlexibleVisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_size=768, num_heads=12, num_layers=12, num_classes=1000, temporal_patch_size=1, is_3d=False):
        super().__init__()
        self.patch_embedding = FlexiblePatchEmbedding(img_size, patch_size, in_channels, embed_size, temporal_patch_size, is_3d)
        self.positional_embedding = PositionalEmbedding(embed_size, embed_size) 

        encoder_layer = TransformerEncoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, num_classes)
        )

    def forward(self, x, return_embeddings=False):
        print(f"Batch size: {x.shape[0]}") 

        x = self.patch_embedding(x)
        x = self.positional_embedding(x)
        x = self.transformer_encoder(x)
        if return_embeddings:
            return x  # Return the sequence of embeddings directly
        cls_token = x[:, 0]
        x = self.mlp_head(cls_token)
        return x


# DualStateMamba
class MambaModelWithDualState(nn.Module):
    def __init__(self, d_model, merge_interval=10):
        super().__init__()
        self.merge_interval = merge_interval
        self.d_model = d_model

        # Simulated components for updating and merging states
        self.state_updater = nn.Linear(d_model, d_model)
        self.state_merger = nn.Linear(2 * d_model, d_model)  # Merges two states into one
        self.state_projector = nn.Linear(d_model, d_model)  # Projects merged state to next state

        # Initial states
        self.reset_states()

    def reset_states(self):
        # Resets/initializes the states; assumes states are kept between forward passes
        self.vit_state = torch.zeros(1, self.d_model)
        self.q_network_state = torch.zeros(1, self.d_model)

    def update_state(self, current_state, new_input):
        if new_input is None:
            # If new_input is None, return the current state without change
            return current_state
        else:
            # Simple state update mechanism; in practice, this could involve more complex temporal processing
            return F.relu(self.state_updater(new_input + current_state))

    def merge_states(self, vit_state, q_network_state):
        # Merges two states; this could be a complex learned fusion operation
        merged = torch.cat([vit_state, q_network_state], dim=1)
        return F.relu(self.state_merger(merged))

    def project_to_next_state(self, merged_state):
        # Projects the merged state to the next state; could involve temporal dynamics modeling
        return F.relu(self.state_projector(merged_state))

    def forward(self, vit_features, q_network_action, timestep):
        # Update states based on new inputs
        self.vit_state = self.update_state(self.vit_state, vit_features)
        self.q_network_state = self.update_state(self.q_network_state, q_network_action)

        if timestep % self.merge_interval == 0:
            # Periodic merge operation
            merged_state = self.merge_states(self.vit_state, self.q_network_state)
            next_state = self.project_to_next_state(merged_state)

            # Update both states with the merged and projected state
            self.vit_state, self.q_network_state = next_state.clone(), next_state.clone()

        # Optionally, return current state for integration with other components

        return self.vit_state, self.q_network_state

# The returned states (vit_state, q_network_state) could be used for further processing, decision making, etc.
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        mean_square = torch.mean(x**2, dim=-1, keepdim=True)
        rms = torch.sqrt(mean_square + self.eps)
        return x / rms * self.scale

class MambaBlock(nn.Module):

    def __init__(self, d_model, expand_factor=2, d_state=512):
        super().__init__()
        self.d_inner = expand_factor * d_model
        self.conv1d = nn.Conv1d(in_channels=d_model, out_channels=self.d_inner, kernel_size=3, padding=1)
        self.x_proj = nn.Linear(self.d_inner, d_state)
        self.out_proj = nn.Linear(d_state, d_model)
        self.silu = nn.SiLU()

    def forward(self, x):
        x = x.transpose(1, 2)  # Assuming [batch, feature, seq_len] for Conv1D
        x, z = torch.chunk(self.conv1d(x), 2, dim=1)
        x = x.transpose(1, 2)  # Switch back to [batch, seq_len, feature]
        ssm_params = self.x_proj(self.silu(x))
        z = self.silu(z.transpose(1, 2))

        # Assuming a selective_scan or equivalent function exists and operates on ssm_params
        # This is a placeholder to represent some form of state-space computation
        x = ssm_params * z  # Simplified representation of a state-space operation
        out = self.out_proj(x)
        return out


class ResidualBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.norm = RMSNorm(d_model)
        self.mamba_block = MambaBlock(d_model)

    def forward(self, x):
        out = self.mamba_block(self.norm(x))
        return x + out

class MambaModel(nn.Module):
    def __init__(self, d_model, n_layers):
        super().__init__()
        self.layers = nn.ModuleList([ResidualBlock(d_model) for _ in range(n_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

def selective_scan(delta, B, x):
    """Implements the selective scan operation.
    Args:
        delta (Tensor): Delta vector (shape: [batch_size, d_state])
        B (Tensor): B matrix (shape: [batch_size, d_state, d_state])
        x (Tensor): Input (shape: [batch_size, seq_len, d_state + d_input])

    Returns:
        Tensor: Updated state representation (shape: [batch_size, seq_len, d_state])
    """
    batch_size, seq_len, _ = x.shape
    state = torch.zeros(batch_size, d_state).to(x.device) 
    output = torch.zeros(batch_size, seq_len, d_state).to(x.device)

    for t in range(seq_len):
        xt = x[:, t, :]  # Current input with state concatenated 
        state = B[:, t] * state  + xt * delta[:, t]  
        output[:, t, :] = state

    return output


# Env


class VisualEnvironment(gym.Env):
    def __init__(self):
        super(VisualEnvironment, self).__init__()
        self.action_space = spaces.Discrete(2)  # Example: left or right
        self.observation_space = spaces.Box(low=0, high=255, shape=(224, 224, 3), dtype=np.uint8)
        self.state = None

    def step(self, action):
        reward = 0
        done = False
        info = {}
        # Generate a random image for the next state
        self.state = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
        return self.state, reward, done, info

    def reset(self):
        # Generate an initial random image
        self.state = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
        return self.state



def preprocess_state(image, target_size=(224, 224)):
    """
    Preprocess the state image for the Vision Transformer model.
    Args:
        image (np.ndarray or torch.Tensor): The state image as a NumPy array or a PyTorch tensor.
        target_size (tuple): The target image size (width, height).
    Returns:
        torch.Tensor: The preprocessed image tensor.
    """
    if isinstance(image, np.ndarray):
        # Convert NumPy array to torch tensor
        image_tensor = torch.tensor(image)
        # Check if the image has more than 3 dimensions (e.g., batch dimension)
        if len(image_tensor.shape) > 3:
            # Remove the batch dimension if present
            image_tensor = image_tensor.squeeze(0)
        # Permute dimensions to match PyTorch convention (assuming the input is HWC)
        image_tensor = image_tensor.permute(2, 0, 1)
    elif isinstance(image, torch.Tensor):
        # If already a torch tensor, no need for conversion
        image_tensor = image
    else:
        raise TypeError("Input type should be np.ndarray or torch.Tensor")

    # Resize and normalize the image tensor
    transform = transforms.Compose([
        transforms.ToPILImage(),  # Convert to PIL Image
        transforms.Resize(target_size),  # Resize to target size
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
    ])

    image_tensor = transform(image_tensor)  # Apply transformations

    # Add a batch dimension if it's not present
    if len(image_tensor.shape) < 3:
        image_tensor = image_tensor.unsqueeze(0)

    return image_tensor






# Q Network
class QNetwork(nn.Module):
    def __init__(self, mamba_output_dim, num_actions):
        super().__init__()
        self.linear1 = nn.Linear(mamba_output_dim, 128)  
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(128, num_actions)  

    def forward(self, state_representation):
        print(f'Inside QNetwork, input shape: {state_representation.shape}, type: {type(state_representation)}')

        x = self.linear1(state_representation)
        x = self.relu(x)
        q_values = self.linear2(x)
        return q_values


class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)  # Binary tree, twice the capacity for leaves and parents
        self.data_pointer = 0  # Pointer to the next free data index

    def add(self, priority, data_index):
        tree_index = self.data_pointer + self.capacity - 1
        self.update(tree_index, priority)  # Update the new leaf's value
        self.data_pointer += 1

        if self.data_pointer >= self.capacity: 
            self.data_pointer = 0  # Wrap around if the buffer is full

    def update(self, tree_index, priority):
        change = priority - self.tree[tree_index]
        self.tree[tree_index] = priority
        self._propagate_changes(tree_index, change)  # Update values up the tree

    def min(self):
        """Return the smallest non-zero priority."""
        min_val = np.min(self.tree[-self.capacity:][self.tree[-self.capacity:] > 0])
        if min_val == 0:
            return np.min(self.tree[-self.capacity:])  # Fallback to min, including zeros
        return min_val

    def _propagate_changes(self, tree_index, change):
        parent_index = (tree_index - 1) // 2
        self.tree[parent_index] += change

        if parent_index != 0:
            self._propagate_changes(parent_index, change)

    def get(self, sample):
        parent_index = 0
        while True:
            left_child_index = 2 * parent_index + 1
            right_child_index = left_child_index + 1

            if left_child_index >= len(self.tree):  # Leaf node
                leaf_index = parent_index
                break

            else:  # Continue searching
                if sample <= self.tree[left_child_index]:
                    parent_index = left_child_index
                else:
                    sample -= self.tree[left_child_index]
                    parent_index = right_child_index

        data_index = leaf_index - self.capacity + 1
        return leaf_index, self.tree[leaf_index], data_index

    def total_priority(self):
        return self.tree[0]  # Root node holds the sum of all priorities

    def sample(self, batch_size):
        batch_indices = []
        p_segments = self.total_priority() / batch_size

        for i in range(batch_size):
            start = i * p_segments
            end = (i + 1) * p_segments
            sample = random.uniform(start, end)
            _, _, idx = self.get(sample)
            batch_indices.append(idx)

        return batch_indices

    def max(self):
        return np.max(self.tree[-self.capacity:])  # Max priority among leave
    

class ReplayBuffer:
    def __init__(self, capacity, n_step, gamma):
        self.capacity = capacity
        self.buffer = []  # List to store experiences
        self.priorities = SumTree(capacity)  # Use a SumTree for efficient priority sampling
        self.n_step = n_step
        self.gamma = gamma 

    def _calculate_n_step_return(self, start_idx):
        """Calculates the n-step return for an experience."""
        reward = 0

        for i in range(self.n_step):
            idx = start_idx + i
            if idx >= len(self.buffer):
                break
            reward += self.buffer[idx][2] * self.gamma**i  # Discounted reward
            if self.buffer[idx][4]:  # Done flag
                break
        return reward 

    def add(self, state, action, reward, next_state, done):
        """Adds a new experience and updates priorities."""
        experience = (state, action, reward, next_state, done)
        max_priority = self.priorities.max() or 1.0  # Avoid 0 priority
        self.buffer.append(experience)
        self.priorities.add(max_priority, len(self.buffer) - 1)


    def sample(self, batch_size, beta=0.4):
        """Samples a batch of experiences with importance-sampling weights."""
        indices = self.priorities.sample(batch_size)
        weights = np.zeros(batch_size, dtype=np.float32)
        batch = [self.buffer[idx] for idx in indices]
        
        total_priority = self.priorities.total_priority()
        min_priority = self.priorities.min() / total_priority
        max_weight = (min_priority * batch_size) ** (-beta)
        
        for i, idx in enumerate(indices):
            priority = self.priorities.get(idx)[0] / total_priority
            weight = (priority * batch_size) ** (-beta) / max_weight
            weights[i] = weight
        
        states, actions, rewards, next_states, dones = zip(*batch)
        
        # Convert lists to tensors. Note: Assuming your states and next_states are images and already in the correct shape.
        states = torch.stack([torch.Tensor(state) for state in states])
        next_states = torch.stack([torch.Tensor(state) for state in next_states])
        actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.float32)
        print(f'Sampled states shape: {states.shape}, type: {type(states)}')

        return states, actions, rewards, next_states, dones, indices, torch.tensor(weights, dtype=torch.float32)


    def update_priorities(self, indices, td_errors):
        """Updates priorities based on TD errors."""
        for idx, td_error in zip(indices, td_errors):
            priority = abs(td_error) + 1e-5  # Avoid zero priority
            self.priorities.update(idx, priority)


def update_q_network(q_network, target_q_network, replay_buffer, optimizer, gamma, batch_size, n_step):
    # Sample a batch of experiences if the buffer is large enough
    if len(replay_buffer.buffer) < batch_size:
        return

    states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(batch_size)

    # Explicitly convert dones to a Boolean tensor
    dones = dones.bool()

    non_final_mask = ~dones
    non_final_next_states = next_states[non_final_mask]

    state_action_values = q_network(states).gather(1, actions.view(-1, 1)).squeeze()

    next_state_values = torch.zeros(batch_size)
    if non_final_next_states.size(0) > 0:  # Check if there are any non-final next states
        next_state_values[non_final_mask] = target_q_network(non_final_next_states).max(1)[0].detach()
    expected_state_action_values = (next_state_values * gamma ** n_step) + rewards

    loss = (weights * F.mse_loss(state_action_values, expected_state_action_values.detach(), reduction='none')).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    td_errors = torch.abs(state_action_values - expected_state_action_values).detach().numpy()
    replay_buffer.update_priorities(indices, td_errors)


Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

class TrainingManager:
    def __init__(self, env, vit_model, mamba_model, q_network, target_q_network, replay_buffer, optimizer, batch_size=32, gamma=0.99, update_target_every=100):
        self.env = env
        self.vit_model = vit_model
        self.mamba_model = mamba_model
        self.q_network = q_network
        self.target_q_network = target_q_network
        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.gamma = gamma
        self.update_target_every = update_target_every
        self.num_actions = env.action_space.n
        self.steps_done = 0

    def preprocess(self, image):
        # Assuming preprocess_state function is defined outside this class
        return preprocess_state(image)

    def select_action(self, state):
        sample = random.random()
        eps_threshold = 0.05  # Fixed strategy for simplicity
        if sample > eps_threshold:
            with torch.no_grad():
                # Forward pass through the vit_model to get state features
                state_features = self.vit_model(self.preprocess(state), return_embeddings=True)
                # Using the cls token as state representation
                state_representation = state_features[:, 0, :]
                # Forward pass through the q_network to get action values
                return self.q_network(state_representation).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.num_actions)]], dtype=torch.long)
        
    @staticmethod
    def extract_features(vit_model, states):
        """
        Extract features from states using the Vision Transformer model.
        Args:
            vit_model (nn.Module): The Vision Transformer model.
            states (List[torch.Tensor]): List of processed state tensors.
        Returns:
            torch.Tensor: Stacked tensor of extracted features.
        """
        features = []
        for state in states:
            processed_state = preprocess_state(state)  # Process state
            with torch.no_grad():
                feature = vit_model(processed_state).squeeze(0)  # Forward pass through Vision Transformer
            features.append(feature)
        return torch.stack(features)



    def optimize_model(self):
        if len(self.replay_buffer.buffer) < self.batch_size:
            return
        states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(batch_size)
        
        # Convert states tensor to list of images
        states_list = [state.numpy() for state in states]
        next_states_list = [state.numpy() for state in next_states]
        # Process states and next_states through the Vision Transformer to get feature representations
        state_features = TrainingManager.extract_features(vit_model, states_list)
        next_state_features = TrainingManager.extract_features(vit_model, next_states_list)

        # Now pass state_features and next_state_features to the QNetwork
        update_q_network(q_network, 
                        target_q_network, 
                        state_features, 
                        actions, 
                        rewards, 
                        next_state_features, 
                        dones, 
                        indices, 
                        weights, 
                        optimizer, 
                        self.gamma, 
                        batch_size, 
                        self.n_step)



    def train(self, num_episodes):
        for i_episode in range(num_episodes):
            # Initialize the environment and state
            print(f"Episode {i_episode}")
            state = self.env.reset()
            for t in range(1000):  # Maximum steps per episode
                print(f"Step {t} , Episode {i_episode}")

                action = self.select_action(state)
                next_state, reward, done, _ = self.env.step(action.item())
                print(f"action shape: {action.shape}")
                print(f"next_state shape: {next_state.shape}")
                print(f"reward shape: {reward.shape}")
                print(f"done shape: {done.shape}")
                print(f"state shape: {state.shape}")



                if not done:
                    next_state_processed = self.preprocess(next_state)
                else:
                    next_state_processed = None

                # Store the transition in memory
                self.replay_buffer.add(state, action, reward, next_state_processed, done)

                # Perform one step of the optimization
                self.optimize_model()
                if done:
                    break

                state = next_state
                self.steps_done += 1


# Assuming an environment class Env that supports env.reset() and env.step(action)
env = VisualEnvironment()
state = env.reset()  # Get initial state as a raw image
print(state.shape)
preprocessed_state = preprocess_state(state)  # Preprocess the initial state

done = False

# Instantiate the model, Q-network, and other components
# Instantiate the model, Q-network, and other components
vit_model = FlexibleVisionTransformer(
    img_size=224,
    patch_size=16,
    embed_size=768,
    num_heads=12,
    num_layers=12,
    num_classes=1000,
    temporal_patch_size=1,
    is_3d=False)


vit_features = vit_model(preprocessed_state)
mamba_model = MambaModelWithDualState(d_model=768)
q_network = QNetwork(mamba_output_dim=768, num_actions=env.action_space.n)
target_q_network = copy.deepcopy(q_network)  # For stable Q-value targets
replay_buffer = ReplayBuffer(capacity=10000, n_step=3, gamma=0.99)
optimizer = torch.optim.Adam(q_network.parameters(), lr=1e-4)
total_episodes = 100
batch_size = 8
# Instantiate TrainingManager
# Assuming the environment, models, replay buffer, and optimizer have been instantiated
training_manager = TrainingManager(env, vit_model, mamba_model, q_network, target_q_network, replay_buffer, optimizer, batch_size=8)

# Start training
training_manager.train(num_episodes=100)

(224, 224, 3)
Batch size: 3
Input shape to FlexiblePatchEmbedding: torch.Size([3, 224, 224])
Shape after projection: torch.Size([768, 14, 14])
Number of patches received for positional embeddings: 14
Shape of x before CLS token: torch.Size([768, 14, 14])
x shape: torch.Size([768, 15, 14])
Positional embedding shape: torch.Size([1, 769, 768])


RuntimeError: The size of tensor a (14) must match the size of tensor b (768) at non-singleton dimension 2

# testing

In [77]:
import torch
import torch.nn as nn



# Let's simulate an image input
image_tensor = torch.randn(3, 224, 224)  

patch_embed = FlexiblePatchEmbedding(img_size=224, patch_size=16) 
pos_embed = PositionalEmbedding(embed_size=768,embed_size=768) 

output = patch_embed(image_tensor)
output += pos_embed(output)  # This should work! 


SyntaxError: keyword argument repeated: embed_size (2537261159.py, line 10)