In [35]:
import torch 
import time 
from sklearn.model_selection import KFold

import torch.nn as nn
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from gensim.models.phrases import Phrases
from gensim.models import Word2Vec
from src.train_utils import set_seed, ModelSave, get_torch_device, EarlyStop, TrainParams
from src.metric import  multi_cls_metrics,multi_cls_log
from src.dataset.tokenizer import GensimTokenizer
from src.loss import BootstrapCrossEntropy

from iflytek_app.dataset import MixDataset
from iflytek_app.process import train_process, test_process, result_process,kfold_inference
from iflytek_app.models import TextcnnPseudoLabel, Textcnn
device = get_torch_device()
set_seed()

No GPU available, using the CPU instead.


In [2]:
c2v = GensimTokenizer( Word2Vec.load('./checkpoint/char_min1_win5_sg_d100'))
w2v = GensimTokenizer(Word2Vec.load('./checkpoint/phrase_min1_win5_sg_d100'))
w2v.init_vocab()
c2v.init_vocab()
phraser = Phrases.load('./checkpoint/phrase_tokenizer')
df, label2idx = train_process()
test = test_process()
label2idx.update({'unlabel':-1})


                id           l1           l2          len
count  4199.000000  4199.000000  4199.000000  4199.000000
mean   2099.000000     8.969278    37.087878    46.057156
std    1212.291219     4.576621    79.204914    79.332999
min       0.000000     2.000000     1.000000     4.000000
25%    1049.500000     5.000000     6.000000    15.000000
50%    2099.000000     8.000000    12.000000    22.000000
75%    3148.500000    12.000000    26.000000    36.000000
max    4198.000000    32.000000   946.000000   961.000000
{'14784131 14858934 14784131 14845064': 0, '14852788 14717848 15639958 15632020': 1, '14844856 14724258 14925237 14854807': 2, '14925756 15639967 14853254 14728639': 3, '14844593 14924945': 4, '15709098 14716590 14924703 14779559': 5, '14726332 14728344 14854542 14844591': 6, '14858934 15636660 15704193 14849963': 7, '15710359 14847407 14845602 14859696': 8, '14794687 14782344': 9, '15630486 15702410 14718849 15709093': 10, '15632285 15706536 14721977 14925219': 11, '147829

In [84]:
log_steps = 10
save_steps = 20
kfold=5
tp = TrainParams(
    epoch_size=30,
    lr=1e-3,
    loss_fn=BootstrapCrossEntropy(),
    max_seq_len=1000,
    batch_size=64,
    dropout_rate=0.5,
    label_size = len(label2idx),
    vocab_size = w2v.vocab_size,
    embedding_dim = w2v.embedding_size + c2v.embedding_size,
    embedding1=c2v.embedding, 
    embedding2 =w2v.embedding,
    filter_size=70,
    kernel_size_list = [2,3,4,5],
    hidden_size = 100,
    early_stop_params = {
        'monitor':'f1_micro',
        'mode':'max',
        'min_delta': 0,
        'patience':5,
        'verbose':False
    },
    scheduler_params={'mode': 'max',
                     'factor': 0.3,
                     'patience': 1,
                     'verbose': True,
                     'threshold':0.0001,
                     'threshold_mode':'rel',
                     'cooldown':0,
                     'min_lr':1e-6},
    T1=5,
    T2=20,
    alpha_f=3,
    q=0.7
    
)

In [85]:
kf = KFold(n_splits=kfold, shuffle=True, random_state=24)
for fold,(train_index, valid_index) in enumerate(kf.split(df)):
    train, valid = df.iloc[train_index], df.iloc[valid_index]

    # combine label and unlabel data
    train_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               train['name'].values.tolist() + test['name'].values.tolist(),
                               train['description'].values.tolist() + test['name'].values.tolist(),
                               train['label'].values.tolist() + ['unlabel']* test.shape[0])
    valid_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx, 
                               valid['name'].values, valid['description'].values, valid['label'].values)
    train_sampler = RandomSampler(train_dataset)
    valid_sampler = SequentialSampler(valid_dataset)
    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=tp.batch_size)
    valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=tp.batch_size)
    
    tp.update({'num_train_steps': len(train_dataset)})
    
    CKPT = './checkpoint/textcnn_pseudo_label/k{}'.format(fold)
    saver = ModelSave(CKPT, continue_train=False)
    saver.init()
    tb = SummaryWriter(CKPT)
    es = EarlyStop(**tp.early_stop_params)
    global_step = 0
    model = TextcnnPseudoLabel(tp)
    optimizer, scheduler = model.get_optimizer()
    
    for epoch_i in range(tp['epoch_size']):
        if global_step==1:
            print(model)
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10}  | {'Elapsed':^9}")
        print("-"*60)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()
        total_loss, batch_loss, batch_counts = 0, 0, 0

        model.train()
        for step, batch in enumerate(train_loader):
            global_step +=1
            batch_counts +=1

            #Forward propogate
            model.zero_grad()
            feature = {k:v.to(device) for k, v in batch.items()}
            logits = model(feature)
            loss= model.compute_loss(feature, logits)
            tb.add_scalar('loss/sup_loss', model.supervised_loss, global_step=global_step)
            tb.add_scalar('loss/unsup_loss', model.unsupervised_loss, global_step=global_step)
            tb.add_scalar('loss/avg_loss', loss, global_step=global_step)
            batch_loss += loss.item()
            total_loss += loss.item()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

            # Log steps for train loss logging
            if (step % log_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                time_elapsed = time.time() - t0_batch
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^9} | {time_elapsed:^9.2f}")
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()

            # Save steps for ckpt saving and dev evaluation
            if (step % save_steps == 0 and step != 0) or (step == len(train_loader) - 1):
                val_metrics = multi_cls_metrics(model, valid_loader, device)
                for key, val in val_metrics.items():
                    tb.add_scalar(f'metric/{key}', val, global_step=global_step)
                avg_train_loss = total_loss / step
                tb.add_scalars('loss/train_valid',{'train': avg_train_loss,
                                                    'valid': val_metrics['val_loss']}, global_step=global_step)
                saver(total_loss / step, val_metrics['val_loss'], epoch_i, global_step, model, optimizer, scheduler)
        model.epoch_update()
        # On Epoch End: calcualte train & valid loss and log overall metrics
        time_elapsed = time.time() - t0_epoch
        val_metrics = multi_cls_metrics(model, valid_loader, device)
        avg_train_loss = total_loss / step
        scheduler.step(val_metrics['f1_micro'])
        print("-"*70)
        print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_metrics['val_loss']:^10.6f} | {time_elapsed:^9.2f}")
        multi_cls_log(epoch_i, val_metrics)
        print("\n")
        if es.check(val_metrics):
            break

