# **LeMaterial/LeMat-Synth Dataset Analysis**


This notebook is a v0 data analysis of LeMaterial/LeMat-Synth-Papers dataset


## **Available Dataset Splits**

- **arxiv**: ArXiv research papers
- **chemrxiv**: ChemRxiv chemistry papers
- **omg24**: OMG24 conference papers
- **sample_for_evaluation**: Evaluation samples


## **Load libraries**


In [None]:
# Import required libraries
import ast
import re
import string

# To ignore warnings
import warnings
from collections import Counter, defaultdict

import matplotlib as mpl
import matplotlib.pyplot as plt
import nltk
import pandas as pd
import plotly.express as px
import seaborn as sns
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from wordcloud import WordCloud

from llm_synthesis.utils.style_utils import get_palette, set_style

set_style()


warnings.filterwarnings("ignore")

# Remove the limit from the number of displayed columns and rows. It helps to see the entire dataframe while printing it
pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)

## **Data set exploration**


### **Understand the data**


Examine the columns, data types, and basic statistics.


In [None]:
from datasets import load_dataset

dataset = load_dataset(
    "LeMaterial/LeMat-Synth-Papers", split="sample_for_evaluation"
)

data = dataset.to_pandas()

data.head()

In [None]:
# Check the datatypes of each column.
data.info()

In [None]:
# Numerical statistics
data.describe()

### **Handle missing values**


Identify and address any missing values in the DataFrame.


In [None]:
missing_percentages = data.isnull().sum()
print("Percentage of missing values per column:")
print(missing_percentages)
if "images" in data.columns:
    data = data.drop("images", axis=1)
    print("\n'images' column dropped.")

# Fill remaining missing values with "N/A"
data = data.fillna("N/A")
print("\nRemaining missing values filled with 'N/A'.")

# Verify that there are no more missing values
print("\nMissing values after handling:")
print(data.isnull().sum().sum())

## **Data Visualisation**


Create visualizations to understand the distribution and relationships within the data.


### **Summary statistics of all categorical variables**


In [None]:
# Explore basic summary statistics of categorical variables.
data.describe(include=["object"])

### **Distribution of publications by date**


In [None]:
data["published_date"] = pd.to_datetime(data["published_date"], errors="coerce")

plt.figure(figsize=(12, 6))
data["published_year"] = data["published_date"].dt.year
sns.countplot(data=data, x="published_year")
plt.title("Number of publications per year")
plt.xlabel("Year of publication")
plt.ylabel("Number of publications")
plt.xticks(rotation=45)
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
# Filter out rows with missing year (due to coercion errors)
yearly_data = data

# Group by year and source and count the publications
yearly_counts = (
    yearly_data.groupby(["published_year", "source"])
    .size()
    .reset_index(name="count")
)

# Plotting the number of publications per year by source
plt.figure(figsize=(4, 3))
sns.lineplot(
    data=yearly_counts, x="published_year", y="count", hue="source", marker="o"
)
plt.title("Number of Publications Per Year by Source")
plt.xlabel("Year")
plt.ylabel("Number of Publications")
plt.xticks(rotation=45)
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()

**Observations**

The dataset spans publication dates from 1996 to 2025, with a median publication year of 2019, indicating a concentration of more recent papers. The year 2023 has the highest number of publications in the dataset.


### **Distribution of views, reads, citations**


In [None]:
# Convert relevant columns to numeric, coercing errors
data["views_count"] = pd.to_numeric(data["views_count"], errors="coerce")
data["read_count"] = pd.to_numeric(data["read_count"], errors="coerce")
data["citation_count"] = pd.to_numeric(data["citation_count"], errors="coerce")

plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
sns.histplot(data["views_count"].dropna(), kde=True, bins=20)
plt.title("Distribution of View Count")
plt.xlabel("Views")
plt.ylabel("Frequency")

plt.subplot(1, 3, 2)
sns.histplot(data["read_count"].dropna(), kde=True, bins=20)
plt.title("Distribution of Read Count")
plt.xlabel("Reads")
plt.ylabel("Frequency")

plt.subplot(1, 3, 3)
sns.histplot(data["citation_count"].dropna(), kde=True, bins=20)
plt.title("Distribution of Citation Count")
plt.xlabel("Citations")
plt.ylabel("Frequency")

