In [1]:
# h5py 안 될 때
#!brew reinstall hdf5
#!export CPATH="/opt/homebrew/include/"
#!export HDF5_DIR=/opt/homebrew/
#!python3 -m pip install h5py

In [2]:
# For Colab
# from google.colab import drive
# drive.mount('/content/drive')
# %cd drive/MyDrive/cuisine-prediction/Hanseul/
# !pip3 install torchmetrics

In [1]:
import os
import pickle
import math
import time
from tqdm import tqdm
from copy import deepcopy

import h5py
import numpy as np
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset, DataLoader
from sklearn.metrics import f1_score

In [2]:
path_root = '../'
path_container = './Container/'

## Loading Datasets

In [3]:
with open(path_container + 'id_cuisine_dict.pickle', 'rb') as f:
    id_cuisine_dict = pickle.load(f)
with open(path_container + 'cuisine_id_dict.pickle', 'rb') as f:
    cuisine_id_dict = pickle.load(f)
with open(path_container + 'id_ingredient_dict.pickle', 'rb') as f:
    id_ingredient_dict = pickle.load(f)
with open(path_container + 'ingredient_id_dict.pickle', 'rb') as f:
    ingredient_id_dict = pickle.load(f)

In [4]:
class RecipeDataset(Dataset):
    def __init__(self, data_dir, test=False):
        self.data_dir = data_dir
        self.test = test
        with h5py.File(data_dir, 'r') as data_file:
            self.bin_data = data_file['bin_data'][:]  # Size (num_recipes=23547, num_ingredients=6714)
            if 'int_labels' in data_file.keys():
                self.int_labels = data_file['int_labels'][:]  # Size (num_recipes=23547,), about cuisines
            if 'bin_labels' in data_file.keys():
                self.bin_labels = data_file['bin_labels'][:]  # Size (num_recipes=23547, 20), about cuisines
        
        self.padding_idx = self.bin_data.shape[1]  # == num_ingredient == 6714
        self.max_num_ingredients_per_recipe = self.bin_data.sum(1).max()  # valid & test의 경우 65
        
        # (59나 65로) 고정된 길이의 row vector에 해당 recipe의 indices 넣고 나머지는 padding index로 채워넣기
        # self.int_data: Size (num_recipes=23547, self.max_num_ingredients_per_recipe=59 or 65)
        self.int_data = np.full((len(self.bin_data), self.max_num_ingredients_per_recipe), self.padding_idx) 
        for i, bin_recipe in enumerate(self.bin_data):
            recipe = np.arange(self.padding_idx)[bin_recipe==1]
            self.int_data[i][:len(recipe)] = recipe
        
    def __len__(self):
        return len(self.bin_data)

    def __getitem__(self, idx):
        bin_data = self.bin_data[idx]
        int_data = self.int_data[idx]
        if self.test:
            return bin_data, int_data
        
        bin_label = self.bin_labels[idx]
        if 'valid_compl' in self.data_dir:
            return bin_data, int_data, bin_label
        
        int_label = self.int_labels[idx]
        return bin_data, int_data, bin_label, int_label

In [5]:
dataset_name = ['train_class', 'train_compl', 'valid_class', 'valid_compl', 'test_class', 'test_compl']

recipe_datasets = {x: RecipeDataset(os.path.join(path_container, x), test='test' in x) for x in dataset_name}

In [6]:
count_single_ingredient_recipe = 0
list_single_ingredient_recipe = []
for i in range(len(recipe_datasets['train_class'])):
    _bd, _,_,_ = recipe_datasets['train_class'][i]
    if _bd.sum()<2:
        count_single_ingredient_recipe += 1
        list_single_ingredient_recipe.append(i)
print(count_single_ingredient_recipe)
print(list_single_ingredient_recipe)

19
[564, 1263, 4074, 4203, 4901, 5277, 5360, 6232, 7835, 10585, 10777, 12476, 13301, 13989, 15951, 17374, 18153, 19039, 20469]


## Model

In [7]:
## Building blocks of Set Transformers ##
# added masks.

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False, dropout=0):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.p = dropout
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Sequential(
            nn.Linear(dim_V, dim_V),
            nn.ReLU(),
            nn.Linear(dim_V, dim_V))
        self.Dropout = nn.Dropout(p=dropout)

    def forward(self, Q, K, mask=None):
        # Q (batch, q_len, d_hid)
        # K (batch, k_len, d_hid)
        # V (batch, v_len, d_hid == dim_V)
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)
        
        dim_split = self.dim_V // self.num_heads
        # Q_ (batch * num_heads, q_len, d_hid // num_heads)
        # K_ (batch * num_heads, k_len, d_hid // num_heads)
        # V_ (batch * num_heads, v_len, d_hid // num_heads)
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)
        
        # energy (batch * num_heads, q_len, k_len)
        energy = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
        if mask is not None:
            energy.masked_fill_(mask, float('-inf'))
        A = torch.softmax(energy, 2)
        
        # O (batch, q_len, d_hid)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        _O = self.fc_o(O)
        if self.p > 0:
            _O = self.Dropout(_O)
        O = O + _O 
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False, dropout=0.2):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None):
        return self.mab(X, X, mask=mask)

