In [40]:
# Load data
import numpy as np

from data_utils import (
    get_embeddings, 
    synthesize_database, 
    synthesize_simple_database,
    get_passenger_database,
    check_passenger_exist,
    create_simple_trainset
)

is_simple_data=True
# Airport dataset
embed_original, indices = get_embeddings() # Considered the query pictures
if is_simple_data:
    embed_data, id_data, location_data = synthesize_simple_database(embed_original)
    date_data = None
else:
    embed_data, id_data, date_data, location_data = synthesize_database(embed_original)

embed_data = np.stack(embed_data)

In [41]:
import torch
import torch.nn as nn
import torch.optim as optim

x_train = np.array(embed_data[:1000, :])
y_train = np.array(location_data[:1000])

x_train = torch.FloatTensor(x_train)
print(x_train.shape)

y_train[y_train==2] = 0 # Make sure the class index start from 0 (0,1 in this example)
y_train = torch.LongTensor(y_train)
print(y_train.shape)


torch.Size([1000, 128])
torch.Size([1000])


In [42]:
def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> torch.Tensor:
    r"""
    Args:
      logits: `[..., num_features]` unnormalized log probabilities
      tau: non-negative scalar temperature
      hard: if ``True``, the returned samples will be discretized as one-hot vectors,
            but will be differentiated as if it is the soft sample in autograd
      dim (int): A dimension along which softmax will be computed. Default: -1.

    Returns:
      Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
      If ``hard=True``, the returned samples will be one-hot, otherwise they will
      be probability distributions that sum to 1 across `dim`.
    """
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0,1)

    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

In [43]:
def get_channel_mask(hidden_size_choices):
    max_hidden_size = max(hidden_size_choices)
    num_choices = len(hidden_size_choices)
    masks = torch.zeros(max_hidden_size, num_choices)
    for i in range(num_choices):
        masks[:hidden_size_choices[i], i]=1
    return masks
    
def get_flops_choices(input_size, hidden_size_choices, num_classes):
    flops = []
    for hidden_size in hidden_size_choices:
        flops.append(2*hidden_size*input_size + 2*hidden_size*num_classes)
    flops = np.array(flops)
    return flops
    
