In [43]:
import re
import ast
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.feature_selection import mutual_info_classif
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report
from sklearn.neural_network import MLPClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import Pipeline

In [25]:
df = pd.read_csv("./data.csv")
df["subject_areas"] = df["subject_areas"].apply(lambda x: ast.literal_eval(x))
df["subject_areas"].explode().value_counts()

subject_areas
Multidisciplinary             1088
Materials Science (all)        907
Chemistry (all)                905
Chemical Engineering (all)     755
Infectious Diseases            753
                              ... 
Museology                        1
Family Practice                  1
Emergency Nursing                1
Chiropractics                    1
Optometry                        1
Name: count, Length: 321, dtype: int64

In [26]:
df.dropna(inplace=True)

In [28]:
def clean_label_text(label):
    """Remove parenthetical junk like (all) or (miscellaneous)"""
    label = re.sub(r"\s*\(.*?\)\s*", "", label)
    return label.strip()

def get_parent_category(label):
    """Extract highest-level parent label using text rules"""
    label = clean_label_text(label)

    # Split using common hierarchical separators
    parts = re.split(r"\sand\s|,|/", label)

    # Use the first part as parent (dominant field)
    parent = parts[0].strip()

    return parent


# =========================================
# Apply Mapping
# =========================================
df["subject_areas"] = df["subject_areas"].apply(
    lambda lst: [get_parent_category(l) for l in lst]
)

# Remove duplicates within each label list
df["subject_areas"] = df["subject_areas"].apply(lambda lst: list(set(lst)))

# Drop rows that end up with no labels
df = df[df["subject_areas"].map(len) > 0].reset_index(drop=True)

In [29]:
def clean_labels(label_list):
    return [
        l for l in label_list
        if "(miscellaneous)" not in l and "(all)" not in l
    ]

df["subject_areas"] = df["subject_areas"].apply(clean_labels)

# ==============================
# Remove Low-Cardinality Labels
# ==============================
min_count = 500  # â¬… Change threshold as needed

# Count occurrences of each label
label_counts = df["subject_areas"].explode().value_counts()

# Keep only labels appearing >= min_count times
valid_labels = set(label_counts[label_counts >= min_count].index)

def filter_rare_labels(label_list):
    return [l for l in label_list if l in valid_labels]

df["subject_areas"] = df["subject_areas"].apply(filter_rare_labels)

# Drop rows now having zero labels
df = df[df["subject_areas"].map(len) > 0].reset_index(drop=True)

print("Remaining label count:", len(valid_labels))

Remaining label count: 22


In [30]:
mlb = MultiLabelBinarizer()
Y = mlb.fit_transform(df["subject_areas"])
X_train, X_test, y_train, y_test = train_test_split(
    df["abstract"], Y, test_size=0.2, random_state=42
)

In [44]:
model = Pipeline([
    ("tfidf", TfidfVectorizer(stop_words="english")),
    ("clf", OneVsRestClassifier(MLPClassifier())),
])

In [None]:
model.fit(X_train, y_train)



In [33]:
df["subject_areas"].explode().value_counts()

subject_areas
Chemistry                        998
Multidisciplinary                998
Biochemistry                     992
Materials Science                913
Chemical Engineering             865
Engineering                      828
Electrical                       689
Infectious Diseases              685
Computer Science Applications    680
Pharmacology                     673
Computer Networks                664
Medicine                         652
Renewable Energy                 645
Immunology                       625
Public Health                    603
Organic Chemistry                597
Materials Chemistry              559
Energy Engineering               558
Physics                          551
Condensed Matter Physics         543
Environmental Science            508
Molecular Biology                500
Name: count, dtype: int64

In [41]:
y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred, target_names=mlb.classes_))

                               precision    recall  f1-score   support

                 Biochemistry       0.63      0.13      0.22       205
         Chemical Engineering       0.45      0.19      0.27       191
                    Chemistry       0.43      0.14      0.21       210
            Computer Networks       0.60      0.29      0.39       152
Computer Science Applications       0.41      0.12      0.19       129
     Condensed Matter Physics       0.60      0.08      0.15       107
                   Electrical       0.54      0.27      0.36       145
           Energy Engineering       0.52      0.10      0.17       122
                  Engineering       0.62      0.16      0.26       173
        Environmental Science       0.58      0.14      0.23        98
                   Immunology       0.59      0.21      0.31       131
          Infectious Diseases       0.66      0.45      0.53       145
          Materials Chemistry       0.32      0.10      0.15       113
     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