./checkpoint/textcnn_pseudo_label/k0 model cleaned
 Epoch  |  Batch  |  Train Loss  |  Val Loss   |  Elapsed 
------------------------------------------------------------
   1    |   10    |   3.652543   |     -     |   8.09   
   1    |   20    |   2.411354   |     -     |   7.31   
   1    |   30    |   2.722228   |     -     |   10.58  
   1    |   40    |   2.456169   |     -     |   7.48   
   1    |   50    |   2.255050   |     -     |   10.46  
   1    |   60    |   2.221796   |     -     |   7.19   
   1    |   70    |   2.065554   |     -     |   10.63  
   1    |   80    |   2.071953   |     -     |   6.91   
----------------------------------------------------------------------
   1    |    -    |   2.527738   |    nan     |   71.89  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   1    |  20.241%  |  72.271%  |  30.177%  |     18.680%     

   8    |   10    |   0.577558   |     -     |   8.76   
   8    |   20    |   0.475350   |     -     |   8.13   
   8    |   30    |   0.348466   |     -     |   11.48  
   8    |   40    |   0.349786   |     -     |   8.31   
   8    |   50    |   0.327266   |     -     |   11.72  
   8    |   60    |   0.384040   |     -     |   7.48   
   8    |   70    |   0.323568   |     -     |   11.56  
   8    |   80    |   0.353554   |     -     |   7.89   
