In [34]:
import torch 
import time 
from sklearn.model_selection import KFold

import torch.nn as nn
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from gensim.models.phrases import Phrases
from gensim.models import Word2Vec

from src.train_utils import set_seed, ModelSave, get_torch_device, EarlyStop, TrainParams
from src.evaluation import classification_inference
from src.metric import  multi_cls_metrics,multi_cls_log
from src.dataset.tokenizer import GensimTokenizer

from iflytek_app.dataset import MixDataset
from iflytek_app.models import Textcnn
from iflytek_app.process import train_process, test_process, result_process, kfold_inference

device = get_torch_device()
set_seed()

No GPU available, using the CPU instead.


In [2]:
phraser= Phrases.load('./checkpoint/phrase_tokenizer')
c2v = GensimTokenizer( Word2Vec.load('./checkpoint/char_min1_win5_sg_d100'))
w2v = GensimTokenizer(Word2Vec.load('./checkpoint/phrase_min1_win5_sg_d100'), phraser)
w2v.init_vocab()
c2v.init_vocab()

df, label2idx = train_process()
test = test_process()

                id           l1           l2          len
count  4199.000000  4199.000000  4199.000000  4199.000000
mean   2099.000000     8.969278    37.087878    46.057156
std    1212.291219     4.576621    79.204914    79.332999
min       0.000000     2.000000     1.000000     4.000000
25%    1049.500000     5.000000     6.000000    15.000000
50%    2099.000000     8.000000    12.000000    22.000000
75%    3148.500000    12.000000    26.000000    36.000000
max    4198.000000    32.000000   946.000000   961.000000
{'14784131 14858934 14784131 14845064': 0, '14852788 14717848 15639958 15632020': 1, '14844856 14724258 14925237 14854807': 2, '14925756 15639967 14853254 14728639': 3, '14844593 14924945': 4, '15709098 14716590 14924703 14779559': 5, '14726332 14728344 14854542 14844591': 6, '14858934 15636660 15704193 14849963': 7, '15710359 14847407 14845602 14859696': 8, '14794687 14782344': 9, '15630486 15702410 14718849 15709093': 10, '15632285 15706536 14721977 14925219': 11, '147829

1. Random Delete
2. Random Insert
3. Random Swap
4. Synonym Replace

In [3]:
from src.preprocess.augment import *
import pandas as pd
import itertools

In [45]:
df.describe()

Unnamed: 0,id,l1,l2,len
count,4199.0,4199.0,4199.0,4199.0
mean,2099.0,8.969278,37.087878,46.057156
std,1212.291219,4.576621,79.204914,79.332999
min,0.0,2.0,1.0,4.0
25%,1049.5,5.0,6.0,15.0
50%,2099.0,8.0,12.0,22.0
75%,3148.5,12.0,26.0,36.0
max,4198.0,32.0,946.0,961.0


In [49]:
aug_samples = [] 

rand_del = WordDelete(w2v, max_sample=1, filters=())
rand_synon= W2vSynonymous(w2v, max_sample=1, filters=())
rand_swap = WordSwap(w2v, max_sample=1, filters=())


for index, row in enumerate(df.iterrows()):
    text = row[1]['description']
    name = row[1]['name']
    label = row[1]['label']
    for aug_text in rand_del.augment(text):
        aug_samples.append((index, name, aug_text.split(' '), label, 'word_random_delete'))
    for aug_text in rand_swap.augment(text):
        aug_samples.append((index, name, aug_text.split(' '), label, 'word_random_swap'))
    for aug_text in rand_synon.augment(text):
        aug_samples.append((index, name, aug_text.split(' '), label, 'word_random_synonymous'))
        
        
# rand_del = WordDelete(c2v, filters=())
# rand_synon= W2vSynonymous(c2v, filters=())
# rand_swap = WordSwap(c2v, filters=())

# for index, row in enumerate(df.iterrows()):
#     text = row[1]['description']
#     name = row[1]['name']
#     label = row[1]['label']
#     for aug_text in rand_del.augment(text):
#         aug_samples.append((index, name, aug_text.split(' '), label, 'char_random_delete'))
#     for aug_text in rand_swap.augment(text):
#         aug_samples.append((index, name, aug_text.split(' '), label, 'char_random_swap'))
#     for aug_text in rand_synon.augment(text):
#         aug_samples.append((index, name, aug_text.split(' '), label, 'char_random_synonymous'))
        
df_aug = pd.DataFrame(aug_samples, columns= ['index', 'name','description','label','augment'])    
df_aug['description'] = df_aug['description'].map(lambda x: list(itertools.chain(*[i.split('_') for i in x])))

In [52]:
log_steps = 10
save_steps = 20
tp = TrainParams(
    epoch_size=30,
    lr=1e-3,
    loss_fn=nn.CrossEntropyLoss(),
    max_seq_len=1000,
    batch_size=32,
    aug_batch_size=8,
    dropout_rate=0.5,
    label_size = len(label2idx),
    vocab_size = w2v.vocab_size,
    embedding_dim = w2v.embedding_size + c2v.embedding_size,
    embedding1=c2v.embedding, 
    embedding2 =w2v.embedding,
    num_train_steps=int(df.shape[0]/5 *4), 
    filter_size=70,
    kernel_size_list = [2,3,4,5],
    hidden_size = 100,
    early_stop_params = {
        'monitor':'f1_micro',
        'mode':'max',
        'min_delta': 0,
        'patience':5,
        'verbose':False
    },
    scheduler_params={'mode': 'max',
                     'factor': 0.5,
                     'patience': 2,
                     'verbose': True,
                     'threshold':0.0001,
                     'threshold_mode':'rel',
                     'cooldown':0,
                     'min_lr':1e-6}
)

In [53]:
kf = KFold(n_splits=5, shuffle=True, random_state=24)
for fold,(train_index, valid_index) in enumerate(kf.split(df)):
    train, valid = df.iloc[train_index], df.iloc[valid_index]

    train_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               train['name'].values, train['description'].values, train['label'].values)
    valid_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               valid['name'].values, valid['description'].values, valid['label'].values)
    train_sampler = RandomSampler(train_dataset)
    valid_sampler = SequentialSampler(valid_dataset)
    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=tp.batch_size)

    valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=tp.batch_size)
    
    CKPT = './checkpoint/textcnn_aug/k{}'.format(fold)
    saver = ModelSave(CKPT, continue_train=False)
    es = EarlyStop(**tp.early_stop_params)
    global_step = 0
    saver.init()
    tb = SummaryWriter(CKPT)
    model = Textcnn(tp)
    optimizer, scheduler = model.get_optimizer()
    

    for epoch_i in range(tp['epoch_size']):
        aug = df_aug.loc[df_aug['index'].isin(train_index),:] #避免数据泄露只使用train部分的增强数据
        aug = aug.groupby('index').sample(1)
        aug_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                           aug['name'].values, aug['description'].values, aug['label'].values)
        aug_sampler = RandomSampler(aug_dataset)
        aug_loader = DataLoader(aug_dataset, sampler=aug_sampler, batch_size=tp.aug_batch_size)
        if global_step==1:
            print(model)
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10}  | {'Elapsed':^9}")
        print("-"*60)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()
        total_loss, batch_loss, batch_counts = 0, 0, 0

        model.train()
        aug_iter = iter(aug_loader)
        for step, batch in enumerate(train_loader):
            aug_batch = next(aug_iter)
            global_step +=1
            batch_counts +=1

            #Forward propogate
            model.zero_grad()
            feature = {k:v.to(device) for k, v in list(batch.items()) + list(aug_batch.items())}
            logits = model(feature)
            loss = model.compute_loss(feature, logits)
            batch_loss += loss.item()
            total_loss += loss.item()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            # Log steps for train loss logging
            if (step % log_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                time_elapsed = time.time() - t0_batch
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^9} | {time_elapsed:^9.2f}")
                tb.add_scalar('loss/batch_train', batch_loss / batch_counts, global_step=global_step)
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

            # Save steps for ckpt saving and dev evaluation
            if (step % save_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                val_metrics = multi_cls_metrics(model, valid_loader, device)
                for key, val in val_metrics.items():
                    tb.add_scalar(f'metric/{key}', val, global_step=global_step)
                avg_train_loss = total_loss / step
                tb.add_scalars('loss/train_valid',{'train': avg_train_loss,
                                                    'valid': val_metrics['val_loss']}, global_step=global_step)
                saver(total_loss / step, val_metrics['val_loss'], epoch_i, global_step, model, optimizer, scheduler)

        # On Epoch End: calcualte train & valid loss and log overall metrics
        time_elapsed = time.time() - t0_epoch
        val_metrics = multi_cls_metrics(model, valid_loader, device)
        avg_train_loss = total_loss / step
        scheduler.step(val_metrics['f1_micro'])
        print("-"*70)
        print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_metrics['val_loss']:^10.6f} | {time_elapsed:^9.2f}")
        multi_cls_log(epoch_i, val_metrics)
        print("\n")
        if es.check(val_metrics):
            break 

