In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import random


In [3]:
#Load dataset CIFAR10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

100%|██████████| 170M/170M [00:13<00:00, 13.1MB/s]


In [4]:
test_loader.dataset.data.shape

(10000, 32, 32, 3)

In [5]:
#Device configurarion
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.is_available()

True

In [6]:
#Implement a CNN
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # Bloque convolucional 1
        self.conv1 = nn.Conv2d(3, 6, 5)       # in_channels=3, out_channels=6, kernel_size=5, padding=2
        self.bn1 = nn.BatchNorm2d(6)

        # Bloque convolucional 2
        self.pool = nn.MaxPool2d(2, 2)        # kernel_size=2, stride=2
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(16)

        # Fully Connected Layers
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn4 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

        self.con_dropout = nn.Dropout(0.1)
        self.flt_dropout = nn.Dropout(0.3)
        self.relu = nn.ReLU()

    def forward(self, x):
        # CONV + BN + ReLU + Pool + Drop
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.con_dropout(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.con_dropout(x)

        # Flatten & FC
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc1(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.flt_dropout(x)

        x = self.fc2(x)
        x = self.bn4(x)
        x = self.relu(x)
        x = self.flt_dropout(x)

        x = self.fc3(x)
        return x

In [7]:

# ---------- Controller (red 1) ----------
class Controller(nn.Module):
    """
    CNN muy pequeña que mira la imagen y produce 10 gates en [0,1].
    Está hecha para ~1.7K parámetros. Ajusta canales si quieres cuadrar al 100%.
    """
    def __init__(self, num_gates=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)   # 3->8
        self.conv2 = nn.Conv2d(8, 8, kernel_size=3, padding=1)   # 8->8
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        # Pequeño MLP
        self.fc1 = nn.Linear(8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_gates)

    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = self.pool(h).flatten(1)        # [B, 8]
        h = F.relu(self.fc1(h))            # [B, 32]
        h = F.relu(self.fc2(h))            # [B, 16]
        gates = torch.sigmoid(self.fc3(h)) # [B, 10] en [0,1]
        return gates

# ---------- Adapter "capa" ligera (red 2 interna) ----------
class AffineAdapter(nn.Module):
    """
    Capa por-canal: y = y * (1 + s) + b, donde s,b están aprendidos (por canal).
    Si el gate g=0, no aporta; si g=1, se aplica completo.
    Parámetros ~2*H por adapter.
    """
    def __init__(self, hidden_dim):
        super().__init__()
        self.scale = nn.Parameter(torch.zeros(hidden_dim))
        self.bias  = nn.Parameter(torch.zeros(hidden_dim))

    def forward(self, y, gate):
        # gate: [B] o [B,1], broadcast a [B,H]
        while gate.dim() < y.dim():
            gate = gate.unsqueeze(-1)
        return y * (1.0 + gate * self.scale) + gate * self.bias

# ---------- Red principal tipo LeNet con 10 adapters ----------
class LeNetWithGating(nn.Module):
    """
    Convs tipo LeNet (para CIFAR10 RGB) + FC reducido para dejar presupuesto
    a la Controller y a 10 adapters.
    """
    def __init__(self, hidden_dim=120, num_adapters=10, num_classes=10):
        super().__init__()
        # Parte conv tipo LeNet (sin padding de LeNet clásico)
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5)   # 3x32x32 -> 6x28x28
        self.pool  = nn.MaxPool2d(2,2)               # -> 6x14x14
        self.conv2 = nn.Conv2d(6,16, kernel_size=5)  # -> 16x10x10
        # pool -> 16x5x5 (flatten 400)

        self.fc1 = nn.Linear(16*5*5, hidden_dim)     # 400 -> H (reducido respecto a LeNet 120)
        self.fc2 = nn.Linear(hidden_dim, 84)
        self.fc3 = nn.Linear(84, num_classes)

        # 10 adapters sobre el espacio oculto H
        self.adapters = nn.ModuleList([AffineAdapter(hidden_dim) for _ in range(num_adapters)])

    def forward(self, x, gates):  # gates: [B, 10] venidos de la Controller
        b = x.size(0)
        h = self.pool(F.relu(self.conv1(x)))
        h = self.pool(F.relu(self.conv2(h)))
        h = h.view(b, -1)              # [B, 400]

        h = F.relu(self.fc1(h))        # [B, H]

        # Aplica 10 "capas" (adapters) moduladas por gates[:, i]
        for i, adp in enumerate(self.adapters):
            h = adp(h, gates[:, i])

        h = F.relu(self.fc2(h))
        logits = self.fc3(h)
        return logits

# ---------- Modelo combinado ----------
class GatedLeNetSystem(nn.Module):
    """
    Encapsula Controller (red 1) + LeNetWithGating (red 2).
    """
    def __init__(self, hidden_dim=120, num_adapters=10, num_classes=10):
        super().__init__()
        self.controller = Controller(num_gates=num_adapters)
        self.learner    = LeNetWithGating(hidden_dim, num_adapters, num_classes)

    def forward(self, x):
        gates = self.controller(x)      # [B, 10] en [0,1]
        logits = self.learner(x, gates) # backprop fluye a controller vía gates
        return logits, gates

In [57]:

# ---------- Controller con softmax (gating MoE) ----------
class Controller(nn.Module):
    """
    CNN pequeña -> vector de 10 pesos softmax (suman 1).
    temperature T controla "dureza" del softmax.
    """
    def __init__(self, num_experts=10, temperature=1.0):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(8, 18, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(18, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_experts)
        self.temperature = temperature

    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = self.pool(h).flatten(1)             # [B, 8]
        h = F.relu(self.fc1(h))                 # [B, 120]
        h = F.relu(self.fc2(h))                 # [B, 84]
        logits = self.fc3(h)                    # [B, 10]
        weights = F.softmax(logits / self.temperature, dim=1)  # [B, 10], sum=1
        return weights  # gating weights para los expertos

# ---------- Expertos completos (capas "encendibles") ----------
class ExpertLayer(nn.Module):
    """
    Capa completa sobre el espacio oculto: Linear(H,H) + ReLU.
    Puedes cambiarla por un bloque más potente si quieres.
    """
    def __init__(self, hidden_dim):
        super().__init__()
        self.fc = nn.Linear(hidden_dim, hidden_dim)
        self.act = nn.ReLU(inplace=True)

    def forward(self, h):
        return self.act(self.fc(h))  # [B, H]


# ---------- Expertos completos (capas "encendibles") ----------
class ExpertLayer(nn.Module):
    """
    Capa completa sobre el espacio oculto: Linear(H,H) + ReLU.
    Puedes cambiarla por un bloque más potente si quieres.
    """
    def __init__(self, hidden_dim):
        super().__init__()
        self.fc = nn.Linear(hidden_dim, hidden_dim)
        self.act = nn.ReLU(inplace=True)

    def forward(self, h):
        return self.act(self.fc(h))  # [B, H]

# ---------- Expertos completos (capas "encendibles") ----------
class ExpertLayer(nn.Module):
    """
    Capa completa sobre el espacio oculto: Linear(H,H) + ReLU.
    Puedes cambiarla por un bloque más potente si quieres.
    """
    def __init__(self, hidden_dim):
        super().__init__()
        self.fc = nn.Linear(hidden_dim, hidden_dim)
        self.act = nn.ReLU(inplace=True)

    def forward(self, h):
        return self.act(self.fc(h))  # [B, H]

class MoEBlock(nn.Module):
    """
    Aplica N expertos en paralelo y hace mezcla ponderada por weights (softmax).
    """
    def __init__(self, hidden_dim, num_experts=10):
        super().__init__()
        self.experts = nn.ModuleList([ExpertLayer(hidden_dim) for _ in range(num_experts)])
        self.num_experts = num_experts

    def forward(self, h, weights):
        # h: [B,H], weights: [B,num_experts]
        # Calcular salidas de expertos en paralelo y mezclar
        outs = [expert(h) for expert in self.experts]            # lista de [B,H]
        stack = torch.stack(outs, dim=1)                         # [B, E, H]
        weights = weights.unsqueeze(-1)                          # [B, E, 1]
        h_moe = (weights * stack).sum(dim=1)                     # [B, H]
        return h_moe

# ---------- Learner tipo LeNet con MoE (mezcla de 10 capas) ----------
class LeNetWithMoE(nn.Module):
    def __init__(self, hidden_dim=120, num_experts=10, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 25, kernel_size=1)  # 3x32x32 -> 6x28x28
        self.pool  = nn.MaxPool2d(2,2)               # -> 6x14x14
        self.conv2 = nn.Conv2d(25,35, kernel_size=1)  # -> 16x10x10
        self.conv3 = nn.Conv2d(35,50, kernel_size=1)  # -> 16x10x10
        # pool -> 16x5x5 = 400

        self.fc1 = nn.Linear(50*4*4, hidden_dim)     # 400 -> H
        self.moe = MoEBlock(hidden_dim, num_experts)  # mezcla de expertos
        self.fc2 = nn.Linear(hidden_dim, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x, weights):  # weights: [B, num_experts] del Controller
        b = x.size(0)
        h = self.pool(F.relu(self.conv1(x)))
        h = self.pool(F.relu(self.conv2(h)))
        h = self.pool(F.relu(self.conv3(h)))
        h = h.view(b, -1)                  # [B, 400]

        h = F.relu(self.fc1(h))            # [B, H]

        # Mezcla MoE controlada por el Controller
        h = self.moe(h, weights)           # [B, H]

        h = F.relu(self.fc2(h))
        logits = self.fc3(h)
        return logits

# ---------- Modelo combinado (Controller + Learner con MoE) ----------
class GatedLeNetSystemMoE(nn.Module):
    def __init__(self, hidden_dim=50, num_experts=5, num_classes=10, temperature=50):
        super().__init__()
        self.controller = Controller(num_experts=num_experts, temperature=temperature)
        self.learner    = LeNetWithMoE(hidden_dim, num_experts, num_classes)

    def forward(self, x):
        weights = self.controller(x)        # [B, 10], softmax
        logits  = self.learner(x, weights)  # mezcla ponderada de expertos

        if random.randint(1, 1000) == 1:
            print(weights[0])
        #print(logits)
        # time.sleep(0.5)
        return logits, weights


In [66]:
#Model and Hyper parameters

# Instantiate the CNN model and move it to the configured device.
model = CNN().to(device)

# Define the loss function and optimizer.
# CrossEntropyLoss is commonly used for classification tasks.
criterion = nn.CrossEntropyLoss()
# Adam optimizer is an effective optimization algorithm.
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 50  # Number of training epochs



In [None]:
def count_params(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

# Cuenta de parámetros para que veas el presupuesto
p_total   = count_params(model)
p_ctrl    = count_params(model.controller)
p_learner = count_params(model.learner)
print(f"Total params: {p_total:,} | controller: {p_ctrl:,} | learner: {p_learner:,}")


In [68]:
for epoch in range(num_epochs):
    # ---- TRAIN ----
    model.train()
    running_loss_train = 0.0
    correct_train = 0
    total_train = 0

    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs, gates = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss_train += loss.item() * labels.size(0)  # acumula total loss
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

        # --- Refrescar pesos pequeños ---
        if i+1 % 50 == 0:
            with torch.no_grad():
                for param in model.parameters():
                    mask = param.abs() < 5e-4  # pesos pequeños
                    # De esos pequeños, seleccionar aleatoriamente un % (aquí 50%)
                    rand_mask = torch.rand_like(param) < 0.5
                    # Máscara final: pequeños Y seleccionados aleatoriamente
                    final_mask = small_mask & rand_mask
                    # Reemplazar solo esos
                    param[final_mask] = torch.randn_like(param[final_mask]) * 0.01

    avg_train_loss = running_loss_train / total_train
    avg_train_acc = 100 * correct_train / total_train

    # ---- TEST ----
    model.eval()
    running_loss_test = 0.0
    correct_test = 0
    total_test = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs, gates = model(images)
            loss = criterion(outputs, labels)

            running_loss_test += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    avg_test_loss = running_loss_test / total_test
    avg_test_acc = 100 * correct_test / total_test

    # ---- PRINT ----
    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.2f}% | "
          f"Test Loss: {avg_test_loss:.4f}, Test Acc: {avg_test_acc:.2f}%")

print('Finished Training')



ValueError: too many values to unpack (expected 2)