class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False, dropout=0.2):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln, dropout=dropout)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X, mask=mask)
        return self.mab1(X, H)

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False, dropout=0.2):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln, dropout=dropout)
        
    def forward(self, X, mask=None):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X, mask=mask)

In [8]:
def make_one_hot(x):
        """ Convert int_data into bin_data, if needed. """
        if type(x) is not torch.Tensor:
            x = torch.LongTensor(x)
        if x.dim() > 2:
            x = x.squeeze()
            if x.dim() > 2:
                return False
        elif x.dim() < 2:
            x = x.unsqueeze(0)
        return F.one_hot(x).sum(1)[:,:-1]

In [9]:
# 0416: num_output 제거. 왜냐하면 pooling layer에서 굳이 (batch 당) 두 개의 벡터를 내보낼 이유가 없다고 판단함.
# ____: 이에 따라 decoder1 제거. decoder2는 decoder로 변경. 사실 decoder 굳이 필요할까 싶어서 일단 주석처리 해둠.
# 0417: CCNet을 여러 요소로 쪼갬. (Encoder, Classifier, Completer)

class ResBlock(nn.Module):
    def __init__(self, dim_input, dim_hidden, dim_output):
        super(ResBlock, self).__init__()
        self.use_skip_conn = (dim_input == dim_output)
        self.ff = nn.Sequential(
            nn.BatchNorm1d(dim_input),
            nn.LeakyReLU(),
            nn.Linear(dim_input, dim_hidden),
            nn.BatchNorm1d(dim_hidden),
            nn.LeakyReLU(),
            nn.Linear(dim_hidden, dim_output))
    def forward(self, x):
        if self.use_skip_conn:
            return self.ff(x) + x
        return self.ff(x)

class ResBlockLN(nn.Module):
    def __init__(self, dim_input, dim_hidden, dim_output):
        super(ResBlockLN, self).__init__()
        self.use_skip_conn = (dim_input == dim_output)
        self.ff = nn.Sequential(
            nn.LayerNorm(dim_input),
            nn.LeakyReLU(),
            nn.Linear(dim_input, dim_hidden),
            nn.LayerNorm(dim_hidden),
            nn.LeakyReLU(),
            nn.Linear(dim_hidden, dim_output))
    def forward(self, x):
        if self.use_skip_conn:
            return self.ff(x) + x
        return self.ff(x)


class Encoder(nn.Module):
    """ Create Feature Vector of Given Recipe. """
    def __init__(self, dim_embedding=256,
                 num_items=6714, 
                 num_inds=32,      # For ISAB
                 dim_hidden=128, 
                 num_heads=4, 
                 num_enc_layers=4,
                 ln=True,          # LayerNorm option
                 dropout=0.2      # Dropout option
                ):
        super(Encoder, self).__init__()
        self.num_heads = num_heads
        self.padding_idx = num_items
        self.embedding =  nn.Embedding(num_embeddings=num_items+1, embedding_dim=dim_embedding, padding_idx=-1)
        self.encoder = nn.ModuleList(
            [ISAB(dim_embedding, dim_hidden, num_heads, num_inds, ln=ln, dropout=dropout)] +
            [ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln, dropout=dropout) for _ in range(num_enc_layers-1)])
        self.pooling = PMA(dim_hidden, num_heads, 1, ln=ln)
        
        self.out = self.mask = None
        
    def forward(self, x):
        # x(=recipes): (batch, max_num_ingredient=65) : int_data.

        self.out = self.embedding(x)
        # (batch, max_num_ingredient=65, dim_embedding=256)
        # cf. embedding.weight: (num_items+1=6715, dim_embedding=256)

        self.mask = (x == self.padding_idx).repeat(self.num_heads,1).unsqueeze(1)
        # mask: (batch*num_heads, 1, max_num_ingredient=65)
        
        for module in self.encoder:
            self.out = module(self.out, mask=self.mask)
        # (batch, max_num_ingredient=65, dim_hidden=128) : permutation-equivariant.

        return self.pooling(self.out, mask=self.mask) # (batch, 1, dim_hidden=128) : permutation-invariant.
    

class Classifier(nn.Module):
    def __init__(self, dim_hidden=128, dim_output=20):
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            ResBlock(dim_hidden, dim_hidden, dim_hidden),
            ResBlock(dim_hidden, dim_hidden, dim_hidden),
            ResBlock(dim_hidden, dim_hidden, dim_output))
       
    def forward(self, x):
        return self.classifier(x)


