In [9]:
import warnings
warnings.filterwarnings(action='ignore')

import torch
import time 
from sklearn.model_selection import KFold
import torch.nn as nn

from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter

from gensim.models.phrases import Phrases
from gensim.models import Word2Vec

from src.train_utils import set_seed, ModelSave, get_torch_device, EarlyStop, TrainParams
from src.metric import  multi_cls_metrics,multi_cls_log
from src.dataset.tokenizer import GensimTokenizer
from iflytek_app.dataset import MixDataset
from iflytek_app.models import TextcnnMixup
from iflytek_app.process import train_process,test_process,result_process, kfold_inference

device = get_torch_device()
set_seed()

No GPU available, using the CPU instead.


## Load Data & word2vec

In [2]:
c2v = GensimTokenizer( Word2Vec.load('./checkpoint/char_min1_win5_sg_d100'))
w2v = GensimTokenizer(Word2Vec.load('./checkpoint/phrase_min1_win5_sg_d100'))
w2v.init_vocab()
c2v.init_vocab()
phraser = Phrases.load('./checkpoint/phrase_tokenizer')
df, label2idx = train_process()

                id           l1           l2          len
count  4199.000000  4199.000000  4199.000000  4199.000000
mean   2099.000000     8.969278    37.087878    46.057156
std    1212.291219     4.576621    79.204914    79.332999
min       0.000000     2.000000     1.000000     4.000000
25%    1049.500000     5.000000     6.000000    15.000000
50%    2099.000000     8.000000    12.000000    22.000000
75%    3148.500000    12.000000    26.000000    36.000000
max    4198.000000    32.000000   946.000000   961.000000
{'14784131 14858934 14784131 14845064': 0, '14852788 14717848 15639958 15632020': 1, '14844856 14724258 14925237 14854807': 2, '14925756 15639967 14853254 14728639': 3, '14844593 14924945': 4, '15709098 14716590 14924703 14779559': 5, '14726332 14728344 14854542 14844591': 6, '14858934 15636660 15704193 14849963': 7, '15710359 14847407 14845602 14859696': 8, '14794687 14782344': 9, '15630486 15702410 14718849 15709093': 10, '15632285 15706536 14721977 14925219': 11, '147829

## Train 5-Fold Model


In [6]:
log_steps = 10
save_steps = 20
kfold=5
tp = TrainParams(
    epoch_size=30,
    lr=1e-3,
    loss_fn=nn.CrossEntropyLoss(),
    max_seq_len=1000,
    batch_size=64,
    dropout_rate=0.5,
    label_size = len(label2idx),
    vocab_size = w2v.vocab_size,
    embedding_dim = w2v.embedding_size + c2v.embedding_size,
    embedding1=c2v.embedding, 
    embedding2 =w2v.embedding,
    num_train_steps=int(df.shape[0]/kfold * (kfold-1)), 
    filter_size=70,
    kernel_size_list = [2,3,4,5],
    hidden_size = 100,
    mixup_alpha =0.1,
    early_stop_params = {
        'monitor':'f1_micro',
        'mode':'min',
        'min_delta': 0,
        'patience':3,
        'verbose':False
    },
    scheduler_params={'mode': 'max',
                     'factor': 0.3,
                     'patience': 1,
                     'verbose': True,
                     'threshold':0.0001,
                     'threshold_mode':'rel',
                     'cooldown':0,
                     'min_lr':1e-6}
)

In [11]:
kf = KFold(n_splits=kfold, shuffle=True, random_state=24)
for fold,(train_index, valid_index) in enumerate(kf.split(df)):
    train, valid = df.iloc[train_index], df.iloc[valid_index]

    train_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               train['name'].values, train['description'].values, train['label'].values)
    valid_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               valid['name'].values, valid['description'].values, valid['label'].values)
    train_sampler = RandomSampler(train_dataset)
    valid_sampler = SequentialSampler(valid_dataset)
    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=tp.batch_size)
    valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=tp.batch_size)
    
    CKPT = './checkpoint/textcnn_mixup/k{}'.format(fold)
    saver = ModelSave(CKPT, continue_train=False)
    tb = SummaryWriter(CKPT)
    es = EarlyStop(**tp.early_stop_params)
    global_step = 0
    saver.init()
    model = TextcnnMixup(tp)
    optimizer, scheduler = model.get_optimizer()
    
    
    for epoch_i in range(tp['epoch_size']):
        if global_step==1:
            print(model)
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10}  | {'Elapsed':^9}")
        print("-"*60)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()
        total_loss, batch_loss, batch_counts = 0, 0, 0

        model.train()
        for step, batch in enumerate(train_loader):
            global_step +=1
            batch_counts +=1

            #Forward propogate
            model.zero_grad()
            feature = {k:v.to(device) for k, v in batch.items()}
            logits = model(feature)
            loss = model.compute_loss(feature, logits)
            batch_loss += loss.item()
            total_loss += loss.item()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            # Log steps for train loss logging
            if (step % log_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                time_elapsed = time.time() - t0_batch
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^9} | {time_elapsed:^9.2f}")
                tb.add_scalar('loss/batch_train', batch_loss / batch_counts, global_step=global_step)
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

            # Save steps for ckpt saving and dev evaluation
            if (step % save_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                val_metrics = multi_cls_metrics(model, valid_loader, device)
                for key, val in val_metrics.items():
                    tb.add_scalar(f'metric/{key}', val, global_step=global_step)
                avg_train_loss = total_loss / step
                tb.add_scalars('loss/train_valid',{'train_loss': avg_train_loss,
                                                    'valid_loss': val_metrics['val_loss']}, global_step=global_step)
                saver(total_loss / step, val_metrics['val_loss'], epoch_i, global_step, model, optimizer, scheduler)

        # On Epoch End: calcualte train & valid loss and log overall metrics
        time_elapsed = time.time() - t0_epoch
        val_metrics = multi_cls_metrics(model, valid_loader, device)
        avg_train_loss = total_loss / step
        scheduler.step(val_metrics['f1_micro'])
        print("-"*70)
        print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_metrics['val_loss']:^10.6f} | {time_elapsed:^9.2f}")
        multi_cls_log(epoch_i, val_metrics)
        print("\n")
        if es.check(val_metrics):
            break 