class SuperNet(nn.Module):
    def __init__(self, input_size, hidden_size_choices, num_classes):
        super(SuperNet, self).__init__()
        
        max_hidden_size = max(hidden_size_choices)
        num_choices = len(hidden_size_choices)
        
        self.arch_params = torch.nn.Parameter(torch.ones(num_choices), requires_grad=True)
        self.masks = get_channel_mask(hidden_size_choices)
        self.flops_choices = get_flops_choices(input_size, hidden_size_choices, num_classes)
        self.flops_choices_normalized = torch.FloatTensor(self.flops_choices / np.max(self.flops_choices))
        
        self.fc1 = nn.Linear(input_size, max_hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(max_hidden_size, num_classes)
        
    def forward(self, x, temperature):
        out = self.fc1(x)
        out = self.relu(out)
        # print(self.arch_params)
        gumbel_weights = gumbel_softmax(self.arch_params, tau=temperature, hard=False)
        # print(gumbel_weights)
        mask = torch.multiply(self.masks, gumbel_weights)
        mask = torch.sum(mask, dim=-1)
        out = torch.multiply(out, mask)
        
        out = self.fc2(out)
        flops_loss = self._get_flops_loss(gumbel_weights)
        return out, flops_loss
    
    def _get_flops_loss(self, gumbel_weights):
        return torch.matmul(self.flops_choices_normalized, gumbel_weights)

## flops_balance_factor = 0.1 
## Searched Hidden Size = 350
The trade-off parameter between performance (CE_Loss or MSE_Loss) and speed (Flops)

Larger flops_balance_factor leads to a faster network with worse performance

In [53]:
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize the neural network
in_dim = x_train.shape[1]
hidden_size_choices = list(range(100,1000,10))
num_classes =  2

net = SuperNet(in_dim, hidden_size_choices, num_classes)

# Set up the loss function and optimizer
criterion = nn.CrossEntropyLoss()

net_weight_lr = 0.0001
arch_lr = 0.01

optimizer_net = optim.Adam([p for name, p in net.named_parameters() if 'arch' not in name], lr=net_weight_lr)
optimizer_arch = optim.Adam([p for name, p in net.named_parameters() if 'arch' in name], lr=arch_lr)

# Search Epoch
num_epochs = 10000

# Iteartively optimize architecture parameters and network weights every 'search_freq' epochs
search_freq = 20

# The factor to balance performance (CE Loss or MSE Loss) and FLOPs
# Larger flops_balance_factor leads to a faster network with worse performance
flops_balance_factor = 0.2

# The tempreature is decayed by 'temp_anneal_factor' every 'temp_anneal_freq' epochs
# Larger temp leads to a gumbel weight that is more close to 1-hot distribution  
temp = 5
temp_anneal_factor = 0.95
temp_anneal_freq = num_epochs/25 # The temperatur will decay 25 times during the search

for epoch in range(num_epochs):
    if epoch % temp_anneal_freq == 0:
        temp = temp * temp_anneal_factor
    
    # Forward pass
    outputs, flops_loss = net(x_train, temp)
    CE_loss = criterion(outputs, y_train)
    loss = (1 - flops_balance_factor) * CE_loss + flops_balance_factor * flops_loss
    
    if int(epoch/search_freq)%2==0:
        optimizer_net.zero_grad()
        loss.backward()
        optimizer_net.step()
    else:
        optimizer_arch.zero_grad()
        loss.backward()
        optimizer_arch.step()
    
    selected_channel_id = np.argmax(net.arch_params.data.numpy())
    selected_channel = hidden_size_choices[selected_channel_id]

    with torch.no_grad():
        _, predicted = torch.max(outputs, 1)
        total = y_train.size(0)
        correct = (predicted == y_train).sum().item()
        accuracy = correct / total
        if (epoch+1) % 10 == 0:
            
            print('Epoch [{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Flops_Loss: {:.4f}, Accuracy: {:.2f}%, Seleted Channel {}'
                  .format(epoch+1, num_epochs, loss.item(), CE_loss.item(),
                          flops_loss.item(), accuracy * 100, selected_channel))

Epoch [10/10000], Loss: 0.6645, CE_Loss: 0.6924, Flops_Loss: 0.5530, Accuracy: 55.10%, Seleted Channel 100
Epoch [20/10000], Loss: 0.6623, CE_Loss: 0.6917, Flops_Loss: 0.5445, Accuracy: 55.30%, Seleted Channel 100
Epoch [30/10000], Loss: 0.6651, CE_Loss: 0.6916, Flops_Loss: 0.5591, Accuracy: 55.60%, Seleted Channel 210
Epoch [40/10000], Loss: 0.6610, CE_Loss: 0.6917, Flops_Loss: 0.5381, Accuracy: 56.00%, Seleted Channel 330
Epoch [50/10000], Loss: 0.6577, CE_Loss: 0.6911, Flops_Loss: 0.5245, Accuracy: 60.50%, Seleted Channel 330
Epoch [60/10000], Loss: 0.6626, CE_Loss: 0.6901, Flops_Loss: 0.5523, Accuracy: 60.70%, Seleted Channel 330
Epoch [70/10000], Loss: 0.6610, CE_Loss: 0.6901, Flops_Loss: 0.5446, Accuracy: 61.30%, Seleted Channel 350
Epoch [80/10000], Loss: 0.6593, CE_Loss: 0.6901, Flops_Loss: 0.5358, Accuracy: 61.40%, Seleted Channel 350
Epoch [90/10000], Loss: 0.6601, CE_Loss: 0.6893, Flops_Loss: 0.5433, Accuracy: 60.60%, Seleted Channel 350
Epoch [100/10000], Loss: 0.6578, CE_L

Epoch [780/10000], Loss: 0.5879, CE_Loss: 0.6356, Flops_Loss: 0.3973, Accuracy: 65.60%, Seleted Channel 160
Epoch [790/10000], Loss: 0.5877, CE_Loss: 0.6354, Flops_Loss: 0.3971, Accuracy: 65.70%, Seleted Channel 160
Epoch [800/10000], Loss: 0.5882, CE_Loss: 0.6352, Flops_Loss: 0.4004, Accuracy: 65.60%, Seleted Channel 160
Epoch [810/10000], Loss: 0.5829, CE_Loss: 0.6363, Flops_Loss: 0.3695, Accuracy: 65.90%, Seleted Channel 160
Epoch [820/10000], Loss: 0.5843, CE_Loss: 0.6330, Flops_Loss: 0.3895, Accuracy: 65.90%, Seleted Channel 160
Epoch [830/10000], Loss: 0.5833, CE_Loss: 0.6332, Flops_Loss: 0.3837, Accuracy: 65.80%, Seleted Channel 160
Epoch [840/10000], Loss: 0.5811, CE_Loss: 0.6349, Flops_Loss: 0.3662, Accuracy: 65.90%, Seleted Channel 160
Epoch [850/10000], Loss: 0.5803, CE_Loss: 0.6332, Flops_Loss: 0.3688, Accuracy: 66.10%, Seleted Channel 160
Epoch [860/10000], Loss: 0.5802, CE_Loss: 0.6308, Flops_Loss: 0.3775, Accuracy: 66.10%, Seleted Channel 160
Epoch [870/10000], Loss: 0.5

Epoch [1540/10000], Loss: 0.5293, CE_Loss: 0.5880, Flops_Loss: 0.2944, Accuracy: 70.90%, Seleted Channel 160
Epoch [1550/10000], Loss: 0.5283, CE_Loss: 0.5902, Flops_Loss: 0.2804, Accuracy: 70.60%, Seleted Channel 160
Epoch [1560/10000], Loss: 0.5282, CE_Loss: 0.5930, Flops_Loss: 0.2691, Accuracy: 70.80%, Seleted Channel 160
Epoch [1570/10000], Loss: 0.5286, CE_Loss: 0.5843, Flops_Loss: 0.3058, Accuracy: 70.70%, Seleted Channel 160
Epoch [1580/10000], Loss: 0.5262, CE_Loss: 0.5878, Flops_Loss: 0.2797, Accuracy: 71.30%, Seleted Channel 160
Epoch [1590/10000], Loss: 0.5264, CE_Loss: 0.5851, Flops_Loss: 0.2915, Accuracy: 71.00%, Seleted Channel 160
Epoch [1600/10000], Loss: 0.5268, CE_Loss: 0.5850, Flops_Loss: 0.2938, Accuracy: 70.70%, Seleted Channel 160
Epoch [1610/10000], Loss: 0.5245, CE_Loss: 0.5843, Flops_Loss: 0.2851, Accuracy: 71.20%, Seleted Channel 160
Epoch [1620/10000], Loss: 0.5239, CE_Loss: 0.5833, Flops_Loss: 0.2863, Accuracy: 71.50%, Seleted Channel 160
Epoch [1630/10000],

Epoch [2370/10000], Loss: 0.4709, CE_Loss: 0.5226, Flops_Loss: 0.2644, Accuracy: 79.40%, Seleted Channel 250
Epoch [2380/10000], Loss: 0.4690, CE_Loss: 0.5202, Flops_Loss: 0.2645, Accuracy: 79.50%, Seleted Channel 250
Epoch [2390/10000], Loss: 0.4692, CE_Loss: 0.5200, Flops_Loss: 0.2659, Accuracy: 79.40%, Seleted Channel 250
Epoch [2400/10000], Loss: 0.4683, CE_Loss: 0.5168, Flops_Loss: 0.2741, Accuracy: 79.50%, Seleted Channel 250
Epoch [2410/10000], Loss: 0.4676, CE_Loss: 0.5189, Flops_Loss: 0.2622, Accuracy: 79.60%, Seleted Channel 250
Epoch [2420/10000], Loss: 0.4650, CE_Loss: 0.5124, Flops_Loss: 0.2754, Accuracy: 79.50%, Seleted Channel 250
Epoch [2430/10000], Loss: 0.4653, CE_Loss: 0.5148, Flops_Loss: 0.2674, Accuracy: 79.70%, Seleted Channel 250
Epoch [2440/10000], Loss: 0.4647, CE_Loss: 0.5124, Flops_Loss: 0.2739, Accuracy: 79.60%, Seleted Channel 250
Epoch [2450/10000], Loss: 0.4643, CE_Loss: 0.5132, Flops_Loss: 0.2689, Accuracy: 80.00%, Seleted Channel 250
Epoch [2460/10000],

Epoch [3180/10000], Loss: 0.3962, CE_Loss: 0.4205, Flops_Loss: 0.2988, Accuracy: 88.30%, Seleted Channel 310
Epoch [3190/10000], Loss: 0.3951, CE_Loss: 0.4183, Flops_Loss: 0.3021, Accuracy: 88.60%, Seleted Channel 310
Epoch [3200/10000], Loss: 0.3954, CE_Loss: 0.4191, Flops_Loss: 0.3008, Accuracy: 88.50%, Seleted Channel 310
Epoch [3210/10000], Loss: 0.3935, CE_Loss: 0.4168, Flops_Loss: 0.3000, Accuracy: 88.50%, Seleted Channel 310
Epoch [3220/10000], Loss: 0.3916, CE_Loss: 0.4139, Flops_Loss: 0.3024, Accuracy: 88.70%, Seleted Channel 310
Epoch [3230/10000], Loss: 0.3901, CE_Loss: 0.4118, Flops_Loss: 0.3034, Accuracy: 88.60%, Seleted Channel 310
Epoch [3240/10000], Loss: 0.3910, CE_Loss: 0.4125, Flops_Loss: 0.3050, Accuracy: 88.60%, Seleted Channel 310
Epoch [3250/10000], Loss: 0.3886, CE_Loss: 0.4090, Flops_Loss: 0.3072, Accuracy: 89.00%, Seleted Channel 310
Epoch [3260/10000], Loss: 0.3880, CE_Loss: 0.4103, Flops_Loss: 0.2988, Accuracy: 89.20%, Seleted Channel 310
Epoch [3270/10000],

Epoch [3940/10000], Loss: 0.3242, CE_Loss: 0.3228, Flops_Loss: 0.3297, Accuracy: 95.00%, Seleted Channel 350
Epoch [3950/10000], Loss: 0.3247, CE_Loss: 0.3247, Flops_Loss: 0.3245, Accuracy: 94.80%, Seleted Channel 350
Epoch [3960/10000], Loss: 0.3241, CE_Loss: 0.3225, Flops_Loss: 0.3305, Accuracy: 94.90%, Seleted Channel 350
Epoch [3970/10000], Loss: 0.3240, CE_Loss: 0.3244, Flops_Loss: 0.3224, Accuracy: 94.80%, Seleted Channel 350
Epoch [3980/10000], Loss: 0.3211, CE_Loss: 0.3186, Flops_Loss: 0.3315, Accuracy: 95.00%, Seleted Channel 350
Epoch [3990/10000], Loss: 0.3207, CE_Loss: 0.3178, Flops_Loss: 0.3325, Accuracy: 95.20%, Seleted Channel 350
Epoch [4000/10000], Loss: 0.3212, CE_Loss: 0.3203, Flops_Loss: 0.3245, Accuracy: 94.80%, Seleted Channel 350
Epoch [4010/10000], Loss: 0.3180, CE_Loss: 0.3134, Flops_Loss: 0.3365, Accuracy: 95.30%, Seleted Channel 350
Epoch [4020/10000], Loss: 0.3170, CE_Loss: 0.3123, Flops_Loss: 0.3358, Accuracy: 95.50%, Seleted Channel 350
Epoch [4030/10000],

Epoch [4710/10000], Loss: 0.2632, CE_Loss: 0.2413, Flops_Loss: 0.3508, Accuracy: 98.00%, Seleted Channel 350
Epoch [4720/10000], Loss: 0.2630, CE_Loss: 0.2410, Flops_Loss: 0.3512, Accuracy: 98.10%, Seleted Channel 350
Epoch [4730/10000], Loss: 0.2621, CE_Loss: 0.2399, Flops_Loss: 0.3509, Accuracy: 98.10%, Seleted Channel 350
Epoch [4740/10000], Loss: 0.2606, CE_Loss: 0.2379, Flops_Loss: 0.3512, Accuracy: 98.20%, Seleted Channel 350
Epoch [4750/10000], Loss: 0.2600, CE_Loss: 0.2369, Flops_Loss: 0.3526, Accuracy: 98.20%, Seleted Channel 350
Epoch [4760/10000], Loss: 0.2601, CE_Loss: 0.2371, Flops_Loss: 0.3523, Accuracy: 98.20%, Seleted Channel 350
Epoch [4770/10000], Loss: 0.2597, CE_Loss: 0.2372, Flops_Loss: 0.3496, Accuracy: 98.30%, Seleted Channel 350
Epoch [4780/10000], Loss: 0.2577, CE_Loss: 0.2340, Flops_Loss: 0.3522, Accuracy: 98.30%, Seleted Channel 350
Epoch [4790/10000], Loss: 0.2581, CE_Loss: 0.2350, Flops_Loss: 0.3503, Accuracy: 98.30%, Seleted Channel 350
Epoch [4800/10000],

Epoch [5510/10000], Loss: 0.2160, CE_Loss: 0.1817, Flops_Loss: 0.3531, Accuracy: 99.50%, Seleted Channel 350
Epoch [5520/10000], Loss: 0.2160, CE_Loss: 0.1817, Flops_Loss: 0.3531, Accuracy: 99.40%, Seleted Channel 350
Epoch [5530/10000], Loss: 0.2153, CE_Loss: 0.1810, Flops_Loss: 0.3524, Accuracy: 99.50%, Seleted Channel 350
Epoch [5540/10000], Loss: 0.2140, CE_Loss: 0.1792, Flops_Loss: 0.3533, Accuracy: 99.50%, Seleted Channel 350
Epoch [5550/10000], Loss: 0.2142, CE_Loss: 0.1797, Flops_Loss: 0.3524, Accuracy: 99.50%, Seleted Channel 350
Epoch [5560/10000], Loss: 0.2139, CE_Loss: 0.1790, Flops_Loss: 0.3534, Accuracy: 99.50%, Seleted Channel 350
Epoch [5570/10000], Loss: 0.2130, CE_Loss: 0.1780, Flops_Loss: 0.3533, Accuracy: 99.50%, Seleted Channel 350
Epoch [5580/10000], Loss: 0.2121, CE_Loss: 0.1768, Flops_Loss: 0.3533, Accuracy: 99.50%, Seleted Channel 350
Epoch [5590/10000], Loss: 0.2120, CE_Loss: 0.1767, Flops_Loss: 0.3531, Accuracy: 99.50%, Seleted Channel 350
Epoch [5600/10000],

Epoch [6310/10000], Loss: 0.1807, CE_Loss: 0.1376, Flops_Loss: 0.3532, Accuracy: 99.80%, Seleted Channel 350
Epoch [6320/10000], Loss: 0.1807, CE_Loss: 0.1376, Flops_Loss: 0.3533, Accuracy: 99.80%, Seleted Channel 350
Epoch [6330/10000], Loss: 0.1800, CE_Loss: 0.1367, Flops_Loss: 0.3534, Accuracy: 99.80%, Seleted Channel 350
Epoch [6340/10000], Loss: 0.1793, CE_Loss: 0.1358, Flops_Loss: 0.3532, Accuracy: 99.80%, Seleted Channel 350
Epoch [6350/10000], Loss: 0.1792, CE_Loss: 0.1356, Flops_Loss: 0.3533, Accuracy: 99.80%, Seleted Channel 350
Epoch [6360/10000], Loss: 0.1792, CE_Loss: 0.1356, Flops_Loss: 0.3534, Accuracy: 99.80%, Seleted Channel 350
Epoch [6370/10000], Loss: 0.1785, CE_Loss: 0.1348, Flops_Loss: 0.3534, Accuracy: 99.80%, Seleted Channel 350
Epoch [6380/10000], Loss: 0.1778, CE_Loss: 0.1340, Flops_Loss: 0.3532, Accuracy: 99.80%, Seleted Channel 350
Epoch [6390/10000], Loss: 0.1777, CE_Loss: 0.1337, Flops_Loss: 0.3534, Accuracy: 99.80%, Seleted Channel 350
Epoch [6400/10000],

Epoch [7110/10000], Loss: 0.1539, CE_Loss: 0.1040, Flops_Loss: 0.3534, Accuracy: 100.00%, Seleted Channel 350
Epoch [7120/10000], Loss: 0.1539, CE_Loss: 0.1040, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7130/10000], Loss: 0.1534, CE_Loss: 0.1033, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7140/10000], Loss: 0.1528, CE_Loss: 0.1026, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7150/10000], Loss: 0.1527, CE_Loss: 0.1026, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7160/10000], Loss: 0.1528, CE_Loss: 0.1026, Flops_Loss: 0.3534, Accuracy: 100.00%, Seleted Channel 350
Epoch [7170/10000], Loss: 0.1523, CE_Loss: 0.1020, Flops_Loss: 0.3534, Accuracy: 100.00%, Seleted Channel 350
Epoch [7180/10000], Loss: 0.1517, CE_Loss: 0.1012, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7190/10000], Loss: 0.1516, CE_Loss: 0.1012, Flops_Loss: 0.3534, Accuracy: 100.00%, Seleted Channel 350
Epoch [720

Epoch [7910/10000], Loss: 0.1338, CE_Loss: 0.0789, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7920/10000], Loss: 0.1338, CE_Loss: 0.0789, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7930/10000], Loss: 0.1334, CE_Loss: 0.0784, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7940/10000], Loss: 0.1330, CE_Loss: 0.0779, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7950/10000], Loss: 0.1329, CE_Loss: 0.0778, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7960/10000], Loss: 0.1329, CE_Loss: 0.0778, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7970/10000], Loss: 0.1326, CE_Loss: 0.0773, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7980/10000], Loss: 0.1321, CE_Loss: 0.0768, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [7990/10000], Loss: 0.1321, CE_Loss: 0.0768, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [800

Epoch [8690/10000], Loss: 0.1192, CE_Loss: 0.0606, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [8700/10000], Loss: 0.1189, CE_Loss: 0.0602, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [8710/10000], Loss: 0.1188, CE_Loss: 0.0602, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [8720/10000], Loss: 0.1188, CE_Loss: 0.0602, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [8730/10000], Loss: 0.1185, CE_Loss: 0.0598, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [8740/10000], Loss: 0.1182, CE_Loss: 0.0594, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [8750/10000], Loss: 0.1182, CE_Loss: 0.0593, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [8760/10000], Loss: 0.1182, CE_Loss: 0.0593, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [8770/10000], Loss: 0.1179, CE_Loss: 0.0590, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [878

Epoch [9490/10000], Loss: 0.1078, CE_Loss: 0.0464, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [9500/10000], Loss: 0.1076, CE_Loss: 0.0461, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [9510/10000], Loss: 0.1076, CE_Loss: 0.0461, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [9520/10000], Loss: 0.1076, CE_Loss: 0.0461, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [9530/10000], Loss: 0.1074, CE_Loss: 0.0458, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [9540/10000], Loss: 0.1071, CE_Loss: 0.0455, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [9550/10000], Loss: 0.1071, CE_Loss: 0.0455, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [9560/10000], Loss: 0.1071, CE_Loss: 0.0455, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [9570/10000], Loss: 0.1069, CE_Loss: 0.0452, Flops_Loss: 0.3535, Accuracy: 100.00%, Seleted Channel 350
Epoch [958

## flops_balance_factor = 0.3
## Searched Hidden Size = 220
The trade-off parameter between performance (CE_Loss or MSE_Loss) and speed (Flops)

Larger flops_balance_factor leads to a faster network with worse performance

In [51]:
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize the neural network
in_dim = x_train.shape[1]
hidden_size_choices = list(range(100,1000,10))
num_classes =  2

net = SuperNet(in_dim, hidden_size_choices, num_classes)

# Set up the loss function and optimizer
criterion = nn.CrossEntropyLoss()

net_weight_lr = 0.0001
arch_lr = 0.01

optimizer_net = optim.Adam([p for name, p in net.named_parameters() if 'arch' not in name], lr=net_weight_lr)
optimizer_arch = optim.Adam([p for name, p in net.named_parameters() if 'arch' in name], lr=arch_lr)

# Search Epoch
num_epochs = 10000

# Iteartively optimize architecture parameters and network weights every 'search_freq' epochs
search_freq = 20

# The factor to balance performance (CE Loss or MSE Loss) and FLOPs
# Larger flops_balance_factor leads to a faster network with worse performance
flops_balance_factor = 0.3

# The tempreature is decayed by 'temp_anneal_factor' every 'temp_anneal_freq' epochs
# Larger temp leads to a gumbel weight that is more close to 1-hot distribution  
temp = 5
temp_anneal_factor = 0.95
temp_anneal_freq = num_epochs/25 # The temperatur will decay 25 times during the search

for epoch in range(num_epochs):
    if epoch % temp_anneal_freq == 0:
        temp = temp * temp_anneal_factor
    
    # Forward pass
    outputs, flops_loss = net(x_train, temp)
    CE_loss = criterion(outputs, y_train)
    loss = (1 - flops_balance_factor) * CE_loss + flops_balance_factor * flops_loss
    
    if int(epoch/search_freq)%2==0:
        optimizer_net.zero_grad()
        loss.backward()
        optimizer_net.step()
    else:
        optimizer_arch.zero_grad()
        loss.backward()
        optimizer_arch.step()
    
    selected_channel_id = np.argmax(net.arch_params.data.numpy())
    selected_channel = hidden_size_choices[selected_channel_id]

    with torch.no_grad():
        _, predicted = torch.max(outputs, 1)
        total = y_train.size(0)
        correct = (predicted == y_train).sum().item()
        accuracy = correct / total
        if (epoch+1) % 10 == 0:
            
            print('Epoch [{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Flops_Loss: {:.4f}, Accuracy: {:.2f}%, Seleted Channel {}'
                  .format(epoch+1, num_epochs, loss.item(), CE_loss.item(),
                          flops_loss.item(), accuracy * 100, selected_channel))

Epoch [10/10000], Loss: 0.6505, CE_Loss: 0.6922, Flops_Loss: 0.5533, Accuracy: 54.40%, Seleted Channel 100
Epoch [20/10000], Loss: 0.6535, CE_Loss: 0.6913, Flops_Loss: 0.5654, Accuracy: 58.40%, Seleted Channel 100
Epoch [30/10000], Loss: 0.6496, CE_Loss: 0.6912, Flops_Loss: 0.5524, Accuracy: 58.60%, Seleted Channel 440
Epoch [40/10000], Loss: 0.6475, CE_Loss: 0.6913, Flops_Loss: 0.5453, Accuracy: 58.50%, Seleted Channel 440
Epoch [50/10000], Loss: 0.6480, CE_Loss: 0.6904, Flops_Loss: 0.5491, Accuracy: 58.20%, Seleted Channel 440
Epoch [60/10000], Loss: 0.6479, CE_Loss: 0.6894, Flops_Loss: 0.5511, Accuracy: 59.80%, Seleted Channel 440
Epoch [70/10000], Loss: 0.6406, CE_Loss: 0.6895, Flops_Loss: 0.5264, Accuracy: 59.70%, Seleted Channel 440
Epoch [80/10000], Loss: 0.6396, CE_Loss: 0.6895, Flops_Loss: 0.5230, Accuracy: 59.80%, Seleted Channel 130
Epoch [90/10000], Loss: 0.6416, CE_Loss: 0.6885, Flops_Loss: 0.5320, Accuracy: 59.70%, Seleted Channel 130
Epoch [100/10000], Loss: 0.6417, CE_L

Epoch [830/10000], Loss: 0.5514, CE_Loss: 0.6327, Flops_Loss: 0.3617, Accuracy: 66.00%, Seleted Channel 130
Epoch [840/10000], Loss: 0.5505, CE_Loss: 0.6330, Flops_Loss: 0.3579, Accuracy: 66.00%, Seleted Channel 130
Epoch [850/10000], Loss: 0.5500, CE_Loss: 0.6313, Flops_Loss: 0.3605, Accuracy: 66.00%, Seleted Channel 130
Epoch [860/10000], Loss: 0.5499, CE_Loss: 0.6291, Flops_Loss: 0.3652, Accuracy: 66.30%, Seleted Channel 130
Epoch [870/10000], Loss: 0.5472, CE_Loss: 0.6303, Flops_Loss: 0.3532, Accuracy: 66.30%, Seleted Channel 130
Epoch [880/10000], Loss: 0.5480, CE_Loss: 0.6299, Flops_Loss: 0.3568, Accuracy: 66.20%, Seleted Channel 130
Epoch [890/10000], Loss: 0.5467, CE_Loss: 0.6285, Flops_Loss: 0.3557, Accuracy: 66.10%, Seleted Channel 130
Epoch [900/10000], Loss: 0.5450, CE_Loss: 0.6273, Flops_Loss: 0.3529, Accuracy: 66.20%, Seleted Channel 130
Epoch [910/10000], Loss: 0.5476, CE_Loss: 0.6259, Flops_Loss: 0.3647, Accuracy: 66.30%, Seleted Channel 130
Epoch [920/10000], Loss: 0.5

Epoch [1600/10000], Loss: 0.4915, CE_Loss: 0.5881, Flops_Loss: 0.2661, Accuracy: 70.60%, Seleted Channel 130
Epoch [1610/10000], Loss: 0.4908, CE_Loss: 0.5861, Flops_Loss: 0.2686, Accuracy: 70.60%, Seleted Channel 130
Epoch [1620/10000], Loss: 0.4862, CE_Loss: 0.5899, Flops_Loss: 0.2443, Accuracy: 70.80%, Seleted Channel 130
Epoch [1630/10000], Loss: 0.4870, CE_Loss: 0.5888, Flops_Loss: 0.2496, Accuracy: 70.70%, Seleted Channel 100
Epoch [1640/10000], Loss: 0.4864, CE_Loss: 0.5891, Flops_Loss: 0.2466, Accuracy: 70.80%, Seleted Channel 100
Epoch [1650/10000], Loss: 0.4853, CE_Loss: 0.5876, Flops_Loss: 0.2465, Accuracy: 70.80%, Seleted Channel 100
Epoch [1660/10000], Loss: 0.4826, CE_Loss: 0.5883, Flops_Loss: 0.2360, Accuracy: 70.70%, Seleted Channel 100
Epoch [1670/10000], Loss: 0.4852, CE_Loss: 0.5847, Flops_Loss: 0.2530, Accuracy: 70.80%, Seleted Channel 100
Epoch [1680/10000], Loss: 0.4834, CE_Loss: 0.5882, Flops_Loss: 0.2387, Accuracy: 70.90%, Seleted Channel 100
Epoch [1690/10000],

Epoch [2370/10000], Loss: 0.4436, CE_Loss: 0.5473, Flops_Loss: 0.2018, Accuracy: 76.10%, Seleted Channel 120
Epoch [2380/10000], Loss: 0.4433, CE_Loss: 0.5436, Flops_Loss: 0.2094, Accuracy: 76.50%, Seleted Channel 120
Epoch [2390/10000], Loss: 0.4427, CE_Loss: 0.5446, Flops_Loss: 0.2048, Accuracy: 76.50%, Seleted Channel 120
Epoch [2400/10000], Loss: 0.4436, CE_Loss: 0.5411, Flops_Loss: 0.2162, Accuracy: 76.30%, Seleted Channel 120
Epoch [2410/10000], Loss: 0.4419, CE_Loss: 0.5430, Flops_Loss: 0.2061, Accuracy: 76.50%, Seleted Channel 120
Epoch [2420/10000], Loss: 0.4403, CE_Loss: 0.5453, Flops_Loss: 0.1954, Accuracy: 76.40%, Seleted Channel 120
Epoch [2430/10000], Loss: 0.4400, CE_Loss: 0.5401, Flops_Loss: 0.2063, Accuracy: 76.20%, Seleted Channel 120
Epoch [2440/10000], Loss: 0.4401, CE_Loss: 0.5445, Flops_Loss: 0.1965, Accuracy: 76.60%, Seleted Channel 120
Epoch [2450/10000], Loss: 0.4390, CE_Loss: 0.5448, Flops_Loss: 0.1921, Accuracy: 76.40%, Seleted Channel 120
Epoch [2460/10000],

Epoch [3130/10000], Loss: 0.4036, CE_Loss: 0.4946, Flops_Loss: 0.1915, Accuracy: 82.00%, Seleted Channel 140
Epoch [3140/10000], Loss: 0.4017, CE_Loss: 0.4913, Flops_Loss: 0.1925, Accuracy: 82.20%, Seleted Channel 140
Epoch [3150/10000], Loss: 0.4027, CE_Loss: 0.4932, Flops_Loss: 0.1914, Accuracy: 82.10%, Seleted Channel 140
Epoch [3160/10000], Loss: 0.4024, CE_Loss: 0.4944, Flops_Loss: 0.1877, Accuracy: 82.00%, Seleted Channel 140
Epoch [3170/10000], Loss: 0.4022, CE_Loss: 0.4941, Flops_Loss: 0.1877, Accuracy: 82.50%, Seleted Channel 140
Epoch [3180/10000], Loss: 0.4007, CE_Loss: 0.4912, Flops_Loss: 0.1896, Accuracy: 82.30%, Seleted Channel 140
Epoch [3190/10000], Loss: 0.4000, CE_Loss: 0.4899, Flops_Loss: 0.1903, Accuracy: 82.50%, Seleted Channel 140
Epoch [3200/10000], Loss: 0.3998, CE_Loss: 0.4869, Flops_Loss: 0.1964, Accuracy: 82.20%, Seleted Channel 140
Epoch [3210/10000], Loss: 0.3993, CE_Loss: 0.4909, Flops_Loss: 0.1855, Accuracy: 82.70%, Seleted Channel 140
Epoch [3220/10000],

Epoch [3900/10000], Loss: 0.3599, CE_Loss: 0.4282, Flops_Loss: 0.2007, Accuracy: 87.30%, Seleted Channel 220
Epoch [3910/10000], Loss: 0.3604, CE_Loss: 0.4307, Flops_Loss: 0.1964, Accuracy: 87.10%, Seleted Channel 220
Epoch [3920/10000], Loss: 0.3594, CE_Loss: 0.4249, Flops_Loss: 0.2065, Accuracy: 87.20%, Seleted Channel 220
Epoch [3930/10000], Loss: 0.3595, CE_Loss: 0.4286, Flops_Loss: 0.1982, Accuracy: 87.20%, Seleted Channel 220
Epoch [3940/10000], Loss: 0.3572, CE_Loss: 0.4239, Flops_Loss: 0.2017, Accuracy: 87.50%, Seleted Channel 220
Epoch [3950/10000], Loss: 0.3562, CE_Loss: 0.4211, Flops_Loss: 0.2049, Accuracy: 87.70%, Seleted Channel 220
Epoch [3960/10000], Loss: 0.3573, CE_Loss: 0.4225, Flops_Loss: 0.2053, Accuracy: 87.40%, Seleted Channel 220
Epoch [3970/10000], Loss: 0.3572, CE_Loss: 0.4251, Flops_Loss: 0.1987, Accuracy: 87.70%, Seleted Channel 220
Epoch [3980/10000], Loss: 0.3553, CE_Loss: 0.4211, Flops_Loss: 0.2017, Accuracy: 87.60%, Seleted Channel 220
Epoch [3990/10000],

Epoch [4670/10000], Loss: 0.3131, CE_Loss: 0.3519, Flops_Loss: 0.2226, Accuracy: 91.70%, Seleted Channel 220
Epoch [4680/10000], Loss: 0.3124, CE_Loss: 0.3514, Flops_Loss: 0.2213, Accuracy: 91.70%, Seleted Channel 220
Epoch [4690/10000], Loss: 0.3115, CE_Loss: 0.3501, Flops_Loss: 0.2212, Accuracy: 91.90%, Seleted Channel 220
Epoch [4700/10000], Loss: 0.3116, CE_Loss: 0.3504, Flops_Loss: 0.2211, Accuracy: 92.10%, Seleted Channel 220
Epoch [4710/10000], Loss: 0.3105, CE_Loss: 0.3485, Flops_Loss: 0.2218, Accuracy: 92.10%, Seleted Channel 220
Epoch [4720/10000], Loss: 0.3107, CE_Loss: 0.3489, Flops_Loss: 0.2216, Accuracy: 92.00%, Seleted Channel 220
Epoch [4730/10000], Loss: 0.3101, CE_Loss: 0.3494, Flops_Loss: 0.2184, Accuracy: 92.10%, Seleted Channel 220
Epoch [4740/10000], Loss: 0.3095, CE_Loss: 0.3475, Flops_Loss: 0.2210, Accuracy: 92.30%, Seleted Channel 220
Epoch [4750/10000], Loss: 0.3077, CE_Loss: 0.3445, Flops_Loss: 0.2218, Accuracy: 92.20%, Seleted Channel 220
Epoch [4760/10000],

Epoch [5430/10000], Loss: 0.2730, CE_Loss: 0.2946, Flops_Loss: 0.2224, Accuracy: 96.40%, Seleted Channel 220
Epoch [5440/10000], Loss: 0.2735, CE_Loss: 0.2948, Flops_Loss: 0.2237, Accuracy: 96.30%, Seleted Channel 220
Epoch [5450/10000], Loss: 0.2722, CE_Loss: 0.2936, Flops_Loss: 0.2222, Accuracy: 96.40%, Seleted Channel 220
Epoch [5460/10000], Loss: 0.2712, CE_Loss: 0.2922, Flops_Loss: 0.2223, Accuracy: 96.40%, Seleted Channel 220
Epoch [5470/10000], Loss: 0.2713, CE_Loss: 0.2924, Flops_Loss: 0.2222, Accuracy: 96.40%, Seleted Channel 220
Epoch [5480/10000], Loss: 0.2714, CE_Loss: 0.2921, Flops_Loss: 0.2231, Accuracy: 96.40%, Seleted Channel 220
Epoch [5490/10000], Loss: 0.2703, CE_Loss: 0.2907, Flops_Loss: 0.2225, Accuracy: 96.40%, Seleted Channel 220
Epoch [5500/10000], Loss: 0.2698, CE_Loss: 0.2901, Flops_Loss: 0.2223, Accuracy: 96.60%, Seleted Channel 220
Epoch [5510/10000], Loss: 0.2694, CE_Loss: 0.2897, Flops_Loss: 0.2222, Accuracy: 96.40%, Seleted Channel 220
Epoch [5520/10000],

Epoch [6250/10000], Loss: 0.2365, CE_Loss: 0.2425, Flops_Loss: 0.2223, Accuracy: 98.30%, Seleted Channel 220
Epoch [6260/10000], Loss: 0.2358, CE_Loss: 0.2415, Flops_Loss: 0.2223, Accuracy: 98.50%, Seleted Channel 220
Epoch [6270/10000], Loss: 0.2357, CE_Loss: 0.2414, Flops_Loss: 0.2225, Accuracy: 98.50%, Seleted Channel 220
Epoch [6280/10000], Loss: 0.2356, CE_Loss: 0.2413, Flops_Loss: 0.2222, Accuracy: 98.60%, Seleted Channel 220
Epoch [6290/10000], Loss: 0.2349, CE_Loss: 0.2402, Flops_Loss: 0.2224, Accuracy: 98.60%, Seleted Channel 220
Epoch [6300/10000], Loss: 0.2340, CE_Loss: 0.2390, Flops_Loss: 0.2223, Accuracy: 98.60%, Seleted Channel 220
Epoch [6310/10000], Loss: 0.2340, CE_Loss: 0.2391, Flops_Loss: 0.2222, Accuracy: 98.60%, Seleted Channel 220
Epoch [6320/10000], Loss: 0.2340, CE_Loss: 0.2389, Flops_Loss: 0.2224, Accuracy: 98.60%, Seleted Channel 220
Epoch [6330/10000], Loss: 0.2333, CE_Loss: 0.2379, Flops_Loss: 0.2225, Accuracy: 98.60%, Seleted Channel 220
Epoch [6340/10000],

Epoch [7060/10000], Loss: 0.2049, CE_Loss: 0.1975, Flops_Loss: 0.2222, Accuracy: 99.60%, Seleted Channel 220
Epoch [7070/10000], Loss: 0.2049, CE_Loss: 0.1974, Flops_Loss: 0.2222, Accuracy: 99.60%, Seleted Channel 220
Epoch [7080/10000], Loss: 0.2049, CE_Loss: 0.1975, Flops_Loss: 0.2222, Accuracy: 99.60%, Seleted Channel 220
Epoch [7090/10000], Loss: 0.2042, CE_Loss: 0.1965, Flops_Loss: 0.2223, Accuracy: 99.60%, Seleted Channel 220
Epoch [7100/10000], Loss: 0.2036, CE_Loss: 0.1956, Flops_Loss: 0.2223, Accuracy: 99.60%, Seleted Channel 220
Epoch [7110/10000], Loss: 0.2036, CE_Loss: 0.1954, Flops_Loss: 0.2227, Accuracy: 99.60%, Seleted Channel 220
Epoch [7120/10000], Loss: 0.2034, CE_Loss: 0.1954, Flops_Loss: 0.2222, Accuracy: 99.60%, Seleted Channel 220
Epoch [7130/10000], Loss: 0.2028, CE_Loss: 0.1945, Flops_Loss: 0.2223, Accuracy: 99.60%, Seleted Channel 220
Epoch [7140/10000], Loss: 0.2021, CE_Loss: 0.1935, Flops_Loss: 0.2222, Accuracy: 99.60%, Seleted Channel 220
Epoch [7150/10000],

Epoch [7870/10000], Loss: 0.1789, CE_Loss: 0.1603, Flops_Loss: 0.2222, Accuracy: 99.70%, Seleted Channel 220
Epoch [7880/10000], Loss: 0.1788, CE_Loss: 0.1602, Flops_Loss: 0.2223, Accuracy: 99.70%, Seleted Channel 220
Epoch [7890/10000], Loss: 0.1783, CE_Loss: 0.1595, Flops_Loss: 0.2222, Accuracy: 99.70%, Seleted Channel 220
Epoch [7900/10000], Loss: 0.1777, CE_Loss: 0.1586, Flops_Loss: 0.2223, Accuracy: 99.70%, Seleted Channel 220
Epoch [7910/10000], Loss: 0.1777, CE_Loss: 0.1585, Flops_Loss: 0.2222, Accuracy: 99.70%, Seleted Channel 220
Epoch [7920/10000], Loss: 0.1776, CE_Loss: 0.1585, Flops_Loss: 0.2222, Accuracy: 99.70%, Seleted Channel 220
Epoch [7930/10000], Loss: 0.1772, CE_Loss: 0.1577, Flops_Loss: 0.2227, Accuracy: 99.70%, Seleted Channel 220
Epoch [7940/10000], Loss: 0.1766, CE_Loss: 0.1569, Flops_Loss: 0.2226, Accuracy: 99.70%, Seleted Channel 220
Epoch [7950/10000], Loss: 0.1765, CE_Loss: 0.1569, Flops_Loss: 0.2222, Accuracy: 99.70%, Seleted Channel 220
Epoch [7960/10000],

Epoch [8690/10000], Loss: 0.1568, CE_Loss: 0.1288, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [8700/10000], Loss: 0.1563, CE_Loss: 0.1281, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [8710/10000], Loss: 0.1563, CE_Loss: 0.1280, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [8720/10000], Loss: 0.1563, CE_Loss: 0.1280, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [8730/10000], Loss: 0.1558, CE_Loss: 0.1274, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [8740/10000], Loss: 0.1554, CE_Loss: 0.1267, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [8750/10000], Loss: 0.1553, CE_Loss: 0.1266, Flops_Loss: 0.2223, Accuracy: 100.00%, Seleted Channel 220
Epoch [8760/10000], Loss: 0.1553, CE_Loss: 0.1266, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [8770/10000], Loss: 0.1549, CE_Loss: 0.1260, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [878

Epoch [9440/10000], Loss: 0.1404, CE_Loss: 0.1053, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [9450/10000], Loss: 0.1400, CE_Loss: 0.1048, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [9460/10000], Loss: 0.1396, CE_Loss: 0.1042, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [9470/10000], Loss: 0.1396, CE_Loss: 0.1042, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [9480/10000], Loss: 0.1396, CE_Loss: 0.1042, Flops_Loss: 0.2223, Accuracy: 100.00%, Seleted Channel 220
Epoch [9490/10000], Loss: 0.1392, CE_Loss: 0.1036, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [9500/10000], Loss: 0.1388, CE_Loss: 0.1031, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [9510/10000], Loss: 0.1388, CE_Loss: 0.1030, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [9520/10000], Loss: 0.1388, CE_Loss: 0.1030, Flops_Loss: 0.2222, Accuracy: 100.00%, Seleted Channel 220
Epoch [953

## flops_balance_factor = 0.4
## Searched Hidden Size = 150
The trade-off parameter between performance (CE_Loss or MSE_Loss) and speed (Flops)

Larger flops_balance_factor leads to a faster network with worse performance

In [55]:
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize the neural network
in_dim = x_train.shape[1]
hidden_size_choices = list(range(10,1000,10))
num_classes =  2

net = SuperNet(in_dim, hidden_size_choices, num_classes)

# Set up the loss function and optimizer
criterion = nn.CrossEntropyLoss()

net_weight_lr = 0.0001
arch_lr = 0.01

optimizer_net = optim.Adam([p for name, p in net.named_parameters() if 'arch' not in name], lr=net_weight_lr)
optimizer_arch = optim.Adam([p for name, p in net.named_parameters() if 'arch' in name], lr=arch_lr)

# Search Epoch
num_epochs = 10000

# Iteartively optimize architecture parameters and network weights every 'search_freq' epochs
search_freq = 20

# The factor to balance performance (CE Loss or MSE Loss) and FLOPs
# Larger flops_balance_factor leads to a faster network with worse performance
flops_balance_factor = 0.4

# The tempreature is decayed by 'temp_anneal_factor' every 'temp_anneal_freq' epochs
# Larger temp leads to a gumbel weight that is more close to 1-hot distribution  
temp = 5
temp_anneal_factor = 0.95
temp_anneal_freq = num_epochs/25 # The temperatur will decay 25 times during the search

for epoch in range(num_epochs):
    if epoch % temp_anneal_freq == 0:
        temp = temp * temp_anneal_factor
    
    # Forward pass
    outputs, flops_loss = net(x_train, temp)
    CE_loss = criterion(outputs, y_train)
    loss = (1 - flops_balance_factor) * CE_loss + flops_balance_factor * flops_loss
    
    if int(epoch/search_freq)%2==0:
        optimizer_net.zero_grad()
        loss.backward()
        optimizer_net.step()
    else:
        optimizer_arch.zero_grad()
        loss.backward()
        optimizer_arch.step()
    
    selected_channel_id = np.argmax(net.arch_params.data.numpy())
    selected_channel = hidden_size_choices[selected_channel_id]

    with torch.no_grad():
        _, predicted = torch.max(outputs, 1)
        total = y_train.size(0)
        correct = (predicted == y_train).sum().item()
        accuracy = correct / total
        if (epoch+1) % 10 == 0:
            
            print('Epoch [{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Flops_Loss: {:.4f}, Accuracy: {:.2f}%, Seleted Channel {}'
                  .format(epoch+1, num_epochs, loss.item(), CE_loss.item(),
                          flops_loss.item(), accuracy * 100, selected_channel))

Epoch [10/10000], Loss: 0.6128, CE_Loss: 0.6926, Flops_Loss: 0.4931, Accuracy: 53.90%, Seleted Channel 10
Epoch [20/10000], Loss: 0.6241, CE_Loss: 0.6918, Flops_Loss: 0.5224, Accuracy: 56.00%, Seleted Channel 10
Epoch [30/10000], Loss: 0.6196, CE_Loss: 0.6918, Flops_Loss: 0.5114, Accuracy: 56.10%, Seleted Channel 330
Epoch [40/10000], Loss: 0.6096, CE_Loss: 0.6919, Flops_Loss: 0.4863, Accuracy: 56.40%, Seleted Channel 100
Epoch [50/10000], Loss: 0.6097, CE_Loss: 0.6912, Flops_Loss: 0.4875, Accuracy: 57.50%, Seleted Channel 100
Epoch [60/10000], Loss: 0.6136, CE_Loss: 0.6904, Flops_Loss: 0.4984, Accuracy: 58.50%, Seleted Channel 100
Epoch [70/10000], Loss: 0.6065, CE_Loss: 0.6904, Flops_Loss: 0.4806, Accuracy: 58.20%, Seleted Channel 100
Epoch [80/10000], Loss: 0.6120, CE_Loss: 0.6903, Flops_Loss: 0.4945, Accuracy: 58.30%, Seleted Channel 160
Epoch [90/10000], Loss: 0.6035, CE_Loss: 0.6896, Flops_Loss: 0.4743, Accuracy: 58.40%, Seleted Channel 160
Epoch [100/10000], Loss: 0.6115, CE_Los

Epoch [810/10000], Loss: 0.5151, CE_Loss: 0.6406, Flops_Loss: 0.3268, Accuracy: 64.90%, Seleted Channel 10
Epoch [820/10000], Loss: 0.5109, CE_Loss: 0.6402, Flops_Loss: 0.3169, Accuracy: 65.10%, Seleted Channel 10
Epoch [830/10000], Loss: 0.5119, CE_Loss: 0.6396, Flops_Loss: 0.3202, Accuracy: 65.40%, Seleted Channel 10
Epoch [840/10000], Loss: 0.5046, CE_Loss: 0.6418, Flops_Loss: 0.2987, Accuracy: 65.30%, Seleted Channel 30
Epoch [850/10000], Loss: 0.5035, CE_Loss: 0.6408, Flops_Loss: 0.2975, Accuracy: 65.40%, Seleted Channel 30
Epoch [860/10000], Loss: 0.5026, CE_Loss: 0.6394, Flops_Loss: 0.2975, Accuracy: 65.70%, Seleted Channel 30
Epoch [870/10000], Loss: 0.5028, CE_Loss: 0.6392, Flops_Loss: 0.2982, Accuracy: 65.50%, Seleted Channel 30
Epoch [880/10000], Loss: 0.5005, CE_Loss: 0.6399, Flops_Loss: 0.2915, Accuracy: 65.60%, Seleted Channel 30
Epoch [890/10000], Loss: 0.5026, CE_Loss: 0.6378, Flops_Loss: 0.2999, Accuracy: 65.50%, Seleted Channel 30
Epoch [900/10000], Loss: 0.4993, CE_L

Epoch [1610/10000], Loss: 0.4375, CE_Loss: 0.6180, Flops_Loss: 0.1668, Accuracy: 69.30%, Seleted Channel 10
Epoch [1620/10000], Loss: 0.4345, CE_Loss: 0.6184, Flops_Loss: 0.1588, Accuracy: 68.80%, Seleted Channel 10
Epoch [1630/10000], Loss: 0.4372, CE_Loss: 0.6162, Flops_Loss: 0.1688, Accuracy: 69.30%, Seleted Channel 10
Epoch [1640/10000], Loss: 0.4388, CE_Loss: 0.6131, Flops_Loss: 0.1775, Accuracy: 68.60%, Seleted Channel 10
Epoch [1650/10000], Loss: 0.4370, CE_Loss: 0.6145, Flops_Loss: 0.1708, Accuracy: 69.20%, Seleted Channel 10
Epoch [1660/10000], Loss: 0.4331, CE_Loss: 0.6174, Flops_Loss: 0.1565, Accuracy: 68.90%, Seleted Channel 10
Epoch [1670/10000], Loss: 0.4361, CE_Loss: 0.6139, Flops_Loss: 0.1695, Accuracy: 69.00%, Seleted Channel 10
Epoch [1680/10000], Loss: 0.4366, CE_Loss: 0.6131, Flops_Loss: 0.1718, Accuracy: 69.00%, Seleted Channel 10
Epoch [1690/10000], Loss: 0.4370, CE_Loss: 0.6111, Flops_Loss: 0.1758, Accuracy: 69.00%, Seleted Channel 10
Epoch [1700/10000], Loss: 0.

Epoch [2410/10000], Loss: 0.4039, CE_Loss: 0.5971, Flops_Loss: 0.1141, Accuracy: 70.60%, Seleted Channel 10
Epoch [2420/10000], Loss: 0.4027, CE_Loss: 0.5987, Flops_Loss: 0.1087, Accuracy: 71.00%, Seleted Channel 10
Epoch [2430/10000], Loss: 0.4029, CE_Loss: 0.6033, Flops_Loss: 0.1023, Accuracy: 70.30%, Seleted Channel 10
Epoch [2440/10000], Loss: 0.4037, CE_Loss: 0.5998, Flops_Loss: 0.1096, Accuracy: 70.50%, Seleted Channel 10
Epoch [2450/10000], Loss: 0.4036, CE_Loss: 0.5978, Flops_Loss: 0.1123, Accuracy: 70.30%, Seleted Channel 10
Epoch [2460/10000], Loss: 0.4015, CE_Loss: 0.6020, Flops_Loss: 0.1008, Accuracy: 70.40%, Seleted Channel 10
Epoch [2470/10000], Loss: 0.4029, CE_Loss: 0.5973, Flops_Loss: 0.1112, Accuracy: 70.50%, Seleted Channel 10
Epoch [2480/10000], Loss: 0.4037, CE_Loss: 0.5938, Flops_Loss: 0.1185, Accuracy: 70.70%, Seleted Channel 10
Epoch [2490/10000], Loss: 0.4020, CE_Loss: 0.5978, Flops_Loss: 0.1085, Accuracy: 70.40%, Seleted Channel 10
Epoch [2500/10000], Loss: 0.

Epoch [3210/10000], Loss: 0.3811, CE_Loss: 0.5714, Flops_Loss: 0.0957, Accuracy: 72.90%, Seleted Channel 10
Epoch [3220/10000], Loss: 0.3815, CE_Loss: 0.5739, Flops_Loss: 0.0931, Accuracy: 73.20%, Seleted Channel 10
Epoch [3230/10000], Loss: 0.3807, CE_Loss: 0.5749, Flops_Loss: 0.0893, Accuracy: 73.10%, Seleted Channel 70
Epoch [3240/10000], Loss: 0.3812, CE_Loss: 0.5751, Flops_Loss: 0.0905, Accuracy: 73.30%, Seleted Channel 70
Epoch [3250/10000], Loss: 0.3802, CE_Loss: 0.5723, Flops_Loss: 0.0922, Accuracy: 73.20%, Seleted Channel 70
Epoch [3260/10000], Loss: 0.3812, CE_Loss: 0.5764, Flops_Loss: 0.0884, Accuracy: 73.20%, Seleted Channel 70
Epoch [3270/10000], Loss: 0.3805, CE_Loss: 0.5742, Flops_Loss: 0.0900, Accuracy: 73.00%, Seleted Channel 70
Epoch [3280/10000], Loss: 0.3807, CE_Loss: 0.5743, Flops_Loss: 0.0902, Accuracy: 73.20%, Seleted Channel 70
Epoch [3290/10000], Loss: 0.3789, CE_Loss: 0.5621, Flops_Loss: 0.1041, Accuracy: 72.90%, Seleted Channel 70
Epoch [3300/10000], Loss: 0.

Epoch [4010/10000], Loss: 0.3537, CE_Loss: 0.5192, Flops_Loss: 0.1055, Accuracy: 78.00%, Seleted Channel 90
Epoch [4020/10000], Loss: 0.3536, CE_Loss: 0.5203, Flops_Loss: 0.1037, Accuracy: 78.60%, Seleted Channel 90
Epoch [4030/10000], Loss: 0.3549, CE_Loss: 0.5279, Flops_Loss: 0.0954, Accuracy: 78.50%, Seleted Channel 90
Epoch [4040/10000], Loss: 0.3532, CE_Loss: 0.5181, Flops_Loss: 0.1059, Accuracy: 78.70%, Seleted Channel 90
Epoch [4050/10000], Loss: 0.3532, CE_Loss: 0.5170, Flops_Loss: 0.1075, Accuracy: 78.50%, Seleted Channel 90
Epoch [4060/10000], Loss: 0.3527, CE_Loss: 0.5167, Flops_Loss: 0.1067, Accuracy: 78.90%, Seleted Channel 90
Epoch [4070/10000], Loss: 0.3519, CE_Loss: 0.5153, Flops_Loss: 0.1070, Accuracy: 79.10%, Seleted Channel 90
Epoch [4080/10000], Loss: 0.3538, CE_Loss: 0.5249, Flops_Loss: 0.0972, Accuracy: 78.80%, Seleted Channel 90
Epoch [4090/10000], Loss: 0.3525, CE_Loss: 0.5222, Flops_Loss: 0.0979, Accuracy: 79.50%, Seleted Channel 90
Epoch [4100/10000], Loss: 0.

Epoch [4810/10000], Loss: 0.3168, CE_Loss: 0.4431, Flops_Loss: 0.1273, Accuracy: 84.80%, Seleted Channel 120
Epoch [4820/10000], Loss: 0.3176, CE_Loss: 0.4497, Flops_Loss: 0.1195, Accuracy: 85.20%, Seleted Channel 120
Epoch [4830/10000], Loss: 0.3151, CE_Loss: 0.4354, Flops_Loss: 0.1348, Accuracy: 85.10%, Seleted Channel 120
Epoch [4840/10000], Loss: 0.3141, CE_Loss: 0.4345, Flops_Loss: 0.1336, Accuracy: 85.30%, Seleted Channel 120
Epoch [4850/10000], Loss: 0.3152, CE_Loss: 0.4412, Flops_Loss: 0.1262, Accuracy: 85.40%, Seleted Channel 120
Epoch [4860/10000], Loss: 0.3150, CE_Loss: 0.4437, Flops_Loss: 0.1219, Accuracy: 85.40%, Seleted Channel 120
Epoch [4870/10000], Loss: 0.3117, CE_Loss: 0.4284, Flops_Loss: 0.1366, Accuracy: 85.50%, Seleted Channel 120
Epoch [4880/10000], Loss: 0.3138, CE_Loss: 0.4397, Flops_Loss: 0.1251, Accuracy: 85.10%, Seleted Channel 150
Epoch [4890/10000], Loss: 0.3135, CE_Loss: 0.4379, Flops_Loss: 0.1270, Accuracy: 85.50%, Seleted Channel 150
Epoch [4900/10000],

Epoch [5610/10000], Loss: 0.2740, CE_Loss: 0.3574, Flops_Loss: 0.1489, Accuracy: 91.60%, Seleted Channel 150
Epoch [5620/10000], Loss: 0.2726, CE_Loss: 0.3538, Flops_Loss: 0.1507, Accuracy: 91.70%, Seleted Channel 150
Epoch [5630/10000], Loss: 0.2728, CE_Loss: 0.3546, Flops_Loss: 0.1500, Accuracy: 91.70%, Seleted Channel 150
Epoch [5640/10000], Loss: 0.2727, CE_Loss: 0.3545, Flops_Loss: 0.1501, Accuracy: 91.70%, Seleted Channel 150
Epoch [5650/10000], Loss: 0.2724, CE_Loss: 0.3549, Flops_Loss: 0.1488, Accuracy: 91.70%, Seleted Channel 150
Epoch [5660/10000], Loss: 0.2712, CE_Loss: 0.3523, Flops_Loss: 0.1496, Accuracy: 91.80%, Seleted Channel 150
Epoch [5670/10000], Loss: 0.2707, CE_Loss: 0.3508, Flops_Loss: 0.1506, Accuracy: 91.80%, Seleted Channel 150
Epoch [5680/10000], Loss: 0.2710, CE_Loss: 0.3520, Flops_Loss: 0.1496, Accuracy: 91.80%, Seleted Channel 150
Epoch [5690/10000], Loss: 0.2704, CE_Loss: 0.3510, Flops_Loss: 0.1494, Accuracy: 91.80%, Seleted Channel 150
Epoch [5700/10000],

Epoch [6410/10000], Loss: 0.2402, CE_Loss: 0.2996, Flops_Loss: 0.1512, Accuracy: 95.40%, Seleted Channel 150
Epoch [6420/10000], Loss: 0.2395, CE_Loss: 0.2985, Flops_Loss: 0.1511, Accuracy: 95.40%, Seleted Channel 150
Epoch [6430/10000], Loss: 0.2393, CE_Loss: 0.2980, Flops_Loss: 0.1513, Accuracy: 95.40%, Seleted Channel 150
Epoch [6440/10000], Loss: 0.2393, CE_Loss: 0.2979, Flops_Loss: 0.1514, Accuracy: 95.40%, Seleted Channel 150
Epoch [6450/10000], Loss: 0.2387, CE_Loss: 0.2970, Flops_Loss: 0.1512, Accuracy: 95.40%, Seleted Channel 150
Epoch [6460/10000], Loss: 0.2379, CE_Loss: 0.2958, Flops_Loss: 0.1512, Accuracy: 95.50%, Seleted Channel 150
Epoch [6470/10000], Loss: 0.2379, CE_Loss: 0.2957, Flops_Loss: 0.1512, Accuracy: 95.50%, Seleted Channel 150
Epoch [6480/10000], Loss: 0.2379, CE_Loss: 0.2957, Flops_Loss: 0.1512, Accuracy: 95.50%, Seleted Channel 150
Epoch [6490/10000], Loss: 0.2373, CE_Loss: 0.2947, Flops_Loss: 0.1511, Accuracy: 95.50%, Seleted Channel 150
Epoch [6500/10000],

Epoch [7200/10000], Loss: 0.2127, CE_Loss: 0.2536, Flops_Loss: 0.1514, Accuracy: 97.40%, Seleted Channel 150
Epoch [7210/10000], Loss: 0.2120, CE_Loss: 0.2524, Flops_Loss: 0.1515, Accuracy: 97.40%, Seleted Channel 150
Epoch [7220/10000], Loss: 0.2114, CE_Loss: 0.2513, Flops_Loss: 0.1515, Accuracy: 97.40%, Seleted Channel 150
Epoch [7230/10000], Loss: 0.2114, CE_Loss: 0.2514, Flops_Loss: 0.1514, Accuracy: 97.40%, Seleted Channel 150
Epoch [7240/10000], Loss: 0.2114, CE_Loss: 0.2513, Flops_Loss: 0.1514, Accuracy: 97.40%, Seleted Channel 150
Epoch [7250/10000], Loss: 0.2108, CE_Loss: 0.2504, Flops_Loss: 0.1514, Accuracy: 97.50%, Seleted Channel 150
Epoch [7260/10000], Loss: 0.2101, CE_Loss: 0.2493, Flops_Loss: 0.1514, Accuracy: 97.60%, Seleted Channel 150
Epoch [7270/10000], Loss: 0.2100, CE_Loss: 0.2491, Flops_Loss: 0.1515, Accuracy: 97.60%, Seleted Channel 150
Epoch [7280/10000], Loss: 0.2101, CE_Loss: 0.2492, Flops_Loss: 0.1514, Accuracy: 97.60%, Seleted Channel 150
Epoch [7290/10000],

Epoch [7990/10000], Loss: 0.1886, CE_Loss: 0.2134, Flops_Loss: 0.1515, Accuracy: 98.50%, Seleted Channel 150
Epoch [8000/10000], Loss: 0.1886, CE_Loss: 0.2134, Flops_Loss: 0.1515, Accuracy: 98.50%, Seleted Channel 150
Epoch [8010/10000], Loss: 0.1881, CE_Loss: 0.2125, Flops_Loss: 0.1515, Accuracy: 98.50%, Seleted Channel 150
Epoch [8020/10000], Loss: 0.1876, CE_Loss: 0.2116, Flops_Loss: 0.1515, Accuracy: 98.50%, Seleted Channel 150
Epoch [8030/10000], Loss: 0.1875, CE_Loss: 0.2115, Flops_Loss: 0.1515, Accuracy: 98.50%, Seleted Channel 150
Epoch [8040/10000], Loss: 0.1875, CE_Loss: 0.2115, Flops_Loss: 0.1515, Accuracy: 98.50%, Seleted Channel 150
Epoch [8050/10000], Loss: 0.1870, CE_Loss: 0.2107, Flops_Loss: 0.1515, Accuracy: 98.50%, Seleted Channel 150
Epoch [8060/10000], Loss: 0.1865, CE_Loss: 0.2098, Flops_Loss: 0.1515, Accuracy: 98.60%, Seleted Channel 150
Epoch [8070/10000], Loss: 0.1864, CE_Loss: 0.2098, Flops_Loss: 0.1515, Accuracy: 98.60%, Seleted Channel 150
Epoch [8080/10000],

Epoch [8790/10000], Loss: 0.1681, CE_Loss: 0.1792, Flops_Loss: 0.1515, Accuracy: 99.10%, Seleted Channel 150
Epoch [8800/10000], Loss: 0.1681, CE_Loss: 0.1792, Flops_Loss: 0.1515, Accuracy: 99.10%, Seleted Channel 150
Epoch [8810/10000], Loss: 0.1677, CE_Loss: 0.1785, Flops_Loss: 0.1515, Accuracy: 99.10%, Seleted Channel 150
Epoch [8820/10000], Loss: 0.1672, CE_Loss: 0.1777, Flops_Loss: 0.1515, Accuracy: 99.20%, Seleted Channel 150
Epoch [8830/10000], Loss: 0.1672, CE_Loss: 0.1776, Flops_Loss: 0.1515, Accuracy: 99.20%, Seleted Channel 150
Epoch [8840/10000], Loss: 0.1672, CE_Loss: 0.1776, Flops_Loss: 0.1515, Accuracy: 99.20%, Seleted Channel 150
Epoch [8850/10000], Loss: 0.1667, CE_Loss: 0.1769, Flops_Loss: 0.1515, Accuracy: 99.20%, Seleted Channel 150
Epoch [8860/10000], Loss: 0.1663, CE_Loss: 0.1761, Flops_Loss: 0.1515, Accuracy: 99.20%, Seleted Channel 150
Epoch [8870/10000], Loss: 0.1662, CE_Loss: 0.1760, Flops_Loss: 0.1515, Accuracy: 99.20%, Seleted Channel 150
Epoch [8880/10000],

Epoch [9590/10000], Loss: 0.1507, CE_Loss: 0.1501, Flops_Loss: 0.1515, Accuracy: 99.70%, Seleted Channel 150
Epoch [9600/10000], Loss: 0.1507, CE_Loss: 0.1501, Flops_Loss: 0.1515, Accuracy: 99.70%, Seleted Channel 150
Epoch [9610/10000], Loss: 0.1503, CE_Loss: 0.1495, Flops_Loss: 0.1515, Accuracy: 99.70%, Seleted Channel 150
Epoch [9620/10000], Loss: 0.1499, CE_Loss: 0.1488, Flops_Loss: 0.1515, Accuracy: 99.70%, Seleted Channel 150
Epoch [9630/10000], Loss: 0.1499, CE_Loss: 0.1488, Flops_Loss: 0.1515, Accuracy: 99.70%, Seleted Channel 150
Epoch [9640/10000], Loss: 0.1499, CE_Loss: 0.1488, Flops_Loss: 0.1515, Accuracy: 99.70%, Seleted Channel 150
Epoch [9650/10000], Loss: 0.1495, CE_Loss: 0.1482, Flops_Loss: 0.1515, Accuracy: 99.70%, Seleted Channel 150
Epoch [9660/10000], Loss: 0.1491, CE_Loss: 0.1475, Flops_Loss: 0.1515, Accuracy: 99.70%, Seleted Channel 150
Epoch [9670/10000], Loss: 0.1491, CE_Loss: 0.1474, Flops_Loss: 0.1515, Accuracy: 99.70%, Seleted Channel 150
Epoch [9680/10000],