----------------------------------------------------------------------
   8    |    -    |   0.399668   |    nan     |   79.17  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   8    |  55.740%  |  84.761%  |  61.326%  |     67.419%     |   55.740%    |  72.024%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------

  15    |   10    |   0.281151   |     -     |   7.73   
  15    |   20    |   0.425981   |     -     |   6.91   
  15    |   30    |   0.078053   |     -     |   9.88   
  15    |   40    |   0.074297   |     -     |   6.75   
  15    |   50    |   0.081680   |     -     |   9.85   
  15    |   60    |   0.098030   |     -     |   6.79   
  15    |   70    |   0.116711   |     -     |   9.94   
  15    |   80    |   0.055521   |     -     |   6.50   
Epoch    15: reducing learning rate of group 0 to 9.0000e-05.
----------------------------------------------------------------------
  15    |    -    |   0.154942   |    nan     |   67.40  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  15    |  55.867%  |  84.699%  |  61.145%  |     68.026%     |   55.867%    |  72.619%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Mi

   4    |   10    |   0.906921   |     -     |   7.59   
   4    |   20    |   0.891986   |     -     |   6.83   
   4    |   30    |   0.852873   |     -     |   9.89   
   4    |   40    |   0.754559   |     -     |   6.80   
   4    |   50    |   0.716283   |     -     |   9.91   
   4    |   60    |   0.701962   |     -     |   6.83   
   4    |   70    |   0.846819   |     -     |   9.86   
   4    |   80    |   0.815850   |     -     |   6.57   
----------------------------------------------------------------------
   4    |    -    |   0.822243   |    nan     |   67.28  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   4    |  56.307%  |  78.025%  |  64.487%  |     62.793%     |   56.307%    |  71.786%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------

  11    |   10    |   0.345998   |     -     |   7.51   
  11    |   20    |   0.358424   |     -     |   6.77   
  11    |   30    |   0.127483   |     -     |   9.74   
  11    |   40    |   0.118938   |     -     |   6.78   
  11    |   50    |   0.127518   |     -     |   9.81   
  11    |   60    |   0.115992   |     -     |   6.78   
  11    |   70    |   0.139249   |     -     |   9.86   
  11    |   80    |   0.095869   |     -     |   6.51   
----------------------------------------------------------------------
  11    |    -    |   0.183009   |    nan     |   66.77  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  11    |  65.040%  |  79.345%  |  71.663%  |     74.042%     |   65.040%    |  76.190%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------

   3    |   10    |   1.560879   |     -     |   7.48   
   3    |   20    |   1.310359   |     -     |   6.80   
   3    |   30    |   1.181938   |     -     |   9.76   
   3    |   40    |   1.101598   |     -     |   6.72   
   3    |   50    |   1.061891   |     -     |   9.79   
   3    |   60    |   1.076083   |     -     |   6.74   
   3    |   70    |   1.113117   |     -     |   9.76   
   3    |   80    |   0.983279   |     -     |   6.53   
----------------------------------------------------------------------
   3    |    -    |   1.193154   |    nan     |   66.63  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   3    |  40.591%  |  80.825%  |  57.352%  |     55.306%     |   40.591%    |  61.667%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------

  10    |   10    |   0.472039   |     -     |   7.50   
  10    |   20    |   0.479335   |     -     |   6.82   
  10    |   30    |   0.209789   |     -     |   9.69   
  10    |   40    |   0.172064   |     -     |   6.70   
  10    |   50    |   0.159257   |     -     |   9.81   
  10    |   60    |   0.199489   |     -     |   6.70   
  10    |   70    |   0.214649   |     -     |   9.69   
  10    |   80    |   0.241846   |     -     |   6.41   
----------------------------------------------------------------------
  10    |    -    |   0.274459   |    nan     |   66.33  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  10    |  59.890%  |  82.515%  |  65.528%  |     71.731%     |   59.890%    |  74.643%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------

  17    |   10    |   0.205933   |     -     |   7.46   
  17    |   20    |   0.328592   |     -     |   6.72   
  17    |   30    |   0.048275   |     -     |   9.82   
  17    |   40    |   0.028047   |     -     |   6.77   
  17    |   50    |   0.018539   |     -     |   9.74   
  17    |   60    |   0.019889   |     -     |   6.71   
  17    |   70    |   0.024460   |     -     |   9.74   
  17    |   80    |   0.035984   |     -     |   6.45   