./checkpoint/textcnn_mixup/k0 model cleaned
 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   1    |   10    |   3.750995   |     -     |   17.47  
   1    |   20    |   2.400598   |     -     |   19.96  
   1    |   30    |   2.677618   |     -     |   31.05  
   1    |   40    |   2.351341   |     -     |   19.94  
   1    |   50    |   2.210276   |     -     |   27.66  
   1    |   52    |   2.049949   |     -     |   2.71   
----------------------------------------------------------------------
   1    |    -    |   2.726138   |  2.073243  |  128.18  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   1    |  19.224%  |  72.335%  |  27.123%  |     20.333%     |   19.224%    |  43.214%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
----

   8    |   30    |   0.443280   |     -     |   32.01  
   8    |   40    |   0.544086   |     -     |   20.83  
   8    |   50    |   0.428949   |     -     |   28.72  
   8    |   52    |   0.952180   |     -     |   2.49   
Epoch     8: reducing learning rate of group 0 to 3.0000e-04.
----------------------------------------------------------------------
   8    |    -    |   0.668524   |  1.277239  |  136.76  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   8    |  55.538%  |  88.383%  |  61.297%  |     66.158%     |   55.538%    |  69.643%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   8    |  69.643%  |     -     |     -     |     69.643%     |   69.643%    |  69.643%  




 Epoch  |  Batch  |  Train Loss  | 

  15    |   40    |   0.694226   |     -     |   21.09  
  15    |   50    |   0.509203   |     -     |   29.46  
  15    |   52    |   0.476147   |     -     |   2.82   
----------------------------------------------------------------------
  15    |    -    |   0.495358   |  1.078441  |  134.47  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  15    |  55.417%  |  88.609%  |  61.110%  |     68.364%     |   55.417%    |  72.500%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  15    |  72.500%  |     -     |     -     |     72.500%     |   72.500%    |  72.500%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  16    |   10    |   0.750155   |

   2    |   50    |   1.708335   |     -     |   29.31  
   2    |   52    |   1.492835   |     -     |   3.01   
----------------------------------------------------------------------
   2    |    -    |   2.228519   |  1.643675  |  139.79  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   2    |  30.790%  |  74.595%  |  41.277%  |     37.204%     |   30.790%    |  52.143%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   2    |  52.143%  |     -     |     -     |     52.143%     |   52.143%    |  52.143%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   3    |   10    |   1.596376   |     -     |   22.63  
   3    |   20    |   1.646244   |

  10    |   10    |   0.786507   |     -     |   18.35  
  10    |   20    |   0.479596   |     -     |   17.55  
  10    |   30    |   0.468079   |     -     |   24.62  
  10    |   40    |   0.215438   |     -     |   16.62  
  10    |   50    |   0.374277   |     -     |   25.11  
  10    |   52    |   0.743298   |     -     |   2.86   
