In [1]:
import warnings
warnings.filterwarnings(action='ignore')

import torch 
import time 
from sklearn.model_selection import KFold
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from gensim.models.phrases import Phrases
from gensim.models import Word2Vec

from src.train_utils import set_seed, ModelSave, get_torch_device, EarlyStop, TrainParams
from src.evaluation import multi_cls_report, classification_inference
from src.metric import  multi_cls_metrics,multi_cls_log
from src.dataset.tokenizer import GensimTokenizer

from iflytek_app.dataset import MixDataset
from iflytek_app.models import Fasttext
from iflytek_app.process import train_process, test_process, result_process

device = get_torch_device()
set_seed()

No GPU available, using the CPU instead.


In [2]:
c2v = GensimTokenizer( Word2Vec.load('./checkpoint/char_min1_win5_sg_d100'))
w2v = GensimTokenizer(Word2Vec.load('./checkpoint/phrase_min1_win5_sg_d100'))
w2v.init_vocab()
c2v.init_vocab()
phraser = Phrases.load('./checkpoint/phrase_tokenizer')
df, label2idx = train_process()

                id           l1           l2          len
count  4199.000000  4199.000000  4199.000000  4199.000000
mean   2099.000000     8.969278    37.087878    46.057156
std    1212.291219     4.576621    79.204914    79.332999
min       0.000000     2.000000     1.000000     4.000000
25%    1049.500000     5.000000     6.000000    15.000000
50%    2099.000000     8.000000    12.000000    22.000000
75%    3148.500000    12.000000    26.000000    36.000000
max    4198.000000    32.000000   946.000000   961.000000
{'14784131 14858934 14784131 14845064': 0, '14852788 14717848 15639958 15632020': 1, '14844856 14724258 14925237 14854807': 2, '14925756 15639967 14853254 14728639': 3, '14844593 14924945': 4, '15709098 14716590 14924703 14779559': 5, '14726332 14728344 14854542 14844591': 6, '14858934 15636660 15704193 14849963': 7, '15710359 14847407 14845602 14859696': 8, '14794687 14782344': 9, '15630486 15702410 14718849 15709093': 10, '15632285 15706536 14721977 14925219': 11, '147829

In [5]:
log_steps = 10
save_steps = 20
kfold = 5
tp = TrainParams(
    epoch_size=30,
    lr=1e-3,
    loss_fn=nn.CrossEntropyLoss(),
    max_seq_len=1000,
    batch_size=64,
    dropout_rate=0.5,
    label_size = len(label2idx),
    vocab_size = c2v.vocab_size,
    embedding_dim = c2v.embedding_size + w2v.embedding_size,
    num_train_steps=int(df.shape[0]/kfold * (kfold-1)), 
    embedding1=c2v.embedding,
    embedding2=w2v.embedding,
    hidden_size=200,
    diff_lr =False,
    early_stop_params = {
        'monitor':'f1_micro',
        'mode':'max',
        'min_delta': 0,
        'patience':3,
        'verbose':False
    },
    scheduler_params={'mode': 'max',
                     'factor': 0.3,
                     'patience': 1,
                     'verbose': True,
                     'threshold':0.0001,
                     'threshold_mode':'rel',
                     'cooldown':0,
                     'min_lr':1e-6}
)


kf = KFold(n_splits=5, shuffle=True, random_state=24)
for fold,(train_index, valid_index) in enumerate(kf.split(df)):
    train, valid = df.iloc[train_index], df.iloc[valid_index]

    train_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               train['name'].values, train['description'].values, train['label'].values)
    valid_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               valid['name'].values, valid['description'].values, valid['label'].values)
    train_sampler = RandomSampler(train_dataset)
    valid_sampler = SequentialSampler(valid_dataset)
    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=tp.batch_size)
    valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=tp.batch_size)
    
    CKPT = './checkpoint/fasttext/k{}'.format(fold)
    saver = ModelSave(CKPT, continue_train=False)
    saver.init()
    tb = SummaryWriter(CKPT)
    es = EarlyStop(**tp.early_stop_params)
    global_step = 0

    
    model = Fasttext(tp)
    optimizer, scheduler = model.get_optimizer()
    
    for epoch_i in range(tp['epoch_size']):
        if global_step==1:
            print(model)
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10}  | {'Elapsed':^9}")
        print("-"*60)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()
        total_loss, batch_loss, batch_counts = 0, 0, 0

        model.train()
        for step, batch in enumerate(train_loader):
            global_step +=1
            batch_counts +=1

            #Forward propogate
            model.zero_grad()
            feature = {k:v.to(device) for k, v in batch.items()}
            logits  = model(feature)
            loss = model.compute_loss(feature, logits)
            batch_loss += loss.item()
            total_loss += loss.item()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            # Log steps for train loss logging
            if (step % log_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                time_elapsed = time.time() - t0_batch
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^9} | {time_elapsed:^9.2f}")
                tb.add_scalar('loss/batch_train', batch_loss / batch_counts, global_step=global_step)
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

            # Save steps for ckpt saving and dev evaluation
            if (step % save_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                val_metrics = multi_cls_metrics(model, valid_loader, device)
                for key, val in val_metrics.items():
                    tb.add_scalar(f'metric/{key}', val, global_step=global_step)
                avg_train_loss = total_loss / step
                tb.add_scalars('loss/train_valid',{'loss': avg_train_loss,
                                                    'oss': val_metrics['val_loss']}, global_step=global_step)
                saver(total_loss / step, val_metrics['val_loss'], epoch_i, global_step, model, optimizer, scheduler)

        # On Epoch End: calcualte train & valid loss and log overall metrics
        time_elapsed = time.time() - t0_epoch
        val_metrics = multi_cls_metrics(model, valid_loader, device)
        avg_train_loss = total_loss / step
        scheduler.step(val_metrics['f1_micro'])
        print("-"*70)
        print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_metrics['val_loss']:^10.6f} | {time_elapsed:^9.2f}")
        multi_cls_log(epoch_i, val_metrics)
        print("\n")
        if es.check(val_metrics):
            break 