class Completer(nn.Module):
    def __init__(self, dim_embedding=256,
                 num_items=6714, 
                 num_inds=32,      # For ISAB
                 dim_hidden=128, 
                 num_heads=4, 
                 num_dec_layers=2,
                 ln=True,          # LayerNorm option
                 dropout=0.2,      # Dropout option
                 mode = 'attention',
                ):
        super(Completer, self).__init__()
        assert mode in ['attention', 'concat']
        self.mode = mode
        self.ff = nn.Sequential(
                ResBlock(dim_hidden, dim_hidden, dim_hidden), ResBlock(dim_hidden, dim_hidden, dim_hidden))
        if mode == 'attention':
            self.decoder = nn.ModuleList(
                [MAB(dim_hidden, dim_embedding, dim_hidden, num_heads, ln=ln, dropout=dropout) for _ in range(num_dec_layers)])
            self.ff1 = nn.Sequential(
                ResBlock(dim_hidden, dim_hidden, dim_hidden), ResBlock(dim_hidden, dim_hidden, num_items))
        elif mode == 'concat':
            self.ff2 = ResBlockLN(dim_embedding, dim_hidden, dim_hidden)
            self.decoder = nn.ModuleList(
                [ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln, dropout=dropout) for _ in range(max(num_dec_layers, 1))])
            self.ff3 = nn.Sequential(
                ResBlockLN(2*dim_hidden, 2*dim_hidden, 2*dim_hidden),
                ResBlockLN(2*dim_hidden, dim_hidden, 2*dim_hidden),
                ResBlockLN(2*dim_hidden, dim_hidden, 1))
        
        self.out = self.emb_feature = None
        
    def forward(self, x, embedding_weight):
        batch_size = x.size(0)
        num_items = embedding_weight.size(0)
        
        self.out = self.ff(x.squeeze(1)).unsqueeze(1) # (batch, 1, dim_hidden=128)
        embedding_weight = embedding_weight.unsqueeze(0)
        # embedding_weight: (1, num_items=6714, dim_embedding=256)
        
        if self.mode == 'attention':
            for module in self.decoder:
                self.out = module(self.out, embedding_weight.repeat(batch_size,1,1))
            return self.ff1(self.out.squeeze(1))  # (batch, num_items=6714)
        elif self.mode == 'concat':
            self.emb_feature = self.ff2(embedding_weight) # (1, num_items, dim_hidden)
            for module in self.decoder:
                self.emb_feature = module(self.emb_feature)
            self.out = torch.cat([self.out.repeat(1,num_items,1),
                self.emb_feature.repeat(batch_size,1,1)], dim=2) # (batch, num_items, 2*dim_hidden)
            return self.ff3(self.out).squeeze() # (batch, num_items)
        
    
class CCNet(nn.Module):
    def __init__(self, dim_embedding=256, #
                 dim_output=20,
                 num_items=6714, 
                 num_inds=32, 
                 dim_hidden=128, 
                 num_heads=4, 
                 num_enc_layers=4, 
                 num_dec_layers=2,
                 ln=True,          # LayerNorm option
                 dropout=0.5,      # Dropout option
                 classify=True,    # completion만 하고 싶으면 False로
                 complete=True,    # classification만 하고 싶으면 False로
                 freeze_classify=False, # classification만 관련된 parameter freeze
                 freeze_complete=False,  # completion만 관련된 parameter freeze
                 decoder_mode = 'attention',
                 ignore_used_ingredients=False
                 ):
   
        super(CCNet, self).__init__()
        self.classify, self.complete = classify, complete
        self.ignore_used_ingredients = ignore_used_ingredients
        self.encoder = Encoder(dim_embedding=dim_embedding,
                               num_items=num_items, 
                               num_inds=num_inds,
                               dim_hidden=dim_hidden, 
                               num_heads=num_heads, 
                               num_enc_layers=num_enc_layers,
                               ln=ln, dropout=dropout)
        if classify:
            self.classifier = Classifier(dim_hidden=dim_hidden,
                                         dim_output=dim_output)
            if freeze_classify:
                for p in self.classifier.parameters():
                    p.requires_grad = False
        if complete:
            self.completer = Completer(dim_embedding=dim_embedding,
                                       num_items=num_items, 
                                       num_inds=num_inds,
                                       dim_hidden=dim_hidden, 
                                       num_heads=num_heads, 
                                       num_dec_layers=num_dec_layers,
                                       ln=ln, dropout=dropout,
                                       mode = decoder_mode)
            if freeze_complete:
                for p in self.completer.parameters():
                    p.requires_grad = False
    
    def forward(self, x, bin_x=None): 
        # x(=recipes): (batch, max_num_ingredient=65) : int_data.
        #print('input', x.size())
        if not (self.classify or self.complete):
            return
        recipe_feature = self.encoder(x) # (batch, 1, dim_hidden)
        #print('encoder output', recipe_feature.size())
        
        logit_classification, logit_completion = None, None

        # Classification:
        if self.classify:
            logit_classification = self.classifier(recipe_feature.squeeze(1))  # (batch, dim_output)
            
        # Completion:
        if self.complete:
            embedding_weight = self.encoder.embedding.weight[:-1]
            # embedding_weight: (1, num_items=6714, dim_embedding=256)
            
            logit_completion = self.completer(recipe_feature, embedding_weight)
            if self.ignore_used_ingredients:
                if bin_x is None:
                    bin_x = make_one_hot(x)
                bool_x = (bin_x == 1)
                logit_completion[bool_x] = float('-inf')

        return logit_classification, logit_completion

