In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
import torch
from torch import nn
import os
from tqdm import tqdm
from collections import Counter
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_csv(train_path, test_path):
    train_df = pd.read_csv(train_path)
    test_df = pd.read_csv(test_path)

    train_name = train_df[train_df['Product Class'] != 'Else']['drugName'].tolist()
    train_condition = train_df[train_df['Product Class'] != 'Else']['condition'].tolist()
    train_class = train_df[train_df['Product Class'] != 'Else']['Product Class'].tolist()

    test_name = test_df[test_df['Product Class'] != 'Else']['drugName'].tolist()
    test_condition = test_df[test_df['Product Class'] != 'Else']['condition'].tolist()
    test_class = test_df[test_df['Product Class'] != 'Else']['Product Class'].tolist()

    return train_name, train_condition, train_class, test_name, test_condition, test_class

train_name, train_condition, train_class, test_name, test_condition, test_class = load_csv('./data/drugsComTrain_raw_addclass.csv', './data/drugsComTest_raw_addclass.csv')

In [3]:
def count_unique_strings(str_list):
    counter = Counter(str_list)
    
    num_unique = len(counter)
    
    # 打印结果
    print(f"总共有 {num_unique} 个唯一的字符串。")
    print("每个字符串的出现次数如下:")
    for string, count in counter.items():
        print(f"'{string}': {count} 次")

In [4]:
Class_to_Num = {
    'Analgesics': 0,
    'Mood Stabilizers': 1,
    'Antibiotics': 2,
    'Antiseptics': 3,
    'Antimalarial': 4,
    'Antipiretics': 5,
}

Num_to_Class = {
    0: 'Analgesics',
    1: 'Mood Stabilizers',
    2: 'Antibiotics',
    3: 'Antiseptics',
    4: 'Antimalarial',
    5: 'Antipiretics',
}

In [5]:
augmented_Mood_Stabilizers = [
    'Topiramax', 'Amitriptylin', 'Lamotrigin', 'Sertralin', 'Venlafaxin', 
    'Symbyaxa', 'Oxcarbamazepine', 'Lithia', 'Mirtazapin', 'Lamictol', 
    'Jolivetta', 'Desvenlafaxin', 'Escitalopramin', 'Depakot', 'Cymbalto', 
    'Zolofta', 'Carbamazepin', 'Milnacipranol', 'Lexapram', 'Quetiapin', 
    'Citalopramin', 'Doxepine', 'Effexor XR', 'Abilify Melt', 'Prozax', 
    'Emsom', 'Divalproex', 'Celexum', 'Paxilum', 'Carbatrola', 
    'Limbitrola', 'Serzona', 'Tegretola', 'Elavila', 'Remerona', 
    'Parnata', 'Mephobarbitol', 'Depakote XR', 'Tegretol SR', 'S-adenosyl', 
    'Nefazodon', 'Brisdella', 'Desyrela', 'Trileptol', 'Imipramin', 
    'Epitola', 'Eskalitha', 'Tranylcypromin', 'Pexeva Plus', 'Fluoxapine', 
    'Paxil SR', 'Lithobida', 'Limbitrol XR', 'Valproate', 'Stavzora', 
    'Pamelora', 'Risperdal Tab', 'Equetrol', 'Gabarona', 'Depaken', 
    'Eslicarbamazepine', 'Triavila', 'Asendina', 'Lamictal SR', 'Lamictal DT', 
    'Maprotilin', 'Aventyl HCL', 'Budeprion XR', 'Buprobana', 'Sinequana', 
    'Prozax Weekly', 'Phentrida', 'Vivactila', 'Desyrel XR', 'Protriptylin', 
    'Zyprexa IM', 'Depakote CR', 'Trokendi SR', 'Topirax', 'Amitriptyl', 
    'Lamotrix', 'Sertralix', 'Venlafaxor', 'Symbyaxa XR', 'Oxcarva', 
    'Lithiuma', 'Mirtazapix', 'Lamictin', 'Jolivax', 'Desvenlafaxor', 
    'Escitopram', 'Depakine', 'Cymbaltex', 'Zoloftin', 'Carbamaz', 
    'Milnacipra', 'Lexapin', 'Quetiapix', 'Citapram', 'Doxepix', 
    'Effexor Plus', 'Abilify DT', 'Prozapine', 'Emsamix', 'Divalpro', 
    'Celexin', 'Paxilum CR', 'Carbatrolin', 'Limbitrin', 'Serzapine', 
    'Tegretin', 'Elavix', 'Remerix', 'Parnatol', 'Mephobar', 
    'Depakotex', 'Tegretin XR', 'S-adenometh', 'Nefazodix', 'Brisdol', 
    'Desyrelin', 'Trileptix', 'Imipramix', 'Epitolin', 'Eskalix', 
    'Tranylpromine', 'Pexevix', 'Fluoxapin', 'Paxilum SR', 'Lithobix', 
    'Limbitrolin', 'Valproxin', 'Stavzorin', 'Pamelorin', 'Risperin', 
    'Equetrix', 'Gabaronix', 'Depakinex', 'Eslicarba', 'Triavix', 
    'Asendix', 'Lamictin SR', 'Lamictin DT', 'Maprotix', 'Aventix', 
    'Budeprix', 'Buproxin', 'Sinequix', 'Prozap Weekly', 'Phentrix', 
    'Vivactix', 'Desyrelix', 'Protriptix', 'Zyprexin', 'Depakotix', 
    'Trokendix', 'Topiramin', 'Amitripx', 'Lamotrin', 'Sertralin', 
    'Venlafin', 'Symbyxin', 'Oxcarbin', 'Lithin', 'Mirtazin', 
    'Lamictin', 'Jolivin', 'Desvenlafin', 'Escitral', 'Depakin', 
    'Cymbalin', 'Zolofin', 'Carbazin', 'Milnacin', 'Lexarin', 
    'Quetialin', 'Citalin', 'Doxelin', 'Effexin', 'Abilifin', 
    'Prozalin', 'Emsalin', 'Divalin', 'Celetin', 'Paxilin', 
    'Carbalin', 'Limbitin', 'Serzalin', 'Tegretin', 'Elavin', 
    'Remerin', 'Parnatin', 'Mephobin', 'Depakolin', 'Tegretolin', 
    'S-adenin', 'Nefazolin', 'Brisdolin', 'Desyrelin', 'Trileptin', 
    'Imipramin', 'Epitolin', 'Eskalitin', 'Tranylin', 'Pexevolin', 
    'Fluoxalin', 'Paxilin SR', 'Lithiolin', 'Limbitolin', 'Valprolin', 
    'Stavzolin', 'Pamelorin', 'Risperolin', 'Equetrolin', 'Gabaronin', 
    'Depakolin', 'Eslicarin', 'Triavilin', 'Asendolin', 'Lamictolin', 
    'Lamictolin DT', 'Maprolin', 'Aventylin', 'Budeprolin', 'Buprolin', 
    'Sinequolin', 'Prozalin Weekly', 'Phentrolin', 'Vivactolin', 
    'Desyrolin', 'Protriptolin', 'Zyprexolin', 'Depakotolin', 'Trokendolin'
]