plt.tight_layout()
plt.show()

**Observations:**

The distributions of views, reads and citations show strongly right-skewed trends, indicating the presence of a few high-performing publications within a majority of low-impact ones.

- Views: spread distribution with a mode around 500–1000 views; a few publications exceed 5000 views.

- Reads: more concentrated distribution, with the majority of articles below 1,000 reads; rare extreme cases up to 8,000.

- Citations: highly unbalanced distribution; the majority of publications have fewer than 10 citations, with very few exceeding 50.

This asymmetry reflects a concentration of visibility and scientific impact on a small number of publications.


### **Frenquency of categories**


In [None]:
# Function to clean and flatten category strings
def clean_categories(categories_list):
    cleaned = []
    if isinstance(categories_list, str):
        try:
            # Attempt to safely evaluate the string as a Python literal (like a list)
            categories_list = ast.literal_eval(categories_list)
        except (ValueError, SyntaxError):
            # If evaluation fails, treat the string as a comma/semicolon separated list
            categories_list = [
                categories_list
            ]  # Put the string in a list to process below

    if isinstance(categories_list, list):
        for category in categories_list:
            if isinstance(category, str):  # Ensure the element is a string
                # Clean special characters and unwanted spaces
                # Split by comma and then by semicolon, strip whitespace, handle potential empty strings
                for part in category.split(","):
                    for sub_part in part.split(";"):
                        cleaned_category = re.sub(
                            r"[^a-zA-Z0-9\s]", "", sub_part
                        ).strip()
                        cleaned_category = cleaned_category.capitalize()
                        # Eliminate generic or irrelevant categories (including "Other" and "Na")
                        if (
                            cleaned_category
                            and cleaned_category.lower()
                            not in [
                                "material",
                                "science",
                                "others",
                                "n/a",
                                "other",
                                "na",
                            ]
                        ):
                            cleaned.append(cleaned_category)
    return cleaned


all_cleaned_categories = [
    category
    for sublist in data["categories"].apply(clean_categories)
    for category in sublist
]
category_counts = Counter(all_cleaned_categories)

# Convert to DataFrame for better visualization with seaborn
df_categories = pd.DataFrame(
    category_counts.most_common(15), columns=["Category", "Count"]
)

plt.figure(figsize=(12, 7))
sns.barplot(data=df_categories, x="Count", y="Category")
plt.title("Top 15 Publication Categories")
plt.xlabel("Number of Occurrences")
plt.ylabel("Category")
plt.tight_layout()
plt.show()

**Observations:**

Publications are mainly concentrated in the fields of materials chemistry, materials science and inorganic chemistry, reflecting a strong focus on research into functional materials. Other dominant categories such as thin films, magnetic and nanomaterials indicate a focus on materials physics at the nanometric scale. This distribution highlights a marked specialisation at the interface between chemistry, physics and nanosciences.


### **Top of keywords**


In [None]:
# Function to clean and format keywords
def clean_keywords(keyword_entry):
    cleaned_keywords = []
    if isinstance(keyword_entry, str):
        # Split by semicolon first to get individual keyword strings
        keyword_strings = keyword_entry.split(";")
        for kw_str in keyword_strings:
            # Remove apostrophes and brackets, then strip whitespace
            cleaned_kw = (
                kw_str.replace("'", "")
                .replace("[", "")
                .replace("]", "")
                .strip()
            )
            # Join letters that might have been separated by apostrophes in the original data
            # This assumes the apostrophes were just separators, and the remaining characters form the word
            cleaned_kw = "".join(
                cleaned_kw.split()
            )  # Remove any remaining spaces within the word

            if cleaned_kw:
                cleaned_keywords.append(cleaned_kw)
    return cleaned_keywords


# Apply the cleaning function to the 'keywords' column and flatten the list of lists
# Ensure to handle NaN values by converting to string and checking
all_cleaned_keywords = [
    keyword
    for sublist in data["keywords"].astype(str).apply(clean_keywords)
    for keyword in sublist
]

# Join all cleaned keywords into a single string for the word cloud
text = " ".join(all_cleaned_keywords)

wordcloud = WordCloud(width=800, height=400, background_color="white").generate(
    text
)

plt.figure(figsize=(10, 5))
plt.imshow(wordcloud, interpolation="bilinear")
plt.axis("off")
plt.title("Word Cloud of Keywords")
plt.show()