## Loss functions

In [10]:
class ClassificationASLoss(nn.Module):
    '''
    MultiClass ASL(single label) + F1 Loss.
    
    References:
    - ASL paper: https://arxiv.org/abs/2009.14119
    - optimized ASL: https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py
    '''
    def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean', average='macro'):
        super(ClassificationASLoss, self).__init__()

        self.eps = eps
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.targets_classes = []
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.reduction = reduction
        self.average = average

    def forward(self, inputs, target):
        '''
        "input" dimensions: - (batch_size,number_classes)
        "target" dimensions: - (batch_size)
        '''
        num_classes = inputs.size()[-1]
        log_preds = self.logsoftmax(inputs)
        self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) # make binary label

        targets = self.targets_classes
        anti_targets = 1 - targets
        xs_pos = torch.exp(log_preds)
        xs_neg = 1 - xs_pos
        
        # TP / FP / FN
        tp = (xs_pos * targets).sum(dim=0)
        fp = (xs_pos * anti_targets).sum(dim=0)
        fn = (xs_neg * targets).sum(dim=0) 
        
        if self.average == 'micro':
            tp = tp.sum()
            fp = fp.sum()
            fn = fn.sum()
        
        # F1 score
        f1 = (tp / (tp + 0.5*(fp + fn) + self.eps)).clamp(min=self.eps, max=1-self.eps).mean()
        
        # ASL weights
        xs_pos = xs_pos * targets
        xs_neg = xs_neg * anti_targets
        asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
                                 self.gamma_pos * targets + self.gamma_neg * anti_targets)
        log_preds = log_preds * asymmetric_w

        if self.eps > 0:  # label smoothing
            self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)

        # loss calculation
        loss = - self.targets_classes.mul(log_preds)

        loss = loss.sum(dim=-1)
        if self.reduction == 'mean':
            loss = loss.mean()

        return loss + (1. - f1)


In [11]:
class ClassificationFocalLoss(nn.Module):
    '''
    MultiClass F1 Loss + FocalLoss.
    The original implmentation is written by Michal Haltuf on Kaggle.
    
    Reference
    ---------
    - https://www.kaggle.com/rejpalcz/best-loss-function-for-f1-score-metric
    - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score
    - https://discuss.pytorch.org/t/calculating-precision-recall-and-f1-score-in-case-of-multi-label-classification/28265/6
    - http://www.ryanzhang.info/python/writing-your-own-loss-function-module-for-pytorch/
    - https://gist.github.com/SuperShinyEyes/dcc68a08ff8b615442e3bc6a9b55a354
    '''
    def __init__(self, eps=1e-8, average='macro', reduction='mean', gamma=2):
        super().__init__()
        self.eps = eps
        self.average = average
        self.reduction = reduction
        self.gamma = gamma
        
    def forward(self, pred, target):
        # focal loss
        ce_loss = F.cross_entropy(pred, target, reduction=self.reduction)
        pt = torch.exp(-ce_loss)

        focalloss = (1-pt)**self.gamma * ce_loss
        if self.reduction == 'mean':
            focalloss = focalloss.mean()
        elif self.reduction == 'sum':
            return focalloss.sum()
        
        # f1 loss
        target = F.one_hot(target, pred.size(-1)).float()
        pred = F.softmax(pred, dim=1)
        
        tp = (target * pred).sum(dim=0).float()
        fp = ((1 - target) * pred).sum(dim=0).float()
        fn = (target * (1 - pred)).sum(dim=0).float()

        if self.average == 'micro':
            tp, fp, fn = tp.sum(), fp.sum(), fn.sum()
        
        f1 = (tp / (tp + 0.5*(fp + fn) + self.eps)).clamp(min=self.eps, max=1-self.eps)

        return 1 - f1.mean() + focalloss

