In [40]:
# Load data
import numpy as np

from data_utils import (
    get_embeddings, 
    synthesize_database, 
    synthesize_simple_database,
    get_passenger_database,
    check_passenger_exist,
    create_simple_trainset
)

is_simple_data=True
# Airport dataset
embed_original, indices = get_embeddings() # Considered the query pictures
if is_simple_data:
    embed_data, id_data, location_data = synthesize_simple_database(embed_original)
    date_data = None
else:
    embed_data, id_data, date_data, location_data = synthesize_database(embed_original)

embed_data = np.stack(embed_data)

In [41]:
import torch
import torch.nn as nn
import torch.optim as optim

x_train = np.array(embed_data[:1000, :])
y_train = np.array(location_data[:1000])

x_train = torch.FloatTensor(x_train)
print(x_train.shape)

y_train[y_train==2] = 0 # Make sure the class index start from 0 (0,1 in this example)
y_train = torch.LongTensor(y_train)
print(y_train.shape)


torch.Size([1000, 128])
torch.Size([1000])


In [42]:
def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> torch.Tensor:
    r"""
    Args:
      logits: `[..., num_features]` unnormalized log probabilities
      tau: non-negative scalar temperature
      hard: if ``True``, the returned samples will be discretized as one-hot vectors,
            but will be differentiated as if it is the soft sample in autograd
      dim (int): A dimension along which softmax will be computed. Default: -1.

    Returns:
      Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
      If ``hard=True``, the returned samples will be one-hot, otherwise they will
      be probability distributions that sum to 1 across `dim`.
    """
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0,1)

    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

In [74]:
def get_channel_mask(hidden_size_choices):
    max_hidden_size = max(hidden_size_choices)
    num_choices = len(hidden_size_choices)
    masks = torch.zeros(max_hidden_size, num_choices)
    for i in range(num_choices):
        masks[:hidden_size_choices[i], i]=1
    return masks
    
def get_flops_choices(input_size, hidden_size_choices, num_classes):
    flops = []
    for hidden_size in hidden_size_choices:
        flops.append(2*hidden_size*input_size + 2*hidden_size*num_classes)
    flops = np.array(flops)
    return flops
    
class SuperNet_CS_LS(nn.Module):
    def __init__(self, input_size, hidden_size_choices, layer_choices, num_classes):
        super(SuperNet_CS_LS, self).__init__()
        
        max_hidden_size = max(hidden_size_choices)
        num_choices_hidden = len(hidden_size_choices)
        num_choices_layer = len(layer_choices)
        
        self.arch_params_hidden = torch.nn.Parameter(torch.ones(num_choices_hidden), requires_grad=True)
        self.arch_params_layer = torch.nn.Parameter(torch.ones(num_choices_layer), requires_grad=True)
        
        self.masks = get_channel_mask(hidden_size_choices)
        self.flops_choices = get_flops_choices(input_size, hidden_size_choices, num_classes)
        self.flops_choices_normalized = torch.FloatTensor(self.flops_choices / np.max(self.flops_choices))
        self.layer_choices_tensor = torch.FloatTensor(layer_choices)
        
        self.fc1 = nn.ModuleList()
        for layer_num in layer_choices:
            layers = []
            layers.append(nn.Linear(input_size, max_hidden_size))
            layers.append(nn.ReLU(inplace=True))
            for i in range(layer_num-1):
                layers.append(nn.Linear(max_hidden_size, max_hidden_size))
                layers.append(nn.ReLU(inplace=True))
            layers = nn.Sequential(*layers)
        self.fc1.append(layers)
        
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(max_hidden_size, num_classes)
        
    def forward(self, x, temperature):
        out_fc1 = []
        for fc1_each in self.fc1:
            out_fc1_each = fc1_each(x)
            out_fc1.append(out_fc1_each)
        out = torch.stack(out_fc1, dim=-1)
        
        gumbel_weights_layer = gumbel_softmax(self.arch_params_layer, tau=temperature, hard=False)
        out = torch.multiply(out, gumbel_weights_layer)
        out = torch.sum(out,dim=-1)
        # print(self.arch_params)
        gumbel_weights_hidden = gumbel_softmax(self.arch_params_hidden, tau=temperature, hard=False)
        # print(gumbel_weights)
        mask = torch.multiply(self.masks, gumbel_weights_hidden)
        mask = torch.sum(mask, dim=-1)
        out = torch.multiply(out, mask)
        
        out = self.fc2(out)
        flops_loss = self._get_flops_loss(gumbel_weights_layer, gumbel_weights_hidden)
        return out, flops_loss
    
    def _get_flops_loss(self, gumbel_weights_layer, gumbel_weights_hidden):
        layer_factor = torch.matmul(self.layer_choices_tensor, gumbel_weights_layer)
        hidden_factor = torch.matmul(self.flops_choices_normalized, gumbel_weights_hidden)
        flops = layer_factor * hidden_factor
        return flops

In [86]:
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize the neural network
in_dim = x_train.shape[1]
hidden_size_choices = list(range(10,1000,10))
layer_choices = [1,2,3]

num_classes =  2

net = SuperNet_CS_LS(in_dim, hidden_size_choices, layer_choices, num_classes)

# Set up the loss function and optimizer
criterion = nn.CrossEntropyLoss()

net_weight_lr = 0.0001
arch_lr = 0.001

optimizer_net = optim.Adam([p for name, p in net.named_parameters() if 'arch' not in name], lr=net_weight_lr)
optimizer_arch = optim.Adam([p for name, p in net.named_parameters() if 'arch' in name], lr=arch_lr)

