In [1]:
from time import time
from pathlib import Path
from itertools import product

import numpy as np
import pandas as pd

from data_utils import (
    get_embeddings, 
    synthesize_database, 
    synthesize_simple_database,
    get_passenger_database,
    check_passenger_exist,
    create_simple_trainset
)


In [2]:
is_simple_data=True
plan_number=None
latency = 0.0
message = "Elapsed time for query based on privacy preserving is {} seconds"

# Airport dataset
embed_original, indices = get_embeddings() # Considered the query pictures
if is_simple_data:
    embed_data, id_data, location_data = synthesize_simple_database(embed_original)
    date_data = None
else:
    embed_data, id_data, date_data, location_data = synthesize_database(embed_original)

embed_data = np.stack(embed_data)


In [3]:
from global_variables import date_range, frequency_range
import torch
import torch.nn as nn
import torch.optim as optim

x_train = np.array(embed_data)
print(x_train.shape)
y_train = np.array(id_data)
print(y_train.shape)

x_train = torch.FloatTensor(x_train)
y_train = torch.LongTensor(y_train)

(2222, 128)
(2222,)


In [95]:
def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> torch.Tensor:
    r"""
    Args:
      logits: `[..., num_features]` unnormalized log probabilities
      tau: non-negative scalar temperature
      hard: if ``True``, the returned samples will be discretized as one-hot vectors,
            but will be differentiated as if it is the soft sample in autograd
      dim (int): A dimension along which softmax will be computed. Default: -1.

    Returns:
      Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
      If ``hard=True``, the returned samples will be one-hot, otherwise they will
      be probability distributions that sum to 1 across `dim`.
    """
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0,1)

    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

In [117]:
def get_channel_mask(hidden_size_choices):
    max_hidden_size = max(hidden_size_choices)
    num_choices = len(hidden_size_choices)
    masks = torch.zeros(max_hidden_size, num_choices)
    for i in range(num_choices):
        masks[:hidden_size_choices[i], i]=1
    return masks
    
def get_flops_choices(input_size, hidden_size_choices, num_classes):
    flops = []
    for hidden_size in hidden_size_choices:
        flops.append(2*hidden_size*input_size + 2*hidden_size*num_classes)
    flops = np.array(flops)
    return flops
    