----------------------------------------------------------------------
  10    |    -    |   0.490617   |  0.972401  |  113.29  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  10    |  63.471%  |  83.411%  |  70.724%  |     69.543%     |   63.471%    |  76.786%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  10    |  76.786%  |     -     |     -     |     76.786%     |   76.786%    

  17    |   20    |   0.337230   |     -     |   19.68  
  17    |   30    |   0.252303   |     -     |   30.53  
  17    |   40    |   0.139471   |     -     |   23.50  
  17    |   50    |   0.454724   |     -     |   27.46  
  17    |   52    |   0.137716   |     -     |   2.75   
----------------------------------------------------------------------
  17    |    -    |   0.350381   |  0.928456  |  133.49  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  17    |  63.050%  |  83.512%  |  70.679%  |     69.283%     |   63.050%    |  76.786%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  17    |  76.786%  |     -     |     -     |     76.786%     |   76.786%    |  76.786%  




 Epoch  |  Batch  |  Train Loss  |  Val 

  24    |   40    |   0.694236   |     -     |   21.84  
  24    |   50    |   0.270257   |     -     |   32.23  
  24    |   52    |   0.143229   |     -     |   3.33   
----------------------------------------------------------------------
  24    |    -    |   0.493552   |  1.048002  |  146.74  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  24    |  62.417%  |  83.235%  |  69.433%  |     69.776%     |   62.417%    |  75.476%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  24    |  75.476%  |     -     |     -     |     75.476%     |   75.476%    |  75.476%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  25    |   10    |   0.657373   |

   4    |   10    |   1.244742   |     -     |   20.69  
   4    |   20    |   1.468969   |     -     |   14.49  
   4    |   30    |   1.157276   |     -     |   25.98  
   4    |   40    |   1.136507   |     -     |   18.44  
   4    |   50    |   1.090475   |     -     |   28.67  
   4    |   52    |   0.734442   |     -     |   3.03   
----------------------------------------------------------------------
   4    |    -    |   1.224871   |  1.243108  |  121.28  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   4    |  47.178%  |  86.933%  |  56.088%  |     55.535%     |   47.178%    |  67.381%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   4    |  67.381%  |     -     |     -     |     67.381%     |   67.381%    

  11    |   52    |   0.815602   |     -     |   2.89   
Epoch    11: reducing learning rate of group 0 to 3.0000e-04.
----------------------------------------------------------------------
  11    |    -    |   0.581700   |  0.991155  |  133.40  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  11    |  57.483%  |  88.669%  |  66.405%  |     71.467%     |   57.483%    |  74.048%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  11    |  74.048%  |     -     |     -     |     74.048%     |   74.048%    |  74.048%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  12    |   10    |   0.628084   |     -     |   22.17  
  12    |   20    |   0.53529

----------------------------------------------------------------------
  18    |    -    |   0.335534   |  1.191964  |  128.18  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  18    |  56.454%  |  87.510%  |  66.183%  |     70.641%     |   56.454%    |  73.690%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  18    |  73.690%  |     -     |     -     |     73.690%     |   73.690%    |  73.690%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  19    |   10    |   0.617525   |     -     |   15.44  
  19    |   20    |   0.415227   |     -     |   17.13  
  19    |   30    |   0.473104   |     -     |   28.13  
  19    |   40    |   0.389882   |

   3    |   10    |   1.727620   |     -     |   18.75  
   3    |   20    |   1.693853   |     -     |   17.35  
   3    |   30    |   1.516016   |     -     |   24.57  
   3    |   40    |   1.410744   |     -     |   16.49  
   3    |   50    |   1.303544   |     -     |   25.10  
   3    |   52    |   1.232341   |     -     |   2.73   
----------------------------------------------------------------------
   3    |    -    |   1.552117   |  1.480988  |  113.03  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   3    |  35.789%  |  83.771%  |  48.473%  |     52.148%     |   35.789%    |  60.119%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   3    |  60.119%  |     -     |     -     |     60.119%     |   60.119%    

  10    |   52    |   0.231931   |     -     |   1.82   
