In [1]:
# h5py 안 될 때
#!brew reinstall hdf5
#!export CPATH="/opt/homebrew/include/"
#!export HDF5_DIR=/opt/homebrew/
#!python3 -m pip install h5py

In [2]:
# For Colab
# from google.colab import drive
# drive.mount('/content/drive')
# %cd drive/MyDrive/cuisine-prediction/Hanseul/

In [58]:
import os, csv
import pickle
import math
import time
#from tqdm import tqdm
from tqdm.notebook import tqdm
from copy import deepcopy

import h5py
import numpy as np
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset, DataLoader
from sklearn.metrics import f1_score, top_k_accuracy_score

In [2]:
path_root = '../'
path_container = './Container/'

In [6]:
with open(path_container + 'id_cuisine_dict.pickle', 'rb') as f:
    id_cuisine_dict = pickle.load(f)
with open(path_container + 'cuisine_id_dict.pickle', 'rb') as f:
    cuisine_id_dict = pickle.load(f)
with open(path_container + 'id_ingredient_dict.pickle', 'rb') as f:
    id_ingredient_dict = pickle.load(f)
with open(path_container + 'ingredient_id_dict.pickle', 'rb') as f:
    ingredient_id_dict = pickle.load(f)

## Loading Datasets

In [10]:
class RecipeDataset(Dataset):
    def __init__(self, data_dir, test=False):
        self.data_dir = data_dir
        self.test = test
        self.classify, self.complete = False, False
        with h5py.File(data_dir, 'r') as data_file:
            self.bin_data = data_file['bin_data'][:]  # Size (num_recipes=23547, num_ingredients=6714)
            if 'label_class' in data_file.keys():
                self.classify = True
                self.label_class = data_file['label_class'][:]  
            if 'label_compl' in data_file.keys():
                self.complete = True
                self.label_compl = data_file['label_compl'][:]
        
        self.padding_idx = self.bin_data.shape[1]  # == num_ingredient == 6714
        self.max_num_ingredients_per_recipe = self.bin_data.sum(1).max()  # valid & test의 경우 65
        
        # (59나 65로) 고정된 길이의 row vector에 해당 recipe의 indices 넣고 나머지는 padding index로 채워넣기
        # self.int_data: Size (num_recipes=23547, self.max_num_ingredients_per_recipe=59 or 65)
        self.int_data = np.full((len(self.bin_data), self.max_num_ingredients_per_recipe), self.padding_idx) 
        for i, bin_recipe in enumerate(self.bin_data):
            recipe = np.arange(self.padding_idx)[bin_recipe==1]
            self.int_data[i][:len(recipe)] = recipe
        
    def __len__(self):
        return len(self.bin_data)

    def __getitem__(self, idx):
        bin_data = self.bin_data[idx]
        int_data = self.int_data[idx]
        if self.test:
            return bin_data, int_data
        
        if self.classify:
            label_class = self.label_class[idx]
            if not self.complete:
                return bin_data, int_data, label_class
        if self.complete:
            label_compl = self.label_compl[idx]
            if not self.classify:
                return bin_data, int_data, label_compl
            else:
                return bin_data, int_data, label_class, label_compl
            

In [12]:
dataset_name = ['train_class', 'train_compl', 'valid_class', 'valid_compl', 'test_class', 'test_compl']

recipe_datasets = {x: RecipeDataset(os.path.join(path_container, x), test='test' in x) for x in dataset_name}

In [13]:
count_single_ingredient_recipe = 0
list_single_ingredient_recipe = []
for i in range(len(recipe_datasets['train_class'])):
    _bd, _,_ = recipe_datasets['train_class'][i]
    if _bd.sum()<2:
        count_single_ingredient_recipe += 1
        list_single_ingredient_recipe.append(i)
print(count_single_ingredient_recipe)
print(list_single_ingredient_recipe)

19
[564, 1263, 4074, 4203, 4901, 5277, 5360, 6232, 7835, 10585, 10777, 12476, 13301, 13989, 15951, 17374, 18153, 19039, 20469]


## Model

In [29]:
## Building blocks of Set Transformers ##

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False, dropout=0):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.p = dropout
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Sequential(
            nn.Linear(dim_V, dim_V),
            nn.ReLU(),
            nn.Linear(dim_V, dim_V))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, Q, K, mask=None):
        Q = self.fc_q(Q)  # (batch, q_len, d_hid == dim_V)
        K, V = self.fc_k(K), self.fc_v(K) # (batch, k_len or v_len, d_hid)
        
        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)  # (batch * num_heads, q_len, d_hid // num_heads)
        K_ = torch.cat(K.split(dim_split, 2), 0)  # (batch * num_heads, c_len, d_hid // num_heads)
        V_ = torch.cat(V.split(dim_split, 2), 0)  # (batch * num_heads, v_len, d_hid // num_heads)
        
        energy = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)  # (batch * num_heads, q_len, k_len)
        if mask is not None:  # mask: (batch, 1, k_len)
            energy.masked_fill_(mask.repeat(self.num_heads, 1, 1), float('-inf'))
        A = torch.softmax(energy, 2)  # (batch * num_heads, q_len, k_len)
        
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)  # (batch, q_len, d_hid)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        _O = self.fc_o(O)
        if self.p > 0:
            _O = self.dropout(_O)
        O = O + _O 
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False, dropout=0.2):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None):
        return self.mab(X, X, mask=mask)