In [12]:
class CompletionASLoss(nn.Module):
    '''
    MultiLabel ASL Loss + F1 Loss
    
    References:
    - ASL paper: https://arxiv.org/abs/2009.14119
    - optimized ASL: https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py
    '''
    def __init__(self, gamma_pos=1, gamma_neg=4, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False, average='macro'):
        super(CompletionASLoss, self).__init__()
        
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps
        self.average = average

        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
        self.tp = self.fp = self.fn = self.f1 = None

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        self.targets = y.float()
        self.anti_targets = 1 - y.float()

        # Calculating Probabilities
        self.xs_pos = torch.sigmoid(x.float())
        self.xs_neg = 1.0 - self.xs_pos
        
        # TP/FP/FN
        self.tp = (self.xs_pos * self.targets).sum(dim=0)
        self.fp = (self.xs_pos * self.anti_targets).sum(dim=0)
        self.fn = (self.xs_neg * self.targets).sum(dim=0)        
        
        if self.average == 'micro':
            self.tp = self.tp.sum()
            self.fp = self.fp.sum()
            self.fn = self.fn.sum()
        
        # F1 score
        self.f1 = (self.tp / (self.tp + 0.5*(self.fp + self.fn) + self.eps)).clamp(
            min=self.eps, max=1-self.eps)

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            self.xs_neg.add_(self.clip)
            self.xs_neg.clamp_(max=1.)

        # Basic CE calculation
        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            self.xs_pos = self.xs_pos * self.targets
            self.xs_neg = self.xs_neg * self.anti_targets
            self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
                                          self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            self.loss *= self.asymmetric_w

        return -self.loss.mean(dim=0).sum() + (1.-self.f1.mean())

In [13]:
class CompletionBCELoss(nn.Module):
    '''
    MultiLabel weighted F1_Loss + BCEWithLogitsLosss.
    '''
    def __init__(self, eps=1e-8, average='macro', reduction='mean', weight=None, gamma=2):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(reduction=reduction)
        self.eps = eps
        self.average = average
        self.reduction = reduction
        self.weight = weight
        self.gamma = gamma
        if average not in ['macro', 'micro']:
            raise ValueError('average should be macro or micro.')
        
    def forward(self, pred, target): # same dimension
        bce_loss = self.bce(pred.float(), target.float())
        
        # f1 loss
        pred = F.softmax(pred, dim=1)
        
        tp = (target * pred).sum(dim=0).float()
        fp = ((1 - target) * pred).sum(dim=0).float()
        fn = (target * (1 - pred)).sum(dim=0).float()

        if self.average == 'micro':
            tp, fp, fn = tp.sum(), fp.sum(), fn.sum()
        
        f1 = tp / (tp + 0.5*(fp + fn) + self.eps).clamp(min=self.eps, max=1-self.eps)

        return 1 - f1.mean() + bce_loss

## Training function

In [14]:
# WandB, 일단 뺐음
# 0415: valid loss는 classification과 completion 각각 구해서 합한 것을 취해야 할 것으로 보임.
# ____: f1_score는 sklearn.metrics.f1_score로 대체, torchmetrics의 dependency 제거, 관련 오류 고침(f1score 항상 1 나오는 오류)

