# Dummy Example

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from core.blueprint import ArchitectureBlueprint
from core.registry import ComponentRegistry
from evolution.mutation import mutate_blueprint
from evolution.controller import MutationRLController
from evolution.transfer import transfer_weights
from models.multimodal_learner import MultimodalLearner
from core.components import TrackedLayer
from data.loader import CombinedDataset
from data.adapters import ImageAdapter, TextAdapter

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
dataset = CombinedDataset(mnist_data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

blueprint = ArchitectureBlueprint()
registry = ComponentRegistry()
controller = MutationRLController()

init_layer = TrackedLayer(128, 64)
blueprint.add_module(init_layer)
registry.register(init_layer)

adapters = {'image': ImageAdapter(), 'text': TextAdapter()}
model = MultimodalLearner(blueprint, adapters)
loss_history = []

for epoch in range(5):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    total_loss = 0
    for batch in dataloader:
        x, target = batch
        output = model(x)
        loss = torch.nn.functional.mse_loss(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    loss_history.append(avg_loss)

    state = controller.get_state(loss_history, blueprint)
    action = controller.select_action(state)
    mutate_blueprint(blueprint, registry)
    new_model = MultimodalLearner(blueprint, adapters)
    transfer_weights(model, new_model)
    model = new_model
    next_state = controller.get_state(loss_history, blueprint)
    controller.update(state, action, -avg_loss, next_state)
    print(f"Epoch {epoch} Loss: {avg_loss:.4f} | Action: {action}")

    blueprint.visualize(filename=f"architecture_epoch_{epoch}.png")

KeyboardInterrupt: 