**Observations:**

The keyword map highlights themes that are strongly correlated with the main publication categories identified above. Terms such as ‘thin film’, ‘metal’, ‘self assembly’, ‘coordination polymer’, “MOF” and ‘supramolecular chemistry’ appear in large font, indicating their high frequency. These keywords confirm the predominance of topics in materials chemistry, inorganic chemistry, nanomaterials and supramolecular chemistry. The recurrence of terms such as ‘porosity’, ‘framework’, ‘host’, “luminescence” and ‘solid electrolytes’ reinforces the idea of a strong interest in functional materials, particularly porous or hybrid materials, for optoelectronic, catalytic or energy applications.


In [None]:
# Count the frequency of each keyword
keyword_counts = Counter(all_cleaned_keywords)

# Convert to DataFrame for better display
# Let's display the top 20 keywords for now
top_n_keywords = 20
df_top_keywords = pd.DataFrame(
    keyword_counts.most_common(top_n_keywords), columns=["Keyword", "Count"]
)

print(f"Top {top_n_keywords} Keywords for the entire sample:")
display(df_top_keywords)

## **Relationship between views, reads and citations**


In [None]:
# Let's ensure these columns are numeric, coercing errors to NaN
# and then fill NaNs if necessary (e.g., with 0 or the mean)
for col in ["views_count", "read_count", "citation_count"]:
    data[col] = pd.to_numeric(data[col], errors="coerce")
    data[col] = data[col].fillna(0)

# Calculate aggregate statistics for the entire dataset
# This is often done using .describe() or .agg()
aggregated_stats = data[["views_count", "read_count", "citation_count"]].agg(
    ["sum", "mean", "median", "min", "max", "std"]
)

print("Aggregated Statistics for Views, Reads, and Citations:")
display(aggregated_stats)

In [None]:
plt.figure(figsize=(8, 6))
sns.scatterplot(data=data, x="read_count", y="citation_count", alpha=0.6)
plt.title("Reads vs. Citations")
plt.xlabel("Number of Reads")
plt.ylabel("Number of Citations")
plt.grid(True, linestyle="--", alpha=0.6)
plt.show()

In [None]:
# Create heatmap with the citation count, read count and views count
plt.figure(figsize=(12, 7))

sns.heatmap(
    data[["views_count", "read_count", "citation_count"]].corr(),
    annot=True,
    vmin=-1,
    vmax=1,
    cmap="coolwarm",
)

plt.title("Correlation Heatmap of Impact Metrics")

plt.show()

## **Advanced data exploration**


Perform a deeper dive into the dataset by exploring relationships between columns.


### **Categories**


In [None]:
# Combine categories from all sources with category data
all_categories = []
# Using the main 'data' DataFrame and iterating through rows
if "data" in globals():
    for index, row in data.iterrows():
        # Use the cleaned categories if available, otherwise clean them here
        categories = row.get("cleaned_categories")
        if categories is None or not isinstance(categories, list):

            def clean_single_category_entry(categories_list):
                cleaned = []
                if isinstance(categories_list, str):
                    try:
                        categories_list = ast.literal_eval(categories_list)
                    except (ValueError, SyntaxError):
                        categories_list = [categories_list]

                if isinstance(categories_list, list):
                    for category in categories_list:
                        if isinstance(category, str):
                            for part in category.split(","):
                                for sub_part in part.split(";"):
                                    cleaned_category = sub_part.strip()
                                    if cleaned_category:
                                        cleaned.append(cleaned_category)
                return cleaned

            categories = clean_single_category_entry(row["categories"])

        # Filter out 'N/A', 'Other', 'Others', 'NA' (case-insensitive)
        filtered_categories = [
            cat
            for cat in categories
            if isinstance(cat, str)
            and cat.lower() not in ["n/a", "other", "others", "na"]
        ]
        all_categories.extend(filtered_categories)

if all_categories:
    # Count the frequency of each category
    category_counts = Counter(all_categories)
    category_df = pd.DataFrame(
        category_counts.items(), columns=["Category", "Count"]
    )

    # Create a treemap
    fig = px.treemap(
        category_df,
        path=["Category"],
        values="Count",
        title="Treemap of Categories Across All Sources",
    )
    fig.show()
