In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import DetrFeatureExtractor, DetrForObjectDetection, ViTFeatureExtractor, ViTModel, DetrImageProcessor
from einops import rearrange
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, Dataset
from datasets import load_dataset
import numpy as np
import math
from torchvision.transforms import functional as TF
from torchvision.transforms import ToTensor, Normalize
from collections import Counter

class VisionExpert(nn.Module):
    def __init__(self, num_classes=80, num_experts=4, pe_type='regular', device='cuda'):
        super(VisionExpert, self).__init__()
        self.device = device
        # Load pretrained models
        self.resnet_model = models.resnet152(pretrained=True)

        # Remove the last layer of resnet
        self.resnet_model = nn.Sequential(*list(self.resnet_model.children())[:-1])
        # Positional Encoding
        self.pe_type = pe_type
        if pe_type == 'regular':
            self.pos_encoder = VisionExpert.PositionalEncoding(d_model=2048)
        elif pe_type == 'rotary':
            self.pos_encoder = VisionExpert.RotaryPositionalEncoding(dim=2048)
        elif pe_type == 'contextual':
            self.pos_encoder = VisionExpert.ContextualPositionalEncoding(d_model=2048)

        # Custom layers
        self.layernorm = nn.LayerNorm(normalized_shape=2048)
        self.dropout = nn.Dropout(p=0.1)
        self.attention = nn.MultiheadAttention(embed_dim=2048, num_heads=8)
        
        # Initialize sub-models
        self.vision_transformer = VisionExpert.VisionTransformer(num_classes=num_classes)
        self.detection_transformer = VisionExpert.DetectionTransformer(num_classes=num_classes)
        self.vision_mamba = VisionExpert.VisionMamba(num_classes=num_classes)
        
        # Sub Model Router
        self.sub_model_gate = nn.Linear(2048, 3)  # 3 sub-models

        # Initialize Lory MoE
        self.moe_layer = VisionExpert.LORY_MOE(input_dim=768, hidden_dim=512, num_experts=4, num_classes=num_classes, num_layers=3, segment_size=256)
        
        # Beam Search (using top-k sampling for simplicity)
        self.k = 5

    def forward(self, x):
        # CNN Image Encoder
        cnn_features = self.resnet_model(x).squeeze()
        
        # Reshape cnn_features to (batch_size, sequence_length, embedding_dim)
        cnn_features = cnn_features.view(cnn_features.size(0), -1, 2048).permute(1, 0, 2)

        # Positional Encoding
        if self.pe_type == 'regular':
            cnn_features = self.pos_encoder(cnn_features)
        elif self.pe_type == 'rotary':
            cnn_features = self.pos_encoder(cnn_features)
        elif self.pe_type == 'contextual':
            cnn_features = self.pos_encoder(cnn_features, mask=None)  # Update mask if needed

        # Blast Attention
        attn_output, _ = self.attention(cnn_features, cnn_features, cnn_features)
        attn_output = attn_output.permute(1, 0, 2)  # Permute back to (batch_size, sequence_length, embedding_dim)
        
        # LayerNorm and Dropout
        norm_output = self.layernorm(attn_output)
        norm_output = self.dropout(norm_output)
        
        # Sub Model Router
        gate_output = self.sub_model_gate(norm_output.mean(dim=1))
        sub_model_idx = gate_output.argmax(dim=-1)

        class_logits_list = []
        bbox_preds_list = []

        for i in range(sub_model_idx.size(0)):
            if sub_model_idx[i] == 0:
                # Use Vision Transformer
                class_logits, bbox_preds = self.vision_transformer(x[i].unsqueeze(0))
            elif sub_model_idx[i] == 1:
                # Use Detection Transformer
                image_input = x[i].unsqueeze(0)
                # Rescale image values to [0, 1]
                mean = torch.tensor([0.485, 0.456, 0.406], device=image_input.device).view(1, 3, 1, 1)
                std = torch.tensor([0.229, 0.224, 0.225], device=image_input.device).view(1, 3, 1, 1)
                image_input = image_input * std + mean
                class_logits, bbox_preds = self.detection_transformer(image_input)
            elif sub_model_idx[i] == 2:
                # Use Vision Mamba
                class_logits, bbox_preds = self.vision_mamba(x[i].unsqueeze(0))
            
            class_logits_list.append(class_logits)
            bbox_preds_list.append(bbox_preds)

        class_logits = torch.cat(class_logits_list, dim=0)
        bbox_preds = torch.cat(bbox_preds_list, dim=0)

        # Ensure the class_logits has 3 dimensions for the MoE layer
        class_logits = class_logits.unsqueeze(1)

        # MoE Layers
        moe_output = self.moe_layer(class_logits)

        # Blast Attention
        final_output, _ = self.attention(moe_output, moe_output, moe_output)
        
        # LayerNorm and Dropout
        final_output = self.layernorm(final_output)
        final_output = self.dropout(final_output)
        
        # Beam Search Decoding (simplified as top-k sampling)
        top_k_output = final_output.topk(self.k, dim=-1).values
        
        return top_k_output, bbox_preds



    def train_model(self, dataloader, num_epochs, learning_rate, mode='full'):
        """
        Train the Vision Expert model.
        
        Parameters:
        - dataloader: DataLoader for training data
        - num_epochs: number of epochs to train
        - learning_rate: learning rate for optimizer
        - mode: 'vit', 'mamba', 'detr', or 'full'
        """
        # Define the optimizer and loss function
        optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        criterion_cls = nn.CrossEntropyLoss()  # Assuming classification task
        criterion_bbox = nn.MSELoss()  # Assuming regression task for bounding boxes

        # Training loop
        for epoch in range(num_epochs):
            self.train()
            running_loss = 0.0
            
            for batch in dataloader:
                inputs = batch['inputs'].to(self.device)
                labels_cls = batch['labels'].to(self.device)
                labels_bbox = batch['boxes'].to(self.device) if 'boxes' in batch else None
                
                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                if mode == 'vit':
                    class_logits, bbox_preds = self.vision_transformer(inputs)
                elif mode == 'mamba':
                    class_logits, bbox_preds = self.vision_mamba(inputs)
                elif mode == 'detr':
                    class_logits, bbox_preds = self.detection_transformer(inputs)
                elif mode == 'full':
                    class_logits, bbox_preds = self(inputs)
                else:
                    raise ValueError("Mode must be 'vit', 'mamba', 'detr', or 'full'")

                # Reshape class_logits to match the shape of labels_cls
                batch_size, num_boxes, num_classes = class_logits.size()
                class_logits = class_logits.view(-1, num_classes)
                labels_cls = labels_cls.view(-1)
                
                # Filter out -1 labels (used for padding)
                valid_idx = labels_cls != -1
                class_logits = class_logits[valid_idx]
                labels_cls = labels_cls[valid_idx]

                loss_cls = criterion_cls(class_logits, labels_cls)
                loss_bbox = criterion_bbox(bbox_preds.view(-1, 4), labels_bbox.view(-1, 4)) if labels_bbox is not None else 0
                loss = loss_cls + loss_bbox

                # Backward pass and optimize
                loss.backward()
                optimizer.step()
                
                # Update running loss
                running_loss += loss.item() * inputs.size(0)
            
            epoch_loss = running_loss / len(dataloader.dataset)
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

        print("Training complete")








    ###################################################
    # POSITIONAL ENCODING

    class PositionalEncoding(nn.Module):
        def __init__(self, d_model, max_len=5000):
            super(VisionExpert.PositionalEncoding, self).__init__()
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0).transpose(0, 1)
            self.register_buffer('pe', pe)

        def forward(self, x):
            x = x + self.pe[:x.size(0), :]
            return x
        
    class RotaryPositionalEncoding(nn.Module):
        def __init__(self, dim, max_len=5000):
            super(VisionExpert.RotaryPositionalEncoding, self).__init__()
            inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
            pos = torch.arange(max_len).float()
            sinusoid_inp = torch.einsum('i,j->ij', pos, inv_freq)
            emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
            self.register_buffer('emb', emb)

        def forward(self, x):
            return self.apply_rotary_pos_emb(x, self.emb[:x.size(-2), :])

        def apply_rotary_pos_emb(self, x, rope):
            x = torch.einsum('bnd,d->bnd', x, rope)
            return x
        
    class ContextualPositionalEncoding(nn.Module):
        def __init__(self, d_model, max_len=5000):
            super(VisionExpert.ContextualPositionalEncoding, self).__init__()
            self.d_model = d_model
            self.max_len = max_len
            self.pe = nn.Parameter(torch.zeros(1, max_len, d_model))

        def forward(self, x, mask):
            seq_len = x.size(1)
            pe = self.pe[:, :seq_len, :]
            x = x + pe
            return x


    ####################################################
    # VISION TRANSFORMER
    class VisionTransformer(nn.Module):
        def __init__(self, num_classes=80):
            super(VisionExpert.VisionTransformer, self).__init__()

            # Load pretrained ViT model and feature extractor
            self.vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
            self.vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

            # Define additional layers
            self.embedding_layer = nn.Linear(768, 512)
            self.layernorm1 = nn.LayerNorm(512)
            self.dropout1 = nn.Dropout(p=0.1)
            self.attention = nn.MultiheadAttention(embed_dim=512, num_heads=8)
            self.layernorm2 = nn.LayerNorm(512)
            self.dropout2 = nn.Dropout(p=0.1)
            self.moe_layer = nn.Linear(512, 512)
            self.classifier = nn.Linear(512, num_classes)
            self.bbox_predictor = nn.Linear(512, 4)  # Predicts bounding boxes

        def forward(self, x):
            # Denormalize the image from range [-2.1179, 2.6400] to [0, 1]
            mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
            x = x * std + mean

            # Extract features using ViT feature extractor
            inputs = self.vit_feature_extractor(images=x, return_tensors="pt").to(x.device)
            outputs = self.vit_model(**inputs)

            # Get the embedding space from ViT model
            embeddings = outputs.last_hidden_state[:, 0, :]  # Using the [CLS] token output

            # Additional layers for downstream tasks
            x = self.embedding_layer(embeddings)
            x = self.layernorm1(x)
            x = self.dropout1(x)

            # Blast Attention
            x = x.unsqueeze(0)  # MultiheadAttention expects input shape (L, N, E)
            attn_output, _ = self.attention(x, x, x)
            x = attn_output.squeeze(0)

            x = self.layernorm2(x)
            x = self.dropout2(x)

            # Mixture of Experts layer
            x = F.relu(self.moe_layer(x))

            # Final classification and bounding box prediction layers
            class_logits = self.classifier(x)
            bbox_preds = self.bbox_predictor(x)

            return class_logits, bbox_preds


    ####################################################
    # DETECTION TRANSFORMER

    class DetectionTransformer(nn.Module):
        def __init__(self, num_classes=80):
            super(VisionExpert.DetectionTransformer, self).__init__()

            # Load pretrained DETR model and feature extractor
            self.detr_feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50')
            self.detr_model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')

            # Define additional layers
            self.embedding_layer = nn.Linear(256, 512)
            self.layernorm1 = nn.LayerNorm(512)
            self.dropout1 = nn.Dropout(p=0.1)
            self.attention = nn.MultiheadAttention(embed_dim=512, num_heads=8)
            self.layernorm2 = nn.LayerNorm(512)
            self.dropout2 = nn.Dropout(p=0.1)
            self.moe_layer = nn.Linear(512, 512)
            self.classifier = nn.Linear(512, num_classes)
            self.bbox_predictor = nn.Linear(512, 4)  # Predicts bounding boxes

        def forward(self, x):
            # Extract features using DETR feature extractor
            inputs = self.detr_feature_extractor(images=x, return_tensors="pt").to(x.device)
            outputs = self.detr_model(**inputs)

            # Get the embedding space from DETR model
            embeddings = outputs.last_hidden_state  # Using the output embeddings

            # Additional layers for downstream tasks
            x = self.embedding_layer(embeddings)
            x = self.layernorm1(x)
            x = self.dropout1(x)

            # Blast Attention
            x = x.permute(1, 0, 2)  # MultiheadAttention expects input shape (L, N, E)
            attn_output, _ = self.attention(x, x, x)
            x = attn_output.permute(1, 0, 2)

            x = self.layernorm2(x)
            x = self.dropout2(x)

            # Mixture of Experts layer
            x = F.relu(self.moe_layer(x))

            # Final classification and bounding box prediction layers
            class_logits = self.classifier(x)
            bbox_preds = self.bbox_predictor(x)

            return class_logits, bbox_preds


    ####################################################
    # VISION MAMBA

    class SSM(nn.Module):
        def __init__(self, input_dim, state_dim, hidden_dim):
            super(VisionExpert.SSM, self).__init__()
            self.state_dim = state_dim
            self.A = nn.Parameter(torch.randn(input_dim, state_dim))
            self.B = nn.Parameter(torch.randn(state_dim, hidden_dim))
            self.C = nn.Parameter(torch.randn(hidden_dim, input_dim))
            self.delta = nn.Parameter(torch.randn(state_dim, hidden_dim))
            self.activation = nn.SiLU()
            
        def forward(self, x):
            x = self.activation(F.conv1d(x, self.A))
            B_tilde = torch.einsum('bmn,nk->bmk', self.delta, self.B)
            C_tilde = torch.einsum('bmn,nk->bmk', self.delta, self.C)
            y = torch.einsum('bmn,bmk->bkn', x, B_tilde)
            y = torch.einsum('bkn,bnm->bkm', y, C_tilde)
            return y

    class VimBlock(nn.Module):
        def __init__(self, input_dim, state_dim, hidden_dim):
            super(VisionExpert.VimBlock, self).__init__()
            self.layernorm1 = nn.LayerNorm(input_dim)
            self.layernorm2 = nn.LayerNorm(hidden_dim)
            self.ssm_forward = VisionExpert.SSM(input_dim, state_dim, hidden_dim)
            self.ssm_backward = VisionExpert.SSM(input_dim, state_dim, hidden_dim)
            self.proj = nn.Linear(hidden_dim, input_dim)
            self.dropout = nn.Dropout(p=0.1)

        def forward(self, x):
            x_norm = self.layernorm1(x)
            x_forward = self.ssm_forward(x_norm)
            x_backward = self.ssm_backward(torch.flip(x_norm, dims=[1]))
            x_combined = x_forward + torch.flip(x_backward, dims=[1])
            x_combined = self.layernorm2(x_combined)
            x_combined = self.proj(x_combined)
            x_combined = self.dropout(x_combined)
            return x + x_combined

    class VisionMamba(nn.Module):
        def __init__(self, image_size=224, patch_size=16, input_dim=768, state_dim=128, hidden_dim=256, num_classes=80, num_blocks=12):
            super(VisionExpert.VisionMamba, self).__init__()
            self.patch_size = patch_size
            self.num_patches = (image_size // patch_size) ** 2
            self.linear_proj = nn.Linear(patch_size * patch_size * 3, input_dim)
            self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, input_dim))
            self.class_token = nn.Parameter(torch.randn(1, 1, input_dim))
            self.vim_blocks = nn.ModuleList([VisionExpert.VimBlock(input_dim, state_dim, hidden_dim) for _ in range(num_blocks)])
            self.norm = nn.LayerNorm(input_dim)
            self.mlp_head = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, num_classes)
            )
            self.bbox_predictor = nn.Linear(hidden_dim, 4)  # Predicts bounding boxes

        def forward(self, x):
            # Preprocess input images into patches
            B, C, H, W = x.shape
            x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)
            x = self.linear_proj(x)
            
            # Add class token and position embeddings
            class_tokens = self.class_token.expand(B, -1, -1)
            x = torch.cat((class_tokens, x), dim=1)
            x = x + self.position_embedding
            
            # Pass through Vim blocks
            for block in self.vim_blocks:
                x = block(x)
            
            # Classification head
            x = self.norm(x)
            class_token_final = x[:, 0]
            logits = self.mlp_head(class_token_final)
            bbox_preds = self.bbox_predictor(x.mean(dim=1))

            return logits, bbox_preds

    ####################################################
    # LORY MIXTURE OF EXPERTS
    class Expert(nn.Module):
        def __init__(self, input_dim, hidden_dim):
            super(VisionExpert.Expert, self).__init__()
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, input_dim)
            self.activation = nn.ReLU()

        def forward(self, x):
            x = self.activation(self.fc1(x))
            x = self.fc2(x)
            return x

    class Router(nn.Module):
        def __init__(self, input_dim, num_experts):
            super(VisionExpert.Router, self).__init__()
            self.fc = nn.Linear(input_dim, num_experts)

        def forward(self, x):
            return F.softmax(self.fc(x), dim=-1)

    class MoELayer(nn.Module):
        def __init__(self, input_dim, hidden_dim, num_experts):
            super(VisionExpert.MoELayer, self).__init__()
            self.experts = nn.ModuleList([VisionExpert.Expert(input_dim, hidden_dim) for _ in range(num_experts)])
            self.router = VisionExpert.Router(input_dim, num_experts)

        def forward(self, x, segment_size):
            B, L, D = x.size()
            num_segments = L // segment_size
            x_segments = x.view(B * num_segments, segment_size, D)
            
            output_segments = []
            for i in range(num_segments):
                segment = x_segments[:, i, :]
                routing_weights = self.router(segment.mean(dim=1))
                merged_expert_params = torch.stack([weight * expert(segment) for weight, expert in zip(routing_weights.T, self.experts)], dim=0).sum(dim=0)
                output_segments.append(merged_expert_params)

            output = torch.cat(output_segments, dim=1)
            return output

    class LORY_MOE(nn.Module):
        def __init__(self, input_dim, hidden_dim, num_experts, num_classes, num_layers, segment_size):
            super(VisionExpert.LORY_MOE, self).__init__()
            self.segment_size = segment_size
            self.moelayers = nn.ModuleList([VisionExpert.MoELayer(input_dim, hidden_dim, num_experts) for _ in range(num_layers)])
            self.fc = nn.Linear(input_dim, num_classes)

        def forward(self, x):
            for moelayer in self.moelayers:
                x = moelayer(x, self.segment_size)
            x = x.mean(dim=1)
            x = self.fc(x)
            return x