./checkpoint/textcnn_aug/k0 model cleaned
 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   1    |   10    |   5.074656   |     -     |   1.88   
   1    |   20    |   3.778959   |     -     |   1.52   
   1    |   30    |   2.685808   |     -     |   6.13   
   1    |   40    |   2.526789   |     -     |   1.52   
   1    |   50    |   2.390051   |     -     |   5.80   
   1    |   60    |   2.401289   |     -     |   1.65   
   1    |   70    |   2.231517   |     -     |   6.01   
   1    |   80    |   2.136148   |     -     |   1.52   
   1    |   90    |   2.196325   |     -     |   5.59   
   1    |   100   |   2.053561   |     -     |   1.48   
   1    |   104   |   1.974780   |     -     |   4.72   
----------------------------------------------------------------------
   1    |    -    |   2.766585   |  1.977645  |   42.25  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | M

 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   7    |   10    |   1.311979   |     -     |   1.94   
   7    |   20    |   1.048122   |     -     |   1.51   
   7    |   30    |   0.764053   |     -     |   6.04   
   7    |   40    |   1.284393   |     -     |   1.52   
   7    |   50    |   0.944036   |     -     |   5.79   
   7    |   60    |   0.938564   |     -     |   1.52   
   7    |   70    |   1.238071   |     -     |   5.90   
   7    |   80    |   0.904513   |     -     |   1.56   
   7    |   90    |   1.246240   |     -     |   5.86   
   7    |   100   |   1.144437   |     -     |   1.56   
   7    |   104   |   0.921485   |     -     |   4.97   
