In [1]:
# Load data
import numpy as np

from data_utils import (
    get_embeddings, 
    synthesize_database, 
    synthesize_simple_database,
    get_passenger_database,
    check_passenger_exist,
    create_simple_trainset
)

is_simple_data=True
# Airport dataset
embed_original, indices = get_embeddings() # Considered the query pictures
if is_simple_data:
    embed_data, id_data, location_data = synthesize_simple_database(embed_original)
    date_data = None
else:
    embed_data, id_data, date_data, location_data = synthesize_database(embed_original)

embed_data = np.stack(embed_data)

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim

x_train = np.array(embed_data[:1000, :])
y_train = np.array(location_data[:1000])

x_train = torch.FloatTensor(x_train)
print(x_train.shape)

y_train[y_train==2] = 0 # Make sure the class index start from 0 (0,1 in this example)
y_train = torch.LongTensor(y_train)
print(y_train.shape)


torch.Size([1000, 128])
torch.Size([1000])


In [28]:
def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> torch.Tensor:
    r"""
    Args:
      logits: `[..., num_features]` unnormalized log probabilities
      tau: non-negative scalar temperature
      hard: if ``True``, the returned samples will be discretized as one-hot vectors,
            but will be differentiated as if it is the soft sample in autograd
      dim (int): A dimension along which softmax will be computed. Default: -1.

    Returns:
      Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
      If ``hard=True``, the returned samples will be one-hot, otherwise they will
      be probability distributions that sum to 1 across `dim`.
    """
    gumbels = (
        -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
    )  # ~Gumbel(0,1)

    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

In [29]:
def get_channel_mask(hidden_size_choices):
    max_hidden_size = max(hidden_size_choices)
    num_choices = len(hidden_size_choices)
    masks = torch.zeros(max_hidden_size, num_choices)
    for i in range(num_choices):
        masks[:hidden_size_choices[i], i]=1
    return masks
    
def get_flops_choices(input_size, hidden_size_choices, num_classes):
    flops = []
    for hidden_size in hidden_size_choices:
        flops.append(2*hidden_size*input_size + 2*hidden_size*num_classes)
    flops = np.array(flops)
    return flops
    
