<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Neural_Architecture_Search_(NAS).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleNAS(nn.Module):
    def __init__(self, input_dim, output_dim, choices):
        super(SimpleNAS, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layers = nn.ModuleList()

        # Create layers with consistent input-output dimensions
        for in_feat, out_feat in choices:
            self.layers.append(nn.Linear(in_feat, out_feat))

    def forward(self, x, choices):
        for i, layer in enumerate(self.layers):
            if choices[i] == 1:
                x = torch.relu(layer(x))
        return x

def is_valid_architecture(model, choices):
    current_dim = model.input_dim
    for i, layer in enumerate(model.layers):
        if choices[i] == 1:
            if current_dim != layer.in_features:
                return False
            current_dim = layer.out_features
    return current_dim == model.output_dim

def search_best_architecture(model, data, targets, num_choices):
    best_architecture = None
    best_loss = float('inf')

    for _ in range(num_choices):
        choices = torch.randint(0, 2, (len(model.layers),))
        if not is_valid_architecture(model, choices):
            continue

        output = model(data, choices)
        if output.shape == targets.shape:
            loss = nn.MSELoss()(output, targets)
            if loss < best_loss:
                best_loss = loss
                best_architecture = choices
    return best_architecture

# Example usage
input_dim = 16
output_dim = 128
choices = [(16, 32), (32, 64), (64, 128)]
model = SimpleNAS(input_dim, output_dim, choices)
data = torch.randn(32, input_dim)
targets = torch.randn(32, output_dim)

best_architecture = search_best_architecture(model, data, targets, num_choices=10)
print("Best architecture choices:", best_architecture)