In [1]:
import torch
import time 
from sklearn.model_selection import KFold

import torch.nn as nn
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from gensim.models.phrases import Phrases
from gensim.models import Word2Vec

from src.train_utils import set_seed, ModelSave, get_torch_device, EarlyStop, TrainParams
from src.evaluation import classification_inference
from src.metric import  multi_cls_metrics,multi_cls_log
from src.dataset.tokenizer import GensimTokenizer

from iflytek_app.dataset import MixDataset
from iflytek_app.models import Textcnn
from iflytek_app.process import train_process, test_process, result_process

device = get_torch_device()
set_seed()

No GPU available, using the CPU instead.


In [2]:
c2v = GensimTokenizer( Word2Vec.load('./checkpoint/char_min1_win5_sg_d100'))
w2v = GensimTokenizer(Word2Vec.load('./checkpoint/phrase_min1_win5_sg_d100'))
w2v.init_vocab()
c2v.init_vocab()
phraser = Phrases.load('./checkpoint/phrase_tokenizer')
df, label2idx = train_process()

                id           l1           l2          len
count  4199.000000  4199.000000  4199.000000  4199.000000
mean   2099.000000     8.969278    37.087878    46.057156
std    1212.291219     4.576621    79.204914    79.332999
min       0.000000     2.000000     1.000000     4.000000
25%    1049.500000     5.000000     6.000000    15.000000
50%    2099.000000     8.000000    12.000000    22.000000
75%    3148.500000    12.000000    26.000000    36.000000
max    4198.000000    32.000000   946.000000   961.000000
{'14784131 14858934 14784131 14845064': 0, '14852788 14717848 15639958 15632020': 1, '14844856 14724258 14925237 14854807': 2, '14925756 15639967 14853254 14728639': 3, '14844593 14924945': 4, '15709098 14716590 14924703 14779559': 5, '14726332 14728344 14854542 14844591': 6, '14858934 15636660 15704193 14849963': 7, '15710359 14847407 14845602 14859696': 8, '14794687 14782344': 9, '15630486 15702410 14718849 15709093': 10, '15632285 15706536 14721977 14925219': 11, '147829

## Train 5-Fold Model


In [3]:
log_steps = 10
save_steps = 20
tp = TrainParams(
    epoch_size=30,
    lr=1e-3,
    loss_fn=nn.CrossEntropyLoss(),
    max_seq_len=1000,
    batch_size=64,
    dropout_rate=0.5,
    label_size = len(label2idx),
    vocab_size = w2v.vocab_size,
    embedding_dim = w2v.embedding_size + c2v.embedding_size,
    embedding1=c2v.embedding, 
    embedding2 =w2v.embedding,
    num_train_steps=int(df.shape[0]/5 *4), 
    filter_size=70,
    kernel_size_list = [2,3,4,5],
    hidden_size = 100,
    early_stop_params = {
        'monitor':'f1_micro',
        'mode':'max',
        'min_delta': 0,
        'patience':3,
        'verbose':False
    },
    scheduler_params={'mode': 'max',
                     'factor': 0.3,
                     'patience': 1,
                     'verbose': True,
                     'threshold':0.0001,
                     'threshold_mode':'rel',
                     'cooldown':0,
                     'min_lr':1e-6}
)



In [4]:

kf = KFold(n_splits=5, shuffle=True, random_state=24)
for fold,(train_index, valid_index) in enumerate(kf.split(df)):
    train, valid = df.iloc[train_index], df.iloc[valid_index]

    train_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               train['name'].values, train['description'].values, train['label'].values)
    valid_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               valid['name'].values, valid['description'].values, valid['label'].values)
    train_sampler = RandomSampler(train_dataset)
    valid_sampler = SequentialSampler(valid_dataset)
    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=tp.batch_size)
    valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=tp.batch_size)
    
    CKPT = './checkpoint/textcnn/k{}'.format(fold)
    saver = ModelSave(CKPT, continue_train=False)
    es = EarlyStop(**tp.early_stop_params)
    global_step = 0
    saver.init()
    tb = SummaryWriter(CKPT)
    model = Textcnn(tp)
    optimizer, scheduler = model.get_optimizer()
    
    
    for epoch_i in range(tp['epoch_size']):
        if global_step==1:
            print(model)
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10}  | {'Elapsed':^9}")
        print("-"*60)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()
        total_loss, batch_loss, batch_counts = 0, 0, 0

        model.train()
        for step, batch in enumerate(train_loader):
            global_step +=1
            batch_counts +=1

            #Forward propogate
            model.zero_grad()
            feature = {k:v.to(device) for k, v in batch.items()}
            logits = model(feature)
            loss = model.compute_loss(feature, logits)
            batch_loss += loss.item()
            total_loss += loss.item()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            # Log steps for train loss logging
            if (step % log_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                time_elapsed = time.time() - t0_batch
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^9} | {time_elapsed:^9.2f}")
                tb.add_scalar('loss/batch_train', batch_loss / batch_counts, global_step=global_step)
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

            # Save steps for ckpt saving and dev evaluation
            if (step % save_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                val_metrics = multi_cls_metrics(model, valid_loader, device)
                for key, val in val_metrics.items():
                    tb.add_scalar(f'metric/{key}', val, global_step=global_step)
                avg_train_loss = total_loss / step
                tb.add_scalars('loss/train_valid',{'train': avg_train_loss,
                                                    'valid': val_metrics['val_loss']}, global_step=global_step)
                saver(total_loss / step, val_metrics['val_loss'], epoch_i, global_step, model, optimizer, scheduler)

        # On Epoch End: calcualte train & valid loss and log overall metrics
        time_elapsed = time.time() - t0_epoch
        val_metrics = multi_cls_metrics(model, valid_loader, device)
        avg_train_loss = total_loss / step
        scheduler.step(val_metrics['f1_micro'])
        print("-"*70)
        print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_metrics['val_loss']:^10.6f} | {time_elapsed:^9.2f}")
        multi_cls_log(epoch_i, val_metrics)
        print("\n")
        if es.check(val_metrics):
            break 

