In [10]:
import warnings
warnings.filterwarnings(action='ignore')

import pandas as pd 
import numpy as np 
import collections 
from itertools import chain
import torch 
import gensim
import time 
from sklearn.model_selection import KFold
from torch.utils.data.dataset import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from gensim.models.phrases import Phrases
from gensim.models import Word2Vec

from src.train_utils import set_seed, ModelSave, get_torch_device, EarlyStop, TrainParams
from src.evaluation import multi_cls_report, classification_inference
from src.metric import  multi_cls_metrics,multi_cls_log
from src.dataset.tokenizer import GensimTokenizer
from src.enhancement.adversarial import FGM

from iflytek_app.dataset import MixDataset
from iflytek_app.process import train_process, test_process, result_process,kfold_inference
from iflytek_app.models import Textcnn
device = get_torch_device()
set_seed()

No GPU available, using the CPU instead.


In [7]:
c2v = GensimTokenizer( Word2Vec.load('./checkpoint/char_min1_win5_sg_d100'))
w2v = GensimTokenizer(Word2Vec.load('./checkpoint/phrase_min1_win5_sg_d100'))
w2v.init_vocab()
c2v.init_vocab()
phraser = Phrases.load('./checkpoint/phrase_tokenizer')
df, label2idx = train_process()

                id           l1           l2          len
count  4199.000000  4199.000000  4199.000000  4199.000000
mean   2099.000000     8.969278    37.087878    46.057156
std    1212.291219     4.576621    79.204914    79.332999
min       0.000000     2.000000     1.000000     4.000000
25%    1049.500000     5.000000     6.000000    15.000000
50%    2099.000000     8.000000    12.000000    22.000000
75%    3148.500000    12.000000    26.000000    36.000000
max    4198.000000    32.000000   946.000000   961.000000
{'14784131 14858934 14784131 14845064': 0, '14852788 14717848 15639958 15632020': 1, '14844856 14724258 14925237 14854807': 2, '14925756 15639967 14853254 14728639': 3, '14844593 14924945': 4, '15709098 14716590 14924703 14779559': 5, '14726332 14728344 14854542 14844591': 6, '14858934 15636660 15704193 14849963': 7, '15710359 14847407 14845602 14859696': 8, '14794687 14782344': 9, '15630486 15702410 14718849 15709093': 10, '15632285 15706536 14721977 14925219': 11, '147829

In [13]:
log_steps = 10
save_steps = 20
kfold = 5
tp = TrainParams(
    epoch_size=30,
    lr=1e-3,
    loss_fn=nn.CrossEntropyLoss(),
    max_seq_len=1000,
    batch_size=64,
    dropout_rate=0.5,
    label_size = len(label2idx),
    vocab_size = w2v.vocab_size,
    embedding_dim = w2v.embedding_size + c2v.embedding_size,
    embedding1=c2v.embedding, 
    embedding2 =w2v.embedding,
    num_train_steps=int(df.shape[0]/kfold * (kfold-1)), 
    filter_size=70,
    kernel_size_list = [2,3,4,5],
    hidden_size = 100,
    early_stop_params = {
        'monitor':'f1_micro',
        'mode':'max',
        'min_delta': 0,
        'patience':3,
        'verbose':False
    },
    scheduler_params={'mode': 'max',
                     'factor': 0.3,
                     'patience': 1,
                     'verbose': True,
                     'threshold':0.0001,
                     'threshold_mode':'rel',
                     'cooldown':0,
                     'min_lr':1e-6},
    epsilon=0.5,
    embedding_name='embedding'
)


In [14]:

kf = KFold(n_splits=kfold, shuffle=True, random_state=24)
for fold,(train_index, valid_index) in enumerate(kf.split(df)):
    train, valid = df.iloc[train_index], df.iloc[valid_index]

    train_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               train['name'].values, train['description'].values, train['label'].values)
    valid_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               valid['name'].values, valid['description'].values, valid['label'].values)
    train_sampler = RandomSampler(train_dataset)
    valid_sampler = SequentialSampler(valid_dataset)
    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=tp.batch_size)
    valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=tp.batch_size)
    
    CKPT = './checkpoint/textcnn_fgm/k{}'.format(fold)
    saver = ModelSave(CKPT, continue_train=False)
    saver.init()
    tb = SummaryWriter(CKPT)
    es = EarlyStop(**tp.early_stop_params)
    global_step = 0

    model = Textcnn(tp)
    fgm = FGM(model,tp.epsilon, tp.embedding_name)
    optimizer, scheduler = model.get_optimizer() 
    
    for epoch_i in range(tp['epoch_size']):
        if global_step==1:
            print(model)
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10}  | {'Elapsed':^9}")
        print("-"*60)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()
        total_loss, batch_loss, batch_counts = 0, 0, 0

        model.train()
        for step, batch in enumerate(train_loader):
            global_step +=1
            batch_counts +=1

            #Forward propogate
            model.zero_grad()
            feature = {k:v.to(device) for k, v in batch.items()}
            logits = model(feature)
            loss = model.compute_loss(feature, logits)
            loss.backward()
            
            # attack
            fgm.attack() 
            logits = model(feature)
            loss = model.compute_loss(feature, logits)
            loss.backward()
            fgm.restore() 
            # step 
            batch_loss += loss.item()
            total_loss += loss.item()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            # Log steps for train loss logging
            if (step % log_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                time_elapsed = time.time() - t0_batch
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^9} | {time_elapsed:^9.2f}")
                tb.add_scalar('loss/batch_train', batch_loss / batch_counts, global_step=global_step)
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

            # Save steps for ckpt saving and dev evaluation
            if (step % save_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                val_metrics = multi_cls_metrics(model, valid_loader, device)
                for key, val in val_metrics.items():
                    tb.add_scalar(f'metric/{key}', val, global_step=global_step)
                avg_train_loss = total_loss / step
                tb.add_scalars('loss/train_valid',{'train': avg_train_loss,
                                                    'valid': val_metrics['val_loss']}, global_step=global_step)
                saver(total_loss / step, val_metrics['val_loss'], epoch_i, global_step, model, optimizer, scheduler)

        # On Epoch End: calcualte train & valid loss and log overall metrics
        time_elapsed = time.time() - t0_epoch
        val_metrics = multi_cls_metrics(model, valid_loader, device)
        avg_train_loss = total_loss / step
        scheduler.step(val_metrics['f1_micro'])
        print("-"*70)
        print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_metrics['val_loss']:^10.6f} | {time_elapsed:^9.2f}")
        multi_cls_log(epoch_i, val_metrics)
        print("\n")
        if es.check(val_metrics):
            break 