def train(model,
          dataloaders,
          #criterion,
          optimizer,
          scheduler,
          #metrics,
          dataset_sizes,
          device='cpu',
          num_epochs=20,
          #wandb_log=False,
          early_stop_patience=None,
          classify=True,
          complete=True,
          random_seed=1):

    def _concatenate(running_v, new_v):
        if running_v is not None:
            return np.concatenate((running_v, new_v.clone().detach().cpu().numpy()), axis=0)
        else:
            return new_v.clone().detach().cpu().numpy()
    
    np.random.seed(random_seed)
    torch.random.manual_seed(random_seed)

    #global label_weight
    if classify:
        criterion_class = ClassificationASLoss().to(device) # ClassificationFocalLoss().to(device)
    if complete:
        criterion_compl = CompletionASLoss().to(device) #CompletionBCELoss().to(device)

    since = time.time()

    best_model_wts = deepcopy(model.state_dict())
    best_loss = 1e4
    best_micro_f1 = 0. # classification only

    if early_stop_patience is not None:
        if not isinstance(early_stop_patience, int):
            raise TypeError('early_stop_patience should be an integer.')
        patience_cnt = 0
    
    print('-' * 5 + 'Training the model' + '-' * 5)
    for epoch in tqdm(range(num_epochs)):
        print(f'\nEpoch {epoch+1}/{num_epochs}')

        val_loss = 0. # sum of classification and completion loss

        # Each epoch has a training and validation phase
        for phase in ['train', 'valid_class', 'valid_compl']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
                if not classify and phase == 'valid_class':
                    continue
                elif not complete and phase == 'valid_compl':
                    continue

            running_loss_class = 0.
            running_loss_compl = 0.
            running_corrects_compl = 0.
            running_corrects_class = 0.
            running_labels_class = None
            running_preds_class = None
            
            dataset_name = phase
            if phase == 'train':
                dataset_name = 'train_class' if classify and not complete else 'train_compl'
            log_gap = 100 if dataset_name != 'train_compl' else 1000
            
            # Iterate over data.
            for idx, loaded_data in enumerate(dataloaders[dataset_name]):
                if phase == 'valid_compl':
                    bin_inputs, int_inputs, bin_labels = loaded_data  # no int_label
                else:
                    bin_inputs, int_inputs, bin_labels, int_labels = loaded_data
                
                batch_size = bin_inputs.size(0)
                num_items = bin_inputs.size(-1)
                if classify and phase in ['train', 'valid_class']:
                    labels_class = int_labels.to(device)
                if complete and phase in ['train', 'valid_compl']:
                    labels_compl = bin_labels.to(device)
                bin_inputs = bin_inputs.to(device)
                int_inputs = int_inputs.to(device)
                    
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs_class, outputs_compl = model(int_inputs, bin_x=bin_inputs)  # bin_x 없어도 작동은 가능
                    if classify and phase in ['train', 'valid_class']:
                        _, preds_class = torch.max(outputs_class, 1)
                    if complete and phase in ['train', 'valid_compl']:
                        _out = outputs_compl.clone().detach()
                        _out[bin_inputs==1] = float('-inf')
                        _, preds_compl = torch.max(_out, 1)
                        _label = labels_compl.clone().detach()
                        _label[bin_inputs==1] = 0  # 주어진 재료 말고,
                        _label = torch.arange(num_items).repeat(batch_size,1)[_label==1].long()  # 새로 넣어야 할 재료만 골라내기
                        
                    if idx == 0:  # 원래 idx == 0 
                        if classify and phase in ['train', 'valid_class']:
                            print('labels_classification', labels_class.cpu().numpy())
                            print('preds_classification', preds_class.cpu().numpy())
                        if complete and phase in ['train', 'valid_compl']:
                            print('labels_completion', _label.cpu().numpy())
                            print('preds_completion', preds_compl.cpu().numpy())
                    
                    if classify and phase in ['train', 'valid_class']:
                        loss_class = criterion_class(outputs_class, labels_class.long())
                    if complete and phase in ['train', 'valid_compl']:
                        loss_compl = criterion_compl(outputs_compl, labels_compl)

                    if classify and complete and phase == 'train':
                        loss = loss_class + loss_compl
                    elif classify and phase in ['train', 'valid_class']:
                        loss = loss_class
                    elif complete and phase in ['train', 'valid_compl']:
                        loss = loss_compl

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        #torch.nn.utils.clip_grad_norm_(model.parameters(), 1) # gradient clipping
                        optimizer.step()

                if idx % 100 == 0 and phase == 'train':
                    log_str = f'    {phase} {idx * 100 // len(dataloaders[dataset_name]):3d}% of an epoch | '
                    if classify and phase in ['train', 'valid_class']:
                        log_str += f'Loss(classif.): {loss_class.item():.4f} | '
                    if complete and phase in ['train', 'valid_compl']:
                        log_str += f'Loss(complet.): {loss_compl.item():.4f} | '
                    print(log_str)

                # statistics
                if classify and phase in ['train', 'valid_class']: # for F1 score & accuracy
                    running_loss_class += loss_class.item() * batch_size
                    running_labels_class = _concatenate(running_labels_class, labels_class)
                    running_preds_class = _concatenate(running_preds_class, preds_class)
                    running_corrects_class += torch.sum(preds_class == labels_class.data)
                if complete and phase in ['train', 'valid_compl']: # for accuracy
                    running_loss_compl += loss_compl.item() * batch_size
                    running_corrects_compl += torch.sum(preds_compl.detach().cpu() == _label)


            epoch_loss = 0.
            log_str = f'{phase.upper()} | '
            if classify and phase in ['train', 'valid_class']:
                epoch_loss_class = running_loss_class / dataset_sizes[dataset_name]
                epoch_loss += epoch_loss_class
                running_labels_class = torch.from_numpy(running_labels_class)
                running_preds_class = torch.from_numpy(running_preds_class)
                epoch_macro_f1 = f1_score(running_labels_class, running_preds_class, average='macro')  # classification: f1 scores.
                epoch_micro_f1 = f1_score(running_labels_class, running_preds_class, average='micro')
                epoch_acc_class = running_corrects_class / dataset_sizes[dataset_name]
                log_str += f'Loss(classif.): {epoch_loss_class:.3f} Acc(classif.): {epoch_acc_class:.3f} Macro-F1: {epoch_macro_f1:.3f} Micro-F1: {epoch_micro_f1:.3f} | '
            if complete and phase in ['train', 'valid_compl']:
                epoch_loss_compl = running_loss_compl / dataset_sizes[dataset_name]
                epoch_loss += epoch_loss_compl
                epoch_acc_compl = running_corrects_compl / dataset_sizes[dataset_name]  # completion task: accuracy.
                log_str += f'Loss(complet.): {epoch_loss_compl:.3f} Acc(complet.): {epoch_acc_compl:.3f} | '
            print(log_str)
            
            if phase == 'train':
                train_loss = epoch_loss
                if classify:
                    train_macro_f1 = epoch_macro_f1
                    train_micro_f1 = epoch_micro_f1
                # if wandb_log:
                #     wandb.watch(model)
            elif 'val' in phase:
                val_loss += epoch_loss
                if classify and phase == 'valid_class':
                    val_macro_f1 = epoch_macro_f1
                    val_micro_f1 = epoch_micro_f1
            
        
        if classify and not complete:
            scheduler.step(-val_micro_f1)  # because schedular's mode == 'min'
            # deep copy the model
            if val_micro_f1 > best_micro_f1:
                best_micro_f1 = val_micro_f1
                best_loss = val_loss  # Actually, it is not the best_loss, but just save it.
                best_model_wts = deepcopy(model.state_dict())
                if early_stop_patience is not None:
                    patience_cnt = 0
            elif early_stop_patience is not None:
                patience_cnt += 1
        else:
            scheduler.step(val_loss)
            # deep copy the model
            if val_loss < best_loss:
                best_loss = val_loss
                best_model_wts = deepcopy(model.state_dict())
                if early_stop_patience is not None:
                    patience_cnt = 0
            elif early_stop_patience is not None:
                patience_cnt += 1
            

        """
        if wandb_log:
            wandb.log({'train_loss': train_loss,
                       'val_loss': val_loss,
                       'train_macro_f1': train_macro_f1,
                       'train_micro_f1': train_micro_f1,
                       'val_macro_f1': val_macro_f1,
                       'val_micro_f1': val_micro_f1,
                       'best_val_loss': best_loss,
                       'learning_rate': optimizer.param_groups[0]['lr']})
                                        # scheduler.get_last_lr()[0] for CosineAnnealingWarmRestarts
        """
        if early_stop_patience is not None:
            if patience_cnt > early_stop_patience:
                print(f'Early stop at epoch {epoch}.')
                break
        

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    print('Best val Loss: {:4f}'.format(best_loss))
    if classify and not complete:
        print('Best micro f1 score: {:4f}'.format(best_micro_f1))
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, best_loss