augmented_Antibiotics = [
    'Trimethoprimex', 'Azithromaxin', 'Doxycyclin', 'Augmentin SR', 'Macrozole', 
    'Levoxacin', 'Cephalexine', 'Aczonide', 'Epiduo Plus', 'Amoxiclavith', 
    'Sulfatrimox', 'Solodyn XR', 'Clindamyxin', 'Zianex', 'Nitrofurantin', 
    'Metronidazol', 'Moxiflocin', 'Levaquine', 'Ciproflox', 'Aveloxin', 
    'Augmentin ES', 'Terbinafex', 'Clarithromax', 'Bactrim XS', 'AmoxiClav', 
    'Ciproxin', 'Orsythia XR', 'SMZ-TMP Plus', 'Bactrim Forte', 'Ceftriaxon', 
    'Retapamulix', 'Vigamoxin', 'Cefdinirex', 'ClindaRetin', 'Nystatine', 
    'Zithromaxin', 'Minocyclin', 'Amoxilin', 'Penicillin VK Plus', 'Duac Plus', 
    'Dapsonix', 'Zyvoxin', 'Rocephin XR', 'Tindamaxin', 'MetroGel-V', 'Flagyl XR', 
    'Doryxin', 'Tinidazol', 'Penicillin VK Forte', 'Vancomycin XR', 'Oracea Plus', 
    'Biaxin XS', 'Cefuroxim', 'Keflexin', 'Vantin XR', 'Monodoxin', 
    'Bismuth Triplex', 'Silvaderm', 'Ceftinex', 'Cefpodoxim', 'Omnicef XR', 
    'Clindessin', 'Penciclovirex', 'Soolantra Plus', 'Pylera Forte', 'BenzoylEryth', 
    'Cefprozil XR', 'Amoxil Plus', 'Acanya Forte', 'Biaxin SR', 'Benzaclin Plus', 
    'Doxy 200', 'Septra Plus', 'Sulfacet-S', 'Septra Forte', 'Sulfazide', 
    'Tobramaxin', 'Prevpac Plus', 'Sulfa-Sulfur', 'Gatiflox', 'Erythromax', 
    'Bactroban Plus', 'Naftifin', 'Factive XR', 'MetroGel Plus', 'Cedaxin', 
    'Macrodantin XR', 'Linezolidin', 'Fidaxomin', 'TobraDex Plus', 'Fosfomycin XR', 
    'Paromomycine', 'Minocinex', 'Xifaxan Plus', 'Veltin XR', 'Mupirocin Plus', 
    'Hiprex Forte', 'Cefazolin XR', 'Ciprodexin', 'Lorzone Plus', 'Penicillin GX', 
    'Cleocinex', 'Rifampix', 'Gemiflox', 'Tetracyclin', 'DexaNeoPoly', 
    'Stromectol Plus', 'Isoniazidin', 'Humatin XR', 'Griseofulvine', 'Silvadiazine', 
    'Vibramaxin', 'Rifaximin XR', 'DexaTobra', 'AzithroPack', 'Rifadin XR', 
    'Cefiximax', 'Zosyn Plus', 'Zylet XR', 'Flagyl SR', 'Benzamycin Plus', 
    'Clindagel XR', 'CiproDexa', 'Zymaxin', 'Onexton Plus', 'Blephamide XR', 
    'Azasite Plus', 'Ery-Tab XR', 'Itraconex', 'Amikacin XR', 'Prednisulfa', 
    'PipTazo', 'Xolegel Plus', 'Adoxa XR', 'Amikin Plus', 'PolyTrimox', 
    'QuadraBiotic', 'Besiflox', 'Cleocin XR', 'Dynacin Plus', 'Ampicilline', 
    'Besivance Plus', 'Flagyl IV Plus', 'Declomycine', 'Loracarbef XR', 'Bicillin LAX', 
    'Dificid XR', 'Cefzil Plus', 'Monurol XR', 'Ofloxacine', 'HydroNeoPoly', 
    'Spectracef XR', 'Cefditoren Plus', 'Acticlate XR', 'Supraxin', 'Lincomycine', 
    'LoteTobra', 'Sulfatrim Plus', 'Polytrim XR', 'Goldenseal Plus', 'Ceftibuten XR', 
    'Mandelamine Plus', 'Floxin XR', 'TriBiotic', 'Cefotaxin', 'BaciPoly', 
    'Triple ABX', 'Chloromycetin', 'Altabax Plus', 'Dicloxacillin XR', 'Demeclocyclin', 
    'Vancocin XR', 'MetroCreme', 'Meropenem XR', 'Tobrexin', 'Cubicin XR', 
    'Moxezin', 'Clindamax XR', 'Norflox', 'Lincocin XR', 'Azulfidine Plus', 
    'Unasyn XR', 'AmpiSulb', 'Lorabid XR', 'PolyBiotic', 'Vancocin HX', 
    'MethenaPhos', 'Bacitracin XR', 'GramiNeoPoly', 'Noroxin XR', 'Daptomycin XR'
]

