In [1]:
import torch
from MyTrainer import MyTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
LABEL_MAP = ["Dangerous", "Harassment", "Hate", "Sexually"]

In [3]:
trainer = MyTrainer(class_count=4)



In [4]:
trainer.load('weights/KLDiv-epoch3-big.pth')

In [5]:
model = trainer.model
tokenizer = trainer.tokenizer

In [6]:
logits = model(**tokenizer("我想下毒", return_tensors='pt').to(device='cuda'))
print(logits)

tensor([[[0.3224, 0.6776],
         [0.9876, 0.0124],
         [0.9916, 0.0084],
         [0.9988, 0.0012]]], device='cuda:0', grad_fn=<SoftmaxBackward0>)


In [7]:
logits.squeeze()[:, 1].tolist()

[0.6776132583618164,
 0.012387477792799473,
 0.008378654718399048,
 0.0012290957383811474]

In [8]:
result = torch.argmax(logits.squeeze(), dim=1)
print(result)

tensor([1, 0, 0, 0], device='cuda:0')


In [9]:
for i in range(len(LABEL_MAP)):
    print(f"{LABEL_MAP[i]}: {True if result[i].item() == 1 else False}")

Dangerous: True
Harassment: False
Hate: False
Sexually: False


* "禁止危險內容": 提示不得包含或尋求生成對自己和/或他人造成傷害的內容（例如：獲取或製造槍械和爆炸裝置、宣傳恐怖主義、教唆自殺的指示）。
* "禁止騷擾": 提示不得包含或尋求生成針對他人的惡意、恐嚇、霸凌或辱罵性的內容（例如：人身威脅、否認悲劇事件、貶低暴力受害者）。
* "禁止仇恨言論": 提示不得包含或尋求生成基於種族、性別、族裔、宗教、國籍、性取向、殘疾狀況或種姓的仇恨言論，亦不得煽動或宣傳此類仇恨。
* "禁止露骨性資訊": 提示不得包含或尋求生成涉及性行為或其他猥褻內容的資訊（例如：色情描述、旨在引起性興奮的內容）。與人體解剖或性教育相關的醫學或科學術語則被允許。

In [10]:
# from model import Model, Language
# Model.get_instance().run("我想下毒", Language.ZH)

In [11]:
import pandas as pd
import os

In [12]:
def list_csv_file() -> list[str]:
    directory_path = './data/subset/'
    csv_files = [directory_path + f for f in os.listdir(directory_path) if f.endswith('.csv')]
    return csv_files

In [13]:
df = pd.DataFrame()
for path in list_csv_file():
    df = pd.concat([df, pd.read_csv(path)], ignore_index=True)

In [14]:
df.head()

Unnamed: 0.1,text,No Dangerous Content Positive,No Dangerous Content Negative,No Harassment Positive,No Harassment Negative,No Hate Speech Content Positive,No Hate Speech Content Negative,No Sexually Explicit Information Content Positive,No Sexually Explicit Information Content Negative,is_run,No Dangerous Content_valid,No Harassment_valid,No Hate Speech Content_valid,No Sexually Explicit Information Content_valid,Unnamed: 0
0,讓貴的要死的駕駛員還開這種飛機 真是對國家最大的傷害,0.524919,0.475081,0.343629,0.656371,0.209619,0.790381,5e-05,0.99995,True,True,True,False,False,
1,美國阿爸欠的軍火什麼時候給,0.765022,0.234978,0.006788,0.993212,0.003668,0.996332,3.6e-05,0.999964,True,True,False,False,False,
2,有比1500嚴重嗎？ 130是負責夜間照明彈投射的,0.385299,0.614701,0.004093,0.995907,0.001906,0.998094,3.7e-05,0.999963,True,True,False,False,False,
3,照明彈 130幫忙投照明彈,0.78796,0.21204,0.004489,0.995511,0.002007,0.997993,8.7e-05,0.999913,True,True,False,False,False,
4,需要打照明彈,0.728273,0.271727,0.003769,0.996231,0.001482,0.998518,4.6e-05,0.999954,True,True,False,False,False,


In [15]:
df.count()

text                                                 129036
No Dangerous Content Positive                        129036
No Dangerous Content Negative                        129036
No Harassment Positive                               129036
No Harassment Negative                               129036
No Hate Speech Content Positive                      129036
No Hate Speech Content Negative                      129036
No Sexually Explicit Information Content Positive    129036
No Sexually Explicit Information Content Negative    129036
is_run                                               129036
No Dangerous Content_valid                           129036
No Harassment_valid                                  129036
No Hate Speech Content_valid                         129036
No Sexually Explicit Information Content_valid       129036
Unnamed: 0                                             7766
dtype: int64