class SuperNet(nn.Module):
    def __init__(self, input_size, hidden_size_choices, num_classes):
        super(SuperNet, self).__init__()
        
        max_hidden_size = max(hidden_size_choices)
        num_choices = len(hidden_size_choices)
        
        self.arch_params = torch.nn.Parameter(torch.ones(num_choices), requires_grad=True)
        self.masks = get_channel_mask(hidden_size_choices)
        self.flops_choices = get_flops_choices(input_size, hidden_size_choices, num_classes)
        self.flops_choices_normalized = torch.FloatTensor(self.flops_choices / np.max(self.flops_choices))
        
        self.fc1 = nn.Linear(input_size, max_hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(max_hidden_size, num_classes)
        
    def forward(self, x, temperature):
        out = self.fc1(x)
        out = self.relu(out)
        # print(self.arch_params)
        gumbel_weights = gumbel_softmax(self.arch_params, tau=temperature, hard=False)
        # print(gumbel_weights)
        mask = torch.multiply(self.masks, gumbel_weights)
        mask = torch.sum(mask, dim=-1)
        out = torch.multiply(out, mask)
        
        out = self.fc2(out)
        flops_loss = self._get_flops_loss(gumbel_weights)
        return out, flops_loss
    
    def _get_flops_loss(self, gumbel_weights):
        return torch.matmul(self.flops_choices_normalized, gumbel_weights)

## flops_balance_factor = 0.2
The trade-off parameter between performance (CE_Loss or MSE_Loss) and speed (Flops)

Larger flops_balance_factor leads to a faster network with worse performance

In [32]:
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize the neural network
in_dim = x_train.shape[1]
hidden_size_choices = list(range(100,1000,10))
num_classes =  2

net = SuperNet(in_dim, hidden_size_choices, num_classes)

# Set up the loss function and optimizer
criterion = nn.CrossEntropyLoss()

net_weight_lr = 0.0001
arch_lr = 0.01

optimizer_net = optim.Adam([p for name, p in net.named_parameters() if 'arch' not in name], lr=net_weight_lr)
optimizer_arch = optim.Adam([p for name, p in net.named_parameters() if 'arch' in name], lr=arch_lr)

# Search Epoch
num_epochs = 5000

# Iteartively optimize architecture parameters and network weights every 'search_freq' epochs
search_freq = 20

# The factor to balance performance (CE Loss or MSE Loss) and FLOPs
# Larger flops_balance_factor leads to a faster network with worse performance
flops_balance_factor = 0.2

# The tempreature is decayed by 'temp_anneal_factor' every 'temp_anneal_freq' epochs
# Larger temp leads to a gumbel weight that is more close to 1-hot distribution  
temp = 5
temp_anneal_factor = 0.95
temp_anneal_freq = num_epochs/25 # The temperatur will decay 25 times during the search

for epoch in range(num_epochs):
    if epoch % temp_anneal_freq == 0:
        temp = temp * temp_anneal_factor
    
    # Forward pass
    outputs, flops_loss = net(x_train, temp)
    CE_loss = criterion(outputs, y_train)
    loss = (1 - flops_balance_factor) * CE_loss + flops_balance_factor * flops_loss
    
    if int(epoch/search_freq)%2==0:
        optimizer_net.zero_grad()
        loss.backward()
        optimizer_net.step()
    else:
        optimizer_arch.zero_grad()
        loss.backward()
        optimizer_arch.step()
    
    selected_channel_id = np.argmax(net.arch_params.data.numpy())
    selected_channel = hidden_size_choices[selected_channel_id]

    with torch.no_grad():
        _, predicted = torch.max(outputs, 1)
        total = y_train.size(0)
        correct = (predicted == y_train).sum().item()
        accuracy = correct / total
        if (epoch+1) % 10 == 0:
            
            print('Epoch [{}/{}], Overall_Loss: {:.4f}, CE_Loss: {:.4f}, Flops_Loss: {:.4f}, Accuracy: {:.2f}%, Seleted Channel {}'
                  .format(epoch+1, num_epochs, loss.item(), CE_loss.item(),
                          flops_loss.item(), accuracy * 100, selected_channel))

Epoch [10/5000], Overall_Loss: 0.6631, CE_Loss: 0.6925, Flops_Loss: 0.5455, Accuracy: 53.00%, Seleted Channel 100
Epoch [20/5000], Overall_Loss: 0.6643, CE_Loss: 0.6918, Flops_Loss: 0.5542, Accuracy: 56.30%, Seleted Channel 100
Epoch [30/5000], Overall_Loss: 0.6619, CE_Loss: 0.6917, Flops_Loss: 0.5425, Accuracy: 56.70%, Seleted Channel 200
Epoch [40/5000], Overall_Loss: 0.6596, CE_Loss: 0.6918, Flops_Loss: 0.5311, Accuracy: 56.00%, Seleted Channel 170
Epoch [50/5000], Overall_Loss: 0.6655, CE_Loss: 0.6911, Flops_Loss: 0.5633, Accuracy: 57.20%, Seleted Channel 170
Epoch [60/5000], Overall_Loss: 0.6584, CE_Loss: 0.6905, Flops_Loss: 0.5301, Accuracy: 58.40%, Seleted Channel 170
Epoch [70/5000], Overall_Loss: 0.6591, CE_Loss: 0.6904, Flops_Loss: 0.5336, Accuracy: 58.30%, Seleted Channel 120
Epoch [80/5000], Overall_Loss: 0.6567, CE_Loss: 0.6905, Flops_Loss: 0.5215, Accuracy: 57.90%, Seleted Channel 120
Epoch [90/5000], Overall_Loss: 0.6567, CE_Loss: 0.6898, Flops_Loss: 0.5246, Accuracy: 57

Epoch [730/5000], Overall_Loss: 0.5968, CE_Loss: 0.6479, Flops_Loss: 0.3925, Accuracy: 64.30%, Seleted Channel 160
Epoch [740/5000], Overall_Loss: 0.5926, CE_Loss: 0.6480, Flops_Loss: 0.3713, Accuracy: 64.80%, Seleted Channel 160
Epoch [750/5000], Overall_Loss: 0.5946, CE_Loss: 0.6467, Flops_Loss: 0.3865, Accuracy: 64.80%, Seleted Channel 160
Epoch [760/5000], Overall_Loss: 0.5922, CE_Loss: 0.6479, Flops_Loss: 0.3696, Accuracy: 64.70%, Seleted Channel 160
Epoch [770/5000], Overall_Loss: 0.5941, CE_Loss: 0.6448, Flops_Loss: 0.3914, Accuracy: 64.90%, Seleted Channel 160
Epoch [780/5000], Overall_Loss: 0.5909, CE_Loss: 0.6442, Flops_Loss: 0.3777, Accuracy: 64.90%, Seleted Channel 160
Epoch [790/5000], Overall_Loss: 0.5902, CE_Loss: 0.6443, Flops_Loss: 0.3739, Accuracy: 65.10%, Seleted Channel 160
Epoch [800/5000], Overall_Loss: 0.5914, CE_Loss: 0.6438, Flops_Loss: 0.3819, Accuracy: 65.10%, Seleted Channel 160
Epoch [810/5000], Overall_Loss: 0.5877, CE_Loss: 0.6439, Flops_Loss: 0.3632, Acc

Epoch [1470/5000], Overall_Loss: 0.5365, CE_Loss: 0.6091, Flops_Loss: 0.2463, Accuracy: 70.20%, Seleted Channel 120
Epoch [1480/5000], Overall_Loss: 0.5376, CE_Loss: 0.6059, Flops_Loss: 0.2642, Accuracy: 70.20%, Seleted Channel 120
Epoch [1490/5000], Overall_Loss: 0.5370, CE_Loss: 0.6035, Flops_Loss: 0.2708, Accuracy: 70.40%, Seleted Channel 120
Epoch [1500/5000], Overall_Loss: 0.5340, CE_Loss: 0.6085, Flops_Loss: 0.2359, Accuracy: 70.80%, Seleted Channel 120
Epoch [1510/5000], Overall_Loss: 0.5343, CE_Loss: 0.6047, Flops_Loss: 0.2523, Accuracy: 70.70%, Seleted Channel 120
Epoch [1520/5000], Overall_Loss: 0.5340, CE_Loss: 0.6054, Flops_Loss: 0.2484, Accuracy: 70.50%, Seleted Channel 120
Epoch [1530/5000], Overall_Loss: 0.5331, CE_Loss: 0.6038, Flops_Loss: 0.2501, Accuracy: 71.20%, Seleted Channel 120
Epoch [1540/5000], Overall_Loss: 0.5317, CE_Loss: 0.6037, Flops_Loss: 0.2440, Accuracy: 71.40%, Seleted Channel 120
Epoch [1550/5000], Overall_Loss: 0.5321, CE_Loss: 0.6003, Flops_Loss: 0.

Epoch [2190/5000], Overall_Loss: 0.4894, CE_Loss: 0.5494, Flops_Loss: 0.2495, Accuracy: 76.90%, Seleted Channel 230
Epoch [2200/5000], Overall_Loss: 0.4902, CE_Loss: 0.5545, Flops_Loss: 0.2332, Accuracy: 77.50%, Seleted Channel 230
Epoch [2210/5000], Overall_Loss: 0.4883, CE_Loss: 0.5514, Flops_Loss: 0.2359, Accuracy: 77.40%, Seleted Channel 230
Epoch [2220/5000], Overall_Loss: 0.4866, CE_Loss: 0.5482, Flops_Loss: 0.2402, Accuracy: 77.60%, Seleted Channel 230
Epoch [2230/5000], Overall_Loss: 0.4857, CE_Loss: 0.5475, Flops_Loss: 0.2388, Accuracy: 77.90%, Seleted Channel 230
Epoch [2240/5000], Overall_Loss: 0.4866, CE_Loss: 0.5499, Flops_Loss: 0.2333, Accuracy: 78.10%, Seleted Channel 230
Epoch [2250/5000], Overall_Loss: 0.4853, CE_Loss: 0.5469, Flops_Loss: 0.2391, Accuracy: 78.30%, Seleted Channel 230
Epoch [2260/5000], Overall_Loss: 0.4841, CE_Loss: 0.5466, Flops_Loss: 0.2340, Accuracy: 78.30%, Seleted Channel 230
Epoch [2270/5000], Overall_Loss: 0.4831, CE_Loss: 0.5420, Flops_Loss: 0.

Epoch [2960/5000], Overall_Loss: 0.4229, CE_Loss: 0.4620, Flops_Loss: 0.2665, Accuracy: 87.10%, Seleted Channel 260
Epoch [2970/5000], Overall_Loss: 0.4217, CE_Loss: 0.4602, Flops_Loss: 0.2677, Accuracy: 87.30%, Seleted Channel 260
Epoch [2980/5000], Overall_Loss: 0.4197, CE_Loss: 0.4581, Flops_Loss: 0.2660, Accuracy: 87.30%, Seleted Channel 260
Epoch [2990/5000], Overall_Loss: 0.4197, CE_Loss: 0.4573, Flops_Loss: 0.2693, Accuracy: 87.40%, Seleted Channel 260
Epoch [3000/5000], Overall_Loss: 0.4184, CE_Loss: 0.4543, Flops_Loss: 0.2750, Accuracy: 87.30%, Seleted Channel 260
Epoch [3010/5000], Overall_Loss: 0.4160, CE_Loss: 0.4520, Flops_Loss: 0.2718, Accuracy: 87.60%, Seleted Channel 260
Epoch [3020/5000], Overall_Loss: 0.4146, CE_Loss: 0.4499, Flops_Loss: 0.2736, Accuracy: 87.60%, Seleted Channel 260
Epoch [3030/5000], Overall_Loss: 0.4150, CE_Loss: 0.4504, Flops_Loss: 0.2730, Accuracy: 87.50%, Seleted Channel 260
Epoch [3040/5000], Overall_Loss: 0.4144, CE_Loss: 0.4497, Flops_Loss: 0.

Epoch [3730/5000], Overall_Loss: 0.3490, CE_Loss: 0.3607, Flops_Loss: 0.3022, Accuracy: 94.40%, Seleted Channel 300
Epoch [3740/5000], Overall_Loss: 0.3479, CE_Loss: 0.3601, Flops_Loss: 0.2995, Accuracy: 94.60%, Seleted Channel 300
Epoch [3750/5000], Overall_Loss: 0.3478, CE_Loss: 0.3597, Flops_Loss: 0.3001, Accuracy: 94.60%, Seleted Channel 300
Epoch [3760/5000], Overall_Loss: 0.3479, CE_Loss: 0.3600, Flops_Loss: 0.2997, Accuracy: 94.60%, Seleted Channel 300
Epoch [3770/5000], Overall_Loss: 0.3456, CE_Loss: 0.3564, Flops_Loss: 0.3023, Accuracy: 94.50%, Seleted Channel 300
Epoch [3780/5000], Overall_Loss: 0.3442, CE_Loss: 0.3550, Flops_Loss: 0.3011, Accuracy: 94.80%, Seleted Channel 300
Epoch [3790/5000], Overall_Loss: 0.3441, CE_Loss: 0.3549, Flops_Loss: 0.3008, Accuracy: 94.70%, Seleted Channel 300
Epoch [3800/5000], Overall_Loss: 0.3443, CE_Loss: 0.3553, Flops_Loss: 0.3004, Accuracy: 94.60%, Seleted Channel 300
Epoch [3810/5000], Overall_Loss: 0.3429, CE_Loss: 0.3538, Flops_Loss: 0.

Epoch [4500/5000], Overall_Loss: 0.2895, CE_Loss: 0.2861, Flops_Loss: 0.3030, Accuracy: 97.10%, Seleted Channel 300
Epoch [4510/5000], Overall_Loss: 0.2893, CE_Loss: 0.2859, Flops_Loss: 0.3030, Accuracy: 97.10%, Seleted Channel 300
Epoch [4520/5000], Overall_Loss: 0.2895, CE_Loss: 0.2863, Flops_Loss: 0.3025, Accuracy: 97.00%, Seleted Channel 300
Epoch [4530/5000], Overall_Loss: 0.2883, CE_Loss: 0.2848, Flops_Loss: 0.3025, Accuracy: 97.10%, Seleted Channel 300
Epoch [4540/5000], Overall_Loss: 0.2868, CE_Loss: 0.2828, Flops_Loss: 0.3029, Accuracy: 97.20%, Seleted Channel 300
Epoch [4550/5000], Overall_Loss: 0.2867, CE_Loss: 0.2826, Flops_Loss: 0.3030, Accuracy: 97.20%, Seleted Channel 300
Epoch [4560/5000], Overall_Loss: 0.2869, CE_Loss: 0.2829, Flops_Loss: 0.3026, Accuracy: 97.20%, Seleted Channel 300
Epoch [4570/5000], Overall_Loss: 0.2854, CE_Loss: 0.2811, Flops_Loss: 0.3030, Accuracy: 97.20%, Seleted Channel 300
Epoch [4580/5000], Overall_Loss: 0.2842, CE_Loss: 0.2796, Flops_Loss: 0.

## flops_balance_factor = 0.1
The trade-off parameter between performance (CE_Loss or MSE_Loss) and speed (Flops)

Larger flops_balance_factor leads to a faster network with worse performance

In [34]:
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize the neural network
in_dim = x_train.shape[1]
hidden_size_choices = list(range(100,1000,10))
num_classes =  2
net = SuperNet(in_dim, hidden_size_choices, num_classes)

# Set up the loss function and optimizer
criterion = nn.CrossEntropyLoss()

net_weight_lr = 0.0001
arch_lr = 0.01

optimizer_net = optim.Adam([p for name, p in net.named_parameters() if 'arch' not in name], lr=net_weight_lr)
optimizer_arch = optim.Adam([p for name, p in net.named_parameters() if 'arch' in name], lr=arch_lr)

# Search Epoch
num_epochs = 5000

# Iteartively optimize architecture parameters and network weights every 'search_freq' epochs
search_freq = 20

# The factor to balance performance (CE Loss or MSE Loss) and FLOPs
# Larger flops_balance_factor leads to a faster network with worse performance
flops_balance_factor = 0.1

# The tempreature is decayed by 'temp_anneal_factor' every 'temp_anneal_freq' epochs
# Larger temp leads to a gumbel weight that is more close to 1-hot distribution  
temp = 5
temp_anneal_factor = 0.95
temp_anneal_freq = num_epochs/25 # The temperatur will decay 25 times during the search

for epoch in range(num_epochs):
    if epoch % temp_anneal_freq == 0:
        temp = temp * temp_anneal_factor
    
    # Forward pass
    outputs, flops_loss = net(x_train, temp)
    CE_loss = criterion(outputs, y_train)
    loss = (1 - flops_balance_factor) * CE_loss + flops_balance_factor * flops_loss
    
    if int(epoch/search_freq)%2==0:
        optimizer_net.zero_grad()
        loss.backward()
        optimizer_net.step()
    else:
        optimizer_arch.zero_grad()
        loss.backward()
        optimizer_arch.step()
    
    selected_channel_id = np.argmax(net.arch_params.data.numpy())
    selected_channel = hidden_size_choices[selected_channel_id]

    with torch.no_grad():
        _, predicted = torch.max(outputs, 1)
        total = y_train.size(0)
        correct = (predicted == y_train).sum().item()
        accuracy = correct / total
        if (epoch+1) % 10 == 0:
            
            print('Epoch [{}/{}], Overall_Loss: {:.4f}, CE_Loss: {:.4f}, Flops_Loss: {:.4f}, Accuracy: {:.2f}%, Seleted Channel {}'
                  .format(epoch+1, num_epochs, loss.item(), CE_loss.item(),
                          flops_loss.item(), accuracy * 100, selected_channel))

Epoch [10/5000], Overall_Loss: 0.6783, CE_Loss: 0.6925, Flops_Loss: 0.5503, Accuracy: 52.80%, Seleted Channel 100
Epoch [20/5000], Overall_Loss: 0.6793, CE_Loss: 0.6918, Flops_Loss: 0.5671, Accuracy: 57.20%, Seleted Channel 100
Epoch [30/5000], Overall_Loss: 0.6766, CE_Loss: 0.6918, Flops_Loss: 0.5403, Accuracy: 55.80%, Seleted Channel 210
Epoch [40/5000], Overall_Loss: 0.6769, CE_Loss: 0.6918, Flops_Loss: 0.5428, Accuracy: 55.40%, Seleted Channel 330
Epoch [50/5000], Overall_Loss: 0.6759, CE_Loss: 0.6912, Flops_Loss: 0.5383, Accuracy: 57.50%, Seleted Channel 330
Epoch [60/5000], Overall_Loss: 0.6756, CE_Loss: 0.6904, Flops_Loss: 0.5425, Accuracy: 56.90%, Seleted Channel 330
Epoch [70/5000], Overall_Loss: 0.6749, CE_Loss: 0.6904, Flops_Loss: 0.5357, Accuracy: 57.00%, Seleted Channel 330
Epoch [80/5000], Overall_Loss: 0.6742, CE_Loss: 0.6904, Flops_Loss: 0.5285, Accuracy: 56.90%, Seleted Channel 330
Epoch [90/5000], Overall_Loss: 0.6745, CE_Loss: 0.6897, Flops_Loss: 0.5383, Accuracy: 57

Epoch [780/5000], Overall_Loss: 0.6167, CE_Loss: 0.6419, Flops_Loss: 0.3895, Accuracy: 65.60%, Seleted Channel 200
Epoch [790/5000], Overall_Loss: 0.6169, CE_Loss: 0.6406, Flops_Loss: 0.4035, Accuracy: 65.90%, Seleted Channel 200
Epoch [800/5000], Overall_Loss: 0.6166, CE_Loss: 0.6414, Flops_Loss: 0.3934, Accuracy: 65.90%, Seleted Channel 200
Epoch [810/5000], Overall_Loss: 0.6153, CE_Loss: 0.6399, Flops_Loss: 0.3939, Accuracy: 65.80%, Seleted Channel 200
Epoch [820/5000], Overall_Loss: 0.6135, CE_Loss: 0.6391, Flops_Loss: 0.3835, Accuracy: 66.10%, Seleted Channel 200
Epoch [830/5000], Overall_Loss: 0.6131, CE_Loss: 0.6401, Flops_Loss: 0.3696, Accuracy: 65.80%, Seleted Channel 200
Epoch [840/5000], Overall_Loss: 0.6133, CE_Loss: 0.6390, Flops_Loss: 0.3818, Accuracy: 66.10%, Seleted Channel 200
Epoch [850/5000], Overall_Loss: 0.6119, CE_Loss: 0.6382, Flops_Loss: 0.3751, Accuracy: 66.40%, Seleted Channel 200
Epoch [860/5000], Overall_Loss: 0.6107, CE_Loss: 0.6347, Flops_Loss: 0.3950, Acc

Epoch [1550/5000], Overall_Loss: 0.5456, CE_Loss: 0.5595, Flops_Loss: 0.4209, Accuracy: 76.50%, Seleted Channel 390
Epoch [1560/5000], Overall_Loss: 0.5453, CE_Loss: 0.5585, Flops_Loss: 0.4265, Accuracy: 76.50%, Seleted Channel 390
Epoch [1570/5000], Overall_Loss: 0.5443, CE_Loss: 0.5588, Flops_Loss: 0.4132, Accuracy: 76.30%, Seleted Channel 390
Epoch [1580/5000], Overall_Loss: 0.5416, CE_Loss: 0.5552, Flops_Loss: 0.4193, Accuracy: 76.70%, Seleted Channel 390
Epoch [1590/5000], Overall_Loss: 0.5403, CE_Loss: 0.5528, Flops_Loss: 0.4279, Accuracy: 76.50%, Seleted Channel 390
Epoch [1600/5000], Overall_Loss: 0.5417, CE_Loss: 0.5563, Flops_Loss: 0.4107, Accuracy: 76.60%, Seleted Channel 390
Epoch [1610/5000], Overall_Loss: 0.5392, CE_Loss: 0.5530, Flops_Loss: 0.4148, Accuracy: 76.90%, Seleted Channel 390
Epoch [1620/5000], Overall_Loss: 0.5362, CE_Loss: 0.5490, Flops_Loss: 0.4213, Accuracy: 77.10%, Seleted Channel 390
Epoch [1630/5000], Overall_Loss: 0.5356, CE_Loss: 0.5478, Flops_Loss: 0.

Epoch [2310/5000], Overall_Loss: 0.4219, CE_Loss: 0.4028, Flops_Loss: 0.5942, Accuracy: 91.30%, Seleted Channel 500
Epoch [2320/5000], Overall_Loss: 0.4205, CE_Loss: 0.3999, Flops_Loss: 0.6060, Accuracy: 91.40%, Seleted Channel 500
Epoch [2330/5000], Overall_Loss: 0.4179, CE_Loss: 0.3976, Flops_Loss: 0.6010, Accuracy: 91.70%, Seleted Channel 500
Epoch [2340/5000], Overall_Loss: 0.4131, CE_Loss: 0.3907, Flops_Loss: 0.6146, Accuracy: 91.90%, Seleted Channel 500
Epoch [2350/5000], Overall_Loss: 0.4132, CE_Loss: 0.3909, Flops_Loss: 0.6137, Accuracy: 92.00%, Seleted Channel 500
Epoch [2360/5000], Overall_Loss: 0.4136, CE_Loss: 0.3920, Flops_Loss: 0.6075, Accuracy: 92.00%, Seleted Channel 500
Epoch [2370/5000], Overall_Loss: 0.4111, CE_Loss: 0.3896, Flops_Loss: 0.6038, Accuracy: 92.50%, Seleted Channel 500
Epoch [2380/5000], Overall_Loss: 0.4056, CE_Loss: 0.3816, Flops_Loss: 0.6218, Accuracy: 92.70%, Seleted Channel 500
Epoch [2390/5000], Overall_Loss: 0.4046, CE_Loss: 0.3801, Flops_Loss: 0.

Epoch [3040/5000], Overall_Loss: 0.2729, CE_Loss: 0.2056, Flops_Loss: 0.8785, Accuracy: 98.40%, Seleted Channel 940
Epoch [3050/5000], Overall_Loss: 0.2674, CE_Loss: 0.1963, Flops_Loss: 0.9067, Accuracy: 98.50%, Seleted Channel 940
Epoch [3060/5000], Overall_Loss: 0.2667, CE_Loss: 0.1970, Flops_Loss: 0.8934, Accuracy: 98.50%, Seleted Channel 940
Epoch [3070/5000], Overall_Loss: 0.2685, CE_Loss: 0.2010, Flops_Loss: 0.8761, Accuracy: 98.40%, Seleted Channel 940
Epoch [3080/5000], Overall_Loss: 0.2680, CE_Loss: 0.1998, Flops_Loss: 0.8821, Accuracy: 98.40%, Seleted Channel 940
Epoch [3090/5000], Overall_Loss: 0.2649, CE_Loss: 0.1955, Flops_Loss: 0.8892, Accuracy: 98.40%, Seleted Channel 940
Epoch [3100/5000], Overall_Loss: 0.2611, CE_Loss: 0.1898, Flops_Loss: 0.9022, Accuracy: 98.60%, Seleted Channel 940
Epoch [3110/5000], Overall_Loss: 0.2606, CE_Loss: 0.1893, Flops_Loss: 0.9025, Accuracy: 98.70%, Seleted Channel 940
Epoch [3120/5000], Overall_Loss: 0.2632, CE_Loss: 0.1944, Flops_Loss: 0.

Epoch [3760/5000], Overall_Loss: 0.2026, CE_Loss: 0.1185, Flops_Loss: 0.9587, Accuracy: 99.90%, Seleted Channel 970
Epoch [3770/5000], Overall_Loss: 0.2004, CE_Loss: 0.1149, Flops_Loss: 0.9698, Accuracy: 99.90%, Seleted Channel 970
Epoch [3780/5000], Overall_Loss: 0.2001, CE_Loss: 0.1154, Flops_Loss: 0.9618, Accuracy: 99.90%, Seleted Channel 970
Epoch [3790/5000], Overall_Loss: 0.1993, CE_Loss: 0.1138, Flops_Loss: 0.9684, Accuracy: 99.90%, Seleted Channel 970
Epoch [3800/5000], Overall_Loss: 0.1999, CE_Loss: 0.1152, Flops_Loss: 0.9618, Accuracy: 99.80%, Seleted Channel 970
Epoch [3810/5000], Overall_Loss: 0.1996, CE_Loss: 0.1156, Flops_Loss: 0.9553, Accuracy: 99.90%, Seleted Channel 970
Epoch [3820/5000], Overall_Loss: 0.1983, CE_Loss: 0.1139, Flops_Loss: 0.9575, Accuracy: 99.90%, Seleted Channel 970
Epoch [3830/5000], Overall_Loss: 0.1977, CE_Loss: 0.1128, Flops_Loss: 0.9612, Accuracy: 99.90%, Seleted Channel 970
Epoch [3840/5000], Overall_Loss: 0.1985, CE_Loss: 0.1145, Flops_Loss: 0.

Epoch [4480/5000], Overall_Loss: 0.1687, CE_Loss: 0.0795, Flops_Loss: 0.9719, Accuracy: 99.90%, Seleted Channel 970
Epoch [4490/5000], Overall_Loss: 0.1678, CE_Loss: 0.0784, Flops_Loss: 0.9725, Accuracy: 99.90%, Seleted Channel 970
Epoch [4500/5000], Overall_Loss: 0.1664, CE_Loss: 0.0761, Flops_Loss: 0.9797, Accuracy: 99.90%, Seleted Channel 970
Epoch [4510/5000], Overall_Loss: 0.1664, CE_Loss: 0.0760, Flops_Loss: 0.9794, Accuracy: 99.90%, Seleted Channel 970
Epoch [4520/5000], Overall_Loss: 0.1666, CE_Loss: 0.0766, Flops_Loss: 0.9768, Accuracy: 99.90%, Seleted Channel 970
Epoch [4530/5000], Overall_Loss: 0.1663, CE_Loss: 0.0767, Flops_Loss: 0.9735, Accuracy: 99.90%, Seleted Channel 970
Epoch [4540/5000], Overall_Loss: 0.1651, CE_Loss: 0.0746, Flops_Loss: 0.9793, Accuracy: 99.90%, Seleted Channel 970
Epoch [4550/5000], Overall_Loss: 0.1651, CE_Loss: 0.0748, Flops_Loss: 0.9782, Accuracy: 99.90%, Seleted Channel 970
Epoch [4560/5000], Overall_Loss: 0.1653, CE_Loss: 0.0751, Flops_Loss: 0.