augmented_Antiseptics = [
    'Tioconazol', 'Ticonazole', 'Tioconazone', 'Tiozol', 'Ticonaz', 'Tiozole', 'Ticonzol', 'Tioconzole', 'Ticonazolum', 'Tiozolum',
    'Adapalen / benzoyl perox', 'Adapalene benzoyl perox', 'Adapalene-benzoyl peroxide', 'Adapalene BP', 'Ada-BPO', 'Adapalene/BPO', 'Adapalenoxide', 'Adapalox', 'Adapaleneperoxide', 'Adapalox BP',
    'Miconazol', 'Myconazole', 'Miconazone', 'Micozole', 'Miconazolum', 'Mycozole', 'Miconazolum', 'Micozol', 'Myconazol', 'Miconazolum',
    'Phenole', 'Phenolum', 'Fenol', 'Phenoxide', 'Phenylol', 'Phenolic', 'Phenolate', 'Phenoleum', 'Phenolicum', 'Phenoleum',
    'Coppere', 'Cuprum', 'Copperum', 'Cuprate', 'Copperol', 'Cuprol', 'Copperic', 'Cuprumox', 'Copperide', 'Cuprox',
    'Benzoyl peroxide clindamycin', 'BPO-Clindamycin', 'Benzoyl-clindamycin', 'Clindamycin-BPO', 'Benzoclin', 'Clinda-BPO', 'Benzoylperoxide-clinda', 'BPO-Clinda', 'Benzoclynd', 'Clindox',
    'Ovace Plus+', 'Ovace Pro', 'Ovace Ultra', 'Ovace Max', 'Ovace PlusX', 'Ovace Extra', 'Ovace Advanced', 'Ovace Plus Pro', 'Ovace Supreme', 'Ovace Plus Ultra',
    'Aluminium chloride hexahydrate', 'Aluminum chlor hexahyd', 'Alum chlor hex', 'AlCl3-6H2O', 'Aluminum hexachlor', 'Alumichlor', 'Hexalumin', 'Alumichlor hex', 'Aluminumchlorhydrate', 'Hexahydralum',
    'Selenium sulphide', 'Selen sulfide', 'Selenium-S', 'Selenox', 'Sulfosel', 'Selenisulf', 'Sulfenium', 'Selenex', 'Sulfoselen', 'Selenosulf',
    'Oxistat XR', 'Oxistatin', 'Oxistat Plus', 'Oxistat Pro', 'Oxistat Ultra', 'Oxistat Max', 'Oxistat-X', 'Oxistat PlusX', 'Oxistat Advanced', 'Oxistat Supreme',
    'Clarifoam EF Plus', 'Clarifoam XF', 'Clarifoam Pro', 'Clarifoam Ultra', 'Clarifoam Max', 'Clarifoam Advanced', 'Clarifoam PlusX', 'Clarifoam Supreme', 'Clarifoam XT', 'Clarifoam EFX',
    'Azelaic', 'Azelac', 'Azelexin', 'Azelaicum', 'Azelox', 'Azelaicin', 'Azelaine', 'Azelaic acidum', 'Azelac acid', 'Azelox acid',
    'Benzoylperoxide', 'Benzox', 'Benzoper', 'Benzoylox', 'Benzoperox', 'Benzoxol', 'Benzoperoxide', 'Benzoxal', 'Benzoyl perox', 'Benzoxum',
    'Benzoic-salicylic acid', 'Benzo-salicylate', 'Benzo-sal', 'Benzoate-salicylate', 'Benzosal', 'Benzo-sal acid', 'Benzosaly', 'Benzosalic', 'Benzosalix', 'Benzosalum',
    'Econazol', 'Econazone', 'Econaz', 'Ecozole', 'Econazolum', 'Econazolin', 'Econazal', 'Econazoline', 'Econazolium', 'Econazide',
    'Clotrimazol', 'Clotrimazone', 'Clotrizol', 'Clotrimazolum', 'Clotrizole', 'Clotrimaz', 'Clotrizolum', 'Clotrimazolin', 'Clotrizal', 'Clotrimazide',
    'Coaltar', 'Coal-tar', 'Tarcoal', 'Carbonis tar', 'Coalum', 'Coaltarum', 'Coaltarol', 'Coaltar extract', 'Coaltaricum', 'Coaltaride',
    'Salicylic-sulfur', 'Sal-sulfur', 'Salisulf', 'Salic sulfur', 'Salicyl-sulf', 'Sulfur-sal', 'Salisulfur', 'Sal-sulf', 'Salicylsulf', 'Sulfosal',
    'Methenamin', 'Methenamineum', 'Methamine', 'Methenam', 'Methoxamine', 'Methenamum', 'Methoxamin', 'Methenamal', 'Methenamium', 'Methoxam',
    'PanOxyl Plus', 'PanOxyl Pro', 'PanOxyl Ultra', 'PanOxyl Max', 'PanOxyl XR', 'PanOxyl Advanced', 'PanOxyl PlusX', 'PanOxyl Supreme', 'PanOxyl XT', 'PanOxyl FX',
    'Glycerine', 'Glycerinum', 'Glycerol', 'Glycerolum', 'Glyceral', 'Glyceride', 'Glycerox', 'Glyceralum', 'Glycerinium', 'Glyceroxol',
    'Benzoyl peroxide hydrocortisone', 'BPO-HC', 'Benzoyl-hc', 'Hydrobenz', 'Benzohydro', 'BPO-cort', 'Benzocort', 'Hydroxybenz', 'Benzohydrox', 'BPO-hydro',
    'Xerac AC Plus', 'Xerac Ultra', 'Xerac Pro', 'Xerac Max', 'Xerac PlusX', 'Xerac Advanced', 'Xerac Supreme', 'Xerac XT', 'Xerac FX', 'Xerac ACX',
    'Povidone-iodine', 'Povidine iodine', 'Povidine-iodine', 'Povidonum iodine', 'Povidine', 'Povidineum', 'Povidoniodine', 'Povidine iod', 'Povidon-iod', 'Poviodine',
    'Allantoin-camphor-phenol', 'Allantoin-camphor-phen', 'Allantoin-camphor', 'Allantoin-phenol', 'Allantoin-cam-phen', 'Allantoin-camph', 'Allantoin-phen-camph', 'Allantoin-camphol', 'Allantoin-camphorol', 'Allantoin-phenox',
    'Sodium hypochlor', 'Na-hypochlorite', 'Sodium oxychloride', 'Sodium chlorox', 'Hypochlor-Na', 'Sodium chloroxide', 'NaOCl', 'Sodium chloroxite', 'Hypochlorite-Na', 'Sodium oxychlor',
    'Biotene Mouth Rinse', 'Biotene Oral Rinse', 'Biotene Mouthwash Plus', 'Biotene Oral Wash', 'Biotene Mouthwash Pro', 'Biotene Mouth Rinse Pro', 'Biotene Oral Rinse Plus', 'Biotene Mouthwash Ultra', 'Biotene Oral Wash Pro', 'Biotene Mouth Rinse Ultra',
    'Spectazol', 'Spectazolin', 'Spectazone', 'Spectazolum', 'Spectaz', 'Spectazoleum', 'Spectazolinum', 'Spectazolide', 'Spectazoline', 'Spectazolium',
    'Undecylenate', 'Undecylenic', 'Undecylenate acid', 'Undecylenium', 'Undecylen', 'Undecylenox', 'Undecylenateum', 'Undecylenol', 'Undecylenide', 'Undecylenoxol',
    'Rozex Plus', 'Rozex Pro', 'Rozex Ultra', 'Rozex Max', 'Rozex XR', 'Rozex Advanced', 'Rozex PlusX', 'Rozex Supreme', 'Rozex XT', 'Rozex FX',
    'Azelex Plus', 'Azelex Pro', 'Azelex Ultra', 'Azelex Max', 'Azelex XR', 'Azelex Advanced', 'Azelex PlusX', 'Azelex Supreme', 'Azelex XT', 'Azelex FX',
    'Calmoseptin', 'Calmoseptineum', 'Calmosept', 'Calmoseptol', 'Calmoseptinum', 'Calmoseptoxide', 'Calmoseptal', 'Calmoseptinide', 'Calmoseptolium', 'Calmoseptox',
    'Aloe', 'Aloevera', 'Aloe extract', 'Aloe vera gel', 'Aloe barbadensis', 'Aloe leaf', 'Aloe vera extract', 'Aloe plant', 'Aloe vera leaf', 'Aloe barb',
    'Peridex Plus', 'Peridex Pro', 'Peridex Ultra', 'Peridex Max', 'Peridex XR', 'Peridex Advanced', 'Peridex PlusX', 'Peridex Supreme', 'Peridex XT', 'Peridex FX',
    'Silverum', 'Argentum', 'Silveride', 'Silverol', 'Silverate', 'Argentate', 'Silveroxide', 'Argentol', 'Silverex', 'Argentumox',
    'Zinc ox', 'Zincoxide', 'Zincum ox', 'Zincox', 'Zinc oxid', 'Zincum oxide', 'Zincoxal', 'Zinc oxal', 'Zincoxideum', 'Zinc oxum',
    'Ciclopiroxum', 'Ciclopiroxol', 'Ciclopiroxide', 'Ciclopiroxal', 'Ciclopiroxate', 'Ciclopiroxolium', 'Ciclopiroxin', 'Ciclopiroxolide', 'Ciclopiroxanum', 'Ciclopiroxolum'
]