./checkpoint/textcnn_fgm/k0 model cleaned
Attack parameters ['embedding1.weight', 'embedding2.weight']
 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   1    |   10    |   4.602240   |     -     |   17.36  
   1    |   20    |   2.793466   |     -     |   16.31  
   1    |   30    |   2.365093   |     -     |   17.44  
   1    |   40    |   2.060813   |     -     |   13.85  
   1    |   50    |   1.860926   |     -     |   16.99  
   1    |   52    |   1.858243   |     -     |   2.49   
----------------------------------------------------------------------
   1    |    -    |   2.791233   |  1.700344  |   88.05  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   1    |  24.364%  |  78.701%  |  35.891%  |     26.954%     |   24.364%    |  46.071%  
 Epoch  | Micro Acc | Micro AUC | 

   8    |   20    |   0.702485   |     -     |   27.46  
   8    |   30    |   0.386362   |     -     |   32.50  
   8    |   40    |   0.400435   |     -     |   27.08  
   8    |   50    |   0.339015   |     -     |   33.83  
   8    |   52    |   0.286082   |     -     |   4.19   
----------------------------------------------------------------------
   8    |    -    |   0.529319   |  1.123520  |  162.97  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   8    |  55.736%  |  89.261%  |  61.900%  |     67.207%     |   55.736%    |  71.071%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   8    |  71.071%  |     -     |     -     |     71.071%     |   71.071%    |  71.071%  




 Epoch  |  Batch  |  Train Loss  |  Val 

  15    |   52    |   0.063155   |     -     |   3.73   
----------------------------------------------------------------------
  15    |    -    |   0.164718   |  1.210034  |  168.66  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  15    |  57.729%  |  89.602%  |  63.059%  |     66.684%     |   57.729%    |  73.333%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  15    |  73.333%  |     -     |     -     |     73.333%     |   73.333%    |  73.333%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  16    |   10    |   0.255542   |     -     |   28.72  
  16    |   20    |   0.256905   |     -     |   29.86  
  16    |   30    |   0.081255   |

   6    |   10    |   0.884851   |     -     |   33.88  
   6    |   20    |   0.857040   |     -     |   26.55  
   6    |   30    |   0.585018   |     -     |   32.16  
   6    |   40    |   0.582434   |     -     |   30.29  
   6    |   50    |   0.585669   |     -     |   38.28  
   6    |   52    |   0.849960   |     -     |   4.98   
----------------------------------------------------------------------
   6    |    -    |   0.721825   |  0.948471  |  172.43  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   6    |  62.844%  |  84.416%  |  69.154%  |     66.693%     |   62.844%    |  74.048%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   6    |  74.048%  |     -     |     -     |     74.048%     |   74.048%    

  13    |   30    |   0.117521   |     -     |   28.42  
  13    |   40    |   0.102485   |     -     |   22.78  
  13    |   50    |   0.087550   |     -     |   28.64  
  13    |   52    |   0.128400   |     -     |   3.42   
----------------------------------------------------------------------
  13    |    -    |   0.186118   |  0.958500  |  137.70  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  13    |  63.980%  |  83.609%  |  69.852%  |     71.925%     |   63.980%    |  77.143%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  13    |  77.143%  |     -     |     -     |     77.143%     |   77.143%    |  77.143%  