Epoch    17: reducing learning rate of group 0 to 9.0000e-05.
----------------------------------------------------------------------
  17    |    -    |   0.091289   |    nan     |   66.42  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  17    |  60.699%  |  82.411%  |  66.946%  |     77.847%     |   60.699%    |  75.595%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Mi

   4    |   10    |   0.931828   |     -     |   7.40   
   4    |   20    |   0.966060   |     -     |   6.81   
   4    |   30    |   0.798671   |     -     |   9.81   
   4    |   40    |   0.825807   |     -     |   6.69   
   4    |   50    |   0.765596   |     -     |   9.77   
   4    |   60    |   0.818001   |     -     |   6.77   
   4    |   70    |   0.830591   |     -     |   9.94   
   4    |   80    |   0.779261   |     -     |   6.44   
----------------------------------------------------------------------
   4    |    -    |   0.851125   |    nan     |   66.67  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   4    |  54.470%  |  82.225%  |  62.225%  |     61.955%     |   54.470%    |  70.833%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------

  11    |   10    |   0.342772   |     -     |   7.49   
  11    |   20    |   0.300221   |     -     |   6.77   
  11    |   30    |   0.138371   |     -     |   9.71   
  11    |   40    |   0.181984   |     -     |   6.71   
  11    |   50    |   0.195154   |     -     |   9.78   
  11    |   60    |   0.113971   |     -     |   6.75   
  11    |   70    |   0.169412   |     -     |   9.72   
  11    |   80    |   0.146971   |     -     |   6.46   
----------------------------------------------------------------------
  11    |    -    |   0.202892   |    nan     |   66.43  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  11    |  58.750%  |  82.426%  |  66.050%  |     67.554%     |   58.750%    |  73.810%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------

   3    |   10    |   1.652635   |     -     |   7.45   
   3    |   20    |   1.348229   |     -     |   6.77   
   3    |   30    |   1.222528   |     -     |   9.79   
   3    |   40    |   1.126942   |     -     |   6.75   
   3    |   50    |   1.009641   |     -     |   9.82   
   3    |   60    |   1.125232   |     -     |   6.71   
   3    |   70    |   1.011737   |     -     |   9.76   
   3    |   80    |   0.972132   |     -     |   6.48   
----------------------------------------------------------------------
   3    |    -    |   1.204292   |    nan     |   66.58  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
   3    |  44.809%  |  83.400%  |  55.084%  |     52.307%     |   44.809%    |  67.819%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------

  10    |   10    |   0.548895   |     -     |   7.49   
  10    |   20    |   0.415685   |     -     |   6.78   
  10    |   30    |   0.198193   |     -     |   9.75   
  10    |   40    |   0.229549   |     -     |   6.75   
  10    |   50    |   0.196354   |     -     |   9.73   
  10    |   60    |   0.200442   |     -     |   6.72   
  10    |   70    |   0.205706   |     -     |   9.74   
  10    |   80    |   0.252531   |     -     |   6.50   
----------------------------------------------------------------------
  10    |    -    |   0.287781   |    nan     |   66.49  


 Epoch  | Macro Acc | Macro AUC | Macro AP  | Macro Precision | Macro Recall | Macro F1 
------------------------------------------------------------------------------------------
  10    |  54.746%  |  84.816%  |  62.886%  |     60.575%     |   54.746%    |  73.897%  
 Epoch  | Micro Acc | Micro AUC | Micro AP  | Micro Precision | Micro Recall | Micro F1 
------------------------------------------------------

## K-Fold Evalutaion

In [21]:
test = test_process()
test_dataset = MixDataset(tp.max_seq_len, w2v, c2v, phraser, label2idx,  test['name'].values, test['description'].values)
test_sampler = SequentialSampler(test_dataset)
test_loader = DataLoader(test_dataset, sampler=test_sampler, batch_size=tp.batch_size)

In [25]:
result = kfold_inference(test_loader, tp, Textcnn, './checkpoint/textcnn_pseudo_label', 5, device)
result['pred'] = result['pred_avg']
result_process(result, label2idx, './submit/textcnn_pseudo_label_5fold_avg.csv')