## Experiment

In [15]:
def experiment(dim_embedding=256,
               dim_hidden=128,
               dropout=0.5,
               subset_length=None,
               decoder_mode='attention',
               batch_size=16,
               n_epochs=50,
               lr=1e-3,
               step_size=10,  # training scheduler
               seed=0,
               classify=True,
               complete=True,
               freeze_classify=False,
               freeze_complete=False,
               pretrained_model_path=None
               ):
    
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    
    train_data_name = 'train_class' if classify and not complete else 'train_compl'
    dataset_names = [train_data_name, 'valid_class', 'valid_compl']
    subset_indices = {x: [i for i in range(len(recipe_datasets[x]) if subset_length is None else subset_length)
                          ] for x in dataset_names}
    dataloaders = {x: DataLoader(Subset(recipe_datasets[x], subset_indices[x]),
                                 batch_size=batch_size, shuffle=True) for x in dataset_names}
    dataset_sizes = {x: len(subset_indices[x]) for x in dataset_names}
    print(dataset_sizes)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device: ', device)

    # Get a batch of training data
    bin_inputs, int_inputs, bin_labels, int_labels = next(iter(dataloaders[train_data_name]))
    print('inputs.shape', bin_inputs.shape, int_inputs.shape)
    print('labels.shape', bin_labels.shape, int_labels.shape)

    model_ft = CCNet(dim_embedding=dim_embedding, dim_output=20, dim_hidden=dim_hidden,
                     num_items=len(bin_inputs[0]), num_enc_layers=4, num_dec_layers=2, ln=True, dropout=dropout,
                     decoder_mode=decoder_mode, classify=classify, complete=complete,
                     freeze_classify=freeze_classify, freeze_complete=freeze_complete).to(device)
    if pretrained_model_path is not None:
        pretrained_dict = torch.load(pretrained_model_path)
        model_dict = model_ft.state_dict()

        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict) 
        # 3. load the new state dict
        model.load_state_dict(pretrained_dict)
        
    #print(model_ft)
    total_params = sum(dict((p.data_ptr(), p.numel()) for p in model_ft.parameters() if p.requires_grad ).values())
    print("Total Number of Parameters", total_params)

    #criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optimizer = optim.AdamW([p for p in model_ft.parameters() if p.requires_grad == True],
                                        lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.2)
    exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=step_size,
                                                      eps=1e-08, verbose=True)
    # (metric) F1Score objects - sklearn으로 대체했음. train 함수 cell 참고
    #macro_f1 = F1Score(num_classes=20, average='macro')
    #micro_f1 = F1Score(num_classes=20, average='micro')
    #metrics = {'macro_f1': macro_f1, 'micro_f1': micro_f1}

    model_ft, best_loss = train(model_ft, dataloaders, #criterion,
                                optimizer, exp_lr_scheduler, #metrics,
                                dataset_sizes, device=device, num_epochs=n_epochs, early_stop_patience=20,
                                classify=classify, complete=complete, random_seed=seed)
    
    fname = ['ckpt', 'CCNet', 'class', str(classify), 'compl', str(complete), 'best_loss', f'{best_loss:.4f}',
             'dim_embedding', str(dim_embedding), 'batch', str(batch_size), 'n_epochs', str(n_epochs),
             'lr', str(lr), 'step_size', str(step_size), 'seed', str(seed), 'subset_length', str(subset_length)]
    if complete:
        fname == ['decoder_mode', decoder_mode]
    fname = '_'.join(fname) + '.pt'
    if not os.path.isdir('./weights/'):
        os.mkdir('./weights/')
    torch.save(model_ft.state_dict(), os.path.join('./weights/', fname))