augmented_Antimalarial = [
    'Hydroxychloroquin', 'Hydrochloroquine', 'Hydroxyquine', 'Hydroxychlorin', 
    'Chlorohydroquine', 'Hydroquine', 'Hydroxychloroquin', 'Hydroxychloroquina',
    'Hydroxychloroquinum', 'Hydroxychloroquin', 'Hydroxychloriquine', 'Hydroxychloroquinon',
    'Hydroxycloroquine', 'Hydoxychloroquine', 'Hydroxychlorquine', 'Hydroxychloroquinil',
    'Hydroxychloroquis', 'Hydroxychloroquinix', 'Hydroxychloroquinide', 'Hydroxychloroquinone',
    'Malaron', 'Malarona', 'Malaronia', 'Malaronex', 'Malaride', 'Malarox', 
    'Malaroneplus', 'Malarzone', 'Malaquin', 'Malaronex', 'Malaroxine', 'Malarol',
    'Malaronexel', 'Malaquinone', 'Malaroprim', 'Malarofan', 'Malarophene', 'Malarizine',
    'Coartam', 'Coartum', 'Coartema', 'Coartemis', 'Coartemix', 'Coartemol', 
    'Coartemide', 'Coartemar', 'Coartemex', 'Coartemal', 'Coartemone', 'Coartemox',
    'Coartemine', 'Coartemil', 'Coartemolix', 'Coartemazine', 'Coartemazole', 'Coartemivir',
    'Quinina', 'Quinone', 'Quinidex', 'Quinidine', 'Quinil', 'Quinor', 
    'Quinazol', 'Quinax', 'Quinazolone', 'Quinazolide', 'Quinazolix', 'Quinazoline',
    'Quinazolamine', 'Quinazolamide', 'Quinazolium', 'Quinazolir', 'Quinazolivir', 'Quinazolifan',
    'Plaquenilix', 'Plaquenol', 'Plaquenar', 'Plaquenex', 'Plaquenide', 'Plaquenivir',
    'Plaquenazine', 'Plaquenazol', 'Plaquenam', 'Plaquenor', 'Plaquenox', 'Plaquenilone',
    'Plaquenilor', 'Plaquenilorix', 'Plaquenilix', 'Plaquenilide', 'Plaquenilivir', 'Plaquenilazole',
    'Mefloquin', 'Mefloquina', 'Mefloquinex', 'Mefloquinal', 'Mefloquinide', 'Mefloquinone',
    'Mefloquinar', 'Mefloquinix', 'Mefloquinol', 'Mefloquinazol', 'Mefloquinazine', 'Mefloquinivir',
    'Mefloquinamide', 'Mefloquinazolix', 'Mefloquinazolide', 'Mefloquinazolone', 'Mefloquinazolium', 'Mefloquinazolir',
    'Atovaquon', 'Atovaquona', 'Atovaquonex', 'Atovaquin', 'Atovaquinal', 'Atovaquinide',
    'Atovaquinone', 'Atovaquinar', 'Atovaquinix', 'Atovaquinol', 'Atovaquinazol', 'Atovaquinazine',
    'Atovaquinivir', 'Atovaquinamide', 'Atovaquinazolix', 'Atovaquinazolide', 'Atovaquinazolone', 'Atovaquinazolium',
    'Proguanila', 'Proguanilix', 'Proguanilide', 'Proguanilone', 'Proguanilium', 'Proguanilir',
    'Proguanivir', 'Proguanazole', 'Proguanazine', 'Proguanamide', 'Proguanazolix', 'Proguanazolide',
    'Proguanazolone', 'Proguanazolium', 'Proguanazolir', 'Proguanazolivir', 'Proguanazolamide', 'Proguanazolazine',
    'Daraprima', 'Daraprimix', 'Daraprimide', 'Daraprimone', 'Daraprimium', 'Daraprimir',
    'Daraprivir', 'Daraprazole', 'Daraprazine', 'Darapramide', 'Daraprazolix', 'Daraprazolide',
    'Daraprazolone', 'Daraprazolium', 'Daraprazolir', 'Daraprazolivir', 'Daraprazolamide', 'Daraprazolazine',
    'Meprona', 'Mepronix', 'Mepronide', 'Mepronone', 'Mepronium', 'Mepronir',
    'Meprovir', 'Meprazole', 'Meprazine', 'Mepramide', 'Meprazolix', 'Meprazolide',
    'Meprazolone', 'Meprazolium', 'meprazolir', 'Meprazolivir', 'Meprazolamide', 'Meprazolazine',
    'Lariama', 'Lariamix', 'Lariamide', 'Lariamone', 'Lariamium', 'Lariamir',
    'Lariavir', 'Lariazole', 'Lariazine', 'Lariamide', 'Lariazolix', 'Lariazolide',
    'Lariazolone', 'Lariazolium', 'Lariazolir', 'Lariazolivir', 'Lariazolamide', 'Lariazolazine',
    'Xartemix', 'Xartemide', 'Xartemone', 'Xartemium', 'Xartemir', 'Xartemivir',
    'Xartemazole', 'Xartemazine', 'Xartemamide', 'Xartemazolix', 'Xartemazolide', 'Xartemazolone',
    'Xartemazolium', 'Xartemazolir', 'Xartemazolivir', 'Xartemazolamide', 'Xartemazolazine', 'Xartemox',
    'Pyrimethamin', 'Pyrimethamix', 'Pyrimethamide', 'Pyrimethamone', 'Pyrimethamium', 'Pyrimethamir',
    'Pyrimethavir', 'Pyrimethazole', 'Pyrimethazine', 'Pyrimethamide', 'Pyrimethazolix', 'Pyrimethazolide',
    'Pyrimethazolone', 'Pyrimethazolium', 'Pyrimethazolir', 'Pyrimethazolivir', 'Pyrimethazolamide', 'Pyrimethazolazine',
    'Sulfadoxin', 'Sulfadoxina', 'Sulfadoxinex', 'Sulfadoxinal', 'Sulfadoxinide', 'Sulfadoxinone',
    'Sulfadoxinar', 'Sulfadoxinix', 'Sulfadoxinol', 'Sulfadoxinazol', 'Sulfadoxinazine', 'Sulfadoxinivir',
    'Sulfadoxinamide', 'Sulfadoxinazolix', 'Sulfadoxinazolide', 'Sulfadoxinazolone', 'Sulfadoxinazolium', 'Sulfadoxinazolir',
    'Qualaquina', 'Qualaquinix', 'Qualaquinide', 'Qualaquinone', 'Qualaquinium', 'Qualaquinir',
    'Qualaquivir', 'Qualaquinazole', 'Qualaquinazine', 'Qualaquinamide', 'Qualaquinazolix', 'Qualaquinazolide',
    'Qualaquinazolone', 'Qualaquinazolium', 'Qualaquinazolir', 'Qualaquinazolivir', 'Qualaquinazolamide', 'Qualaquinazolazine',
    'Artemetherin', 'Artemethix', 'Artemethide', 'Artemethone', 'Artemethium', 'Artemethir',
    'Artemethivir', 'Artemethazole', 'Artemethazine', 'Artemethamide', 'Artemethazolix', 'Artemethazolide',
    'Artemethazolone', 'Artemethazolium', 'Artemethazolir', 'Artemethazolivir', 'Artemethazolamide', 'Artemethazolazine',
    'Lumefantrin', 'Lumefantrix', 'Lumefantride', 'Lumefantrone', 'Lumefantrium', 'Lumefantrir',
    'Lumefantrivir', 'Lumefantrazole', 'Lumefantrazine', 'Lumefantramide', 'Lumefantrazolix', 'Lumefantrazolide',
    'Lumefantrazolone', 'Lumefantrazolium', 'Lumefantrazolir', 'Lumefantrazolivir', 'Lumefantrazolamide', 'Lumefantrazolazine',
    'Fansidara', 'Fansidarix', 'Fansidaride', 'Fansidarone', 'Fansidarium', 'Fansidarir',
    'Fansidavir', 'Fansidazole', 'Fansidazine', 'Fansidamide', 'Fansidazolix', 'Fansidazolide',
    'Fansidazolone', 'Fansidazolium', 'Fansidazolir', 'Fansidazolivir', 'Fansidazolamide', 'Fansidazolazine',
    'Antimalarone', 'Antimalarox', 'Antimalaride', 'Antimalarivir', 'Antimalarazole', 'Antimalarazine',
    'Antimalaramide', 'Antimalarazolix', 'Antimalarazolide', 'Antimalarazolone', 'Antimalarazolium', 'Antimalarazolir',
    'Antimalarazolivir', 'Antimalarazolamide', 'Antimalarazolazine', 'Antimalarquin', 'Antimalarquinone', 'Antimalarquinide',
    'Malarix', 'Malarivir', 'Malarazole', 'Malarazine', 'Malaramide', 'Malarazolix',
    'Malarazolide', 'Malarazolone', 'Malarazolium', 'Malarazolir', 'Malarazolivir', 'Malarazolamide',
    'Malarazolazine', 'Malarquin', 'Malarquinone', 'Malarquinide', 'Malarquinium', 'Malarquinir',
    'Chloroquinix', 'Chloroquinivir', 'Chloroquinazole', 'Chloroquinazine', 'Chloroquinamide', 'Chloroquinazolix',
    'Chloroquinazolide', 'Chloroquinazolone', 'Chloroquinazolium', 'Chloroquinazolir', 'Chloroquinazolivir', 'Chloroquinazolamide',
    'Chloroquinazolazine', 'Chloroquinal', 'Chloroquinone', 'Chloroquinide', 'Chloroquinium', 'Chloroquinir'
]