----------------------------------------------------------------------
   7    |    -    |   1.088866   |  1.378038  |   42.60  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
---------------------------------

 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  13    |   10    |   0.731968   |     -     |   1.75   
  13    |   20    |   0.980180   |     -     |   1.61   
  13    |   30    |   0.793977   |     -     |   5.97   
  13    |   40    |   0.586515   |     -     |   1.58   
  13    |   50    |   0.742843   |     -     |   5.81   
  13    |   60    |   0.519622   |     -     |   1.52   
  13    |   70    |   0.493908   |     -     |   5.78   
  13    |   80    |   0.763564   |     -     |   1.53   
  13    |   90    |   0.806786   |     -     |   5.92   
  13    |   100   |   0.908289   |     -     |   1.61   
  13    |   104   |   1.448251   |     -     |   4.94   
----------------------------------------------------------------------
  13    |    -    |   0.767322   |  1.179225  |   42.31  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
---------------------------------

 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  19    |   10    |   0.628765   |     -     |   1.73   
  19    |   20    |   0.695120   |     -     |   1.50   
  19    |   30    |   0.567348   |     -     |   5.93   
  19    |   40    |   0.388347   |     -     |   1.60   
  19    |   50    |   0.542981   |     -     |   5.85   
  19    |   60    |   0.341235   |     -     |   1.50   
  19    |   70    |   0.520477   |     -     |   5.82   
  19    |   80    |   0.410636   |     -     |   1.51   
  19    |   90    |   0.504536   |     -     |   5.96   
  19    |   100   |   0.470566   |     -     |   1.54   
  19    |   104   |   0.169661   |     -     |   4.88   
----------------------------------------------------------------------
  19    |    -    |   0.500072   |  1.320641  |   42.24  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
---------------------------------

 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  25    |   10    |   0.541780   |     -     |   1.69   
  25    |   20    |   0.555198   |     -     |   1.51   
  25    |   30    |   0.473276   |     -     |   5.63   
  25    |   40    |   0.114115   |     -     |   1.48   
  25    |   50    |   0.175552   |     -     |   5.64   
  25    |   60    |   0.180843   |     -     |   1.52   
  25    |   70    |   0.251170   |     -     |   5.64   
  25    |   80    |   0.157968   |     -     |   1.49   
  25    |   90    |   0.326040   |     -     |   5.63   
  25    |   100   |   0.495060   |     -     |   1.47   
  25    |   104   |   0.185865   |     -     |   4.72   
----------------------------------------------------------------------
  25    |    -    |   0.326878   |  1.192616  |   40.54  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
---------------------------------