./checkpoint/textcnn/k0 model cleaned
 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   1    |   10    |   3.680274   |     -     |   7.74   
   1    |   20    |   2.340497   |     -     |   7.69   




   1    |   30    |   2.658483   |     -     |   10.45  
   1    |   40    |   2.354419   |     -     |   6.86   
   1    |   50    |   2.194548   |     -     |   10.42  
   1    |   52    |   2.043952   |     -     |   1.08   
----------------------------------------------------------------------
   1    |    -    |   2.693277   |  2.050832  |   47.51  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   1    |  20.785%  |  74.853%  |  27.918%  |     19.512%     |   20.785%    |  45.119%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   1    |  45.119%  |     -     |     -     |     45.119%     |   45.119%    |  45.119%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
--------------------------------------

   9    |   10    |   0.353457   |     -     |   12.47  
   9    |   20    |   0.325020   |     -     |   11.18  
   9    |   30    |   0.160770   |     -     |   16.75  
   9    |   40    |   0.154682   |     -     |   11.79  
   9    |   50    |   0.142319   |     -     |   16.87  
   9    |   52    |   0.091085   |     -     |   1.81   
----------------------------------------------------------------------
   9    |    -    |   0.228810   |  1.078423  |   76.14  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   9    |  55.668%  |  89.996%  |  62.265%  |     63.700%     |   55.668%    |  73.095%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   9    |  73.095%  |     -     |     -     |     73.095%     |   73.095%    

   1    |   30    |   2.620985   |     -     |   9.84   
   1    |   40    |   2.357014   |     -     |   6.79   
   1    |   50    |   2.239123   |     -     |   9.86   
   1    |   52    |   2.133878   |     -     |   1.01   
----------------------------------------------------------------------
   1    |    -    |   2.823604   |  2.055621  |   44.85  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   1    |  20.286%  |  68.164%  |  28.459%  |     21.678%     |   20.286%    |  41.786%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   1    |  41.786%  |     -     |     -     |     41.786%     |   41.786%    |  41.786%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
--------------------------------------

   9    |   10    |   0.464687   |     -     |   7.44   
   9    |   20    |   0.407982   |     -     |   6.73   
   9    |   30    |   0.249552   |     -     |   9.67   
   9    |   40    |   0.263825   |     -     |   7.63   
   9    |   50    |   0.235173   |     -     |   10.44  
   9    |   52    |   0.256835   |     -     |   1.15   
----------------------------------------------------------------------
   9    |    -    |   0.330588   |  1.135926  |   46.36  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   9    |  63.061%  |  83.001%  |  69.276%  |     65.856%     |   63.061%    |  75.119%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   9    |  75.119%  |     -     |     -     |     75.119%     |   75.119%    

  16    |   30    |   0.072129   |     -     |   9.68   
  16    |   40    |   0.056205   |     -     |   6.72   
  16    |   50    |   0.058533   |     -     |   9.76   
  16    |   52    |   0.051051   |     -     |   1.00   
Epoch    16: reducing learning rate of group 0 to 2.7000e-05.
----------------------------------------------------------------------
  16    |    -    |   0.110229   |  1.109312  |   44.47  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  16    |  63.147%  |  82.860%  |  69.857%  |     68.924%     |   63.147%    |  76.071%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  16    |  76.071%  |     -     |     -     |     76.071%     |   76.071%    |  76.071%  




 Epoch  |  Batch  |  Train Loss  | 

   6    |   50    |   0.607681   |     -     |   10.80  
   6    |   52    |   0.774527   |     -     |   1.00   