# Train Vision Expert on COCO

# Initialize the processor
# Load the dataset
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')

dataset = load_dataset("detection-datasets/coco", split='train[:1%]')

def preprocess_image(image):
    image = TF.resize(image, (800, 800))
    image = ToTensor()(image)  # Converts to [0, 1]

    # Ensure that the image values are in the expected range [0, 1] after normalization
    print(f"Min and Max pixel values after normalization: {image.min()}, {image.max()}")
    return image

class CustomCocoDataset(Dataset):
    def __init__(self, dataset, target_size=(800, 800)):
        self.dataset = dataset
        self.target_size = target_size

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image = sample['image']

        # Ensure image is in RGB
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Preprocess image
        image = preprocess_image(image)
        #print(f"Image pixel values range before processor: {image.min()} to {image.max()}")

        inputs = processor(images=image, return_tensors="pt")

        # Scale bounding boxes according to the new image size
        scale_x = self.target_size[0] / sample['width']
        scale_y = self.target_size[1] / sample['height']
        bboxes = []
        for box in sample['objects']['bbox']:
            x0 = box[0] * scale_x
            y0 = box[1] * scale_y
            x1 = (box[0] + box[2]) * scale_x
            y1 = (box[1] + box[3]) * scale_y
            bboxes.append([x0, y0, x1, y1])
        tensor_boxes = torch.tensor(bboxes, dtype=torch.float32)

        labels = torch.tensor(sample['objects']['category'], dtype=torch.long)
        #print(f" inputs : {inputs['pixel_values'].squeeze(0)}, 'boxes': {tensor_boxes}, 'labels': {labels}")
        return {'inputs': inputs['pixel_values'].squeeze(0), 'boxes': tensor_boxes, 'labels': labels}

