In [22]:
# h5py 안 될 때
#!brew reinstall hdf5
#!export CPATH="/opt/homebrew/include/"
#!export HDF5_DIR=/opt/homebrew/
#!python3 -m pip install h5py

In [23]:
# For Colab
# from google.colab import drive
# drive.mount('/content/drive')
# %cd drive/MyDrive/cuisine-prediction/Hanseul/
# !pip3 install torchmetrics

In [24]:
import os
import pickle
import math
import time
from tqdm.notebook import tqdm
from copy import deepcopy

import h5py
import numpy as np
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torchmetrics import F1Score
from torch.utils.data import Dataset, Subset, DataLoader


In [25]:
path_root = '../'
path_container = './Container/'

## Loading Datasets

In [26]:
with open(path_container + 'id_cuisine_dict.pickle', 'rb') as f:
    id_cuisine_dict = pickle.load(f)
with open(path_container + 'cuisine_id_dict.pickle', 'rb') as f:
    cuisine_id_dict = pickle.load(f)
with open(path_container + 'id_ingredient_dict.pickle', 'rb') as f:
    id_ingredient_dict = pickle.load(f)
with open(path_container + 'ingredient_id_dict.pickle', 'rb') as f:
    ingredient_id_dict = pickle.load(f)

In [27]:
class RecipeDataset(Dataset):
    def __init__(self, data_dir, test=False):
        self.data_dir = data_dir
        self.test = test
        with h5py.File(data_dir, 'r') as data_file:
            self.bin_data = data_file['bin_data'][:]  # Size (num_recipes=23547, num_ingredients=6714)
            if not test:
                self.int_labels = data_file['int_labels'][:]  # Size (num_recipes=23547,), about cuisines
                self.bin_labels = data_file['bin_labels'][:]  # Size (num_recipes=23547, 20), about cuisines
        
        self.padding_idx = self.bin_data.shape[1]  # == num_ingredient == 6714
        self.max_num_ingredients_per_recipe = self.bin_data.sum(1).max()  # valid & test의 경우 65
        
        # (59나 65로) 고정된 길이의 row vector에 해당 recipe의 indices 넣고 나머지는 padding index로 채워넣기
        # self.int_data: Size (num_recipes=23547, self.max_num_ingredients_per_recipe=59 or 65)
        self.int_data = np.full((len(self.bin_data), self.max_num_ingredients_per_recipe), self.padding_idx) 
        for i, bin_recipe in enumerate(self.bin_data):
            recipe = np.arange(self.padding_idx)[bin_recipe==1]
            self.int_data[i][:len(recipe)] = recipe
        
    def __len__(self):
        return len(self.bin_data)

    def __getitem__(self, idx):
        bin_data = self.bin_data[idx]
        int_data = self.int_data[idx]
        bin_label = None if self.test else self.bin_labels[idx]
        int_label = None if self.test else self.int_labels[idx]
        
        return bin_data, int_data, bin_label, int_label

In [28]:
dataset_name = ['train', 'valid_class', 'valid_compl', 'test_class', 'test_compl']

recipe_datasets = {x: RecipeDataset(os.path.join(path_container, x), test='test' in x) for x in dataset_name}

In [29]:
_bd,_id,_bl,_il = recipe_datasets['train'][0]
print(_bd.shape)
print(_id.shape)
print(_bl.shape)
print(_il.shape)

(6714,)
(59,)
(20,)
()


## Model

In [30]:
## Building blocks of Set Transformers ##
# added masks.

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False, dropout=0):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.p = dropout
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Sequential(
            nn.Linear(dim_V, dim_V),
            nn.ReLU(),
            nn.Linear(dim_V, dim_V))
        self.Dropout = nn.Dropout(p=dropout)

    def forward(self, Q, K, mask=None):
        # Q (batch, q_len, d_hid)
        # K (batch, k_len, d_hid)
        # V (batch, v_len, d_hid == dim_V)
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)
        
        dim_split = self.dim_V // self.num_heads
        
        # Q_ (batch * num_heads, q_len, d_hid // num_heads)
        # K_ (batch * num_heads, k_len, d_hid // num_heads)
        # V_ (batch * num_heads, v_len, d_hid // num_heads)
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)
        
        # energy (batch * num_heads, q_len, k_len)
        energy = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V)
        if mask is not None:
            energy.masked_fill_(mask, float('-inf'))
        A = torch.softmax(energy, 2)
        
        # O (batch, q_len, d_hid)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        _O = self.fc_o(O)
        if self.p > 0:
            _O = self.Dropout(_O)
        O = O + _O 
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False, dropout=0.2):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None):
        return self.mab(X, X, mask=mask)