class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False, dropout=0.2):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln, dropout=dropout)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X, mask=mask)
        return self.mab1(X, H)

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False, dropout=0.2):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln, dropout=dropout)
        
    def forward(self, X, mask=None):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X, mask=mask)

In [30]:
def make_one_hot(x):
        """ Convert int_data into bin_data, if needed. """
        if type(x) is not torch.Tensor:
            x = torch.LongTensor(x)
        if x.dim() > 2:
            x = x.squeeze()
            if x.dim() > 2:
                return False
        elif x.dim() < 2:
            x = x.unsqueeze(0)
        return F.one_hot(x).sum(1)[:,:-1]

In [55]:
class ResBlock(nn.Module):
    """
    (BatchNorm - LeakyReLU - Linear) * 2.
    Apply skip connection only when dim_input == dim_output.
    """
    def __init__(self, dim_input, dim_hidden, dim_output, norm='bn', dropout=0.2):
        super(ResBlock, self).__init__()
        self.use_skip_conn = (dim_input == dim_output)
        if norm == 'bn':
            norm_layer = nn.BatchNorm1d
        elif norm == 'ln':
            norm_layer = nn.LayerNorm
        ff = []
        if norm in ['bn', 'ln']:
            ff.append(norm_layer(dim_input))
        ff.extend([nn.LeakyReLU(), nn.Linear(dim_input, dim_hidden)])
        if norm in ['bn', 'ln']:
            ff.append(norm_layer(dim_hidden))
        ff.extend([nn.LeakyReLU(), nn.Linear(dim_hidden, dim_output)])
        if dropout > 0:
            ff.append(nn.Dropout(dropout))
        self.ff = nn.Sequential(*ff)
        
    def forward(self, x, **kwargs):
        if self.use_skip_conn:
            return self.ff(x) + x
        return self.ff(x)

