# Load dataset

In [1]:
import pandas as pd

df = pd.read_csv("data.csv", dtype={"content": str, "label": str})
df = df.dropna()
df

Unnamed: 0,content,label,start
0,Áo bao đẹp ạ!,POS,5
1,Tuyệt vời,POS,5
2,2day ao khong giong trong,NEG,1
3,"Mùi thơm,bôi lên da mềm da",POS,5
4,"Vải đẹp, dày dặn",POS,5
...,...,...,...
31455,Không đáng tiền,NEG,1
31456,Quần rất đẹp,POS,5
31457,Hàng đẹp đúng giá tiền,POS,5
31458,Chất vải khá ổn,POS,4


In [2]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [3]:
data = df['content'].to_list()
dataset = CustomDataset(data)
print(len(dataset))

31436


# Load model

In [4]:
import torch
from transformers import pipeline

model = "wonrax/phobert-base-vietnamese-sentiment"
pipe = pipeline("text-classification", model=model, 
                framework="pt", device="cuda", torch_dtype=torch.float16, # torch.float16 or torch.float32
                padding='max_length', # Will pad the sequences up to the model max length
                truncation=True) # Will truncate the sequences that are longer than the specified max length

  from .autonotebook import tqdm as notebook_tqdm


# Inference

In [9]:
import time
from tqdm import tqdm

for batch_size in [8]:
    start = time.time()
    for output in tqdm(pipe(dataset, batch_size=batch_size), total=len(dataset)):
        pass
    end = time.time()
    
    inference_time = end - start
    num_requests = len(dataset)
    print(f"Batch size: {batch_size}")
    print(f"Total inference time: {round(inference_time, 4)}s")
    print(f"Total sample: {num_requests}")
    print(f"Result: {round(num_requests / inference_time)} sample/s")
    print('---------------------------------------------------------')

100%|███████████████████████████████████| 31436/31436 [00:29<00:00, 1048.22it/s]

Batch size: 8
Total inference time: 29.9955s
Total sample: 31436
Result: 1048 sample/s
---------------------------------------------------------