else:
    print("No category data available to create a treemap after filtering.")

In [None]:
# Function to clean and flatten category strings (copied from previous cell)
def clean_categories(categories_list):
    cleaned = []
    if isinstance(categories_list, str):
        try:
            categories_list = ast.literal_eval(categories_list)
        except (ValueError, SyntaxError):
            categories_list = [categories_list]

    if isinstance(categories_list, list):
        for category in categories_list:
            if isinstance(category, str):
                for part in category.split(","):
                    for sub_part in part.split(";"):
                        cleaned_category = re.sub(
                            r"[^a-zA-Z0-9\s]", "", sub_part
                        ).strip()
                        if cleaned_category:
                            cleaned.append(cleaned_category)
    return cleaned


# Apply the cleaning function to the 'categories' column within this cell
data["cleaned_categories"] = data["categories"].apply(clean_categories)

# Flatten the list of lists and explicitly filter out 'N/A', 'Other', 'Others', and 'NA' (case-insensitive)
all_cleaned_categories = [
    category
    for sublist in data["cleaned_categories"]
    for category in sublist
    if category.lower() not in ["n/a", "other", "others", "na"]
]

# Count the occurrences of each filtered category
category_counts = Counter(all_cleaned_categories)

# Convert the counts to a pandas Series for easy plotting
category_counts_series = pd.Series(category_counts).sort_values(ascending=False)

# Plot a bar plot of the category counts
plt.figure(figsize=(15, 8))
ax = category_counts_series.plot(kind="bar")

plt.title("Number of Articles per Category")
plt.xlabel("Category")
plt.ylabel("Number of Articles")
plt.xticks(rotation=90, ha="right")
plt.tight_layout()

# Add the exact number above each bar
for p in ax.patches:
    ax.annotate(
        f"{p.get_height()}",
        (p.get_x() + p.get_width() / 2.0, p.get_height()),
        ha="center",
        va="center",
        xytext=(0, 5),
        textcoords="offset points",
    )

plt.show()

**Observations:**

This histogram reveals the most frequent categories in this dataset, as seen previously in the frequency of categories: materials chemistry/science, inorganic chemsitry and thin films.


In [None]:
data2 = pd.DataFrame(data["categories"])


# Function to split and clean category strings - Modified to split by comma and semicolon
def clean_categories(category_string):
    if isinstance(category_string, str):
        # Split by comma and then by semicolon, strip whitespace, handle potential empty strings
        categories = []
        for part in category_string.split(","):
            for sub_part in part.split(";"):
                cleaned_cat = sub_part.strip()
                # Convert to lowercase for case-insensitive filtering
                if cleaned_cat and cleaned_cat.lower() not in [
                    "n/a",
                    "other",
                    "others",
                    "na",
                ]:
                    categories.append(cleaned_cat)
        return categories
    elif isinstance(category_string, list):
        return [
            cat.strip()
            for cat in category_string
            if isinstance(cat, str)
            and cat.strip().lower() not in ["n/a", "other", "others", "na"]
        ]
    else:
        return []


# Apply the cleaning function
data2["cleaned_categories"] = data2["categories"].apply(clean_categories)

# --- 1. Select Top N Categories ---
# Calculate the frequency of each category
all_cats_list = [
    cat for sublist in data2["cleaned_categories"] for cat in sublist
]
category_counts = Counter(all_cats_list)

# Define how many top categories you want to see
N = 20
top_categories = [cat for cat, count in category_counts.most_common(N)]
top_categories = sorted(
    top_categories
)  # Sort alphabetically for consistent order

print(
    f"Top {len(top_categories)} categories selected for the heatmap: {top_categories}"
)


# --- 2. Create Co-occurrence Matrix for Top Categories ---
# Initialize the DataFrame with dtype=float
co_occurrence_matrix = pd.DataFrame(
    0, index=top_categories, columns=top_categories, dtype=float
)

# Populate the matrix
for categories_list in data2["cleaned_categories"]:
    # Filter the list to only include top categories
    filtered_list = [cat for cat in categories_list if cat in top_categories]
    # Use a set to get unique categories for co-occurrence counting within a single entry
    for cat1 in set(filtered_list):
        for cat2 in set(filtered_list):
            if (
                cat1 in co_occurrence_matrix.index
                and cat2 in co_occurrence_matrix.columns
            ):
                co_occurrence_matrix.loc[cat1, cat2] += 1


