In [None]:
# h5py 안 될 때
#!brew reinstall hdf5
#!export CPATH="/opt/homebrew/include/"
#!export HDF5_DIR=/opt/homebrew/
#!python3 -m pip install h5py

In [2]:
import os
import pickle
import math
from copy import deepcopy

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset


In [3]:
path_root = '../'
path_container = './Container/'

### Loading Datasets

In [4]:
with open(path_container + 'id_cuisine_dict.pickle', 'rb') as f:
    id_cuisine_dict = pickle.load(f)
with open(path_container + 'cuisine_id_dict.pickle', 'rb') as f:
    cuisine_id_dict = pickle.load(f)
with open(path_container + 'id_ingredient_dict.pickle', 'rb') as f:
    id_ingredient_dict = pickle.load(f)
with open(path_container + 'ingredient_id_dict.pickle', 'rb') as f:
    ingredient_id_dict = pickle.load(f)

In [5]:
class RecipeDataset(Dataset):
    def __init__(self, data_dir, test=False):
        self.data_dir = data_dir
        self.test = test
        self.data_file = h5py.File(data_dir, 'r')
        self.bin_data = self.data_file['bin_data'][:]  # Size (num_recipes=23547, num_ingredients=6714)
        if not test:
            self.int_labels = self.data_file['int_labels'][:]  # Size (num_recipes=23547,), about cuisines
            self.bin_labels = self.data_file['bin_labels'][:]  # Size (num_recipes=23547, 20), about cuisines
        
        self.padding_idx = self.bin_data.shape[1]  # == num_ingredient == 6714
        self.max_num_ingredients_per_recipe = self.bin_data.sum(1).max()  # valid & test의 경우 65
        
        # (59나 65로) 고정된 길이의 row vector에 해당 recipe의 indices 넣고 나머지는 padding index로 채워넣기
        # self.int_data: Size (num_recipes=23547, self.max_num_ingredients_per_recipe=59 or 65)
        self.int_data = np.full((len(self.bin_data), self.max_num_ingredients_per_recipe), self.padding_idx) 
        for i, bin_recipe in enumerate(self.bin_data):
            recipe = np.arange(self.padding_idx)[bin_recipe==1]
            self.int_data[i][:len(recipe)] = recipe
        self.data_file.close()
        
    def __len__(self):
        return len(self.bin_data)

    def __getitem__(self, idx):
        bin_data = self.bin_data[idx]
        int_data = self.int_data[idx]
        bin_label = None if self.test else self.bin_labels[idx]
        int_label = None if self.test else self.int_labels[idx]
        return bin_data, int_data, bin_label, int_label

In [6]:
dataset_name = ['train', 'valid_class', 'valid_compl', 'test_class', 'test_compl']

recipe_datasets = {x: RecipeDataset(os.path.join(path_container, x), test='test' in x) for x in dataset_name}

### Model

In [None]:
# Utils : convert int_data to bin_data.

def make_one_hot(x):
        if type(x) is not torch.Tensor:
            x = torch.LongTensor(x)
        if x.dim() > 2:
            x = x.squeeze()
            if x.dim() > 2:
                return False
        elif x.dim() < 2:
            x = x.unsqueeze(0)
        return F.one_hot(x).sum(1)[:,:-1]

In [None]:
## Building blocks of Set Transformers ##
# added masks.

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Sequential(
            nn.Linear(dim_V, 2*dim_V),
            nn.ReLU(),
            nn.Linear(2*dim_V, dim_V))

    def forward(self, Q, K, mask=None):
        # Q (batch, q_len, d_hid)
        # K (batch, k_len, d_hid)
        # V (batch, v_len, d_hid == dim_V)
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)
        
        dim_split = self.dim_V // self.num_heads
        
        # Q_ (batch * num_heads, q_len, d_hid // num_heads)
        # K_ (batch * num_heads, k_len, d_hid // num_heads)
        # V_ (batch * num_heads, v_len, d_hid // num_heads)
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)
        
        # energy (batch * num_heads, q_len, k_len)
        energy = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
        if mask is not None:
            energy.masked_fill_(mask, float('-inf'))
        A = torch.softmax(energy, 2)
        
        # O (batch, q_len, d_hid)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + self.fc_o(O)
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X, mask=None):
        return self.mab(X, X, mask=mask)