./checkpoint/fasttext/k0 model cleaned
 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   1    |   10    |   3.224917   |     -     |   7.85   
   1    |   20    |   2.270853   |     -     |   6.31   
   1    |   30    |   2.546636   |     -     |   11.26  
   1    |   40    |   2.370054   |     -     |   6.70   
   1    |   50    |   2.374516   |     -     |   9.62   
   1    |   52    |   2.383209   |     -     |   0.91   
----------------------------------------------------------------------
   1    |    -    |   2.612713   |  2.366268  |   45.82  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   1    |  15.920%  |  67.281%  |  21.458%  |     24.859%     |   15.920%    |  31.905%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
---------

   8    |   30    |   0.777751   |     -     |   9.12   
   8    |   40    |   0.680992   |     -     |   6.67   
   8    |   50    |   0.605592   |     -     |   12.38  
   8    |   52    |   1.003255   |     -     |   1.54   
----------------------------------------------------------------------
   8    |    -    |   0.711052   |  1.333730  |   46.61  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   8    |  49.918%  |  88.368%  |  61.145%  |     55.741%     |   49.918%    |  67.024%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   8    |  67.024%  |     -     |     -     |     67.024%     |   67.024%    |  67.024%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
--------------------------------------

Epoch    15: reducing learning rate of group 0 to 9.0000e-05.
----------------------------------------------------------------------
  15    |    -    |   0.213470   |  1.320475  |   41.44  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  15    |  55.837%  |  88.401%  |  61.537%  |     61.587%     |   55.837%    |  71.786%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  15    |  71.786%  |     -     |     -     |     71.786%     |   71.786%    |  71.786%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  16    |   10    |   0.190246   |     -     |   6.79   
  16    |   20    |   0.194375   |     -     |   6.29   
  16    |   30    |   0.21787

  23    |   10    |   0.172424   |     -     |   6.76   
  23    |   20    |   0.197955   |     -     |   6.35   
  23    |   30    |   0.160062   |     -     |   10.04  
  23    |   40    |   0.177815   |     -     |   6.74   
  23    |   50    |   0.152603   |     -     |   9.29   
  23    |   52    |   0.101374   |     -     |   0.91   
----------------------------------------------------------------------
  23    |    -    |   0.172765   |  1.276763  |   43.11  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  23    |  55.636%  |  88.288%  |  66.375%  |     64.087%     |   55.636%    |  72.619%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  23    |  72.619%  |     -     |     -     |     72.619%     |   72.619%    

   7    |   40    |   0.737773   |     -     |   6.01   
   7    |   50    |   0.896897   |     -     |   8.59   
   7    |   52    |   1.156845   |     -     |   0.93   
----------------------------------------------------------------------
   7    |    -    |   0.928988   |  1.115368  |   40.50  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   7    |  54.448%  |  83.136%  |  64.604%  |     66.746%     |   54.448%    |  69.762%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   7    |  69.762%  |     -     |     -     |     69.762%     |   69.762%    |  69.762%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   8    |   10    |   0.782642   |

  15    |   10    |   0.173898   |     -     |   6.88   
  15    |   20    |   0.146419   |     -     |   6.29   
  15    |   30    |   0.179862   |     -     |   9.62   
  15    |   40    |   0.195154   |     -     |   5.81   
  15    |   50    |   0.191620   |     -     |   8.96   
  15    |   52    |   0.205055   |     -     |   0.89   
----------------------------------------------------------------------
  15    |    -    |   0.181799   |  1.235160  |   41.44  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  15    |  64.211%  |  82.899%  |  68.469%  |     68.340%     |   64.211%    |  74.405%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  15    |  74.405%  |     -     |     -     |     74.405%     |   74.405%    

   3    |   30    |   1.704277   |     -     |   9.03   
   3    |   40    |   1.561633   |     -     |   5.98   
   3    |   50    |   1.451778   |     -     |   8.98   
   3    |   52    |   1.862376   |     -     |   0.88   
