In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Temporal Encoding for Sequence Data
class TemporalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=500):
        super(TemporalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_dim))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.encoding[:, :seq_len, :].to(x.device)

# Modality-Specific Tokenizers
class ModalityTokenizer(nn.Module):
    def __init__(self, input_dim, embed_dim):
        super(ModalityTokenizer, self).__init__()
        self.fc = nn.Linear(input_dim, embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        self.activation = nn.GELU()

    def forward(self, x):
        x = self.fc(x)
        x = self.norm(x)
        x = self.activation(x)
        return x

# Shared Transformer Encoder
class SharedTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers):
        super(SharedTransformer, self).__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=num_heads, dim_feedforward=2048, activation="gelu"
            ),
            num_layers=num_layers,
        )

    def forward(self, x):
        return self.encoder(x)

# Task-Specific Heads
class TaskHead(nn.Module):
    def __init__(self, embed_dim, output_dim):
        super(TaskHead, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.GELU(),
            nn.Linear(256, output_dim)  # Output: 6 (pose prediction: x, y, z, α, β, γ)
        )

    def forward(self, x):
        return self.fc(x.mean(dim=1))


# Multi-Modal Transformer Model
class MultiModalTransformer(nn.Module):
    def __init__(self, modalities, embed_dim, num_heads, num_layers, max_seq_len=500):
        super(MultiModalTransformer, self).__init__()
        self.tokenizers = nn.ModuleDict({
            name: ModalityTokenizer(input_dim, embed_dim) for name, input_dim in modalities.items()
        })
        self.temporal_encoding = TemporalEncoding(embed_dim, max_len=max_seq_len)
        self.shared_transformer = SharedTransformer(embed_dim, num_heads, num_layers)
        self.task_heads = nn.ModuleDict({
            "strawberry_picking": TaskHead(embed_dim, output_dim=6),
            "vertebrae_scanning": TaskHead(embed_dim, output_dim=6),
            "autonomous_vehicle": TaskHead(embed_dim, output_dim=6),
        })

    def forward(self, inputs, task):
        # Tokenize each modality and stack along modality dimension
        tokenized_inputs = [self.tokenizers[name](inputs[name]) for name in inputs.keys()]
        fused_tokens = torch.stack(tokenized_inputs, dim=1)  # (batch_size, num_modalities, embed_dim)
        # Combine modalities by averaging embeddings
        fused_tokens = fused_tokens.mean(dim=1)  # (batch_size, embed_dim)
        # Add temporal encoding
        encoded_tokens = self.temporal_encoding(fused_tokens.unsqueeze(1))  # Add sequence dimension
        # Pass through shared transformer
        shared_features = self.shared_transformer(encoded_tokens)
        # Task-specific head
        return self.task_heads[task](shared_features)

# Instantiate the Model Modalities Dictionary: Include all required modalities
modalities = {
    "vision": 1024,          # Strawberry Picking
    "proprioception": 256,
    "tactile": 128,
    
    "pose": 64,              # Vertebrae Scanning
    "ultrasonic": 128,
    "detected_position": 64,  
    
    "gps": 64,               # Autonomous Vehicle
    "imu": 256,
    "mmwave": 512,
    "lidar": 1024,
    "camera": 1024,
}

# Instantiate the Model
model = MultiModalTransformer(modalities, embed_dim=768, num_heads=12, num_layers=6)

# Example Inputs
inputs_strawberry = {
    "vision": torch.rand(32, 1024),
    "proprioception": torch.rand(32, 256),
    "tactile": torch.rand(32, 128),
}

inputs_vertebrae = {
    "pose": torch.rand(32, 64),
    "ultrasonic": torch.rand(32, 128),
    "detected_position": torch.rand(32, 64),  # Corrected key
}

inputs_vehicle = {
    "gps": torch.rand(32, 64),
    "imu": torch.rand(32, 256),
    "mmwave": torch.rand(32, 512),
    "lidar": torch.rand(32, 1024),
    "camera": torch.rand(32, 1024),
}

# Testing with actual input

## Vision input (strawberry)

In [None]:
# resnet is a pretrained CNN with skip connections to prevent gradient vanishing
# images are resized to 224x224 to match resnet50 input requirements
# output: feature vector per image of size 2048
# final fully connected layer projects 2048 to a lower 1024 embedding (original dimension of vision input)


import torchvision.models as models
from torchvision.transforms import Compose, Resize, ToTensor

transform = Compose([Resize((224, 224)), ToTensor()])
images = torch.rand(32, 3, 224, 224)  # Example input images

resnet = models.resnet50(pretrained=True)
resnet.fc = nn.Linear(2048, 1024)  # Output dimension matches "vision" input_dim

# vision_embeddings = resnet(images).detach()  # (32, 1024)
processed_vision = resnet(images)  # Shape: (32, 1024)


# inputs_strawberry["vision"] = vision_embeddings
inputs_strawberry["vision"] = processed_vision


## Lidar input (Driving)

In [None]:
# lidar generates 3D point cloud : 2048 points x 3 (x,y,z)
# averaged along the point dimension to produce 1 vector to summarize 3D point cloud for the batch
# linear layer maps mean from 3D to 1024-dim embedding 

# Simulate Lidar Embedding Extraction
lidar_data = torch.rand(32, 2048, 3)  # Example point cloud (batch_size, num_points, dimensions)

# lidar_embeddings = torch.mean(lidar_data, dim=1)  # Reduce to 1024-dim features
lidar_embeddings = nn.Linear(3, 1024)(lidar_data.mean(dim=1))  # Project to 1024 dimensions

inputs_vehicle["lidar"] = lidar_embeddings


## GPS input (Driving)

In [None]:
# linear layer maps latitude,longitude to 64-dim embedding 
# this expands 2D to higher-dim, thus more compatible with transformer

gps_data = torch.rand(32, 2)  # Example GPS coordinates (latitude, longitude)
gps_embeddings = nn.Linear(2, 64)(gps_data)  # Project to 64 dimensions
inputs_vehicle["gps"] = gps_embeddings


# Output

In [None]:
# Forward Pass for Strawberry Picking
output_strawberry = model(inputs_strawberry, task="strawberry_picking")
print(output_strawberry.shape)  # Expected: (32, 6)

# Forward Pass for Vertebrae Scanning
output_vertebrae = model(inputs_vertebrae, task="vertebrae_scanning")
print(output_vertebrae.shape)  # Expected: (32, 6)

# Forward Pass for Autonomous Vehicle
output_vehicle = model(inputs_vehicle, task="autonomous_vehicle")
print(output_vehicle.shape)  # Expected: (32, 6)

In [None]:
print(output_strawberry)

# Evaluation

In [None]:
criterion = nn.MSELoss()
target = torch.rand(32, 6)  # Example target
loss = criterion(output_strawberry, target)


optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer.zero_grad()
loss.backward()
optimizer.step()


In [None]:
print(loss)