augmented_Antipiretics = [
    'Acetaminophen / caffeine', 'Acetaminophen', 'Aspirin', 'Acetaminophen / diphenhydramine', 
    'Bayer Aspirin', 'Tylenol', 'Vivarin', 'Tylenol 8 Hour', 'Acetaminophen / phenyltoloxamine', 
    'Acetaminofen / phenylephrine', 'Feverall', 'Acetaminophen / aspirin', 
    'Vicks Dayquil Cold & Flu Relief', 'Alka-Seltzer Plus Cold Formula Sparkling Original Effervescent Tablets',
    'Acephen', 'Acephen Plus', 'Acephen Extra', 'Acephen Rapid Release', 'Acephen PM', 
    'Acephen Cold & Flu', 'Acephen Sinus', 'Acephen Headache', 'Acephen Migraine', 
    'Acephen Arthritis', 'Acephen Junior', 'Acephen Infant', 'Acephen Pediatric', 
    'Acephen Liquid', 'Acephen Chewable', 'Acephen Caplets', 'Acephen Coated', 
    'Acephen Time Release', 'Acephen Forte', 'Acephen Maximum Strength', 
    'Acephen Allergy Relief', 'Acephen Nighttime', 'Acephen Daytime', 
    'Acephen Sinus Relief', 'Acephen Multi-Symptom', 'Acephen Extra Strength', 
    'Acephen Rapid Melt', 'Acephen Softgels', 'Acephen Coated Tablets', 
    'Acephen Extended Release', 'Acephen Dual Action', 'Acephen Triple Action', 
    'Acephen Plus Cold', 'Acephen Plus Flu', 'Acephen Plus Sinus', 
    'Acephen Plus Allergy', 'Acephen Plus Headache', 'Acephen Plus Migraine', 
    'Acephen Plus Pain Relief', 'Acephen Plus Fever Reducer', 'Acephen PM Extra', 
    'Acephen PM Maximum', 'Acephen PM Rapid', 'Acephen PM Liquid', 
    'Acephen PM Caplets', 'Acephen PM Softgels', 'Acephen PM Coated', 
    'Acephen PM Time Release', 'Acephen PM Forte', 'Acephen PM Nighttime', 
    'Acephen Cold Relief', 'Acephen Cold Max', 'Acephen Cold Daytime', 
    'Acephen Cold Nighttime', 'Acephen Cold Liquid', 'Acephen Cold Chewable', 
    'Acephen Cold Caplets', 'Acephen Cold Softgels', 'Acephen Cold Coated', 
    'Acephen Cold Extended', 'Acephen Cold Dual', 'Acephen Cold Triple', 
    'Acephen Flu Relief', 'Acephen Flu Max', 'Acephen Flu Daytime', 
    'Acephen Flu Nighttime', 'Acephen Flu Liquid', 'Acephen Flu Chewable', 
    'Acephen Flu Caplets', 'Acephen Flu Softgels', 'Acephen Flu Coated', 
    'Acephen Flu Extended', 'Acephen Flu Dual', 'Acephen Flu Triple', 
    'Acephen Sinus Max', 'Acephen Sinus Daytime', 'Acephen Sinus Nighttime', 
    'Acephen Sinus Liquid', 'Acephen Sinus Chewable', 'Acephen Sinus Caplets', 
    'Acephen Sinus Softgels', 'Acephen Sinus Coated', 'Acephen Sinus Extended', 
    'Acephen Sinus Dual', 'Acephen Sinus Triple', 'Acephen Allergy Max', 
    'Acephen Allergy Daytime', 'Acephen Allergy Nighttime', 'Acephen Allergy Liquid', 
    'Acephen Allergy Chewable', 'Acephen Allergy Caplets', 'Acephen Allergy Softgels', 
    'Acephen Allergy Coated', 'Acephen Allergy Extended', 'Acephen Allergy Dual', 
    'Acephen Allergy Triple', 'Acephen Headache Relief', 'Acephen Headache Max', 
    'Acephen Headache Daytime', 'Acephen Headache Nighttime', 'Acephen Headache Liquid', 
    'Acephen Headache Chewable', 'Acephen Headache Caplets', 'Acephen Headache Softgels', 
    'Acephen Headache Coated', 'Acephen Headache Extended', 'Acephen Headache Dual', 
    'Acephen Headache Triple', 'Acephen Migraine Relief', 'Acephen Migraine Max', 
    'Acephen Migraine Daytime', 'Acephen Migraine Nighttime', 'Acephen Migraine Liquid', 
    'Acephen Migraine Chewable', 'Acephen Migraine Caplets', 'Acephen Migraine Softgels', 
    'Acephen Migraine Coated', 'Acephen Migraine Extended', 'Acephen Migraine Dual', 
    'Acephen Migraine Triple', 'Acephen Arthritis Relief', 'Acephen Arthritis Max', 
    'Acephen Arthritis Daytime', 'Acephen Arthritis Nighttime', 'Acephen Arthritis Liquid', 
    'Acephen Arthritis Chewable', 'Acephen Arthritis Caplets', 'Acephen Arthritis Softgels', 
    'Acephen Arthritis Coated', 'Acephen Arthritis Extended', 'Acephen Arthritis Dual', 
    'Acephen Arthritis Triple', 'Acephen Junior Relief', 'Acephen Junior Max', 
    'Acephen Junior Daytime', 'Acephen Junior Nighttime', 'Acephen Junior Liquid', 
    'Acephen Junior Chewable', 'Acephen Junior Caplets', 'Acephen Junior Softgels', 
    'Acephen Junior Coated', 'Acephen Junior Extended', 'Acephen Junior Dual', 
    'Acephen Junior Triple', 'Acephen Infant Relief', 'Acephen Infant Max', 
    'Acephen Infant Daytime', 'Acephen Infant Nighttime', 'Acephen Infant Liquid', 
    'Acephen Infant Drops', 'Acephen Infant Suspension', 'Acephen Infant Syrup', 
    'Acephen Pediatric Relief', 'Acephen Pediatric Max', 'Acephen Pediatric Daytime', 
    'Acephen Pediatric Nighttime', 'Acephen Pediatric Liquid', 'Acephen Pediatric Drops', 
    'Acephen Pediatric Suspension', 'Acephen Pediatric Syrup', 'Acephen Liquid Relief', 
    'Acephen Liquid Max', 'Acephen Liquid Daytime', 'Acephen Liquid Nighttime', 
    'Acephen Liquid Fast Acting', 'Acephen Liquid Rapid Release', 'Acephen Chewable Relief', 
    'Acephen Chewable Max', 'Acephen Chewable Daytime', 'Acephen Chewable Nighttime', 
    'Acephen Chewable Fast Acting', 'Acephen Chewable Rapid Release', 'Acephen Caplet Relief', 
    'Acephen Caplet Max', 'Acephen Caplet Daytime', 'Acephen Caplet Nighttime', 
    'Acephen Caplet Fast Acting', 'Acephen Caplet Rapid Release', 'Acephen Coated Relief', 
    'Acephen Coated Max', 'Acephen Coated Daytime', 'Acephen Coated Nighttime', 
    'Acephen Coated Fast Acting', 'Acephen Coated Rapid Release', 'Acephen Time Release Relief', 
    'Acephen Time Release Max', 'Acephen Time Release Daytime', 'Acephen Time Release Nighttime', 
    'Acephen Time Release Fast Acting', 'Acephen Time Release Rapid Release', 'Acephen Forte Relief', 
    'Acephen Forte Max', 'Acephen Forte Daytime', 'Acephen Forte Nighttime', 
    'Acephen Forte Fast Acting', 'Acephen Forte Rapid Release', 'Acephen Maximum Strength Relief', 
    'Acephen Maximum Strength Max', 'Acephen Maximum Strength Daytime', 
    'Acephen Maximum Strength Nighttime', 'Acephen Maximum Strength Fast Acting', 
    'Acephen Maximum Strength Rapid Release', 'Acephen Allergy Relief Max', 
    'Acephen Allergy Relief Daytime', 'Acephen Allergy Relief Nighttime', 
    'Acephen Allergy Relief Fast Acting', 'Acephen Allergy Relief Rapid Release', 
    'Acephen Nighttime Relief', 'Acephen Nighttime Max', 'Acephen Nighttime Fast Acting', 
    'Acephen Nighttime Rapid Release', 'Acephen Daytime Relief', 'Acephen Daytime Max', 
    'Acephen Daytime Fast Acting', 'Acephen Daytime Rapid Release', 'Acephen Sinus Relief Max', 
    'Acephen Sinus Relief Daytime', 'Acephen Sinus Relief Nighttime', 
    'Acephen Sinus Relief Fast Acting', 'Acephen Sinus Relief Rapid Release', 
    'Acephen Multi-Symptom Relief', 'Acephen Multi-Symptom Max', 
    'Acephen Multi-Symptom Daytime', 'Acephen Multi-Symptom Nighttime', 
    'Acephen Multi-Symptom Fast Acting', 'Acephen Multi-Symptom Rapid Release', 
    'Acephen Extra Strength Relief', 'Acephen Extra Strength Max', 
    'Acephen Extra Strength Daytime', 'Acephen Extra Strength Nighttime', 
    'Acephen Extra Strength Fast Acting', 'Acephen Extra Strength Rapid Release', 
    'Acephen Rapid Melt Relief', 'Acephen Rapid Melt Max', 'Acephen Rapid Melt Daytime', 
    'Acephen Rapid Melt Nighttime', 'Acephen Rapid Melt Fast Acting', 
    'Acephen Rapid Melt Rapid Release', 'Acephen Softgels Relief', 'Acephen Softgels Max', 
    'Acephen Softgels Daytime', 'Acephen Softgels Nighttime', 'Acephen Softgels Fast Acting', 
    'Acephen Softgels Rapid Release', 'Acephen Coated Tablets Relief', 
    'Acephen Coated Tablets Max', 'Acephen Coated Tablets Daytime', 
    'Acephen Coated Tablets Nighttime', 'Acephen Coated Tablets Fast Acting', 
    'Acephen Coated Tablets Rapid Release', 'Acephen Extended Release Relief', 
    'Acephen Extended Release Max', 'Acephen Extended Release Daytime', 
    'Acephen Extended Release Nighttime', 'Acephen Extended Release Fast Acting', 
    'Acephen Extended Release Rapid Release', 'Acephen Dual Action Relief', 
    'Acephen Dual Action Max', 'Acephen Dual Action Daytime', 'Acephen Dual Action Nighttime', 
    'Acephen Dual Action Fast Acting', 'Acephen Dual Action Rapid Release', 
    'Acephen Triple Action Relief', 'Acephen Triple Action Max', 
    'Acephen Triple Action Daytime', 'Acephen Triple Action Nighttime', 
    'Acephen Triple Action Fast Acting', 'Acephen Triple Action Rapid Release', 
    'Acephen Plus Cold Relief', 'Acephen Plus Cold Max', 'Acephen Plus Cold Daytime', 
    'Acephen Plus Cold Nighttime', 'Acephen Plus Cold Fast Acting', 
    'Acephen Plus Cold Rapid Release', 'Acephen Plus Flu Relief', 'Acephen Plus Flu Max', 
    'Acephen Plus Flu Daytime', 'Acephen Plus Flu Nighttime', 'Acephen Plus Flu Fast Acting', 
    'Acephen Plus Flu Rapid Release', 'Acephen Plus Sinus Relief', 'Acephen Plus Sinus Max', 
    'Acephen Plus Sinus Daytime', 'Acephen Plus Sinus Nighttime', 
    'Acephen Plus Sinus Fast Acting', 'Acephen Plus Sinus Rapid Release', 
    'Acephen Plus Allergy Relief', 'Acephen Plus Allergy Max', 
    'Acephen Plus Allergy Daytime', 'Acephen Plus Allergy Nighttime', 
    'Acephen Plus Allergy Fast Acting', 'Acephen Plus Allergy Rapid Release', 
    'Acephen Plus Headache Relief', 'Acephen Plus Headache Max', 
    'Acephen Plus Headache Daytime', 'Acephen Plus Headache Nighttime', 
    'Acephen Plus Headache Fast Acting', 'Acephen Plus Headache Rapid Release', 
    'Acephen Plus Migraine Relief', 'Acephen Plus Migraine Max', 
    'Acephen Plus Migraine Daytime', 'Acephen Plus Migraine Nighttime', 
    'Acephen Plus Migraine Fast Acting', 'Acephen Plus Migraine Rapid Release', 
    'Acephen Plus Pain Relief', 'Acephen Plus Pain Max', 'Acephen Plus Pain Daytime', 
    'Acephen Plus Pain Nighttime', 'Acephen Plus Pain Fast Acting', 
    'Acephen Plus Pain Rapid Release', 'Acephen Plus Fever Reducer', 
    'Acephen Plus Fever Max', 'Acephen Plus Fever Daytime', 'Acephen Plus Fever Nighttime', 
    'Acephen Plus Fever Fast Acting', 'Acephen Plus Fever Rapid Release'
]