palette = get_palette()  # basically list of hex codes
col_start = palette[2]  # Changed from palette[0]
col_end = palette[3]  # Changed from palette[1]

# Fix 2: Use LinearSegmentedColormap for smooth gradient
colormap = mpl.colors.LinearSegmentedColormap.from_list(
    "custom", ["#E7F0FF", "#448FF2"], N=256
)


# --- 3. Plot the Heatmap (Full Matrix) ---
plt.figure(figsize=(12, 10))
heatmap = sns.heatmap(
    co_occurrence_matrix,
    annot=True,
    # cmap is from palette[2] and palette[3]
    cmap=colormap,
    fmt=".0f",
    linewidths=0.5,
    cbar_kws={"label": "Co-occurrence Count"},
)

# Improve plot aesthetics
# plt.title("Co-occurrence Heatmap of Top 20 Categories", fontsize=16, pad=20)
# plt.xlabel("Category", fontsize=12)
# plt.ylabel("Category", fontsize=12)
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
# plt.show()
# save as pdf
plt.savefig("co_occurrence_heatmap.pdf")

**Observations:**

This heatmap reveals the predominant interdisciplinary fields in materials science and chemistry. Inorganic chemistry is strongly linked to materials chemistry and solid state chemistry. Materials chemistry is also linked to physical chemistry. Finally, materials science co-occurs significantly with thin films and materials chemistry. The analysis highlights clear interactions between fundamental chemistry and materials applications. Plus, Supramolecular Chemisry (Org.) is correlated to Organic Chemistry.

The diagonal indicates individual frequency, and off-diagonal values quantify interdisciplinary links.


In [None]:
# Create heatmap with the categories columns and source

# Apply the cleaning function and create a list of (source, category) pairs
source_category_pairs = []
for index, row in data.iterrows():
    source = row["source"]
    cleaned_cats = clean_categories(row["categories"])
    for category in cleaned_cats:
        source_category_pairs.append({"source": source, "category": category})

# Create a DataFrame from the pairs
source_category_df = pd.DataFrame(source_category_pairs)

# Create a crosstab of source and categories using the new DataFrame
if not source_category_df.empty:
    crosstab_data = pd.crosstab(
        source_category_df["source"], source_category_df["category"]
    )

    plt.figure(figsize=(50, 10))
    sns.heatmap(crosstab_data, fmt="d", cmap="Blues")
    plt.title("Heatmap of Source vs Categories")
    plt.xlabel("Categories")
    plt.ylabel("Source")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.show()
else:
    print("No valid category-source pairs found after cleaning.")

In [None]:
# Function to clean and flatten category strings (more robust version)
def clean_categories_for_counting(categories_list):
    cleaned = []
    if isinstance(categories_list, str):
        try:
            categories_list = ast.literal_eval(categories_list)
        except (ValueError, SyntaxError):
            categories_list = [categories_list]

    if isinstance(categories_list, list):
        for category in categories_list:
            if isinstance(category, str):
                for part in category.split(","):
                    for sub_part in part.split(";"):
                        cleaned_category = re.sub(
                            r"[^a-zA-Z0-9\s]", "", sub_part
                        ).strip()
                        if cleaned_category:
                            cleaned.append(cleaned_category)
    return cleaned


# Function to get the most common categories for a given source using the main data DataFrame
def get_top_categories(df, source_name, n=10):
    source_df = df[df["source"] == source_name].copy()

    # Apply the cleaning function to the 'categories' column for the specific source
    source_df["cleaned_categories_list"] = source_df["categories"].apply(
        clean_categories_for_counting
    )

    # Flatten the list of lists of cleaned categories
    all_categories = [
        cat
        for sublist in source_df["cleaned_categories_list"].dropna()
        for cat in sublist
    ]

    # Count the frequency of each category
    category_counts = Counter(all_categories)

    # Return the most common categories
    return category_counts.most_common(n)