class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False, dropout=0.2):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln, dropout=dropout)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X, mask=mask)
        return self.mab1(X, H)

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False, dropout=0.2):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln, dropout=dropout)
        
    def forward(self, X, mask=None):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X, mask=mask)

In [31]:
def make_one_hot(x):
        """ Convert int_data into bin_data, if needed. """
        if type(x) is not torch.Tensor:
            x = torch.LongTensor(x)
        if x.dim() > 2:
            x = x.squeeze()
            if x.dim() > 2:
                return False
        elif x.dim() < 2:
            x = x.unsqueeze(0)
        return F.one_hot(x).sum(1)[:,:-1]

class CCNet(nn.Module):
    def __init__(self, dim_embedding=256, #
                 dim_output=20,
                 num_items=6714, 
                 num_inds=32, 
                 dim_hidden=128, 
                 num_heads=4, 
                 num_outputs=1+1,  # classification 1 + completion 1
                 num_enc_layers=4, 
                 num_dec_layers=2,
                 ln=True,          # LayerNorm option
                 dropout=0.2,      # Dropout option
                 classify=True,    # completion만 하고 싶으면 False로
                 complete=True):    # classification만 하고 싶으면 False로
   
        super(CCNet, self).__init__()
        
        self.num_heads = num_heads
        self.padding_idx = num_items
        self.classify, self.complete = classify, complete
        self.embedding =  nn.Embedding(num_embeddings=num_items+1, embedding_dim=dim_embedding, padding_idx=-1)
        """
        self.encoder = nn.ModuleList(
            [SAB(dim_embedding, dim_hidden, num_heads, ln=ln, dropout=dropout)] +
            [SAB(dim_hidden, dim_hidden, num_heads, ln=ln, dropout=dropout) for _ in range(num_enc_layers-1)])
        """
        self.encoder = nn.ModuleList(
            [ISAB(dim_embedding, dim_hidden, num_heads, num_inds, ln=ln, dropout=dropout)] +
            [ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln, dropout=dropout) for _ in range(num_enc_layers-1)])
        #for p in self.encoder.parameters():
        #    p.requires_grad = False
        self.pooling = PMA(dim_hidden, num_heads, num_outputs, ln=ln)
        #for p in self.pooling.parameters():
        #    p.requires_grad = False
        if classify:
            self.decoder1 = nn.Sequential(
                *[SAB(dim_hidden, dim_hidden, num_heads, ln=ln, dropout=dropout) for _ in range(num_dec_layers)])
            self.ff1 = nn.Sequential(
                nn.Linear(dim_hidden, dim_hidden),
                nn.ReLU(),
                nn.Linear(dim_hidden, dim_output))
        if complete:
            self.decoder2 = nn.ModuleList(
                [MAB(dim_hidden, dim_embedding, dim_hidden, num_heads, ln=ln, dropout=dropout) for _ in range(num_dec_layers)])
            self.ff2 = nn.Linear(dim_hidden, num_items)
    
    def forward(self, x, bin_x=None): 
        # x(=recipes): (batch, max_num_ingredient=65) : int_data.
        if not (self.classify or self.complete):
            return
        #print('x',x)
        #print("==="*50)
        feature = self.embedding(x)
        #print('feature',feature)
        #print('feature is NaN? :', torch.isnan(feature).any().item())
        # feature: (batch, max_num_ingredient=65, dim_embedding=256)
        # cf. embedding.weight: (num_items+1=6715, dim_embedding=256)

        mask = (x == self.padding_idx).repeat(self.num_heads,1).unsqueeze(1)
        #print('mask',mask)
        #print('mask is NaN? :', torch.isnan(mask).any().item())
        # mask: (batch*num_heads, 1, max_num_ingredient=65)
        code = feature.clone()
        #cnt = 0
        #print(f'code{cnt}',code)
        for module in self.encoder:
            #cnt += 1
            code = module(code, mask=mask)
            #print(f'code{cnt}',code)
            #print('code is NaN? :', torch.isnan(code).any().item())
        # code: (batch, max_num_ingredient=65, dim_hidden=128) : permutation-equivariant.

        pooled = self.pooling(code, mask=mask)
        #print('pooled',pooled)
        #print('pooled is NaN? :', torch.isnan(pooled).any().item())
        # pooled: (batch, num_outputs=2, dim_hidden=128) : permutation-invariant.

        signals = self.decoder1(pooled)
        #print('signals', signals)
        #print('signals is NaN? :', torch.isnan(signals).any().item())
        # no mask; signals: (batch, num_outputs=2, dim_hidden=128) : permutation-invariant.

        if signals.size(1) == 2 and self.classify and self.complete:
            # split two signals: for classification & completion.
            signal_classification, signal_completion = signals.chunk(2, dim=1)  # (batch, 1, dim_hidden=128) * 2
            #print('signal_classification is NaN? :', torch.isnan(signal_classification).any().item())
            #print('signal_completion is NaN? :', torch.isnan(signal_completion).any().item())
        elif signals.size(1) == 1:
            if self.classify and not self.complete:
                signal_classification = signals
            elif self.complete and not self.classify:
                signal_completion = signals
        else:
            raise ValueError(f"num_outputs={signals.size(1)}; but classify={self.classify} and complete={self.complete}")
        #print('signal_classification', signal_classification)

        logit_classification, logit_completion = None, None

        # Classification:
        if self.classify:
            logit_classification = self.ff1(signal_classification.squeeze(1))  # (batch, dim_output)
            #print('logit_classification', logit_classification)
            #print('logit_classification is NaN? :', torch.isnan(logit_classification).any().item())
        
        # Completion:
        if self.complete:
            if bin_x is None:
                bin_x = make_one_hot(x)
            bool_x = (bin_x == True)
            #print('bool_x is NaN? :', torch.isnan(bool_x).any().item())

            used_ingred_mask = bool_x.repeat(self.num_heads,1).unsqueeze(1)
            #print('used_ingred_mask is NaN? :', torch.isnan(used_ingred_mask).any().item())
            # used_ingred_mask: (batch*num_heads, 1, num_items=6714)
            
            embedding_weight = self.embedding.weight[:-1].unsqueeze(0).repeat(feature.size(0),1,1)
            #print('embedding_weight is NaN? :', torch.isnan(embedding_weight).any().item())
            # embedding_weight: (batch, num_items+1=6715, dim_embedding=256)
            
            
            for module in self.decoder2:
                signal_completion = module(signal_completion, embedding_weight, mask=used_ingred_mask)
                #print('signal_completion is NaN? :', torch.isnan(signal_completion).any().item())
            logit_completion = self.ff2(signal_completion.squeeze()) # (batch, num_items=6714)
            #print('logit_completion is NaN? :', torch.isnan(logit_completion).any().item())
            logit_completion[bool_x] = float('-inf')

        return logit_classification, logit_completion

