In [1]:
import os
import csv
import torch
import argparse
import numpy as np
from mlp import mlp
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from transformers import InputExample, InputFeatures
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer, BertModel
from transformers import glue_convert_examples_to_features as convert_examples_to_features

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
config=BertConfig.from_pretrained('./model')
tokenizer=BertTokenizer.from_pretrained('./model')
model=BertModel.from_pretrained('./model',config=config)

Some weights of the model checkpoint at ./model were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
def create_examples(lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    del lines[0]
    for (i, line) in enumerate(lines):
        guid = "%s-%s" % (set_type, i)
        # label = int(line[1])
        # CNM!!@!!
        text_a = line[2].replace("YZYHUST", ',')
        examples.append(
            InputExample(guid=guid, text_a=text_a, text_b=None, label=None))
    return examples

def Load_data(tokenizer,file_path):
    csv.field_size_limit(500 * 1024 * 1024)
    with open(file_path, 'r') as f:
        examples = create_examples(list(csv.reader(f)), 'predict')
    label_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
    features = convert_examples_to_features(
        examples,
        tokenizer,
        label_list=label_list,
        max_length=256,
        output_mode="classification",
    )
    all_input_ids = torch.tensor([f.input_ids for f in features],
                                 dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features],
                                      dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features],
                                      dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_attention_mask,
                            all_token_type_ids)
    return DataLoader(dataset, batch_size=16)

In [6]:
file_path='./url/ip/ip_test.csv'
pred_dataloader = Load_data(tokenizer,file_path=file_path)
file=pd.read_csv(file_path)
label=file['label']
label.to_csv('./url/ip/test_label.csv',index=None)



In [7]:
feature_list=[]
for batch in tqdm(pred_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(device) for t in batch)

        with torch.no_grad():
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'token_type_ids':batch[2]
            }
            seq_outputs,pool_outputs = model(**inputs,return_dict=False)
            feature_list.append(pool_outputs)
features=torch.concat(feature_list,dim=0)
torch.save(features,'./url/ip/features_test.pt')

Evaluating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 55/55 [03:40<00:00,  4.01s/it]


In [117]:
def eval_on_test(net):
    net.eval()
    labels=pd.read_csv('./features_labels_test.csv')
    labels=labels['label']
    length=len(labels)
    labels=DataLoader(labels,batch_size=64)
    features=torch.load('./features_test.pt')
    features=DataLoader(features,batch_size=64)

    sum=0
    for feature,label in zip(features,labels):
        output=net(feature)
        predict=torch.argmax(output,dim=-1)
        sum+=(predict==label).sum()
    print(f"accuracy on test :{sum/length}")


In [102]:
def train(lr,batch_size,epoches,over_write=False):
    # load data
    label_data=pd.read_csv('./features_labels.csv')
    labels=DataLoader(label_data['label'],batch_size=batch_size)
    features_data=torch.load('./features.pt')
    # net and solver
    net=mlp()
    if os.path.exists('./classifier_model/mlp.pkl') and not over_write:
        state_dict=torch.load('./classifier_model/mlp.pkl')
        net.load_state_dict(state_dict=state_dict)
    criterion=torch.nn.CrossEntropyLoss()
    optimizer=torch.optim.Adam(net.parameters(),lr=lr)
    features=DataLoader(features_data,batch_size=batch_size)
    # train
    loss_list=[]
    for e in range(epoches):
        epoch_loss=[]
        for label,feature in zip(labels,features):
            net.train()
            output=net(feature)
            loss=criterion(output,label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())
        eval_on_test(net)
        mean_loss=np.array(epoch_loss).mean()
        loss_list.append(mean_loss)
        print(f'epoch {e+1} loss: {mean_loss}')
    loss_log=pd.DataFrame(loss_list,columns=['loss'])
    if os.path.exists('./log/loss.csv') and not over_write:
        f=pd.read_csv('./log/loss.csv')
        loss_log=pd.concat([f,loss_log],axis=0)
    loss_log.to_csv('./log/loss.csv',index=None)
    torch.save(net.state_dict(),'./classifier_model/mlp.pkl')


In [141]:
train(batch_size=32,lr=1e-5,epoches=10,over_write=False)

accuracy on test :0.8268229365348816
epoch 1 loss: 0.1995803379783562
accuracy on test :0.8268229365348816
epoch 2 loss: 0.20420094601771174
accuracy on test :0.8268229365348816
epoch 3 loss: 0.2036771236647231
accuracy on test :0.8268229365348816
epoch 4 loss: 0.20452660261071287
accuracy on test :0.8268229365348816
epoch 5 loss: 0.2002989074777967
accuracy on test :0.8268229365348816
epoch 6 loss: 0.2070799386225796
accuracy on test :0.8268229365348816
epoch 7 loss: 0.19853777044530338
accuracy on test :0.826171875
epoch 8 loss: 0.20011322557305297
accuracy on test :0.8268229365348816
epoch 9 loss: 0.19757083319806648
accuracy on test :0.8268229365348816
epoch 10 loss: 0.19722782179208784