# Get top categories for each source
top_n = 10
if "data" in globals():
    arxiv_top_categories = get_top_categories(data, "arxiv", n=top_n)
    chemrxiv_top_categories = get_top_categories(data, "chemrxiv", n=top_n)
    omg24_top_categories = get_top_categories(data, "omg24", n=top_n)

    def plot_categories(category_freq, title):
        if not category_freq:
            print(f"No category data to plot for {title}")
            return
        categories, counts = zip(*category_freq)
        plt.figure(figsize=(12, 7))
        sns.barplot(x=list(counts), y=list(categories))
        plt.title(title)
        plt.xlabel("Count")
        plt.ylabel("Category")
        plt.tight_layout()
        plt.show()

    # Create bar charts for each source
    plot_categories(
        arxiv_top_categories, f"Top {top_n} Categories in ArXiv Dataset"
    )
    plot_categories(
        chemrxiv_top_categories, f"Top {top_n} Categories in ChemRxiv Dataset"
    )
    plot_categories(
        omg24_top_categories, f"Top {top_n} Categories in OMG24 Dataset"
    )
else:
    print(
        "Error: 'data' DataFrame not found. Please run the data loading cells."
    )

### **Number of views**


In [None]:
# Drop rows where 'views_count' or 'categories' are NaN, as they are not useful for this analysis
data_cleaned = data.dropna(subset=["views_count", "categories"]).copy()

# Dictionary to store total views per category
category_views = defaultdict(float)

# Iterate through each row of the cleaned DataFrame
for index, row in data_cleaned.iterrows():
    views = row["views_count"]
    # Only process if views is a valid number
    if pd.notna(views):
        categories = clean_categories(row["categories"])
        for category in categories:
            category_views[category] += views

# Convert the dictionary to a DataFrame for better manipulation and visualization
df_category_views = pd.DataFrame(
    category_views.items(), columns=["Category", "Total_Views"]
)

# Sort categories by total views
df_category_views = df_category_views.sort_values(
    by="Total_Views", ascending=False
)

# Select Top N categories for better readability (e.g., top 15)
top_n = 15
df_top_categories_views = df_category_views.head(top_n)

