<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/_Multi_Modal_Model_Design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
import numpy as np

# Example: A simple model definition
class MultimodalModel(nn.Module):
    def __init__(self):
        super(MultimodalModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')  # BERT for text
        self.fc_image = nn.Linear(256 * 256, 128)  # Example for image input (flattened 256x256 image)
        self.fc_sensor = nn.Linear(10, 128)  # Example for sensor data (10 features)
        self.fc_final = nn.Linear(128 + 128 + 768, 1)  # Final output layer for binary classification

    def forward(self, input_ids, attention_mask, image_data, sensor_data):
        # Text data processing with BERT
        text_outputs = self.bert(input_ids, attention_mask=attention_mask)
        text_embedding = text_outputs.pooler_output  # Use the pooled output

        # Image data processing
        image_embedding = torch.flatten(image_data, 1)  # Flatten image data to a vector
        image_embedding = self.fc_image(image_embedding)

        # Sensor data processing
        sensor_embedding = self.fc_sensor(sensor_data)

        # Concatenate all embeddings and pass through the final layer
        combined = torch.cat((text_embedding, image_embedding, sensor_embedding), dim=1)
        output = self.fc_final(combined)

        return output

# Prepare tokenizer (BERT)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Sample DataLoader setup (replace with actual data loading logic)
class MultimodalDataset(Dataset):
    def __init__(self, text_data, image_data, sensor_data, labels):
        self.text_data = text_data
        self.image_data = image_data
        self.sensor_data = sensor_data
        self.labels = labels

    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, idx):
        text = self.text_data[idx]
        image = self.image_data[idx]
        sensor = self.sensor_data[idx]
        label = self.labels[idx]

        # Tokenize text
        encoding = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')

        return {
            'input_ids': encoding['input_ids'].squeeze(0),  # Remove batch dimension
            'attention_mask': encoding['attention_mask'].squeeze(0),  # Remove batch dimension
            'image': torch.tensor(image, dtype=torch.float32),
            'sensor': torch.tensor(sensor, dtype=torch.float32),
            'label': torch.tensor(label, dtype=torch.float32)
        }

# Dummy data (replace with actual dataset)
text_data = ["This is a sample text.", "Another sample text."]
image_data = [np.random.rand(256, 256), np.random.rand(256, 256)]  # Random image data
sensor_data = [np.random.rand(10), np.random.rand(10)]  # Random sensor data (10 features)
labels = [0, 1]  # Binary labels

# Create DataLoader
train_dataset = MultimodalDataset(text_data, image_data, sensor_data, labels)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

# Initialize model, optimizer, and loss function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultimodalModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss

# Early stopping parameters
best_val_loss = float('inf')
patience = 3
epochs_since_improvement = 0
best_model = None

# Training Loop
num_epochs = 20
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        image_data = batch['image'].to(device)
        sensor_data = batch['sensor'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()

        # Forward pass
        output = model(input_ids, attention_mask, image_data, sensor_data)

        # Compute loss
        loss = criterion(output.squeeze(1), labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Compute accuracy (if binary classification)
        predicted = torch.sigmoid(output).round()
        correct_predictions += (predicted == labels).sum().item()
        total_predictions += labels.size(0)

    train_loss = running_loss / len(train_loader)
    train_accuracy = correct_predictions / total_predictions

    # Validation (can add a validation set here)
    model.eval()
    val_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            image_data = batch['image'].to(device)
            sensor_data = batch['sensor'].to(device)
            labels = batch['label'].to(device)

            output = model(input_ids, attention_mask, image_data, sensor_data)
            loss = criterion(output.squeeze(1), labels)

            val_loss += loss.item()

            predicted = torch.sigmoid(output).round()
            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)

    val_loss = val_loss / len(train_loader)
    val_accuracy = correct_predictions / total_predictions

    print(f"Epoch {epoch}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
    print(f"Epoch {epoch}/{num_epochs}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model.state_dict()  # Save the best model
        epochs_since_improvement = 0
    else:
        epochs_since_improvement += 1

    if epochs_since_improvement >= patience:
        print(f"Early stopping on epoch {epoch}")
        break

# Save the best model
if best_model is not None:
    torch.save(best_model, 'best_model.pth')
    print("Best model saved as 'best_model.pth'")

# Inference Phase
# Example inference with the best model
test_data = ["This is a test text."]
image_test = [np.random.rand(256, 256)]  # Random test image
sensor_test = [np.random.rand(10)]  # Random test sensor data

# Tokenize the test data
test_encoding = tokenizer(test_data, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
text_input_ids = test_encoding['input_ids'].to(device)
text_attention_mask = test_encoding['attention_mask'].to(device)

# Prepare the input tensors for the image and sensor data
image_tensor = torch.stack([torch.tensor(img, dtype=torch.float32) for img in image_test]).to(device)
sensor_tensor = torch.tensor(sensor_test, dtype=torch.float32).to(device)

# Load the best model for inference
best_model = MultimodalModel().to(device)  # Reinitialize the model
best_model.load_state_dict(torch.load('best_model.pth', weights_only=True))  # Load the saved model's state_dict
best_model.eval()

# Make prediction
with torch.no_grad():
    output = best_model(text_input_ids, text_attention_mask, image_tensor, sensor_tensor)
    prediction = torch.sigmoid(output).cpu().numpy()
    print(f"Prediction: {prediction}")