## Training function

In [32]:
# WandB, 일단 뺐음

def train(model,
          dataloaders,
          criterion,
          optimizer,
          scheduler,
          metrics,
          dataset_sizes,
          device='cpu',
          num_epochs=20,
          #wandb_log=False,
          early_stop_patience=None,
          classify=True,
          complete=True,
          random_seed=1):
    
    assert isinstance(metrics, dict), f"'metrics' argument should be a dictionary, but {type(metrics)}."
    
    def _concatenate(running_v, new_v):
        if running_v is not None:
            return np.concatenate((running_v, new_v.clone().detach().cpu().numpy()), axis=0)
        else:
            return new_v.clone().detach().cpu().numpy()
    
    np.random.seed(random_seed)
    torch.random.manual_seed(random_seed)

    since = time.time()

    best_model_wts = deepcopy(model.state_dict())
    best_loss = 1e4
    
    
    if early_stop_patience is not None:
        if not isinstance(early_stop_patience, int):
            raise TypeError('early_stop_patience should be an integer.')
        patience_cnt = 0
    

    print('-' * 5 + 'Training the model' + '-' * 5)
    for epoch in tqdm(range(num_epochs)):
        print(f'\nEpoch {epoch+1}/{num_epochs}')

        # Each epoch has a training and validation phase
        for phase in ['train', 'valid_class', 'valid_compl']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
                if not classify and phase == 'valid_class':
                    continue
                elif not complete and phase == 'valid_compl':
                    continue

            running_loss = 0.0
            running_corrects_compl = 0.0
            running_corrects_class = 0.0
            running_labels_class = None
            running_preds_class = None

            # Iterate over data.
            for idx, (bin_inputs, int_inputs, bin_labels, int_labels) in enumerate(dataloaders[phase]):
                #print("==", idx, "==")
                if classify and phase in ['train', 'valid_class']:
                    labels_class = int_labels.to(device)
                if complete:
                    # randomly remove one ingredient for each recipe/batch
                    if phase == 'train':
                        labels_compl = torch.zeros_like(int_labels)
                        for batch in range(int_labels.size(0)):
                            ingreds = torch.arange(bin_inputs.size(-1))[bin_inputs[batch]==1]
                            mask_ingred_idx = ingreds[np.random.randint(len(ingreds))]
                            bin_inputs[batch][mask_ingred_idx] = 0
                            int_inputs[batch][int_inputs[batch] == mask_ingred_idx] = int(bin_inputs.size(-1))
                            labels_compl[batch] = mask_ingred_idx
                        labels_compl = labels_compl.to(device)
                    elif phase == 'valid_compl':
                        labels_compl = int_labels.to(device)
                bin_inputs = bin_inputs.to(device)
                int_inputs = int_inputs.to(device)
                
                    
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs_class, outputs_compl = model(int_inputs, bin_x=bin_inputs)  # bin_x 없어도 작동은 가능
                    if classify and phase in ['train', 'valid_class']:
                        _, preds_class = torch.max(outputs_class, 1)
                    if complete and phase in ['train', 'valid_compl']:
                        _, preds_compl = torch.max(outputs_compl, 1)

                    if idx == 0 and phase == 'train':  # 원래 idx == 0 
                        if classify and phase in ['train', 'valid_class']:
                            print('labels_classification', labels_class.cpu().numpy())
                            #print('outputs_classification', outputs_class.clone().detach().cpu().numpy())
                            print('preds_classification', preds_class.cpu().numpy())
                        if complete and phase in ['train', 'valid_compl']:
                            print('labels_completion', labels_compl.cpu().numpy())
                            #print('outputs_completion', outputs_compl[0])
                            print('preds_completion', preds_compl.cpu().numpy())

                    if classify and complete and phase == 'train':
                        loss = criterion(outputs_class, labels_class.long()) \
                                + criterion(outputs_compl, labels_compl.long())
                    elif classify and phase in ['train', 'valid_class']:
                        loss = criterion(outputs_class, labels_class.long())
                    elif complete and phase in ['train', 'valid_compl']:
                        loss = criterion(outputs_compl, labels_compl.long())

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                        optimizer.step()

                if idx % 100 == 0:
                    print(f'    {phase} {idx * 100 // len(dataloaders[phase]):3d}% of an epoch | Loss: {loss.item()}')

                # statistics
                running_loss += loss.item() * bin_inputs.size(0)
                if classify and phase in ['train', 'valid_class']: # for F1 score & accuracy
                    running_labels_class = _concatenate(running_labels_class, labels_class)
                    running_preds_class = _concatenate(running_preds_class, labels_class)
                    running_corrects_class += torch.sum(preds_class == labels_class.data)
                if complete and phase in ['train', 'valid_compl']: # for accuracy
                    running_corrects_compl += torch.sum(preds_compl == labels_compl.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            log_str = f'{phase.upper()} Loss: {epoch_loss:.4f} '
            if classify and phase in ['train', 'valid_class']:
                running_labels_class = torch.from_numpy(running_labels_class)
                running_preds_class = torch.from_numpy(running_preds_class)
                epoch_macro_f1 = metrics['macro_f1'](running_labels_class, running_preds_class)  # classification: f1 scores.
                epoch_micro_f1 = metrics['micro_f1'](running_labels_class, running_preds_class)
                epoch_acc_class = running_corrects_class / dataset_sizes[phase]
                log_str += f'Acc(classif.): {epoch_acc_class:.4f} Macro-F1: {epoch_macro_f1:.4f} Micro-F1: {epoch_micro_f1:.4f} '
            if complete and phase in ['train', 'valid_compl']:
                epoch_acc_compl = running_corrects_compl / dataset_sizes[phase]  # completion task: accuracy.
                log_str += f'Acc(compl.): {epoch_acc_compl:.4f} '
            print(log_str)
            
            if phase == 'train':
                train_loss = epoch_loss
                train_macro_f1 = epoch_macro_f1
                train_micro_f1 = epoch_micro_f1
                # if wandb_log:
                #     wandb.watch(model)
            elif 'val' in phase:
                val_loss = epoch_loss
                val_macro_f1 = epoch_macro_f1
                val_micro_f1 = epoch_micro_f1
            # if phase == 'train':
            #     scheduler.step()
            if 'val' in phase:
                scheduler.step(val_loss)

            if 'val' in phase:
                # deep copy the model
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = deepcopy(model.state_dict())
                    if early_stop_patience is not None:
                        patience_cnt = 0
                elif early_stop_patience is not None:
                    patience_cnt += 1

        """
        if wandb_log:
            wandb.log({'train_loss': train_loss,
                       'val_loss': val_loss,
                       'train_macro_f1': train_macro_f1,
                       'train_micro_f1': train_micro_f1,
                       'val_macro_f1': val_macro_f1,
                       'val_micro_f1': val_micro_f1,
                       'best_val_loss': best_loss,
                       'learning_rate': optimizer.param_groups[0]['lr']})
                                        # scheduler.get_last_lr()[0] for CosineAnnealingWarmRestarts
        """
        if early_stop_patience is not None:
            if patience_cnt > early_stop_patience:
                print(f'Early stop at epoch {epoch}.')
                break
        

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Loss: {:4f}'.format(best_loss))

    # after last epoch, generate confusion matrix of validation phase

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

## Experiment

In [33]:
def experiment(dim_embedding=256,
               dropout=0.2,
               subset_length=None,
               batch_size=16,
               n_epochs=50,
               lr=1e-3,
               step_size=10,  # training scheduler
               seed=0,
               classify=True,
               complete=True):
    
    dataset_name = ['train', 'valid_class', 'valid_compl']
    if subset_length is None:
        dataloaders = {x: DataLoader(recipe_datasets[x], batch_size=batch_size,
                                    shuffle=True) for x in dataset_name}
        dataset_sizes = {x: len(recipe_datasets[x]) for x in dataset_name}
    else:
        dataloaders = {x: DataLoader(Subset(recipe_datasets[x], range(subset_length)), batch_size=batch_size,
                                    shuffle=True) for x in dataset_name}
        dataset_sizes = {x: len(Subset(recipe_datasets[x], range(subset_length))) for x in dataset_name}
    

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device: ', device)

    # Get a batch of training data
    bin_inputs, int_inputs, bin_labels, int_labels = next(iter(dataloaders['train']))
    print('inputs.shape', bin_inputs.shape, int_inputs.shape)
    print('labels.shape', bin_labels.shape, int_labels.shape)

    model_ft = CCNet(dim_embedding=dim_embedding, dim_output=len(bin_labels[0]),
                     num_items=len(bin_inputs[0]), num_outputs=2 if classify and complete else 1,
                     num_enc_layers=4, num_dec_layers=2, ln=True, dropout=0.5,
                     classify=classify, complete=complete).to(device)
    #print(model_ft)
    total_params = sum(dict((p.data_ptr(), p.numel()) for p in model_ft.parameters() if p.requires_grad ).values())
    print("Total Number of Parameters", total_params)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optimizer = optim.AdamW([p for p in model_ft.parameters() if p.requires_grad == True],
                                        lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.2)
    exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=step_size,
                                                      eps=1e-08, verbose=True)
    # (metric) F1Score objects
    macro_f1 = F1Score(num_classes=20, average='macro')
    micro_f1 = F1Score(num_classes=20, average='micro')
    metrics = {'macro_f1': macro_f1, 'micro_f1': micro_f1}

    model_ft = train(model_ft, dataloaders, criterion, optimizer, exp_lr_scheduler, metrics,
                     dataset_sizes, device=device, num_epochs=n_epochs, early_stop_patience=20,
                     classify=classify, complete=complete, random_seed=seed)
    
    fname = ['ckpt', 'CCNet', 'batch', str(batch_size),
             'n_epochs', str(n_epochs), 'lr', str(lr), 'step_size', str(step_size),
             'seed', str(seed)]
    fname = '_'.join(fname) + '.pt'
    if not os.path.isdir('./weights/'):
        os.mkdir('./weights/')
    torch.save(model_ft.state_dict(), os.path.join('./weights/', fname))

In [34]:
# Dry Run
#experiment(n_epochs=1, lr=1e-6, dim_embedding=64, classify=False)

In [None]:
# Experiment
experiment(batch_size=16, n_epochs=100, lr=1e-3, dim_embedding=256, complete=False)