----------------------------------------------------------------------
  10    |    -    |   0.627079   |  1.074976  |   78.63  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  10    |  59.686%  |  87.482%  |  66.698%  |     71.003%     |   59.686%    |  75.714%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  10    |  75.714%  |     -     |     -     |     75.714%     |   75.714%    |  75.714%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  11    |   10    |   0.660228   |     -     |   13.12  
  11    |   20    |   0.546526   |     -     |   12.36  
  11    |   30    |   0.651862   |

----------------------------------------------------------------------
  17    |    -    |   0.450881   |  1.115088  |  237.12  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  17    |  57.102%  |  87.100%  |  63.544%  |     66.478%     |   57.102%    |  73.333%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  17    |  73.333%  |     -     |     -     |     73.333%     |   73.333%    |  73.333%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  18    |   10    |   0.486846   |     -     |   12.80  
  18    |   20    |   0.407192   |     -     |   11.33  
  18    |   30    |   0.391708   |     -     |   17.02  
  18    |   40    |   0.473954   |

  25    |   10    |   0.531181   |     -     |   13.12  
  25    |   20    |   0.393668   |     -     |   11.89  
  25    |   30    |   0.458273   |     -     |   17.44  
  25    |   40    |   0.381100   |     -     |   11.19  
  25    |   50    |   0.244401   |     -     |   11.01  
  25    |   52    |   0.055004   |     -     |   1.15   
----------------------------------------------------------------------
  25    |    -    |   0.398604   |  0.992846  |   69.28  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  25    |  59.588%  |  87.477%  |  65.991%  |     68.424%     |   59.588%    |  76.310%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  25    |  76.310%  |     -     |     -     |     76.310%     |   76.310%    

   2    |   50    |   1.719295   |     -     |   11.12  
   2    |   52    |   1.335518   |     -     |   1.14   
----------------------------------------------------------------------
   2    |    -    |   2.214214   |  1.573619  |   50.55  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   2    |  28.539%  |  83.229%  |  39.028%  |     34.746%     |   28.539%    |  52.205%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   2    |  52.205%  |     -     |     -     |     52.205%     |   52.205%    |  52.205%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   3    |   10    |   2.058890   |     -     |   8.43   
   3    |   20    |   1.899668   |

  10    |   10    |   0.725504   |     -     |   8.46   
  10    |   20    |   0.774765   |     -     |   7.69   
  10    |   30    |   0.545833   |     -     |   11.16  
  10    |   40    |   0.436577   |     -     |   7.68   
  10    |   50    |   0.347584   |     -     |   11.04  
  10    |   52    |   0.612045   |     -     |   1.15   
----------------------------------------------------------------------
  10    |    -    |   0.581773   |  1.021743  |   50.63  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  10    |  54.936%  |  89.069%  |  61.174%  |     58.040%     |   54.936%    |  73.659%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  10    |  73.659%  |     -     |     -     |     73.659%     |   73.659%    

  17    |   30    |   0.341398   |     -     |   10.10  
  17    |   40    |   0.441538   |     -     |   6.76   
  17    |   50    |   0.405032   |     -     |   9.74   
  17    |   52    |   0.244350   |     -     |   1.03   
Epoch    17: reducing learning rate of group 0 to 8.1000e-06.
----------------------------------------------------------------------
  17    |    -    |   0.478960   |  1.093106  |   45.54  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  17    |  54.795%  |  88.850%  |  61.591%  |     58.184%     |   54.795%    |  73.063%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  17    |  73.063%  |     -     |     -     |     73.063%     |   73.063%    |  73.063%  




 Epoch  |  Batch  |  Train Loss  | 

  24    |   40    |   0.455788   |     -     |   7.61   
  24    |   50    |   0.512639   |     -     |   11.08  
  24    |   52    |   0.157431   |     -     |   1.15   
----------------------------------------------------------------------
  24    |    -    |   0.474459   |  1.048282  |   50.60  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  24    |  54.508%  |  89.062%  |  61.948%  |     57.907%     |   54.508%    |  74.255%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  24    |  74.255%  |     -     |     -     |     74.255%     |   74.255%    |  74.255%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  25    |   10    |   0.621446   |

In [7]:
test = test_process()
test_dataset = test_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx,
                                         test['name'].values, test['description'].values)
test_sampler = SequentialSampler(test_dataset)
test_loader = DataLoader(test_dataset, sampler=test_sampler, batch_size=tp.batch_size)

In [19]:
result = kfold_inference(test_loader, tp,TextcnnMixup, './checkpoint/textcnn_mixup', 5, device)
result['pred'] = result['pred_avg']
result_process(result, label2idx, './submit/textcnn_mixup_5fold_avg.csv')