In [None]:
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import models, transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchinfo import summary
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import glob
from torch.utils.data import DataLoader
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader,Subset
from tqdm import tqdm
import json
import random
import math
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,confusion_matrix
import time
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve, auc

In [None]:
def set_seed(seed: int):
    """Seed everything for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # enforce deterministic algorithms (may slow things down)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # torch 2.x: fully deterministic
    if hasattr(torch, "use_deterministic_algorithms"):
        torch.use_deterministic_algorithms(True)
# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# choose your seed
seed_list = [3,5,11,1344,2506]
SEED = 3
set_seed(SEED)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Data

In [None]:
NATURE_CLASSES = ["01. Normal", "02. Variation from normal", "03. OPMD", "04. Oral Cancer"]
UNANNOTATED_SUBDIR = "01. Unannotated"

In [None]:
ROOT_DIR = "path_to_dataset_directory"

In [None]:
data = []
for nature in NATURE_CLASSES:
    nature_path = os.path.join(ROOT_DIR, nature)
    nature_path = os.path.join(nature_path,UNANNOTATED_SUBDIR)
    if not os.path.exists(nature_path):
        print(nature_path)
        continue

    for site in os.listdir(nature_path):
        site_path = os.path.join(nature_path, site)
        if not os.path.isdir(site_path):
            continue

        # Get all image files (you can restrict by extensions if needed)
        image_files = glob.glob(os.path.join(site_path, '*'))
        for img_path in image_files:
            data.append({
                "image_path": img_path,
                "nature": nature.replace("01. ", "").replace("02. ", "").replace("03. ", "").replace("04. ", ""),
                "site": site.split(". ", 1)[-1] if ". " in site else site
            })

In [None]:
# Convert to DataFrame
df = pd.DataFrame(data)
df

In [None]:
def get_encoded_df(df):
  df_enc = df
  df_enc['label'] = df_enc['nature'].apply(lambda x: 0 if str(x).lower() == 'normal' else 1)
  # df_enc.drop(columns=['nature'], inplace=True)
  return df_enc

In [None]:
total_df_enc = get_encoded_df(df)

In [None]:
total_df_enc

In [None]:
duplicate_paths = df[df.duplicated(subset="image_path", keep=False)]["image_path"].value_counts()
print(duplicate_paths)

In [None]:
df["nature"] = df["nature"].replace({
    "Variation from normal": "Var",
    "Oral Cancer": "OC"})

In [None]:
# Step 1: Create stratify key
df["stratify_key"] = df["nature"] + " - " + df["site"]

# Step 2: Separate rare groups (initial stratify)
group_counts = df["stratify_key"].value_counts()
valid_keys = group_counts[group_counts >= 2].index

strat_df = df[df["stratify_key"].isin(valid_keys)].reset_index(drop=True)
non_strat_df = df[~df["stratify_key"].isin(valid_keys)].reset_index(drop=True)

# Step 3: Stratified train_val/test split
train_strat_df, test_df = train_test_split(
    strat_df,
    test_size=0.2,
    stratify=strat_df["stratify_key"],
    random_state=SEED
)

train_df = pd.concat([train_strat_df, non_strat_df], ignore_index=True)



for d in [train_df, test_df]:
    d.drop(columns=["stratify_key"], inplace=True)

print(f"Total samples: {len(df)}")
print(f"Train: {len(train_df)}")
print(f"Test:  {len(test_df)}")
print(f"Sum:   {len(train_df) +len(test_df)}")

In [None]:
train_df_enc = get_encoded_df(train_df)
test_df_enc = get_encoded_df(test_df)

In [None]:
total_abnormal_count = (train_df_enc['label'] == 1).sum() + (test_df_enc['label'] == 1).sum()
total_normal_count = (train_df_enc['label'] == 0).sum() + (test_df_enc['label'] == 0).sum()
total_samples = len(train_df_enc) + len(test_df_enc)

# Print
print("Total abnormal images:", total_abnormal_count)
print("Total normal images:", total_normal_count)
print("Total samples:", total_samples)

In [None]:
test_df_enc.reset_index(inplace=True)

In [None]:
test_df_enc['image_path'][1]

In [None]:
train_df_enc

In [None]:
test_df_enc

In [None]:
test_df_enc[test_df_enc['label']==1]['site'].unique()

In [None]:
train_df_enc[train_df_enc['label']==1]['site'].unique()

In [None]:
# train_df_enc.to_csv('train_df_encoded1.csv', index=False)
# print("DataFrame saved successfully to 'train_df_encoded1.csv")

# test_df_enc.to_csv('test_df_encoded1.csv', index=False)
# print("DataFrame saved successfully to 'test_df_encoded1.csv")

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Increase global font size (but keep legend smaller later)
plt.rcParams.update({
    "font.size": 16,          # Base font size
    "axes.titlesize": 20,     # Title font size
    "axes.labelsize": 16,     # X/Y label font size
    "xtick.labelsize": 16,    # X-tick font size
    "ytick.labelsize": 16,    # Y-tick font size
})

# Add dataset identifiers
train_df_enc["set"] = "Train"
test_df_enc["set"] = "Test"

# Concatenate
combined_df = pd.concat([train_df_enc, test_df_enc])

# Site order
site_order = [
    "Dorsal tongue",
    "Ventral tongue",
    "Left buccal mucosa",
    "Right buccal mucosa",
    "Upper lip",
    "Lower lip",
    "Upper arch",
    "Lower arch"
]

# Count per site, label, and set
counts = combined_df.groupby(["site", "label", "set"]).size().reset_index(name="count")

# Pivot for stacking
pivot = counts.pivot_table(index=["site", "label"], columns="set", values="count", fill_value=0).reset_index()
pivot["site"] = pd.Categorical(pivot["site"], categories=site_order, ordered=True)
pivot = pivot.sort_values(["site", "label"])

# Plot (make plot box bigger)
fig, ax = plt.subplots(figsize=(20, 10))  # Bigger figure
bar_width = 0.35
x = np.arange(len(site_order))

# Colors for Train vs Test
colors = {
    (0, "Train"): "#c6dbef",   # Normal Train
    (0, "Test"):  "#6baed6",   # Normal Test
    (1, "Train"): "#fcbba1",   # Abnormal Train
    (1, "Test"):  "#fb6a4a"    # Abnormal Test
}

# Loop over labels → two bars per site
for i, label in enumerate([0, 1]):
    subset = pivot[pivot["label"] == label]
    positions = x + (i - 0.5) * bar_width
    
    # Train stacked first
    ax.bar(
        positions,
        subset["Train"],
        width=bar_width,
        color=colors[(label, "Train")],
        edgecolor="black",
        linewidth=1.2,
        label=f"Train - {'Normal' if label==0 else 'Abnormal'}"
    )
    
    # Test stacked on top
    ax.bar(
        positions,
        subset["Test"],
        bottom=subset["Train"],
        width=bar_width,
        color=colors[(label, "Test")],
        edgecolor="black",
        linewidth=1.2,
        label=f"Test - {'Normal' if label==0 else 'Abnormal'}"
    )

# Formatting
ax.set_xticks(x)
ax.set_xticklabels(site_order, rotation=45, ha="right")
ax.set_ylabel("Count", fontsize=16)
ax.set_xlabel("Intra Oral Site", fontsize=16)

# Adjust ylim
ymax = ax.get_ylim()[1]
ax.set_ylim(0, ymax * 1.2)

# Deduplicate legend and keep inside (smaller font & box)
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax.legend(
    by_label.values(),
    by_label.keys(),
    # title="Dataset + Label",
    loc="upper right",         # inside top-right
    frameon=False,
    framealpha=1,            # slight transparency
    borderpad=1,             # smaller padding inside
    borderaxespad=0.1,         # distance from axes
    handlelength=4,            # shorter legend handles
    fontsize=16,               # smaller font size
    # title_fontsize=12          # smaller title
)

# Frame
for spine in ax.spines.values():
    spine.set_edgecolor("black")
    spine.set_linewidth(1.2)

plt.tight_layout()
plt.savefig("final_site_label_stacked.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Add dataset identifiers
train_df_enc["set"] = "Train"
test_df_enc["set"] = "Test"

# Concatenate
combined_df = pd.concat([train_df_enc, test_df_enc])

# Site order
site_order = [
    "Dorsal tongue",
    "Ventral tongue",
    "Left buccal mucosa",
    "Right buccal mucosa",
    "Upper lip",
    "Lower lip",
    "Upper arch",
    "Lower arch"
]

# Count per site, label, and set
counts = combined_df.groupby(["site", "label", "set"]).size().reset_index(name="count")

# Pivot for stacking
pivot = counts.pivot_table(index=["site", "label"], columns="set", values="count", fill_value=0).reset_index()
pivot["site"] = pd.Categorical(pivot["site"], categories=site_order, ordered=True)
pivot = pivot.sort_values(["site", "label"])

# Plot
fig, ax = plt.subplots(figsize=(14, 8))
bar_width = 0.35
x = np.arange(len(site_order))

# Colors for Train vs Test (subtle blues & reds)
colors = {
    (0, "Train"): "#c6dbef",   # very soft blue - Normal Train
    (0, "Test"):  "#6baed6",   # medium soft blue - Normal Test
    (1, "Train"): "#fcbba1",   # very soft red - Abnormal Train
    (1, "Test"):  "#fb6a4a"    # medium soft red - Abnormal Test
}

# Loop over labels → two bars per site
for i, label in enumerate([0, 1]):
    subset = pivot[pivot["label"] == label]
    positions = x + (i - 0.5) * bar_width
    
    # Train stacked first
    ax.bar(
        positions,
        subset["Train"],
        width=bar_width,
        color=colors[(label, "Train")],
        edgecolor="black",
        linewidth=1.2,
        label=f"Train - {'Normal' if label==0 else 'Abnormal'}"
    )
    
    # Test stacked on top
    bars = ax.bar(
        positions,
        subset["Test"],
        bottom=subset["Train"],
        width=bar_width,
        color=colors[(label, "Test")],
        edgecolor="black",
        linewidth=1.2,
        label=f"Test - {'Normal' if label==0 else 'Abnormal'}"
    )
    
    # Annotate bar totals (Train+Test for that label)
    for pos, train, test in zip(positions, subset["Train"], subset["Test"]):
        total = int(train + test)
        ax.text(
            pos,
            total + 1,  # a bit above the stacked bar
            str(total),
            ha="center", va="bottom", fontsize=10, fontweight="bold"
        )

# Formatting
ax.set_xticks(x)
ax.set_xticklabels(site_order, rotation=45, ha="right")
ax.set_ylabel("Count")
ax.set_xlabel("Intra Oral Site")

# Adjust ylim
ymax = ax.get_ylim()[1]
ax.set_ylim(0, ymax * 1.2)

# Deduplicate legend
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax.legend(by_label.values(), by_label.keys(), title="Dataset + Label", loc="upper right", frameon=True)

# Frame
for spine in ax.spines.values():
    spine.set_edgecolor("black")
    spine.set_linewidth(1.2)

plt.tight_layout()
plt.savefig("site_label_stacked_bar_totals.png", dpi=300, bbox_inches="tight")
plt.show()