# Create the plot
plt.figure(figsize=(12, 8))
sns.barplot(
    data=df_top_categories_views,
    x="Total_Views",
    y="Category",
    palette="viridis",
)
plt.title(f"Top {top_n} Categories by Total Views")
plt.xlabel("Total Number of Views")
plt.ylabel("Category")
plt.grid(axis="x", linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()

print(f"Top {top_n} Categories by Total Views:")
print(df_top_categories_views)

## **Sources**


### **Number of views, reads and citations per source**


In [None]:
# Create a copy and fill missing values with 0 for aggregation
df_aggregated_counts = data.copy()
df_aggregated_counts[["views_count", "read_count", "citation_count"]] = (
    df_aggregated_counts[
        ["views_count", "read_count", "citation_count"]
    ].fillna(0)
)

# Group by source and calculate the total for each metric
total_counts_by_source = (
    df_aggregated_counts.groupby("source")[
        ["views_count", "read_count", "citation_count"]
    ]
    .sum()
    .reset_index()
)

# Melt the DataFrame to long format for easier plotting with seaborn
df_melted_counts = total_counts_by_source.melt(
    id_vars="source",
    value_vars=["views_count", "read_count", "citation_count"],
    var_name="Metric",
    value_name="Total_Count",
)

# Define custom titles and labels
metric_titles = {
    "views_count": "Total Number of Views by Source",
    "read_count": "Total Number of Reads by Source",
    "citation_count": "Total Number of Citations by Source",
}

metric_ylabels = {
    "views_count": "Total Views",
    "read_count": "Total Reads",
    "citation_count": "Total Citations",
}

# Create a separate plot for each metric
for metric in ["views_count", "read_count", "citation_count"]:
    plt.figure(figsize=(10, 6))
    sns.barplot(
        data=df_melted_counts[df_melted_counts["Metric"] == metric],
        x="source",
        y="Total_Count",
        palette="viridis",
    )
    plt.title(metric_titles[metric])
    plt.xlabel("Source")
    plt.ylabel(metric_ylabels[metric])
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.show()

In [None]:
# Create a pivot table.
pivot_table_by_source_sum = data.pivot_table(
    index="source",
    values=["views_count", "read_count", "citation_count"],
    aggfunc="sum",
)

print("\nPivot Table: Sum of Views, Reads, and Citations by Source:")
display(pivot_table_by_source_sum)

### **Publication Trends Over time by source**


In [None]:
try:
    data = pd.read_parquet("sample_for_evaluation.parquet")
except FileNotFoundError:
    print(
        "Please make sure 'sample_for_evaluation.parquet' is in the same directory."
    )
    data = pd.DataFrame(
        {
            "published_date": ["2020", "2021-01-01T00:00:00Z", "2022"],
            "source": ["arxiv", "chemrxiv", "omg24"],
        }
    )

data["published_year"] = data["published_date"].astype(str).str[:4]

# Convert the extracted year to a numeric type, coercing errors to NaN
data["published_year"] = pd.to_numeric(data["published_year"], errors="coerce")

# Drop rows where the year could not be parsed or is missing
data_cleaned = data.dropna(subset=["published_year"]).copy()

# Convert the year column to integer type for clean plotting
data_cleaned["published_year"] = data_cleaned["published_year"].astype(int)

# Count the number of publications for each source for each year.
current_year = pd.Timestamp.now().year
valid_years_data = data_cleaned[
    (data_cleaned["published_year"] >= 1990)
    & (data_cleaned["published_year"] <= current_year + 5)
].copy()

# Check if all sources are present after cleaning and filtering
print("Data points per source after cleaning and filtering:")
print(valid_years_data["source"].value_counts())
print("-" * 30)

# Group by year and source, then unstack to get sources as columns
publications_over_time = (
    valid_years_data.groupby(["published_year", "source"])
    .size()
    .unstack(fill_value=0)
)

# Print the sources and the head of the aggregated data to verify all three are present
print(
    "\nUnique sources in the aggregated data:",
    publications_over_time.columns.tolist(),
)
print("\nHead of publications_over_time DataFrame:")
print(publications_over_time.head())
print("-" * 30)

# Create a line plot showing the number of publications over time for each source
plt.style.use("seaborn-v0_8-whitegrid")
plt.figure(figsize=(14, 8))

if not publications_over_time.empty:
    sns.lineplot(
        data=publications_over_time, markers=True, dashes=False, linewidth=2.5
    )

    plt.title(
        "Publication Trends Over Time by Source", fontsize=16, fontweight="bold"
    )
    plt.xlabel("Publication Year", fontsize=12)
    plt.ylabel("Number of Publications", fontsize=12)
    plt.legend(title="Source", fontsize=10)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
else:
    print("No valid publication data found to plot trends over time.")

### **Top words by source**


In [None]:
# Download necessary NLTK data
try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt", quiet=True)
try:
    nltk.data.find("corpora/stopwords")
except LookupError:
    nltk.download("stopwords", quiet=True)
try:
    nltk.data.find("tokenizers/punkt_tab")
except LookupError:
    nltk.download("punkt_tab", quiet=True)


# Get English stop words
stop_words = set(stopwords.words("english"))


# Define a text preprocessing function
def preprocess_text(text):
    if pd.isna(text):
        return ""
    text = str(text).lower()
    text = text.translate(str.maketrans("", "", string.punctuation))
    try:
        words = word_tokenize(text)
    except LookupError:
        nltk.download("punkt", quiet=True)
        words = word_tokenize(text)
    words = [word for word in words if word not in stop_words and len(word) > 2]
    return " ".join(words)


# Filter data for each source
arxiv_df = data[data["source"] == "arxiv"].copy()
chemrxiv_df = data[data["source"] == "chemrxiv"].copy()
omg24_df = data[data["source"] == "omg24"].copy()

# Combine and preprocess text for each source
arxiv_df["combined_text"] = (
    arxiv_df["title"].fillna("") + " " + arxiv_df["abstract"].fillna("")
)
arxiv_df["cleaned_text"] = arxiv_df["combined_text"].apply(preprocess_text)

chemrxiv_df["combined_text"] = (
    chemrxiv_df["title"].fillna("") + " " + chemrxiv_df["abstract"].fillna("")
)
chemrxiv_df["cleaned_text"] = chemrxiv_df["combined_text"].apply(
    preprocess_text
)

omg24_df["combined_text"] = (
    omg24_df["title"].fillna("") + " " + omg24_df["abstract"].fillna("")
)
omg24_df["cleaned_text"] = omg24_df["combined_text"].apply(preprocess_text)


# Perform word frequency analysis for each source
def get_word_frequencies(df, text_column="cleaned_text", n=30):
    if df.empty:
        return []
    all_words = " ".join(df[text_column]).split()
    word_counts = Counter(all_words)
    return word_counts.most_common(n)


top_n = 30
arxiv_word_freq = get_word_frequencies(arxiv_df, n=top_n)
chemrxiv_word_freq = get_word_frequencies(chemrxiv_df, n=top_n)
omg24_word_freq = get_word_frequencies(omg24_df, n=top_n)

print("Top 30 words in ArXiv dataset:")
print(arxiv_word_freq)
print("\nTop 30 words in ChemRxiv dataset:")
print(chemrxiv_word_freq)
print("\nTop 30 words in OMG24 dataset:")
print(omg24_word_freq)

In [None]:
# Function to create a bar chart of word frequencies
def plot_word_frequencies(word_freq, title):
    words, counts = zip(*word_freq)
    plt.figure(figsize=(12, 7))
    sns.barplot(x=list(counts), y=list(words), palette="viridis")
    plt.title(title)
    plt.xlabel("Frequency")
    plt.ylabel("Words")
    plt.tight_layout()
    plt.show()


# Create bar charts for each source
plot_word_frequencies(arxiv_word_freq, "Top 30 Most Frequent Words in ArXiv")
plot_word_frequencies(
    chemrxiv_word_freq, "Top 30 Most Frequent Words in ChemRxiv"
)
plot_word_frequencies(omg24_word_freq, "Top 30 Most Frequent Words in OMG24")

In [None]:
# Function to generate and display a word cloud
def plot_word_cloud(word_freq, title):
    if not word_freq:
        print(
            f"No word frequency data available for '{title}'. Skipping word cloud generation."
        )
        return

    wordcloud = WordCloud(
        width=800, height=400, background_color="white"
    ).generate_from_frequencies(dict(word_freq))
    plt.figure(figsize=(10, 5))
    plt.imshow(wordcloud, interpolation="bilinear")
    plt.axis("off")
    plt.title(title)
    plt.show()


# Generate and display word clouds for each source
plot_word_cloud(arxiv_word_freq, "Word Cloud for ArXiv Dataset")
plot_word_cloud(chemrxiv_word_freq, "Word Cloud for ChemRxiv Dataset")
plot_word_cloud(omg24_word_freq, "Word Cloud for OMG24 Dataset")

### **Abstract lenght**


In [None]:
# Combine the dataframes
if "combined_df" not in locals():
    combined_df = pd.concat(
        [
            arxiv_df.assign(source="arxiv"),
            chemrxiv_df.assign(source="chemrxiv"),
            omg24_df.assign(source="omg24"),
        ]
    )

# Calculate the length of abstracts
combined_df["abstract_length"] = combined_df["abstract"].apply(
    lambda x: len(str(x)) if pd.notna(x) and x != "N/A" else 0
)

# Create a distribution plot (density plot) of abstract lengths by source
plt.figure(figsize=(12, 7))
sns.kdeplot(
    data=combined_df,
    x="abstract_length",
    hue="source",
    fill=True,
    common_norm=False,
)
plt.title("Distribution of Abstract Lengths by Source")
plt.xlabel("Abstract Length (Number of Characters)")
plt.ylabel("Density")
plt.show()

**Observations:**

This graph represents the distribution of abstract lengths (in number of characters) for three different sources: arXiv, chemrxiv, and omg24.

- The distribution of abstract lengths for arXiv is unimodal and relatively symmetrical, centred around 800-1000 characters, suggesting a standardised or preferred length for scientific publications submitted to this repository.

- The distribution for chemrxiv is also unimodal but wider and slightly asymmetrical (positive skewness), with a peak around 1200-1400 characters. This could indicate greater variability in abstract lengths for chemistry articles, perhaps due to specific disciplinary requirements or a greater diversity of topics requiring more or less detailed descriptions.

- In contrast, the distribution of omg24 is bimodal, with a pronounced first peak around 0-200 characters and a second broader and lower peak around 800-1000 characters. The presence of the first peak at very short lengths is notable and could indicate the presence of empty abstracts, very short summaries, or missing/atypical data in this source. The second peak partially overlaps with the distribution of arXiv, suggesting that some omg24 abstracts share similar length characteristics with those of arXiv. The nature of ‘omg24’ is unknown, but this bimodality raises questions about the typology of the documents or their completeness.