./checkpoint/textcnn_aug/k1 model cleaned
 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   1    |   10    |   5.448909   |     -     |   1.72   
   1    |   20    |   3.346522   |     -     |   1.50   
   1    |   30    |   2.718800   |     -     |   5.72   
   1    |   40    |   2.575842   |     -     |   1.47   
   1    |   50    |   2.429005   |     -     |   5.62   
   1    |   60    |   2.408510   |     -     |   1.47   
   1    |   70    |   2.409909   |     -     |   5.68   
   1    |   80    |   2.289253   |     -     |   1.48   
   1    |   90    |   2.021042   |     -     |   5.65   
   1    |   100   |   1.907148   |     -     |   1.46   
   1    |   104   |   2.169859   |     -     |   4.74   
----------------------------------------------------------------------
   1    |    -    |   2.785363   |  1.928305  |   40.68  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | M

 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   7    |   10    |   0.985310   |     -     |   1.77   
   7    |   20    |   1.195302   |     -     |   1.48   
   7    |   30    |   1.282200   |     -     |   5.64   
   7    |   40    |   0.896012   |     -     |   1.48   
   7    |   50    |   0.954022   |     -     |   5.66   
   7    |   60    |   1.168383   |     -     |   1.49   
   7    |   70    |   0.965996   |     -     |   5.64   
   7    |   80    |   1.592232   |     -     |   1.49   
   7    |   90    |   1.247428   |     -     |   5.62   
   7    |   100   |   0.811402   |     -     |   1.50   
   7    |   104   |   1.153562   |     -     |   4.72   
----------------------------------------------------------------------
   7    |    -    |   1.120985   |  1.076969  |   40.66  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
---------------------------------

 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  13    |   10    |   0.986295   |     -     |   1.77   
  13    |   20    |   0.636493   |     -     |   1.51   
  13    |   30    |   0.622460   |     -     |   5.64   
  13    |   40    |   0.769189   |     -     |   1.50   
  13    |   50    |   0.703505   |     -     |   5.68   
  13    |   60    |   0.764656   |     -     |   1.48   
  13    |   70    |   0.554478   |     -     |   5.69   
  13    |   80    |   0.726883   |     -     |   1.51   
  13    |   90    |   0.705398   |     -     |   5.92   
  13    |   100   |   0.775631   |     -     |   1.63   
  13    |   104   |   0.725965   |     -     |   4.75   
----------------------------------------------------------------------
  13    |    -    |   0.734039   |  1.054610  |   41.52  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
---------------------------------

 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
  19    |   10    |   0.689589   |     -     |   1.71   
  19    |   20    |   0.774408   |     -     |   1.48   
  19    |   30    |   0.551916   |     -     |   6.16   
  19    |   40    |   0.560088   |     -     |   1.74   
  19    |   50    |   0.547460   |     -     |   6.18   
  19    |   60    |   0.237577   |     -     |   1.56   
  19    |   70    |   0.390466   |     -     |   6.08   
  19    |   80    |   0.493982   |     -     |   1.49   
  19    |   90    |   0.423881   |     -     |   6.02   
  19    |   100   |   0.820175   |     -     |   1.53   
  19    |   104   |   0.615769   |     -     |   5.12   
Epoch    19: reducing learning rate of group 0 to 5.0000e-04.
----------------------------------------------------------------------
  19    |    -    |   0.558155   |  1.039981  |   43.43  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precisio

 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   4    |   10    |   1.721190   |     -     |   3.06   
   4    |   20    |   1.529012   |     -     |   2.95   
   4    |   30    |   1.295495   |     -     |   9.29   
   4    |   40    |   1.353530   |     -     |   2.82   
   4    |   50    |   1.626776   |     -     |   9.53   
   4    |   60    |   1.545698   |     -     |   2.54   
   4    |   70    |   1.659009   |     -     |   9.31   
   4    |   80    |   1.376158   |     -     |   2.74   
   4    |   90    |   1.424977   |     -     |   9.82   
   4    |   100   |   1.392891   |     -     |   2.67   
   4    |   104   |   1.494296   |     -     |   8.34   
----------------------------------------------------------------------
   4    |    -    |   1.509094   |  1.387021  |   70.40  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
---------------------------------

KeyboardInterrupt: 

In [None]:
test = test_process()
test_dataset = test_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx,
                                         test['name'].values, test['description'].values)
test_sampler = SequentialSampler(test_dataset)
test_loader = DataLoader(test_dataset, sampler=test_sampler, batch_size=tp.batch_size)

In [36]:
result = kfold_inference(test_loader, tp, Textcnn, './checkpoint/textcnn_aug', 5, device)
result['pred'] = result['pred_avg']
result_process(result, label2idx, './submit/textcnn_aug_5fold_avg.csv')