<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/_Longformer_and_CNN_Integration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install fastapi

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LongformerModel, LongformerTokenizer, AutoModel, AdamW
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from fastapi import FastAPI, Request
from torch.optim.lr_scheduler import LambdaLR
from sklearn.metrics import precision_score, recall_score, f1_score
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
import logging

# --- Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Logging setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")

# --- Text Dataset ---
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=2048, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        encoding = self.tokenizer(
            text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt"
        )
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = item["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# --- Perception Module ---
class PerceptionModule(nn.Module):
    def __init__(self, text_dim, image_dim, sensor_dim, hidden_dim, text_model="allenai/longformer-base-4096"):
        super(PerceptionModule, self).__init__()
        # Text model (Longformer)
        self.text_model = LongformerModel.from_pretrained(text_model)
        self.text_fc = nn.Linear(self.text_model.config.hidden_size, hidden_dim)

        # Image CNN
        self.image_cnn = nn.Sequential(
            nn.Conv2d(image_dim[0], 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
        )
        self.image_fc = nn.Linear(32 * (image_dim[1] // 4) * (image_dim[2] // 4), hidden_dim)

        # Sensor data processing
        self.sensor_fc = nn.Linear(sensor_dim, hidden_dim)

        # Combined feature layer
        self.fc = nn.Linear(hidden_dim * 3, hidden_dim)

    def forward(self, text, image, sensor):
        text_features = F.relu(self.text_fc(self.text_model(**text).last_hidden_state.mean(dim=1)))
        image_features = F.relu(self.image_fc(self.image_cnn(image)))
        sensor_features = F.relu(self.sensor_fc(sensor))
        combined_features = torch.cat((text_features, image_features, sensor_features), dim=1)
        return F.relu(self.fc(combined_features))

# --- Memory Module ---
class MemoryBank(nn.Module):
    def __init__(self, memory_size, memory_dim):
        super(MemoryBank, self).__init__()
        self.keys = torch.randn(memory_size, memory_dim).to(device)
        self.values = torch.randn(memory_size, memory_dim).to(device)
        self.access_count = torch.zeros(memory_size).to(device)

    def write(self, key, value):
        idx = torch.argmin(self.access_count)
        self.keys[idx] = key
        self.values[idx] = value
        self.access_count[idx] = 0

    def read(self, key):
        idx = torch.argmax(torch.cosine_similarity(self.keys, key.unsqueeze(0)))
        self.access_count[idx] += 1
        return self.values[idx]

# --- Decision Making Module ---
class DecisionMakingModule(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DecisionMakingModule, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, features):
        return self.fc(features)

# --- Unified AGI System ---
class UnifiedAGISystem(nn.Module):
    def __init__(self, text_dim, image_dim, sensor_dim, hidden_dim, memory_size, output_dim):
        super(UnifiedAGISystem, self).__init__()
        self.perception = PerceptionModule(text_dim, image_dim, sensor_dim, hidden_dim)
        self.memory = MemoryBank(memory_size, hidden_dim)
        self.decision_making = DecisionMakingModule(hidden_dim, output_dim)

    def perform_task(self, text, image, sensor):
        features = self.perception(text, image, sensor)
        self.memory.write(features, features)
        decision = self.decision_making(features)
        return decision

# --- Training Function ---
def train(model, train_loader, criterion, optimizer, epochs=10, use_amp=True):
    scaler = GradScaler(enabled=use_amp)
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for batch in tqdm(train_loader):
            images, labels = batch
            optimizer.zero_grad()
            with autocast(enabled=use_amp):
                outputs = model(images)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()
        logging.info(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")

# --- Deployment with FastAPI ---
app = FastAPI()

@app.post("/predict")
async def predict(request: Request):
    data = await request.json()
    text = data["text"]
    image = data["image"]
    sensor = torch.tensor(data["sensor"]).float().to(device)

    # Tokenize and preprocess inputs
    text_inputs = model.text_model.tokenizer(text, return_tensors="pt").to(device)
    image_tensor = torch.tensor(image).unsqueeze(0).to(device)

    decision = agi_system.perform_task(text_inputs, image_tensor, sensor)
    return {"decision": decision.tolist()}

# To run the server, execute: uvicorn script_name:app --reload