class SuperNet(nn.Module):
    def __init__(self, input_size, hidden_size_choices, num_classes):
        super(SuperNet, self).__init__()
        
        max_hidden_size = max(hidden_size_choices)
        num_choices = len(hidden_size_choices)
        
        self.arch_params = torch.nn.Parameter(torch.ones(num_choices), requires_grad=True)
        self.masks = get_channel_mask(hidden_size_choices)
        self.flops_choices = get_flops_choices(input_size, hidden_size_choices, num_classes)
        self.flops_choices_normalized = torch.FloatTensor(self.flops_choices / np.max(self.flops_choices))
        
        self.fc1 = nn.Linear(input_size, max_hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(max_hidden_size, num_classes)
        
    def forward(self, x, temperature):
        out = self.fc1(x)
        out = self.relu(out)
        # print(self.arch_params)
        gumbel_weights = gumbel_softmax(self.arch_params, tau=temperature, hard=False)
        # print(gumbel_weights)
        mask = torch.multiply(self.masks, gumbel_weights)
        mask = torch.sum(mask, dim=-1)
        out = torch.multiply(out, mask)
        
        out = self.fc2(out)
        flops_loss = self._get_flops_loss(gumbel_weights)
        return out, flops_loss
    
    def _get_flops_loss(self, gumbel_weights):
        return torch.matmul(self.flops_choices_normalized, gumbel_weights)

## flops_balance_factor = 0.5
The trade-off parameter between performance (CE_Loss or MSE_Loss) and speed (Flops)

Larger flops_balance_factor leads to a faster network with worse performance

In [135]:
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize the neural network
in_dim = x_train.shape[1]
hidden_size_choices = list(range(100,1000,10))
num_classes =  x_train.shape[0]
net = SuperNet(in_dim, hidden_size_choices, num_classes)

# Set up the loss function and optimizer
criterion = nn.CrossEntropyLoss()

net_weight_lr = 0.0001
arch_lr = 0.01

optimizer_net = optim.Adam([p for name, p in net.named_parameters() if 'arch' not in name], lr=net_weight_lr)
optimizer_arch = optim.Adam([p for name, p in net.named_parameters() if 'arch' in name], lr=arch_lr)

# Search Epoch
num_epochs = 800

# Iteartively optimize architecture parameters and network weights every 'search_freq' epochs
search_freq = 20

# The factor to balance performance (CE Loss or MSE Loss) and FLOPs
# Larger flops_balance_factor leads to a faster network with worse performance
flops_balance_factor = 0.5 

# The tempreature is decayed by 'temp_anneal_factor' every 'temp_anneal_freq' epochs
# Larger temp leads to a gumbel weight that is more close to 1-hot distribution  
temp = 5
temp_anneal_factor = 0.95
temp_anneal_freq = num_epochs/25 # The temperatur will decay 25 times during the search

for epoch in range(num_epochs):
    if epoch % temp_anneal_freq == 0:
        temp = temp * temp_anneal_factor
    
    # Forward pass
    outputs, flops_loss = net(x_train, temp)
    CE_loss = criterion(outputs, y_train)
    loss = (1 - flops_balance_factor) * CE_loss + flops_balance_factor * flops_loss
    
    if int(epoch/search_freq)%2==0:
        optimizer_net.zero_grad()
        loss.backward()
        optimizer_net.step()
    else:
        optimizer_arch.zero_grad()
        loss.backward()
        optimizer_arch.step()
    
    selected_channel_id = np.argmax(net.arch_params.data.numpy())
    selected_channel = hidden_size_choices[selected_channel_id]

    with torch.no_grad():
        _, predicted = torch.max(outputs, 1)
        total = y_train.size(0)
        correct = (predicted == y_train).sum().item()
        accuracy = correct / total
        if (epoch+1) % 10 == 0:
            
            print('Epoch [{}/{}], Overall_Loss: {:.4f}, CE_Loss: {:.4f}, Flops_Loss: {:.4f}, Accuracy: {:.2f}%, Seleted Channel {}'
                  .format(epoch+1, num_epochs, loss.item(), CE_loss.item(),
                          flops_loss.item(), accuracy * 100, selected_channel))

Epoch [10/800], Overall_Loss: 4.1291, CE_Loss: 7.6995, Flops_Loss: 0.5588, Accuracy: 0.05%, Seleted Channel 100
Epoch [20/800], Overall_Loss: 4.1269, CE_Loss: 7.6906, Flops_Loss: 0.5632, Accuracy: 0.09%, Seleted Channel 100
Epoch [30/800], Overall_Loss: 4.1156, CE_Loss: 7.6903, Flops_Loss: 0.5409, Accuracy: 0.09%, Seleted Channel 450
Epoch [40/800], Overall_Loss: 4.1162, CE_Loss: 7.6903, Flops_Loss: 0.5422, Accuracy: 0.09%, Seleted Channel 420
Epoch [50/800], Overall_Loss: 4.1128, CE_Loss: 7.6811, Flops_Loss: 0.5445, Accuracy: 0.45%, Seleted Channel 420
Epoch [60/800], Overall_Loss: 4.1062, CE_Loss: 7.6696, Flops_Loss: 0.5428, Accuracy: 0.99%, Seleted Channel 420
Epoch [70/800], Overall_Loss: 4.1052, CE_Loss: 7.6684, Flops_Loss: 0.5420, Accuracy: 1.04%, Seleted Channel 170
Epoch [80/800], Overall_Loss: 4.1035, CE_Loss: 7.6687, Flops_Loss: 0.5384, Accuracy: 0.99%, Seleted Channel 230
Epoch [90/800], Overall_Loss: 4.0931, CE_Loss: 7.6577, Flops_Loss: 0.5285, Accuracy: 2.61%, Seleted Chan

Epoch [740/800], Overall_Loss: 3.5228, CE_Loss: 6.6329, Flops_Loss: 0.4127, Accuracy: 98.42%, Seleted Channel 450
Epoch [750/800], Overall_Loss: 3.5090, CE_Loss: 6.5703, Flops_Loss: 0.4476, Accuracy: 98.65%, Seleted Channel 450
Epoch [760/800], Overall_Loss: 3.4997, CE_Loss: 6.5416, Flops_Loss: 0.4578, Accuracy: 98.60%, Seleted Channel 450
Epoch [770/800], Overall_Loss: 3.4850, CE_Loss: 6.5267, Flops_Loss: 0.4432, Accuracy: 98.74%, Seleted Channel 450
Epoch [780/800], Overall_Loss: 3.4698, CE_Loss: 6.5141, Flops_Loss: 0.4255, Accuracy: 98.83%, Seleted Channel 450
Epoch [790/800], Overall_Loss: 3.4451, CE_Loss: 6.4259, Flops_Loss: 0.4643, Accuracy: 98.96%, Seleted Channel 450
Epoch [800/800], Overall_Loss: 3.4573, CE_Loss: 6.4906, Flops_Loss: 0.4239, Accuracy: 98.56%, Seleted Channel 450


## flops_balance_factor = 0.8
The trade-off parameter between performance (CE_Loss or MSE_Loss) and speed (Flops)

Larger flops_balance_factor leads to a faster network with worse performance

In [137]:
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize the neural network
in_dim = x_train.shape[1]
hidden_size_choices = list(range(100,1000,10))
num_classes =  x_train.shape[0]
net = SuperNet(in_dim, hidden_size_choices, num_classes)

# Set up the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_net = optim.Adam([p for name, p in net.named_parameters() if 'arch' not in name], lr=0.0001)
optimizer_arch = optim.Adam([p for name, p in net.named_parameters() if 'arch' in name], lr=0.01)

# Search Epoch
num_epochs = 1000

# Iteartively optimize architecture parameters and network weights every 'search_freq' epochs
search_freq = 20

# The factor to balance performance (CE Loss or MSE Loss) and FLOPs
# Larger flops_balance_factor leads to a faster network with worse performance
flops_balance_factor = 0.8 

# The tempreature is decayed by 'temp_anneal_factor' every 'temp_anneal_freq' epochs
# Larger temp leads to a gumbel weight that is more close to 1-hot distribution  
temp = 5
temp_anneal_factor = 0.95
temp_anneal_freq = num_epochs/25 # The temperatur will decay 25 times during the search

for epoch in range(num_epochs):
    if epoch % temp_anneal_freq == 0:
        temp = temp * temp_anneal_factor
    
    # Forward pass
    outputs, flops_loss = net(x_train, temp)
    CE_loss = criterion(outputs, y_train)
    loss = (1 - flops_balance_factor) * CE_loss + flops_balance_factor * flops_loss
    
    if int(epoch/search_freq)%2==0:
        optimizer_net.zero_grad()
        loss.backward()
        optimizer_net.step()
    else:
        optimizer_arch.zero_grad()
        loss.backward()
        optimizer_arch.step()
    
    selected_channel_id = np.argmax(net.arch_params.data.numpy())
    selected_channel = hidden_size_choices[selected_channel_id]

    with torch.no_grad():
        _, predicted = torch.max(outputs, 1)
        total = y_train.size(0)
        correct = (predicted == y_train).sum().item()
        accuracy = correct / total
        if (epoch+1) % 10 == 0:
            
            print('Epoch [{}/{}], Overall_Loss: {:.4f}, CE_Loss: {:.4f}, Flops_Loss: {:.4f}, Accuracy: {:.2f}%, Seleted Channel {}'
                  .format(epoch+1, num_epochs, loss.item(), CE_loss.item(),
                          flops_loss.item(), accuracy * 100, selected_channel))

Epoch [10/1000], Overall_Loss: 1.9738, CE_Loss: 7.7000, Flops_Loss: 0.5423, Accuracy: 0.14%, Seleted Channel 100
Epoch [20/1000], Overall_Loss: 1.9717, CE_Loss: 7.6919, Flops_Loss: 0.5417, Accuracy: 0.23%, Seleted Channel 100
Epoch [30/1000], Overall_Loss: 1.9773, CE_Loss: 7.6909, Flops_Loss: 0.5489, Accuracy: 0.23%, Seleted Channel 330
Epoch [40/1000], Overall_Loss: 1.9776, CE_Loss: 7.6908, Flops_Loss: 0.5493, Accuracy: 0.23%, Seleted Channel 340
Epoch [50/1000], Overall_Loss: 1.9689, CE_Loss: 7.6825, Flops_Loss: 0.5405, Accuracy: 0.81%, Seleted Channel 340
Epoch [60/1000], Overall_Loss: 1.9890, CE_Loss: 7.6695, Flops_Loss: 0.5689, Accuracy: 2.30%, Seleted Channel 340
Epoch [70/1000], Overall_Loss: 1.9669, CE_Loss: 7.6700, Flops_Loss: 0.5411, Accuracy: 2.52%, Seleted Channel 130
Epoch [80/1000], Overall_Loss: 1.9578, CE_Loss: 7.6708, Flops_Loss: 0.5295, Accuracy: 2.39%, Seleted Channel 210
Epoch [90/1000], Overall_Loss: 1.9552, CE_Loss: 7.6593, Flops_Loss: 0.5292, Accuracy: 3.47%, Sel

Epoch [730/1000], Overall_Loss: 1.6053, CE_Loss: 7.0879, Flops_Loss: 0.2346, Accuracy: 85.96%, Seleted Channel 120
Epoch [740/1000], Overall_Loss: 1.6330, CE_Loss: 6.9114, Flops_Loss: 0.3134, Accuracy: 94.28%, Seleted Channel 120
Epoch [750/1000], Overall_Loss: 1.6100, CE_Loss: 7.0205, Flops_Loss: 0.2574, Accuracy: 93.61%, Seleted Channel 120
Epoch [760/1000], Overall_Loss: 1.5969, CE_Loss: 7.0751, Flops_Loss: 0.2274, Accuracy: 88.34%, Seleted Channel 120
Epoch [770/1000], Overall_Loss: 1.5904, CE_Loss: 7.0998, Flops_Loss: 0.2131, Accuracy: 91.18%, Seleted Channel 120
Epoch [780/1000], Overall_Loss: 1.6023, CE_Loss: 6.9293, Flops_Loss: 0.2705, Accuracy: 93.25%, Seleted Channel 120
Epoch [790/1000], Overall_Loss: 1.5916, CE_Loss: 6.9875, Flops_Loss: 0.2426, Accuracy: 92.53%, Seleted Channel 120
Epoch [800/1000], Overall_Loss: 1.5849, CE_Loss: 7.0581, Flops_Loss: 0.2166, Accuracy: 91.72%, Seleted Channel 120
Epoch [810/1000], Overall_Loss: 1.5911, CE_Loss: 6.9281, Flops_Loss: 0.2569, Acc