----------------------------------------------------------------------
   3    |    -    |   1.685298   |  1.597463  |   40.94  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   3    |  29.258%  |  83.263%  |  46.879%  |     46.307%     |   29.258%    |  53.929%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   3    |  53.929%  |     -     |     -     |     53.929%     |   53.929%    |  53.929%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
--------------------------------------

  11    |   10    |   0.536214   |     -     |   6.67   
  11    |   20    |   0.356104   |     -     |   6.04   
  11    |   30    |   0.365840   |     -     |   8.93   
  11    |   40    |   0.387132   |     -     |   5.82   
  11    |   50    |   0.334900   |     -     |   8.71   
  11    |   52    |   0.366593   |     -     |   0.91   
----------------------------------------------------------------------
  11    |    -    |   0.405217   |  1.251171  |   40.04  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  11    |  56.873%  |  89.436%  |  63.115%  |     64.533%     |   56.873%    |  70.357%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  11    |  70.357%  |     -     |     -     |     70.357%     |   70.357%    

  18    |   30    |   0.165677   |     -     |   8.73   
  18    |   40    |   0.149818   |     -     |   5.82   
  18    |   50    |   0.123721   |     -     |   8.90   
  18    |   52    |   0.119125   |     -     |   0.90   
----------------------------------------------------------------------
  18    |    -    |   0.153529   |  1.189157  |   40.41  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  18    |  54.784%  |  89.101%  |  62.850%  |     60.371%     |   54.784%    |  71.190%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  18    |  71.190%  |     -     |     -     |     71.190%     |   71.190%    |  71.190%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
--------------------------------------

   7    |   52    |   0.712009   |     -     |   0.89   
----------------------------------------------------------------------
   7    |    -    |   0.896499   |  1.141249  |   39.67  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   7    |  50.738%  |  86.314%  |  60.452%  |     65.524%     |   50.738%    |  69.405%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   7    |  69.405%  |     -     |     -     |     69.405%     |   69.405%    |  69.405%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   8    |   10    |   0.687231   |     -     |   6.67   
   8    |   20    |   0.695820   |     -     |   6.03   
   8    |   30    |   0.841864   |

  15    |   10    |   0.173011   |     -     |   7.16   
  15    |   20    |   0.147516   |     -     |   6.55   
  15    |   30    |   0.126082   |     -     |   10.09  
  15    |   40    |   0.172698   |     -     |   6.17   
  15    |   50    |   0.183680   |     -     |   8.04   
  15    |   52    |   0.264511   |     -     |   0.71   
----------------------------------------------------------------------
  15    |    -    |   0.167921   |  1.184091  |   41.03  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  15    |  56.507%  |  85.935%  |  63.302%  |     71.060%     |   56.507%    |  73.333%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  15    |  73.333%  |     -     |     -     |     73.333%     |   73.333%    

   1    |   20    |   2.390223   |     -     |   6.29   
   1    |   30    |   2.516872   |     -     |   9.33   
   1    |   40    |   2.403444   |     -     |   6.11   
   1    |   50    |   2.315595   |     -     |   9.26   
   1    |   52    |   2.412311   |     -     |   0.91   
----------------------------------------------------------------------
   1    |    -    |   2.555357   |  2.324626  |   41.88  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   1    |  11.005%  |  58.721%  |  21.756%  |     13.724%     |   11.005%    |  28.486%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   1    |  28.486%  |     -     |     -     |     28.486%     |   28.486%    |  28.486%  




 Epoch  |  Batch  |  Train Loss  |  Val 

----------------------------------------------------------------------
   8    |    -    |   0.741010   |  1.088877  |   47.35  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   8    |  46.943%  |  85.603%  |  56.119%  |     54.258%     |   46.943%    |  70.083%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   8    |  70.083%  |     -     |     -     |     70.083%     |   70.083%    |  70.083%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   9    |   10    |   0.723213   |     -     |   4.84   
   9    |   20    |   0.547902   |     -     |   4.49   
   9    |   30    |   0.580553   |     -     |   6.00   
   9    |   40    |   0.615530   |

  16    |   10    |   0.150212   |     -     |   5.97   
  16    |   20    |   0.144794   |     -     |   5.84   
  16    |   30    |   0.131584   |     -     |   8.37   
  16    |   40    |   0.155865   |     -     |   5.20   
  16    |   50    |   0.125209   |     -     |   8.39   
  16    |   52    |   0.054652   |     -     |   0.75   
----------------------------------------------------------------------
  16    |    -    |   0.141080   |  1.092981  |   38.08  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  16    |  56.583%  |  87.179%  |  60.433%  |     57.860%     |   56.583%    |  74.613%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  16    |  74.613%  |     -     |     -     |     74.613%     |   74.613%    