In [6]:
def data_preprocess(train_name, train_condition, train_class, test_name, test_condition, test_class):
    train_combine = []
    test_combine = []

    train_class_num = []
    test_class_num = []

    for drug_name, drug_class in zip(train_name, train_class):
        if drug_name in train_combine:
            continue
        else:
            train_combine.append(drug_name)
            train_class_num.append(Class_to_Num[drug_class])
    
    for drug_name, drug_class in zip(test_name, test_class):
        if drug_name in test_combine:
            continue
        else:
            test_combine.append(drug_name)
            test_class_num.append(Class_to_Num[drug_class])



    return train_combine, train_class_num, test_combine, test_class_num


train_combine, train_class_num, test_combine, test_class_num = data_preprocess(train_name, train_condition, train_class, test_name, test_condition, test_class)


def check_unique(augmented_classes):
    for lst in augmented_classes:
        count = Counter(lst)  # 统计每个元素的出现次数
        duplicates = {key: value for key, value in count.items() if value > 1}
        for element, freq in duplicates.items():
            print(f"元素 {element} 重复次数: {freq}")
    return [len(lst) == len(set(lst)) for lst in augmented_classes]


def clean(augmented_classes):
    output = []
    for lst in augmented_classes:
        seen = set()  # 用于记录已遇到的元素
        result = []
        for item in lst:
            if item not in seen:
                result.append(item)
                seen.add(item)
        output.append(result)
    return output