class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X, mask=None):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X, mask=mask)
        return self.mab1(X, H)

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
        
    def forward(self, X, mask=None):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X, mask=mask)

In [None]:
class CCNet(nn.Module):
    def __init__(self, dim_input=256,
                 dim_output=20,
                 num_items=6714+1, 
                 num_inds=32, 
                 dim_hidden=128, 
                 num_heads=4, 
                 num_outputs=1+1,  # class 1 + compl 1
                 num_enc_layers=4, 
                 num_dec_layers=2, 
                 ln=False, 
                 classify=True, 
                 complete=True):
        super(CCNet, self).__init__()
        
        self.num_heads = num_heads
        self.padding_idx = num_items-1
        self.classify, self.complete = classify, complete
        self.embedding =  nn.Embedding(num_embeddings=num_items, embedding_dim=dim_input, padding_idx=-1)
        self.encoder = nn.ModuleList(
            [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln)] +
            [ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln) for _ in range(num_enc_layers-1)])
        self.pooling = PMA(dim_hidden, num_heads, num_outputs, ln=ln)
        self.decoder1 = nn.Sequential(
            *[SAB(dim_hidden, dim_hidden, num_heads, ln=ln) for _ in range(num_dec_layers)])
        self.ff1 = nn.Sequential(
            nn.Linear(dim_hidden, dim_hidden),
            nn.ReLU(),
            nn.Linear(dim_hidden, dim_output))
        self.decoder2 = nn.ModuleList(
            [MAB(dim_hidden, dim_input, dim_hidden, num_heads, ln=ln) for _ in range(2)])
        self.ff2 = nn.Linear(dim_hidden, num_items-1)
    
    
    
    def forward(self, x): 
        # x(=recipes): (batch, max_num_ingredient=65) : int_data.
        if not (self.classify or self.complete):
            return
        
        x = torch.LongTensor(x)
        feature = self.embedding(x)
        # feature: (batch, max_num_ingredient=65, dim_input=256)
        # cf. embedding.weight: (num_items=6715, dim_input=256)
        mask = (x == self.padding_idx).repeat(self.num_heads,1).unsqueeze(1)
        # mask: (batch*num_heads, 1, max_num_ingredient=65)
        code = feature.clone()
        for module in self.encoder:
            code = module(code, mask=mask)
        # code: (batch, max_num_ingredient=65, dim_hidden=128) : permutation-equivariant.
        
        pooled = self.pooling(code, mask=mask)
        # pooled: (batch, num_outputs=2, dim_hidden=128) : permutation-invariant.
        
        signals = self.decoder1(pooled)
        # no mask; signals: (batch, num_outputs=2, dim_hidden=128) : permutation-invariant.
        
        # split two signals: for classification & completion.
        signal_classification = signals[:][0]                   # (batch, dim_hidden=128)
        signal_completion = signals.clone()[:][1].unsqueeze(1)  # (batch, 1, dim_hidden=128)
        
        # Classification:
        if self.classify:
            logit_classification = self.ff1(signal_classification)  # (batch, dim_output)
            if not self.complete:
                return logit_clasification
        
        # Completion:
        if self.complete:
            bool_x = (make_one_hot(x) == True)
            used_ingred_mask = bool_x.repeat(self.num_heads,1).unsqueeze(1)
            # used_ingred_mask: (batch*num_heads, 1, num_items-1=6714)
            
            embedding_weight = self.embedding.weight[:-1].unsqueeze(0).repeat(feature.size(0),1,1)
            # embedding_weight: (batch, num_items=6715, dim_input=256)
            
            for module in self.decoder2:
                signal_completion = module(signal_completion, embedding_weight, mask=used_ingred_mask)
            logit_completion = self.ff2(signal_completion.squeeze()) # (batch, num_items-1=6714)
            logit_completion[bool_x] = float('-inf')
            if not self.classify:
                return logit_completion

        return logit_classification, logit_completion