# Search Epoch
num_epochs = 1000

# Iteartively optimize architecture parameters and network weights every 'search_freq' epochs
search_freq = 10

# The factor to balance performance (CE Loss or MSE Loss) and FLOPs
# Larger flops_balance_factor leads to a faster network with worse performance
flops_balance_factor = 0.1

# The tempreature is decayed by 'temp_anneal_factor' every 'temp_anneal_freq' epochs
# Larger temp leads to a gumbel weight that is more close to 1-hot distribution  
temp = 5
temp_anneal_factor = 0.95
temp_anneal_freq = num_epochs/100 # The temperatur will decay 25 times during the search

warmup = 50

for epoch in range(num_epochs):
    if (epoch % temp_anneal_freq == 0) and(epoch>=warmup):
        temp = temp * temp_anneal_factor
    
    # Forward pass
    outputs, flops_loss = net(x_train, temp)
    CE_loss = criterion(outputs, y_train)
    loss = (1 - flops_balance_factor) * CE_loss + flops_balance_factor * flops_loss
    
    if (int(epoch/search_freq)%2==0) or (epoch<=warmup):
        optimizer_net.zero_grad()
        loss.backward()
        optimizer_net.step()
    else:
        optimizer_arch.zero_grad()
        loss.backward()
        optimizer_arch.step()
    
    selected_channel_id = np.argmax(net.arch_params_hidden.data.numpy())
    selected_channel = hidden_size_choices[selected_channel_id]
    
    selected_layer_id = np.argmax(net.arch_params_layer.data.numpy())
    selected_layer = layer_choices[selected_layer_id]

    with torch.no_grad():
        _, predicted = torch.max(outputs, 1)
        total = y_train.size(0)
        correct = (predicted == y_train).sum().item()
        accuracy = correct / total
        if (epoch+1) % 10 == 0:
            
            print('Epoch [{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Flops_Loss: {:.4f}, Accuracy: {:.2f}%, Channel {}, Layer {}'
                  .format(epoch+1, num_epochs, loss.item(), CE_loss.item(),
                          flops_loss.item(), accuracy * 100, selected_channel, selected_layer))

Epoch [10/1000], Loss: 0.7285, CE_Loss: 0.6919, Flops_Loss: 1.0574, Accuracy: 60.20%, Channel 10, Layer 1
Epoch [20/1000], Loss: 0.7186, CE_Loss: 0.6889, Flops_Loss: 0.9864, Accuracy: 60.70%, Channel 10, Layer 1
Epoch [30/1000], Loss: 0.7239, CE_Loss: 0.6818, Flops_Loss: 1.1019, Accuracy: 61.10%, Channel 10, Layer 1
Epoch [40/1000], Loss: 0.7013, CE_Loss: 0.6684, Flops_Loss: 0.9973, Accuracy: 61.60%, Channel 10, Layer 1
Epoch [50/1000], Loss: 0.6771, CE_Loss: 0.6485, Flops_Loss: 0.9343, Accuracy: 62.60%, Channel 10, Layer 1
Epoch [60/1000], Loss: 0.6948, CE_Loss: 0.6444, Flops_Loss: 1.1485, Accuracy: 62.90%, Channel 50, Layer 1
Epoch [70/1000], Loss: 0.6619, CE_Loss: 0.6284, Flops_Loss: 0.9640, Accuracy: 65.50%, Channel 50, Layer 1
Epoch [80/1000], Loss: 0.6574, CE_Loss: 0.6252, Flops_Loss: 0.9474, Accuracy: 65.60%, Channel 50, Layer 1
Epoch [90/1000], Loss: 0.6481, CE_Loss: 0.6053, Flops_Loss: 1.0329, Accuracy: 67.40%, Channel 50, Layer 1
Epoch [100/1000], Loss: 0.6559, CE_Loss: 0.602

Epoch [770/1000], Loss: 0.0508, CE_Loss: 0.0311, Flops_Loss: 0.2280, Accuracy: 100.00%, Channel 500, Layer 1
Epoch [780/1000], Loss: 0.2229, CE_Loss: 0.0002, Flops_Loss: 2.2271, Accuracy: 100.00%, Channel 500, Layer 1
Epoch [790/1000], Loss: 0.0534, CE_Loss: 0.0009, Flops_Loss: 0.5253, Accuracy: 100.00%, Channel 500, Layer 1
Epoch [800/1000], Loss: 0.0631, CE_Loss: 0.0004, Flops_Loss: 0.6270, Accuracy: 100.00%, Channel 500, Layer 1
Epoch [810/1000], Loss: 0.5289, CE_Loss: 0.5843, Flops_Loss: 0.0307, Accuracy: 100.00%, Channel 500, Layer 1
Epoch [820/1000], Loss: 0.0672, CE_Loss: 0.0075, Flops_Loss: 0.6043, Accuracy: 100.00%, Channel 480, Layer 1
Epoch [830/1000], Loss: 0.0684, CE_Loss: 0.0339, Flops_Loss: 0.3787, Accuracy: 100.00%, Channel 480, Layer 1
Epoch [840/1000], Loss: 0.0940, CE_Loss: 0.0000, Flops_Loss: 0.9394, Accuracy: 100.00%, Channel 480, Layer 1
Epoch [850/1000], Loss: 0.1205, CE_Loss: 0.1226, Flops_Loss: 0.1010, Accuracy: 100.00%, Channel 480, Layer 1
Epoch [860/1000], L