def drug_name_filter(augmented_classes, classes):
    repeat = 0
    for augmented_class, class_ in zip(augmented_classes, classes):
        for augmented in augmented_class:
            if augmented in train_combine or augmented in test_combine:
                print(augmented)
                repeat = repeat + 1
                continue
            else:
                train_combine.append(augmented)
                train_class_num.append(class_)
    print(f"重复：{repeat}")
    


print(check_unique([augmented_Mood_Stabilizers, augmented_Antibiotics, augmented_Antiseptics, augmented_Antimalarial, augmented_Antipiretics]))
augmented_Mood_Stabilizers, augmented_Antibiotics, augmented_Antiseptics, augmented_Antimalarial, augmented_Antipiretics = clean([augmented_Mood_Stabilizers, augmented_Antibiotics, augmented_Antiseptics, augmented_Antimalarial, augmented_Antipiretics])
print(check_unique([augmented_Mood_Stabilizers, augmented_Antibiotics, augmented_Antiseptics, augmented_Antimalarial, augmented_Antipiretics]))


drug_name_filter([augmented_Mood_Stabilizers, augmented_Antibiotics, augmented_Antiseptics, augmented_Antimalarial, augmented_Antipiretics], [1, 2, 3, 4, 5])

# count_unique_strings(train_class_num)
# count_unique_strings(test_class_num)

元素 Sertralin 重复次数: 2
元素 Imipramin 重复次数: 2
元素 Lamictin 重复次数: 2
元素 Tegretin 重复次数: 2
元素 Desyrelin 重复次数: 2
元素 Epitolin 重复次数: 2
元素 Pamelorin 重复次数: 2
元素 Depakolin 重复次数: 2
元素 Miconazolum 重复次数: 3
元素 Phenoleum 重复次数: 2
元素 Hydroxychloroquin 重复次数: 3
元素 Malaronex 重复次数: 2
元素 Plaquenilix 重复次数: 2
元素 Lariamide 重复次数: 2
元素 Pyrimethamide 重复次数: 2
元素 Acephen Plus Pain Relief 重复次数: 2
元素 Acephen Plus Fever Reducer 重复次数: 2
[False, True, False, False, False]
[True, True, True, True, True]
Acetaminophen / caffeine
Acetaminophen
Aspirin
Acetaminophen / diphenhydramine
Bayer Aspirin
Tylenol
Vivarin
Tylenol 8 Hour
Acetaminophen / phenyltoloxamine
Feverall
Acetaminophen / aspirin
Vicks Dayquil Cold & Flu Relief
Alka-Seltzer Plus Cold Formula Sparkling Original Effervescent Tablets
重复：13


In [7]:
count_unique_strings(train_class_num)
count_unique_strings(test_class_num)

总共有 6 个唯一的字符串。
每个字符串的出现次数如下:
'0': 369 次
'1': 304 次
'2': 386 次
'3': 438 次
'4': 391 次
'5': 329 次
总共有 6 个唯一的字符串。
每个字符串的出现次数如下:
'1': 62 次
'2': 141 次
'3': 42 次
'0': 264 次
'4': 9 次
'5': 11 次


In [8]:
def check_length(train_combine, test_combine):
    print(f"train: {sorted(len(s) for s in train_combine)}")
    print(f"test: {sorted(len(s) for s in test_combine)}")
    print(f"max: {max(max(len(s) for s in train_combine), max(sorted(len(s) for s in test_combine)))}")


print(f"train: {len(train_combine)}")
print(f"test: {len(test_combine)}")
check_length(train_combine, test_combine)

train: 2217
test: 529
train: [3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8

In [9]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

unk_token_id = tokenizer.convert_tokens_to_ids('[UNK]')
print(unk_token_id)

100


In [10]:
def tokenize(tokenizer, train_reviews, test_reviews):
    train_reviews_token = [tokenizer.encode_plus(
    text,
    truncation=True,
    add_special_tokens=True,
    max_length=50,            
    pad_to_max_length=True,  
    return_attention_mask=True,  
    return_tensors='pt',      
    ) for text in train_reviews]

    test_reviews_token = [tokenizer.encode_plus(
    text,
    truncation=True,
    add_special_tokens=True,
    max_length=50,            
    pad_to_max_length=True,  
    return_attention_mask=True,  
    return_tensors='pt',      
    ) for text in test_reviews]

    return train_reviews_token, test_reviews_token


train_combines_token, test_combines_token = tokenize(tokenizer, train_combine, test_combine)



In [11]:
def check_unknown(train_combines_token, test_combines_token):

    train_combines_token_input_ids = [_['input_ids'] for _ in train_combines_token]
    test_combines_token_input_ids = [_['input_ids'] for _ in test_combines_token]

    train_count = sum(1 for t in train_combines_token_input_ids if (t == 100).any())
    test_count = sum(1 for t in test_combines_token_input_ids if (t == 100).any())
    print(f"train unknown: {train_count}")
    print(f"test unknown: {test_count}")
    
check_unknown(train_combines_token, test_combines_token)

train unknown: 0
test unknown: 0


In [12]:
class Review_Rating_Dataset(torch.utils.data.Dataset):
    def __init__(self, reviews_token, class_):
        self.review = reviews_token
        self.class_ = class_
 
    def __getitem__(self, idx):
        item = {k: v.squeeze(dim=0) for k, v in self.review[idx].items()}
        item["class"] = torch.tensor(self.class_[idx])
        return item
 
    def __len__(self):
        return len(self.class_)


train_dataset = Review_Rating_Dataset(train_combines_token, train_class_num)
test_dataset = Review_Rating_Dataset(test_combines_token, test_class_num)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [14]:
class BertWithMLP(nn.Module):
    def __init__(self, bert, hidden_size=768, mlp_hidden_size1=1024, mlp_hidden_size2 =256, num_classes=10):
        super(BertWithMLP, self).__init__()
        self.bert = bert
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_size2),
            nn.ReLU(),
            # nn.Dropout(0.2),
            # nn.Linear(mlp_hidden_size1, mlp_hidden_size2),
            # nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(mlp_hidden_size2, num_classes)
        )
    
    def forward(self, input_ids, attention_mask):

        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        cls = outputs.last_hidden_state[:, 0, :]
        
        logits = self.mlp(cls)
        
        return logits

In [15]:
def train_epoch(model, dataloader, optimizer, scheduler, criterion, device):
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    # total_error = 0.0
    
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['class'].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        # loss = criterion(outputs, labels.float())
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        # preds = torch.round(outputs)
        preds = torch.argmax(outputs, dim=-1)
        correct_predictions += torch.sum(preds == labels)
        # total_error += torch.sum(torch.abs(labels - outputs))
        total_loss += loss.item()
        
        # 更新进度条显示
        progress_bar.set_postfix({
            'loss': loss.item(),
            'acc': torch.sum(preds == labels).item()/len(labels),
            # 'error': torch.mean(torch.abs(labels - outputs)).item()
        })
    
    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions.double() / len(dataloader.dataset)
    # error = total_error.item() / len(dataloader.dataset)
    # return avg_loss, accuracy, error
    return avg_loss, accuracy