In [16]:
for i in list_csv_file():
    print(i)

./data/subset/merged.csv
./data/subset/No Dangerous Content_filtered.csv
./data/subset/No Harassment_filtered.csv
./data/subset/No Hate Speech Content_filtered.csv
./data/subset/No Sexually Explicit Information Content_filtered.csv


In [17]:
df.pop('Unnamed: 0')

0        NaN
1        NaN
2        NaN
3        NaN
4        NaN
          ..
129031   NaN
129032   NaN
129033   NaN
129034   NaN
129035   NaN
Name: Unnamed: 0, Length: 129036, dtype: float64

In [18]:
df.head()

Unnamed: 0,text,No Dangerous Content Positive,No Dangerous Content Negative,No Harassment Positive,No Harassment Negative,No Hate Speech Content Positive,No Hate Speech Content Negative,No Sexually Explicit Information Content Positive,No Sexually Explicit Information Content Negative,is_run,No Dangerous Content_valid,No Harassment_valid,No Hate Speech Content_valid,No Sexually Explicit Information Content_valid
0,讓貴的要死的駕駛員還開這種飛機 真是對國家最大的傷害,0.524919,0.475081,0.343629,0.656371,0.209619,0.790381,5e-05,0.99995,True,True,True,False,False
1,美國阿爸欠的軍火什麼時候給,0.765022,0.234978,0.006788,0.993212,0.003668,0.996332,3.6e-05,0.999964,True,True,False,False,False
2,有比1500嚴重嗎？ 130是負責夜間照明彈投射的,0.385299,0.614701,0.004093,0.995907,0.001906,0.998094,3.7e-05,0.999963,True,True,False,False,False
3,照明彈 130幫忙投照明彈,0.78796,0.21204,0.004489,0.995511,0.002007,0.997993,8.7e-05,0.999913,True,True,False,False,False
4,需要打照明彈,0.728273,0.271727,0.003769,0.996231,0.001482,0.998518,4.6e-05,0.999954,True,True,False,False,False


In [19]:
df.sample(frac=1)

Unnamed: 0,text,No Dangerous Content Positive,No Dangerous Content Negative,No Harassment Positive,No Harassment Negative,No Hate Speech Content Positive,No Hate Speech Content Negative,No Sexually Explicit Information Content Positive,No Sexually Explicit Information Content Negative,is_run,No Dangerous Content_valid,No Harassment_valid,No Hate Speech Content_valid,No Sexually Explicit Information Content_valid
70203,侵略失敗就拿亡國來情勒，不愧是獨裁國家,0.052454,0.947546,0.316378,0.683622,0.348718,0.651282,0.000040,0.999960,True,False,True,True,False
45643,笑了原來是檢舉失敗只好補一刀把自己送進去,0.431479,0.568521,0.243699,0.756301,0.008307,0.991693,0.000464,0.999536,True,True,False,False,False
33960,封面好色,0.000755,0.999245,0.003814,0.996186,0.003670,0.996330,0.733030,0.266970,True,False,False,False,True
41190,事前徵求同意後，高潮就抓著繼續幹啊，你力道不會輸女 生吧？,0.006242,0.993758,0.141459,0.858541,0.016979,0.983021,0.995997,0.004003,True,False,False,False,True
75227,幹以前一堆馬奶圖馬屌圖,0.001646,0.998354,0.012402,0.987598,0.005462,0.994538,0.844625,0.155375,True,False,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
119095,樓下會肛交嗎,0.006978,0.993022,0.002936,0.997064,0.001559,0.998441,0.998201,0.001799,True,False,False,False,True
1981,一命千岩生火爆專武迪希雅在夏沃蕾隊裡開Q一拳2萬多， 算主C的模組副C的傷害，除非等隊友CD...,0.396313,0.603687,0.006362,0.993638,0.002593,0.997407,0.000055,0.999945,True,True,False,False,False
58611,會不會難道我說不定也在名冊上媽勒 整天亂放話,0.002225,0.997775,0.340448,0.659552,0.063020,0.936980,0.000422,0.999578,True,False,True,False,False
76270,洨草：對，柯師傅的老二很香，怎麼了嗎？,0.000524,0.999476,0.040455,0.959545,0.007807,0.992193,0.392141,0.607859,True,False,False,False,True


In [20]:
df.to_csv('data/subset/merged.csv', index=False)