# Adjust collate function as needed:
def collate_fn(batch):
    inputs = torch.stack([item['inputs'] for item in batch])
    max_boxes = max(len(item['boxes']) for item in batch)
    padded_boxes = torch.zeros((len(batch), max_boxes, 4))
    box_masks = torch.zeros((len(batch), max_boxes), dtype=torch.bool)
    padded_labels = torch.full((len(batch), max_boxes), -1)  # Fill labels that are not present with -1

    for i, item in enumerate(batch):
        num_boxes = item['boxes'].shape[0]
        padded_boxes[i, :num_boxes] = item['boxes']
        padded_labels[i, :num_boxes] = item['labels']
        box_masks[i, :num_boxes] = 1

    #print(f"batch size: {len(batch)}, inputs: {inputs}, boxes: {padded_boxes}, labels: {padded_labels}, box_masks: {box_masks} ")
    return {'inputs': inputs, 'boxes': padded_boxes, 'labels': padded_labels, 'box_masks': box_masks}

# Create the dataset and data loader
coco_dataset = CustomCocoDataset(dataset)
dataloader = DataLoader(coco_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

'''# Fetch the first batch
for batch in dataloader:
    print("Batch overview:")
    print(f"Inputs Shape: {batch['inputs'].shape}")  # Should be [batch_size, C, H, W]
    print(f"Boxes Shape: {batch['boxes'].shape}")  # Should be [batch_size, max_boxes_per_image, 4]
    print(f"Labels Shape: {batch['labels'].shape}")  # Should be [batch_size, max_boxes_per_image]
    print(f"Box Masks Shape: {batch['box_masks'].shape}")  # Should be [batch_size, max_boxes_per_image]

    # Optionally, you can print more detailed information about a single sample in the batch
    print("\nDetailed view of first sample in batch:")
    print(f"First Sample - Pixel Values (inputs): {batch['inputs'][0]}")  # Showing the actual pixel values can be too verbose, consider showing stats
    print(f"First Sample - Bounding Boxes: {batch['boxes'][0]}")
    print(f"First Sample - Labels: {batch['labels'][0]}")
    print(f"First Sample - Box Masks: {batch['box_masks'][0]}")

    # Since we only want to check the first batch, break after the first iteration
    break'''


# Function to collect labels from the dataset
def get_unique_labels(dataset):
    label_counter = Counter()
    for sample in dataset:
        labels = sample['objects']['category']
        label_counter.update(labels)
    return label_counter

'''# Collecting unique labels
unique_labels = get_unique_labels(dataset)
print("Unique labels and their counts:")
for label, count in unique_labels.items():
    print(f"Label {label}: {count} occurrences")'''

# Usage example
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionExpert().to(device)

'''# Train with different modes
print("Training Vision Transformer...")
model.train_model(dataloader, num_epochs=1, learning_rate=1e-4, mode='vit')

print("Training Vision Mamba...")
model.train_model(dataloader, num_epochs=1, learning_rate=1e-4, mode='mamba')

print("Training Detection Transformer...")
model.train_model(dataloader, num_epochs=1, learning_rate=1e-4, mode='detr')
'''
print("Training Full Vision Expert...")
model.train_model(dataloader, num_epochs=1, learning_rate=1e-4, mode='full')