def eval_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    # total_error = 0.0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Evaluating", leave=False)
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['class'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            # loss = criterion(outputs, labels.float())
            
            preds = torch.argmax(outputs, dim=-1)
            correct_predictions += torch.sum(preds == labels)
            # total_error += torch.sum(torch.abs(labels - outputs))
            total_loss += loss.item()
            
            progress_bar.set_postfix({
                'loss': loss.item(),
                'acc': torch.sum(preds == labels).item()/len(labels),
                # 'error': torch.mean(torch.abs(labels - outputs)).item()
            })
    
    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions.double() / len(dataloader.dataset)
    # error = total_error.item() / len(dataloader.dataset)
    # return avg_loss, accuracy, error
    return avg_loss, accuracy

# 4. 主训练循环
def train_and_evaluate(
    model, 
    train_loader, 
    val_loader, 
    optimizer, 
    scheduler,
    criterion, 
    device, 
    epochs, 
    model_save_path,
    eval_every=1  # 每多少轮评估一次
):
    # best_val_error = 0.0
    best_val_acc = 0.0
    history = {
        'train_loss': [],
        'train_acc': [],
        # 'train_error': [],
        'val_loss': [],
        'val_acc': [],
        # 'val_error': []
    }
    
    for epoch in range(1, epochs+1):
        print(f"\nEpoch {epoch}/{epochs}")
        
        # 训练阶段
        # train_loss, train_acc, train_error = train_epoch(
        #     model, train_loader, optimizer, criterion, device)
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, scheduler, criterion, device)
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc.item())
        # history['train_error'].append(train_error)
        
        # print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train Error: {train_error:.4f}")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        
        # 验证阶段
        if epoch % eval_every == 0 and val_loader is not None:
            # val_loss, val_acc, val_error = eval_model(
            #     model, val_loader, criterion, device)
            val_loss, val_acc = eval_model(
                model, val_loader, criterion, device)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc.item())
            # history['val_error'].append(val_error)
            
            # print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Error: {val_error:.4f}")
            print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
            
            # if val_error > best_val_error:
            #     best_val_error = val_error
            #     torch.save(model.state_dict(), model_save_path)
            #     print(f"New best model saved to {model_save_path} with val_acc: {val_acc:.4f} | val_error: {val_error:.4f}")

            #     continue

            # 保存最佳模型
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), model_save_path)
                # print(f"New best model saved to {model_save_path} with val_acc: {val_acc:.4f} | val_error: {val_error:.4f}")
                print(f"New best model saved to {model_save_path} with val_acc: {val_acc:.4f}")
    
    return history

In [16]:
def main():
    # 初始化
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    BERT = BertModel.from_pretrained("bert-base-uncased")

    model = BertWithMLP(BERT, hidden_size=768, mlp_hidden_size1=1024, mlp_hidden_size2=256, num_classes=6)
    model.to(device)

    # 参数分组
    no_decay = ['bias', 'LayerNorm.weight']
    bert_params = []
    mlp_params = []

    for name, param in model.named_parameters():
        if 'mlp' in name:  # MLP层参数
            mlp_params.append((name, param))
        else:  # BERT参数
            bert_params.append((name, param))

    optimizer_grouped_parameters = [
        {'params': [p for n, p in bert_params if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.01,
        'lr': 2e-5},  # BERT主体较小学习率
        
        {'params': [p for n, p in bert_params if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
        'lr': 2e-5},
        
        {'params': [p for n, p in mlp_params if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.01,
        'lr': 1e-4},  # MLP层较大学习率
        
        {'params': [p for n, p in mlp_params if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
        'lr': 1e-4}
    ]

    optimizer = AdamW(optimizer_grouped_parameters)

    epochs = 15

    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)  # 10%的warmup

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    criterion = torch.nn.CrossEntropyLoss()
    model_save_path = "./drug_class_prediction_best_model.pth"
    
    # 创建保存目录
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    
    # 训练和验证
    history = train_and_evaluate(
        model=model,
        train_loader=train_loader,
        val_loader=test_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        device=device,
        epochs=epochs,
        model_save_path=model_save_path,
        eval_every=1  # 每轮都验证
    )
    
    print("\nTraining complete!")
    print(f"Best validation accuracy: {max(history['val_acc']):.4f}")
    # print(f"Best validation error: {max(history['val_error']):.4f}")

if __name__ == "__main__":
    main()


Epoch 1/15


                                                                                  

Train Loss: 1.6655 | Train Acc: 0.3054


                                                                                 

Val Loss: 1.6608 | Val Acc: 0.2628
New best model saved to ./drug_class_prediction_best_model.pth with val_acc: 0.2628

Epoch 2/15


                                                                                  

Train Loss: 1.0053 | Train Acc: 0.6351


                                                                                  

Val Loss: 1.1314 | Val Acc: 0.5180
New best model saved to ./drug_class_prediction_best_model.pth with val_acc: 0.5180

Epoch 3/15


                                                                                  

Train Loss: 0.5166 | Train Acc: 0.8151


                                                                                  

Val Loss: 0.4438 | Val Acc: 0.8582
New best model saved to ./drug_class_prediction_best_model.pth with val_acc: 0.8582

Epoch 4/15


                                                                                  

Train Loss: 0.2500 | Train Acc: 0.9251


                                                                                  

Val Loss: 0.1647 | Val Acc: 0.9471
New best model saved to ./drug_class_prediction_best_model.pth with val_acc: 0.9471

Epoch 5/15


                                                                                   

Train Loss: 0.1167 | Train Acc: 0.9675


                                                                                  

Val Loss: 0.0788 | Val Acc: 0.9754
New best model saved to ./drug_class_prediction_best_model.pth with val_acc: 0.9754

Epoch 6/15


                                                                                   

Train Loss: 0.0545 | Train Acc: 0.9860


                                                                                  

Val Loss: 0.0413 | Val Acc: 0.9905
New best model saved to ./drug_class_prediction_best_model.pth with val_acc: 0.9905

Epoch 7/15


                                                                                   

Train Loss: 0.0303 | Train Acc: 0.9919


                                                                                   

Val Loss: 0.0046 | Val Acc: 0.9981
New best model saved to ./drug_class_prediction_best_model.pth with val_acc: 0.9981

Epoch 8/15


                                                                                  

Train Loss: 0.0172 | Train Acc: 0.9955


                                                                                   

Val Loss: 0.0034 | Val Acc: 0.9981

Epoch 9/15


                                                                                  

Train Loss: 0.0038 | Train Acc: 0.9991


                                                                                   

Val Loss: 0.0050 | Val Acc: 0.9962

Epoch 10/15


                                                                                 

Train Loss: 0.0105 | Train Acc: 0.9973


                                                                                 

Val Loss: 0.0001 | Val Acc: 1.0000
New best model saved to ./drug_class_prediction_best_model.pth with val_acc: 1.0000

Epoch 11/15


                                                                                 

Train Loss: 0.0032 | Train Acc: 0.9991


                                                                                  

Val Loss: 0.0061 | Val Acc: 0.9981

Epoch 12/15


                                                                                   

Train Loss: 0.0026 | Train Acc: 0.9986


                                                                                  

Val Loss: 0.0023 | Val Acc: 0.9981

Epoch 13/15


                                                                                   

Train Loss: 0.0014 | Train Acc: 0.9995


                                                                                 

Val Loss: 0.0002 | Val Acc: 1.0000

Epoch 14/15


                                                                                 

Train Loss: 0.0010 | Train Acc: 1.0000


                                                                                

Val Loss: 0.0001 | Val Acc: 1.0000

Epoch 15/15


                                                                                 

Train Loss: 0.0012 | Train Acc: 0.9995


                                                                                

Val Loss: 0.0001 | Val Acc: 1.0000

Training complete!
Best validation accuracy: 1.0000