----------------------------------------------------------------------
   6    |    -    |   0.661212   |  1.047008  |   48.16  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   6    |  51.574%  |  87.449%  |  63.766%  |     57.556%     |   51.574%    |  70.238%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   6    |  70.238%  |     -     |     -     |     70.238%     |   70.238%    |  70.238%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   7    |   10    |   0.613168   |     -     |   7.53   
   7    |   20    |   0.641291   |

 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   1    |   10    |   3.717372   |     -     |   7.52   
   1    |   20    |   2.436640   |     -     |   6.89   
   1    |   30    |   2.638983   |     -     |   10.25  
   1    |   40    |   2.409991   |     -     |   6.98   
   1    |   50    |   2.198556   |     -     |   9.90   
   1    |   52    |   2.113897   |     -     |   1.00   
----------------------------------------------------------------------
   1    |    -    |   2.730011   |  2.014649  |   45.60  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   1    |  21.055%  |  74.356%  |  28.530%  |     18.705%     |   21.055%    |  45.119%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------

   8    |   20    |   0.424472   |     -     |   6.84   
   8    |   30    |   0.275133   |     -     |   10.34  
   8    |   40    |   0.298913   |     -     |   6.93   
   8    |   50    |   0.291291   |     -     |   9.95   
   8    |   52    |   0.384770   |     -     |   1.01   
Epoch     8: reducing learning rate of group 0 to 3.0000e-04.
----------------------------------------------------------------------
   8    |    -    |   0.359016   |  1.176820  |   45.72  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   8    |  60.594%  |  87.831%  |  66.041%  |     59.754%     |   60.594%    |  71.667%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   8    |  71.667%  |     -     |     -     |     71.667%     |   71.667

  15    |   40    |   0.057595   |     -     |   6.67   
  15    |   50    |   0.066567   |     -     |   9.59   
  15    |   52    |   0.052382   |     -     |   1.01   
----------------------------------------------------------------------
  15    |    -    |   0.124382   |  1.104593  |   43.99  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  15    |  58.388%  |  87.817%  |  67.816%  |     66.952%     |   58.388%    |  74.524%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  15    |  74.524%  |     -     |     -     |     74.524%     |   74.524%    |  74.524%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  16    |   10    |   0.170858   |

   3    |   52    |   1.205098   |     -     |   1.02   
----------------------------------------------------------------------
   3    |    -    |   1.430278   |  1.202388  |   44.17  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   3    |  40.359%  |  86.446%  |  50.909%  |     47.156%     |   40.359%    |  64.839%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   3    |  64.839%  |     -     |     -     |     64.839%     |   64.839%    |  64.839%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   4    |   10    |   1.265130   |     -     |   7.41   
   4    |   20    |   1.237431   |     -     |   6.81   
   4    |   30    |   1.035680   |

  11    |   10    |   0.290660   |     -     |   7.99   
  11    |   20    |   0.319156   |     -     |   7.29   
  11    |   30    |   0.149211   |     -     |   10.67  
  11    |   40    |   0.122324   |     -     |   7.38   
  11    |   50    |   0.165952   |     -     |   10.53  
  11    |   52    |   0.110428   |     -     |   1.24   
----------------------------------------------------------------------
  11    |    -    |   0.211241   |  1.081315  |   48.12  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  11    |  55.342%  |  88.970%  |  62.121%  |     61.451%     |   55.342%    |  73.421%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  11    |  73.421%  |     -     |     -     |     73.421%     |   73.421%    

  18    |   50    |   0.026335   |     -     |   9.74   
  18    |   52    |   0.022512   |     -     |   1.00   
Epoch    18: reducing learning rate of group 0 to 9.0000e-05.
----------------------------------------------------------------------
  18    |    -    |   0.068203   |  1.106625  |   44.51  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  18    |  54.098%  |  88.802%  |  64.354%  |     62.106%     |   54.098%    |  74.255%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  18    |  74.255%  |     -     |     -     |     74.255%     |   74.255%    |  74.255%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  19    |   10    |   0.12037

## K-Fold Evalutaion

In [74]:
test = test_process()
test_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx,  test['name'].values, test['description'].values)
test_sampler = SequentialSampler(test_dataset)
test_loader = DataLoader(test_dataset, sampler=test_sampler, batch_size=tp.batch_size)

In [82]:
result = kfold_inference(test_loader, tp, Textcnn, './checkpoint/textcnn', 5)
result['pred'] = result['pred_major']
result_process(result, label2idx, 'textcnn_5fold_major.csv')

In [83]:
result['pred'] = result['pred_avg']
result_process(result, label2idx, 'textcnn_5fold_avg.csv')