./checkpoint/textcnn_fgm/k2 model cleaned
Attack parameters ['embedding1.weight', 'embedding2.wei

   7    |   52    |   0.407723   |     -     |   3.46   
----------------------------------------------------------------------
   7    |    -    |   0.610963   |  1.108050  |  139.22  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   7    |  54.565%  |  87.601%  |  66.162%  |     71.999%     |   54.565%    |  72.143%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   7    |  72.143%  |     -     |     -     |     72.143%     |   72.143%    |  72.143%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   8    |   10    |   0.619871   |     -     |   25.53  
   8    |   20    |   0.637821   |     -     |   22.91  
   8    |   30    |   0.320906   |

  15    |   10    |   0.202663   |     -     |   25.89  
  15    |   20    |   0.221003   |     -     |   22.76  
  15    |   30    |   0.099764   |     -     |   27.75  
  15    |   40    |   0.057996   |     -     |   22.62  
  15    |   50    |   0.057701   |     -     |   27.92  
  15    |   52    |   0.083960   |     -     |   3.37   
----------------------------------------------------------------------
  15    |    -    |   0.130036   |  1.159692  |  135.68  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  15    |  58.762%  |  88.211%  |  68.557%  |     72.058%     |   58.762%    |  75.000%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  15    |  75.000%  |     -     |     -     |     75.000%     |   75.000%    

   4    |   30    |   0.992177   |     -     |   26.82  
   4    |   40    |   0.937955   |     -     |   21.79  
   4    |   50    |   1.020262   |     -     |   27.07  
   4    |   52    |   0.938353   |     -     |   3.33   
----------------------------------------------------------------------
   4    |    -    |   1.137063   |  1.095060  |  130.59  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   4    |  50.229%  |  86.056%  |  61.006%  |     62.580%     |   50.229%    |  69.048%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   4    |  69.048%  |     -     |     -     |     69.048%     |   69.048%    |  69.048%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
--------------------------------------

  12    |   10    |   0.506803   |     -     |   24.34  
  12    |   20    |   0.499944   |     -     |   21.50  
  12    |   30    |   0.230254   |     -     |   26.49  
  12    |   40    |   0.150384   |     -     |   21.44  
  12    |   50    |   0.182824   |     -     |   26.19  
  12    |   52    |   0.301471   |     -     |   3.28   
Epoch    12: reducing learning rate of group 0 to 3.0000e-04.
----------------------------------------------------------------------
  12    |    -    |   0.323304   |  1.207824  |  128.22  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  12    |  57.719%  |  86.487%  |  64.760%  |     64.021%     |   57.719%    |  73.810%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  12    |  73.8

  19    |   30    |   0.046781   |     -     |   25.10  
  19    |   40    |   0.054505   |     -     |   19.92  
  19    |   50    |   0.045382   |     -     |   24.90  
  19    |   52    |   0.060410   |     -     |   3.08   
----------------------------------------------------------------------
  19    |    -    |   0.122125   |  1.171695  |  120.29  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  19    |  58.198%  |  86.721%  |  65.263%  |     72.843%     |   58.198%    |  75.000%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  19    |  75.000%  |     -     |     -     |     75.000%     |   75.000%    |  75.000%  




./checkpoint/textcnn_fgm/k4 model cleaned
Attack parameters ['embedding1.weight', 'embedding2.wei

   7    |   52    |   0.508687   |     -     |   3.07   
----------------------------------------------------------------------
   7    |    -    |   0.642375   |  0.973841  |  119.82  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   7    |  54.366%  |  88.669%  |  61.430%  |     55.850%     |   54.366%    |  73.063%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
   7    |  73.063%  |     -     |     -     |     73.063%     |   73.063%    |  73.063%  




 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   8    |   10    |   0.691887   |     -     |   22.30  
   8    |   20    |   0.650416   |     -     |   20.49  
   8    |   30    |   0.381459   |

  15    |   10    |   0.221764   |     -     |   22.48  
  15    |   20    |   0.211552   |     -     |   20.02  
  15    |   30    |   0.068383   |     -     |   24.58  
  15    |   40    |   0.060845   |     -     |   20.09  
  15    |   50    |   0.070249   |     -     |   24.81  
  15    |   52    |   0.056160   |     -     |   3.08   
----------------------------------------------------------------------
  15    |    -    |   0.128116   |  1.001171  |  119.72  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  15    |  55.759%  |  89.876%  |  63.659%  |     58.448%     |   55.759%    |  75.685%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------------------------------------------
  15    |  75.685%  |     -     |     -     |     75.685%     |   75.685%    

In [15]:
test = test_process()
test_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx,  test['name'].values, test['description'].values)
test_sampler = SequentialSampler(test_dataset)
test_loader = DataLoader(test_dataset, sampler=test_sampler, batch_size=tp.batch_size)

In [17]:
result = kfold_inference(test_loader, tp, Textcnn, './checkpoint/textcnn_fgm', 5,device)
result['pred'] = result['pred_avg']
result_process(result, label2idx, './submit/textcnn_fgm_5fold_avg.csv')