In [None]:
# Experiment 3-1. classification + completion 
experiment(batch_size=32, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, decoder_mode='concat', dropout=0.5)

In [None]:
# Experiment 3-2. classification + completion 
experiment(batch_size=32, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, decoder_mode='attention', dropout=0.5)

In [None]:
# Experiment 2-1. completion only
experiment(batch_size=32, n_epochs=100, lr=1e-4, dim_embedding=256, dim_hidden=256, decoder_mode='concat', classify=False)

In [None]:
# Experiment 2-2. completion only
experiment(batch_size=32, n_epochs=100, lr=1e-4, dim_embedding=256, dim_hidden=256, decoder_mode='attention', classify=False)

In [None]:
# Experiment 1. classification only
experiment(batch_size=128, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, dropout=0.5, complete=False)

In [None]:
# 시간: grad_clip 없을 때 기준. (gpu: 3090)
experiment(batch_size=2048, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, dropout=0.5, complete=False)

In [None]:
experiment(batch_size=1024, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, dropout=0.5, complete=False)

In [None]:
experiment(batch_size=512, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, dropout=0.5, complete=False)

In [None]:
experiment(batch_size=256, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, dropout=0.5, complete=False)

In [None]:
experiment(batch_size=128, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, dropout=0.5, complete=False)  # (Macro f1 0.658)

In [None]:
experiment(batch_size=64, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, dropout=0.5, complete=False)  # best micro f1 score (Macro f1 0.661)

In [None]:
experiment(batch_size=32, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, dropout=0.5, complete=False)

{'train_class': 23547, 'valid_class': 7848, 'valid_compl': 7848}
device:  cuda:0
inputs.shape torch.Size([32, 6714]) torch.Size([32, 59])
labels.shape torch.Size([32, 20]) torch.Size([32])
Total Number of Parameters 5059092
-----Training the model-----


  0%|          | 0/100 [00:00<?, ?it/s]


Epoch 1/100
labels_classification [ 9 16  9  5 13 11  9  3  9 16 11 16  3 18  9 14  7 13 11 16 13 13  5 16
  3 13  9  7 13  5  5  5]
preds_classification [ 1  5  3 11  3 13  0 11 10 13  1  1  0 17  3  1  0  1 13  1 10  3  3  0
  0  3 13 11 13  0  3  3]
    train   0% of an epoch | Loss(classif.): 3.6449 | 
    train  13% of an epoch | Loss(classif.): 2.6280 | 
    train  27% of an epoch | Loss(classif.): 2.8409 | 
    train  40% of an epoch | Loss(classif.): 2.2156 | 
    train  54% of an epoch | Loss(classif.): 2.5720 | 
    train  67% of an epoch | Loss(classif.): 2.0802 | 
    train  81% of an epoch | Loss(classif.): 1.9998 | 
    train  95% of an epoch | Loss(classif.): 2.0518 | 
TRAIN | Loss(classif.): 2.356 Acc(classif.): 0.497 Macro-F1: 0.223 Micro-F1: 0.497 | 
labels_classification [ 0 16  5  6 16  9  7 13 19  3  5  9 14  2  5  8  7 14  9  7  0  1  9 11
 11 18  7  9 11  3 13  9]
preds_classification [ 9 16 16  9 16  9  7 13 18 16  9  9 14  2  9  5  7  7  2  7 13 16  9  7
 11 1

  1%|          | 1/100 [00:28<46:36, 28.25s/it]

VALID_CLASS | Loss(classif.): 2.081 Acc(classif.): 0.569 Macro-F1: 0.299 Micro-F1: 0.569 | 

Epoch 2/100
labels_classification [14  7  9 13  9  8  9  9  3  8  3 10 11 13  9  9 13 19 13 16  7 13 16 13
 17  7 13 17  0 16 17  5]
preds_classification [14  7  9 13  9  9  9  9 12  9 11 18  3 13  9  9 13 16 13 16 18 13 16 13
  9  7 13  9 18 16  5  9]
    train   0% of an epoch | Loss(classif.): 1.8286 | 
    train  13% of an epoch | Loss(classif.): 1.7182 | 
    train  27% of an epoch | Loss(classif.): 1.7571 | 
    train  40% of an epoch | Loss(classif.): 1.9555 | 
    train  54% of an epoch | Loss(classif.): 1.7283 | 
    train  67% of an epoch | Loss(classif.): 2.2410 | 
    train  81% of an epoch | Loss(classif.): 1.8258 | 


In [None]:
experiment(batch_size=16, n_epochs=100, lr=1e-3, dim_embedding=256, dim_hidden=256, dropout=0.5, complete=False) # best loss (Macro f1 0.652)