In [32]:
class Encoder(nn.Module):
    """ Create Feature Vector of Given Recipe. """
    def __init__(self, dim_embedding=256,
                 num_items=6714, 
                 num_inds=32,      # For ISAB
                 dim_hidden=128, 
                 num_heads=4, 
                 num_enc_layers=4,
                 ln=True,          # LayerNorm option
                 dropout=0.2,       # Dropout option
                 encoder_mode = 'set_transformer',
                 enc_pool_mode = 'set_transformer',
                ):
        super(Encoder, self).__init__()
        assert num_enc_layers % 2 == 0
        self.encoder_mode, self.enc_pool_mode = encoder_mode, enc_pool_mode
        self.padding_idx = num_items
        self.embedding = nn.Embedding(num_embeddings=num_items+1, embedding_dim=dim_embedding, padding_idx=-1)
        if encoder_mode == 'deep_sets':
            self.encoder = nn.ModuleList(
                [ResBlock(dim_embedding, dim_hidden, dim_hidden, norm='ln', dropout=dropout)] +
                [ResBlock(dim_hidden, dim_hidden, dim_hidden, norm='ln', dropout=dropout) for _ in range(num_enc_layers-1)])
        elif encoder_mode == 'set_transformer':
            self.encoder = nn.ModuleList(
                [ISAB(dim_embedding, dim_hidden, num_heads, num_inds, ln=ln, dropout=dropout)] +
                [ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln, dropout=dropout) for _ in range(num_enc_layers-1)])
        elif encoder_mode == 'fusion':
            self.encoder = nn.ModuleList(
                [ResBlock(dim_embedding, dim_hidden, dim_hidden, norm='ln', dropout=dropout)] +
                [ResBlock(dim_hidden, dim_hidden, dim_hidden, norm='ln', dropout=dropout) for _ in range(num_enc_layers//2-1)] +
                [ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln, dropout=dropout) for _ in range(num_enc_layers//2)])
        if enc_pool_mode == 'deep_sets':
            def sumpool(x,**kwargs):
                return torch.sum(x, 1)
            self.pooling = sumpool
        elif enc_pool_mode == 'set_transformer':
            self.pooling = PMA(dim_hidden, num_heads, 1, ln=ln, dropout=dropout)
        self.out = self.mask = None
        
    def forward(self, x, mask):
        # x(=recipes): (batch, max_num_ingredient=65) : int_data.
        self.out = self.embedding(x)  # (batch, max_num_ingredient=65, dim_embedding=256)
        # cf. embedding.weight: (num_items+1=6715, dim_embedding=256)
        for module in self.encoder:
            self.out = module(self.out, mask=mask) # (batch, max_num_ingredient=65, dim_hidden=128) : permutation-equivariant.
        return self.pooling(self.out, mask=mask) # (batch, 1, dim_hidden=128) : permutation-invariant.

In [33]:
class Classifier(nn.Module):
    def __init__(self, dim_hidden=128, dim_output=20, dropout=0.2, num_dec_layers=4):
        super(Classifier, self).__init__()
        self.classifier = nn.ModuleList(
                [ResBlock(dim_hidden, dim_hidden, dim_hidden, norm='bn', dropout=dropout) for _ in range(num_dec_layers-1)]
                +[ResBlock(dim_hidden, dim_hidden, dim_output, norm='bn', dropout=dropout)])
       
    def forward(self, x):
        # x: (batch, dim_hidden)
        assert x.ndim == 2
        self.out = x
        for module in self.classifier:
            self.out = module(self.out)
        return self.out  # (batch, dim_output)

In [34]:
class Completer(nn.Module):
    def __init__(self, dim_embedding=256,
                 num_items=6714, 
                 #num_inds=32,      # For ISAB
                 dim_hidden=128, 
                 num_heads=4, 
                 num_dec_layers=4,
                 ln=True,          # LayerNorm option
                 dropout=0.2,      # Dropout option
                 mode = 'simple',
                ):
        super(Completer, self).__init__()

        assert mode in ['simple','concat','concat_attention','attention']
        
        self.num_items = num_items
        self.mode = mode
        # feedforward layer to process recipe representation
        self.ff = nn.Sequential(
                ResBlock(dim_hidden, dim_hidden, dim_hidden, norm='bn', dropout=dropout),
                ResBlock(dim_hidden, dim_hidden, dim_hidden, norm='bn', dropout=dropout))
        # 'simple': no need of embedding weight
        if mode == 'simple': 
            self.decoder = nn.ModuleList(
                [ResBlock(dim_hidden, dim_hidden, dim_hidden, norm='bn', dropout=dropout) for _ in range(num_dec_layers-3)]
                +[ResBlock(dim_hidden, dim_hidden, num_items, norm='bn', dropout=dropout)])
        # NCF style completer
        elif mode == 'concat':
            self.emb_encoder = nn.Sequential(
                ResBlock(dim_embedding, dim_hidden, dim_hidden, norm='ln', dropout=dropout),
                ResBlock(dim_hidden, dim_hidden, dim_hidden, norm='ln', dropout=dropout))
            # decoding feedforward layer to deal with a concatenated feature (dim=2*dim_hidden)
            self.decoder = nn.ModuleList(
                [ResBlock(2*dim_hidden, dim_hidden, dim_hidden//2, norm='ln', dropout=dropout)]
                +[ResBlock(dim_hidden//2, dim_hidden//2, dim_hidden//2, norm='ln', dropout=dropout) for _ in range(num_dec_layers-4)]
                +[ResBlock(dim_hidden//2, dim_hidden//2, 1, norm='ln', dropout=dropout)])
        # completer based on concat + attention
        elif mode == 'concat_attention':
            self.emb_encoder = nn.Sequential(
                ResBlock(dim_embedding, dim_hidden, dim_hidden, norm='ln', dropout=dropout),
                ResBlock(dim_hidden, dim_hidden, dim_hidden, norm='ln', dropout=dropout))
            self.new_set_encoder = nn.ModuleList(
                [SAB(dim_hidden, dim_hidden, num_heads, ln=ln, dropout=dropout) for _ in range(num_dec_layers//2-1)])
            self.decoder= nn.ModuleList(
                [ResBlock(dim_hidden, dim_hidden, dim_hidden, norm='ln', dropout=dropout) for _ in range(num_dec_layers-num_dec_layers//2-2)]
                +[ResBlock(dim_hidden, dim_hidden, 1, norm='ln', dropout=dropout)])
        # completer based on attention
        elif mode == 'attention':
            pass
        
        self.out = self.emb_feature = None
        
    def forward(self, x, embedding_weight):
        # x: (batch, 1, dim_hidden) / embedding_weight: (num_items, dim_embedding)        
        self.out = self.ff(x.squeeze(1))  # (batch, dim_hidden=128)

        if self.mode == 'simple':
            for module in self.decoder:
                self.out = module(self.out)
            return self.out # (batch, num_items=6714)
        else:
            batch_size, num_items = x.size(0), embedding_weight.size(0)
            if self.mode == 'concat':
                self.emb_feature = self.emb_encoder(embedding_weight)  # (num_items, dim_hidden)
                self.out = torch.cat([self.out.unsqueeze(1).expand(-1,num_items,-1),
                    self.emb_feature.unsqueeze(0).expand(batch_size,-1,-1)], dim=2)  # (batch, num_items, 2*dim_hidden)
                for module in self.decoder:
                    self.out = module(self.out)  # (batch, num_items, 1)
                return self.out.squeeze(-1)  # (batch, num_items)
            elif self.mode == 'concat_attention':
                self.emb_feature = self.emb_encoder(embedding_weight)  # (num_items, dim_hidden)
                self.out = torch.cat([self.out.view(batch_size,1,1,-1).expand(-1,num_items,-1,-1),
                    self.emb_feature.view(1,num_items,1,-1).expand(batch_size,-1,-1,-1)], dim=2).view(batch_size*num_items,2,-1)  # (batch*num_items, 2, dim_hidden)
                for module in self.new_set_encoder:
                    self.out = module(self.out)  # (batch*num_items, 2, dim_hidden)
                self.out = self.out.sum(1).view(batch_size, num_items, -1)  # (batch, num_items, dim_hidden)
                for module in self.decoder:
                    self.out = module(self.out)  # (batch, num_items, 1)
                return self.out.squeeze(-1)  # (batch, num_items)
            elif self.mode == 'attention':
                pass

In [35]:
class CCNet(nn.Module):
    def __init__(self, dim_embedding=256,
                 dim_output=20,
                 num_items=6714, 
                 num_inds=32, 
                 dim_hidden=128, 
                 num_heads=4, 
                 num_enc_layers=4, 
                 num_dec_layers=2,
                 ln=True,          # LayerNorm option
                 dropout=0.5,      # Dropout option
                 classify=True,    # completion만 하고 싶으면 False로
                 complete=True,    # classification만 하고 싶으면 False로
                 freeze_classify=False, # classification만 관련된 parameter freeze
                 freeze_complete=False,  # completion만 관련된 parameter freeze
                 encoder_mode = 'set_transformer',
                 enc_pool_mode = 'set_transformer',
                 decoder_mode = 'simple',
                 ):
        super(CCNet, self).__init__()
        self.padding_idx = num_items
        self.classify, self.complete = classify, complete

        self.encoder = Encoder(dim_embedding=dim_embedding,
                               num_items=num_items, 
                               num_inds=num_inds,
                               dim_hidden=dim_hidden, 
                               num_heads=num_heads, 
                               num_enc_layers=num_enc_layers,
                               ln=ln, dropout=dropout,
                               encoder_mode=encoder_mode,
                               enc_pool_mode=enc_pool_mode)
        if classify:
            self.classifier = Classifier(dim_hidden=dim_hidden,
                                         dim_output=dim_output,
                                         dropout=dropout)
            if freeze_classify:
                for p in self.classifier.parameters():
                    p.requires_grad = False
        if complete:
            self.completer = Completer(dim_embedding=dim_embedding,
                                       num_items=num_items, 
                                       #num_inds=num_inds,
                                       dim_hidden=dim_hidden, 
                                       num_heads=num_heads, 
                                       num_dec_layers=num_dec_layers,
                                       ln=ln, dropout=dropout,
                                       mode = decoder_mode)
            if freeze_complete:
                for p in self.completer.parameters():
                    p.requires_grad = False
    
    def forward(self, x, bin_x=None): 
        # x(=recipes): (batch, max_num_ingredients=65) : int_data.
        if not (self.classify or self.complete):
            return
        self.mask = (x == self.padding_idx).unsqueeze(1)  # (batch, 1, max_num_ingredients)
        recipe_feature = self.encoder(x, self.mask)  # (batch, 1, dim_hidden)
        
        logit_classification, logit_completion = None, None

        # Classification:
        if self.classify:
            logit_classification = self.classifier(recipe_feature.squeeze(1))  # (batch, dim_output)
            
        # Completion:
        if self.complete:
            embedding_weight = self.encoder.embedding.weight[:-1]  # (num_items=6714, dim_embedding=256)
            logit_completion = self.completer(recipe_feature, embedding_weight)  # (batch, num_items)

        return logit_classification, logit_completion

## Loss functions

In [36]:
class LogitSelector(nn.Module):
    """
    For each 3-tuple (output vector, label, data), choose 'important' model outputs.
    Specifically, choose a fixed number (e.g. rank = 100) of outputs consist of
    1) output for label index(=missing ingredient index),
    2) outputs for data indices(=given ingredient indices),  -- maybe unnecessary. can be included or not by 'contain_data' option.
    3) and several highest outputs for non-label indices.

    x: (batch, max_ingredient_num = 65 or 59), LongTensor
    output: (batch, num_items = 6714), FloatTensor
    labels: (batch, ), LongTensor
    rank: int
    """
    def __init__(self, rank=100, contain_data=False):
        super(LogitSelector, self).__init__()
        self.rank = rank
        self.contain_data=contain_data

    def forward(self, output, labels, x=None):
        num_items = output.size(1)
        if self.rank > num_items:
            raise ValueError
        target_indices = output.argsort(1)[:,-self.rank:]  # (batch, rank)
        label_where = (target_indices == labels.view(-1,1))  # target_indices의 각 batch마다 이미 label이 있으면 그 위치에 True
        no_label = torch.logical_not(label_where).all(dim=1)  # label 없는 batch에 대해 True
        yes_label, label_where = label_where.nonzero(as_tuple=True)  # label 있는 batch와 그 때 label의 위치를 long으로
        target_indices[no_label,0] = labels[no_label]  # label이 안 보였던 경우 맨 앞에 label 갖다 놓기
        if self.contain_data and x is not None:
            x_extended = F.pad(x, (1, self.rank-1-x.size(1)), 'constant', num_items)  # (batch, rank)
            target_indices[x_extended != num_items] = x_extended[x_extended != num_items]
        new_output = torch.gather(output, 1, target_indices) # (batch, rank)
        new_labels = torch.zeros_like(labels).long()  # label을 맨 앞에 갖다 놨음
        new_labels[yes_label] = label_where  # label이 이미 있었던 batch에 대해서만 수정
        return new_output, new_labels

In [37]:
class MultiClassASLoss(nn.Module):
    '''
    MultiClass ASL(single label) + F1 Loss.
    
    References:
    - ASL paper: https://arxiv.org/abs/2009.14119
    - optimized ASL: https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py
    '''
    def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean', average='macro'):
        super(MultiClassASLoss, self).__init__()

        self.eps = eps
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.targets_classes = []
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.reduction = reduction
        self.average = average

    def forward(self, inputs, target):
        '''
        "input" dimensions: - (batch_size,number_classes)
        "target" dimensions: - (batch_size)
        '''
        log_preds = self.logsoftmax(inputs)
        self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) # make binary label

        targets = self.targets_classes
        anti_targets = 1. - targets
        xs_pos = torch.exp(log_preds)
        xs_neg = 1. - xs_pos
        
        # TP / FP / FN
        tp = (xs_pos * targets).sum(dim=0)
        fp = (xs_pos * anti_targets).sum(dim=0)
        fn = (xs_neg * targets).sum(dim=0) 
        
        if self.average == 'micro':
            tp = tp.sum()
            fp = fp.sum()
            fn = fn.sum()
        
        # F1 score
        f1 = (tp / (tp + 0.5*(fp + fn) + self.eps)).clamp(min=self.eps, max=1.-self.eps).mean()
        
        # ASL weights
        xs_pos = xs_pos * targets
        xs_neg = xs_neg * anti_targets
        if self.gamma_pos > 0 or self.gamma_neg > 0:
            asymmetric_w = torch.pow(1. - xs_pos - xs_neg,
                                    self.gamma_pos * targets + self.gamma_neg * anti_targets)
            log_preds = log_preds * asymmetric_w

        if self.eps > 0:  # label smoothing
            num_classes = inputs.size()[-1]
            self.targets_classes = self.targets_classes.mul(1. - self.eps).add(self.eps / num_classes)

        # loss calculation
        loss = - self.targets_classes.mul(log_preds)

        loss = loss.sum(dim=-1)
        if self.reduction == 'mean':
            loss = loss.mean()

        return loss + (1. - f1)

In [38]:
class MultiClassFocalLoss(nn.Module):
    '''
    MultiClass F1 Loss + FocalLoss.
    The original implmentation is written by Michal Haltuf on Kaggle.
    
    Reference
    ---------
    - https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
    - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score
    - https://discuss.pytorch.org/t/calculating-precision-recall-and-f1-score-in-case-of-multi-label-classification/28265/6
    - http://www.ryanzhang.info/python/writing-your-own-loss-function-module-for-pytorch/
    - https://gist.github.com/SuperShinyEyes/dcc68a08ff8b615442e3bc6a9b55a354
    '''
    def __init__(self, eps=1e-8, average='macro', reduction='mean', gamma=2):
        super().__init__()
        self.eps = eps
        self.average = average
        self.reduction = reduction
        self.gamma = gamma
        
    def forward(self, pred, target):
        # focal loss
        loss = F.cross_entropy(pred, target, reduction=self.reduction)
        pt = torch.exp(-loss)
        if self.gamma>0:
            loss = (1-pt)**self.gamma * loss
        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        
        # f1 loss
        target = F.one_hot(target, pred.size(-1)).float()
        pred = F.softmax(pred, dim=1)
        
        tp = (target * pred).sum(dim=0).float()
        fp = ((1 - target) * pred).sum(dim=0).float()
        fn = (target * (1 - pred)).sum(dim=0).float()

        if self.average == 'micro':
            tp, fp, fn = tp.sum(), fp.sum(), fn.sum()
        
        f1 = (tp / (tp + 0.5*(fp + fn) + self.eps)).clamp(min=self.eps, max=1-self.eps).mean()

        return 1. - f1 + loss

In [39]:
class MultiLabelASLoss(nn.Module):
    '''
    MultiLabel ASL Loss + F1 Loss
    
    References:
    - ASL paper: https://arxiv.org/abs/2009.14119
    - optimized ASL: https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py
    '''
    def __init__(self, gamma_pos=1, gamma_neg=4, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False, average='macro'):
        super(MultiLabelASLoss, self).__init__()
        
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps
        self.average = average

        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
        self.tp = self.fp = self.fn = self.f1 = None

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        self.targets = y.float()
        self.anti_targets = 1 - y.float()

        # Calculating Probabilities
        self.xs_pos = torch.sigmoid(x.float())
        self.xs_neg = 1.0 - self.xs_pos
        
        # TP/FP/FN
        self.tp = (self.xs_pos * self.targets).sum(dim=0)
        self.fp = (self.xs_pos * self.anti_targets).sum(dim=0)
        self.fn = (self.xs_neg * self.targets).sum(dim=0)        
        
        if self.average == 'micro':
            self.tp = self.tp.sum()
            self.fp = self.fp.sum()
            self.fn = self.fn.sum()
        
        # F1 score
        self.f1 = (self.tp / (self.tp + 0.5*(self.fp + self.fn) + self.eps)).clamp(
            min=self.eps, max=1-self.eps)

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            self.xs_neg.add_(self.clip)
            self.xs_neg.clamp_(max=1.)

        # Basic CE calculation
        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            self.xs_pos = self.xs_pos * self.targets
            self.xs_neg = self.xs_neg * self.anti_targets
            self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
                                          self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            self.loss *= self.asymmetric_w

        return -self.loss.mean(dim=0).sum() + (1.-self.f1.mean())

In [40]:
class MultiLabelBCELoss(nn.Module):
    """
    MultiLabel weighted F1_Loss + BCEWithLogitsLosss.
    """
    def __init__(self, eps=1e-8, average='macro', reduction='mean', weight=None, gamma=2):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(reduction=reduction)
        self.eps = eps
        self.average = average
        self.reduction = reduction
        self.weight = weight
        self.gamma = gamma
        if average not in ['macro', 'micro']:
            raise ValueError('average should be macro or micro.')
        
    def forward(self, pred, target): # same dimension
        bce_loss = self.bce(pred.float(), target.float())
        
        # f1 loss
        pred = F.softmax(pred, dim=1)
        
        tp = (target * pred).sum(dim=0).float()
        fp = ((1 - target) * pred).sum(dim=0).float()
        fn = (target * (1 - pred)).sum(dim=0).float()

        if self.average == 'micro':
            tp, fp, fn = tp.sum(), fp.sum(), fn.sum()
        
        f1 = tp / (tp + 0.5*(fp + fn) + self.eps).clamp(min=self.eps, max=1-self.eps)

        return 1 - f1.mean() + bce_loss

## Training function

In [66]:
def train(model,
          dataloaders,
          criterion,
          optimizer,
          scheduler,
          dataset_sizes,
          device='cpu',
          num_epochs=20,
          #wandb_log=False,
          early_stop_patience=None,
          classify=True,
          complete=True,
          random_seed=1):

    def _concatenate(running_v, new_v):
        if running_v is not None:
            return np.concatenate((running_v, new_v.clone().detach().cpu().numpy()), axis=0)
        else:
            return new_v.clone().detach().cpu().numpy()
    
    np.random.seed(random_seed)
    torch.random.manual_seed(random_seed)
    
    since = time.time()
    
    # Logit selector
    #logit_selector = LogitSelector(rank=100).to(device)
    
    # BEST MODEL SAVING
    best = {'loss': float('inf')}
    if classify:
        best['F1micro'] = -1.
        best['F1macro'] = -1.
        best['top5cls'] = -1.
    if complete:
        best['acc'] = -1.
        best['top10cmp'] = -1.
    best_model_wts = deepcopy(model.state_dict())

    if early_stop_patience is not None:
        if not isinstance(early_stop_patience, int):
            raise TypeError('early_stop_patience should be an integer.')
        patience_cnt = 0
    
    print('-' * 5 + 'Training the model' + '-' * 5)
    for epoch in tqdm(range(1,num_epochs+1)):
        print(f'\nEpoch {epoch}/{num_epochs}')

        val_loss = 0. # sum of classification and completion loss

        # Each epoch has a training phase and two validation phases
        for phase in ['train', 'valid_class', 'valid_compl']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                if not classify and phase == 'valid_class':
                    continue
                elif not complete and phase == 'valid_compl':
                    continue
                model.eval()   # Set model to evaluate mode

            running_loss_class = 0.
            running_corrects_class = 0.
            running_labels_class = None
            running_preds_class = None
            running_top_k_class = 0.
            
            running_loss_compl = 0.
            running_corrects_compl = 0.
            running_top_k_compl = 0.
            
            dataset_name = phase
            if phase == 'train':
                dataset_name = 'train_class' if classify and not complete else 'train_compl'
            
            # Iterate over data.
            for idx, loaded_data in enumerate(dataloaders[dataset_name]):
                if phase == 'train':
                    if complete:
                        bin_inputs, int_inputs, label_class, label_compl = loaded_data
                    else:
                        bin_inputs, int_inputs, label_class = loaded_data
                elif phase == 'valid_class':
                        bin_inputs, int_inputs, label_class = loaded_data
                elif phase == 'valid_compl':
                        bin_inputs, int_inputs, label_compl = loaded_data
                
                batch_size, num_items = bin_inputs.size()
                if classify and phase in ['train', 'valid_class']:
                    labels_class = label_class.to(device)
                if complete and phase in ['train', 'valid_compl']:
                    labels_compl = label_compl.to(device)
                bin_inputs = bin_inputs.to(device)
                int_inputs = int_inputs.to(device)
                    
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs_class, outputs_compl = model(int_inputs, bin_x=bin_inputs)
                    new_outputs_compl, new_labels_compl = None, None
                    if classify and phase in ['train', 'valid_class']:
                        _, preds_class = torch.max(outputs_class, 1)
                    if complete and phase in ['train', 'valid_compl']:
                        _, preds_compl = torch.max(outputs_compl, 1)
                        #new_outputs_compl, new_labels_compl = logit_selector(outputs_compl, labels_compl)
                        
                    """if idx == 0:
                        if classify and phase in ['train', 'valid_class']:
                            print('labels_classification', labels_class.cpu().numpy())
                            print('preds_classification', preds_class.cpu().numpy())
                        if complete and phase in ['train', 'valid_compl']:
                            if new_labels_compl is not None:
                                print('new label', new_labels_compl.cpu().numpy())
                            print('labels_completion', labels_compl.cpu().numpy())
                            print('preds_completion', preds_compl.cpu().numpy())"""
                    
                    if classify and phase in ['train', 'valid_class']:
                        loss_class = criterion(outputs_class, labels_class.long())
                    if complete and phase in ['train', 'valid_compl']:
                        if new_outputs_compl is None:
                            loss_compl = criterion(outputs_compl, labels_compl.long())
                        else:
                            loss_compl = criterion(new_outputs_compl, new_labels_compl)

                    if classify and complete and phase == 'train':
                        loss = loss_class + loss_compl
                    elif classify and phase in ['train', 'valid_class']:
                        loss = loss_class
                    elif complete and phase in ['train', 'valid_compl']:
                        loss = loss_compl

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        #torch.nn.utils.clip_grad_norm_(model.parameters(), 1) # gradient clipping
                        optimizer.step()

                if idx % 100 == 0 and phase == 'train':
                    log_str = f'    {phase} {idx * 100 // len(dataloaders[dataset_name]):3d}% of an epoch | '
                    if classify and phase in ['train', 'valid_class']:
                        log_str += f'Loss(classif.): {loss_class.item():.4f} | '
                    if complete and phase in ['train', 'valid_compl']:
                        log_str += f'Loss(complet.): {loss_compl.item():.4f} | '
                    print(log_str)

                # statistics
                if classify and phase in ['train', 'valid_class']: # for F1 score
                    running_loss_class += loss_class.item() * batch_size
                    running_labels_class = _concatenate(running_labels_class, labels_class)
                    running_preds_class = _concatenate(running_preds_class, preds_class)
                    running_top_k_class += top_k_accuracy_score(labels_class.cpu().numpy(), outputs_class.detach().cpu().numpy(), k=5, labels=np.arange(outputs_class.size(1)), normalize=False)
                if complete and phase in ['train', 'valid_compl']: # for accuracy
                    running_loss_compl += loss_compl.item() * batch_size
                    running_corrects_compl += torch.sum(preds_compl == labels_compl)
                    running_top_k_compl += top_k_accuracy_score(labels_compl.cpu().numpy(), outputs_compl.detach().cpu().numpy(), k=10, labels=np.arange(outputs_compl.size(1)), normalize=False)

            epoch_loss = 0.
            log_str = f'{phase.upper()} | '
            if classify and phase in ['train', 'valid_class']:
                epoch_loss_class = running_loss_class / dataset_sizes[dataset_name]
                epoch_loss += epoch_loss_class
                running_labels_class = torch.from_numpy(running_labels_class)
                running_preds_class = torch.from_numpy(running_preds_class)
                epoch_macro_f1 = f1_score(running_labels_class, running_preds_class, average='macro')  # classification: f1 scores.
                epoch_micro_f1 = f1_score(running_labels_class, running_preds_class, average='micro')  # micro f1 score == accuracy, for single-label classification.
                epoch_top_k_class = running_top_k_class / dataset_sizes[dataset_name]
                log_str += f'Loss(classif.): {epoch_loss_class:.3f} Macro-F1: {epoch_macro_f1:.3f} Micro-F1: {epoch_micro_f1:.3f} Top-5 Acc: {epoch_top_k_class:.3f} | '
            if complete and phase in ['train', 'valid_compl']:
                epoch_loss_compl = running_loss_compl / dataset_sizes[dataset_name]
                epoch_loss += epoch_loss_compl
                epoch_acc_compl = running_corrects_compl / dataset_sizes[dataset_name]  # completion task: accuracy.
                epoch_top_k_compl = running_top_k_compl / dataset_sizes[dataset_name]
                log_str += f'Loss(complet.): {epoch_loss_compl:.3f} Acc(complet.): {epoch_acc_compl:.3f} Top-10 Acc: {epoch_top_k_compl:.3f} | '
            print(log_str)
            
            if phase == 'train':
                train_loss = epoch_loss
                if classify:
                    train_macro_f1 = epoch_macro_f1
                    train_micro_f1 = epoch_micro_f1
                    train_top_k_class = epoch_top_k_class
                if complete:
                    train_acc = epoch_acc_compl
                    train_top_k_compl = epoch_top_k_compl
                # if wandb_log:
                #     wandb.watch(model)
            elif 'val' in phase:
                val_loss += epoch_loss
                if classify and phase == 'valid_class':
                    val_macro_f1 = epoch_macro_f1
                    val_micro_f1 = epoch_micro_f1
                    val_top_k_class = epoch_top_k_class
                if complete and phase == 'valid_compl':
                    val_acc = epoch_acc_compl
                    val_top_k_compl = epoch_top_k_compl
        
        scheduler.step(-val_micro_f1 if classify and not complete else -val_acc)
        is_new_best = val_micro_f1 > best['F1micro'] if classify and not complete else val_acc > best['acc']
        if is_new_best:
            best['bestEpoch'] = int(epoch)
            best['loss'] = val_loss
            if classify:
                best['F1micro'] = val_micro_f1
                best['F1macro'] = val_macro_f1
                best['top5cls'] = val_top_k_class
            if complete:
                best['acc'] = val_acc
                best['top10cmp'] = val_top_k_compl
            best_model_wts = deepcopy(model.state_dict()) # deep copy the model
            if early_stop_patience is not None:
                patience_cnt = 0
        elif early_stop_patience is not None:
            patience_cnt += 1

        """
        if wandb_log:
            wandb.log({'train_loss': train_loss,
                       'val_loss': val_loss,
                       'train_macro_f1': train_macro_f1,
                       'train_micro_f1': train_micro_f1,
                       'val_macro_f1': val_macro_f1,
                       'val_micro_f1': val_micro_f1,
                       'best_val_loss': best['loss'],
                       'learning_rate': optimizer.param_groups[0]['lr']})
                                        # scheduler.get_last_lr()[0] for CosineAnnealingWarmRestarts
        """
        if early_stop_patience is not None:
            if patience_cnt > early_stop_patience:
                print(f'Early stop at epoch {epoch}.')
                break
        

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('==== Best Result ====')
    for k in best:
        if k == 'bestEpoch':
            print(f"{k}: {int(best[k])}")
        else:
            print(f"{k}: {float(best[k]):.8f}")
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, best

## Experiment

In [78]:
def foo(x=1, y=2):
    print(x+y)
list({'x':3, 'y':6}.keys())

['x', 'y']

In [69]:
def experiment(dim_embedding=256,
               dim_hidden=128,
               dropout=0.5,
               subset_length=None,
               encoder_mode='deep_sets',
               enc_pool_mode='set_transformer',
               decoder_mode='simple',
               num_enc_layers=4,
               num_dec_layers=4,
               batch_size=16,
               n_epochs=50,
               loss='ASLoss',
               opt='AdamW',
               lr=1e-3,
               step_size=10,  # lr_scheduler
               step_factor=0.1, # lr_scheduler
               patience=20,   # early stop
               seed=0,
               classify=True,
               complete=True,
               freeze_classify=False,
               freeze_complete=False,
               pretrained_model_path=None
               ):
    
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    
    train_data_name = 'train_class' if classify and not complete else 'train_compl'
    dataset_names = [train_data_name, 'valid_class', 'valid_compl']
    subset_indices = {x: [i for i in range(len(recipe_datasets[x]) if subset_length is None else subset_length)
                          ] for x in dataset_names}
    dataloaders = {x: DataLoader(Subset(recipe_datasets[x], subset_indices[x]),
                                 batch_size=batch_size, shuffle=('train' in x)) for x in dataset_names}
    dataset_sizes = {x: len(subset_indices[x]) for x in dataset_names}
    print(dataset_sizes)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('device: ', device)

    # Get a batch of training data
    loaded_data = next(iter(dataloaders[train_data_name]))
    print('bin_inputs, int_inputs, *labels:', [x.shape for x in loaded_data])

    model_ft = CCNet(dim_embedding=dim_embedding, dim_output=20, dim_hidden=dim_hidden,
                     num_items=len(loaded_data[0][0]), num_enc_layers=num_enc_layers, num_dec_layers=num_dec_layers,
                     ln=True, dropout=dropout, encoder_mode=encoder_mode, enc_pool_mode=enc_pool_mode, decoder_mode=decoder_mode,
                     classify=classify, complete=complete, freeze_classify=freeze_classify, freeze_complete=freeze_complete).to(device)
    if pretrained_model_path is not None:
        pretrained_dict = torch.load(pretrained_model_path)
        model_dict = model_ft.state_dict()
        
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict) 
        # 3. load the new state dict
        model_ft.load_state_dict(model_dict)
        
    # Model Info
    #print(model_ft)
    total_params = sum(dict((p.data_ptr(), p.numel()) for p in model_ft.parameters() if p.requires_grad).values())
    print("Total Number of Parameters", total_params)

    # Loss, Optimizer, LR Scheduler
    LOSSES = {
        'CrossEntropyLoss': nn.CrossEntropyLoss,
        'FocalLoss': MultiClassFocalLoss,
        'ASLoss': MultiClassASLoss,
    }
    OPTIMIZERS = {
        'SGD': optim.SGD,
        'MomentumSGD': optim.SGD,
        'NestrovSGD': optim.SGD,
        'Adam': optim.Adam,
        'AdamW': optim.AdamW,
    }
    OPTIMIZERS_ARG = {
        'SGD': {'lr':lr, 'weight_decay':0.2},
        'MomentumSGD': {'lr':lr, 'weight_decay':0.2, 'momentum':0.9},
        'NestrovSGD': {'lr':lr, 'weight_decay':0.2, 'momentum':0.9, 'nesterov':True},
        'Adam': {'lr':lr, 'weight_decay':0.2},
        'AdamW': {'lr':lr, 'weight_decay':0.2},
    }
    criterion = LOSSES[loss]().to(device)
    optimizer = OPTIMIZERS[opt]([p for p in model_ft.parameters() if p.requires_grad == True], **OPTIMIZERS_ARG[opt])
    exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=step_factor, patience=step_size, verbose=True)

    model_ft, best = train(model_ft, dataloaders, criterion, optimizer, exp_lr_scheduler,
                           dataset_sizes, device=device, num_epochs=n_epochs, early_stop_patience=patience,
                           classify=classify, complete=complete, random_seed=seed)

    fname = ['ckpt', 'CCNet']
    if classify:
        fname.append('cls')
    if complete:
        fname.append('cmp')
    for k in best:
        if k == 'bestEpoch':
            fname.append(f'bestEpoch{int(best[k]):2d}')
        else:
            fname += [f"{k}{float(best[k]):.4f}"]
    fname += [f'bs{batch_size}',f'lr{lr}', f'seed{seed}',f'nEpochs{n_epochs}',]
    fname += ['encoder', encoder_mode, 'encPool', enc_pool_mode]
    if complete:
        fname += ['decoder', decoder_mode]
    fname = '_'.join(fname) + '.pt'
    if not os.path.isdir('./weights/'):
        os.mkdir('./weights/')
    torch.save(model_ft.state_dict(), os.path.join('./weights/', fname))

### Classification Only

In [None]:
experiment(classify=True, complete=False, batch_size=128, n_epochs=1, lr=1e-3, encoder_mode='deep_sets', enc_pool_mode='deep_sets', num_enc_layers=4, num_dec_layers=4, dim_embedding=256, dim_hidden=256, dropout=0.2)


### Completion Only

In [None]:
experiment(classify=False, complete=True, batch_size=1024, n_epochs=1, lr=1e-4, encoder_mode='deep_sets', enc_pool_mode='deep_sets', num_enc_layers=4, decoder_mode='simple', num_dec_layers=4, dim_embedding=256, dim_hidden=256, dropout=0.2, subset_length=2000)


### Classification + Completion

In [None]:
experiment(classify=True, complete=True, batch_size=1024, n_epochs=100, lr=1e-4, encoder_mode='deep_sets', enc_pool_mode='deep_sets', num_enc_layers=2, decoder_mode='simple', num_dec_layers=2, dim_embedding=256, dim_hidden=256, dropout=0.2)
