In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import os
import random
import torch
import torch.nn as nn
from custommodels import LoadDataset, ResNet50, DenseNet121, MobileNetV2
from fairtraining import ModelTrainer
import pickle

In [2]:
# to reproduce
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [3]:
db = "HAM"
root = f"../../dataset/CAI 2025/{db}/"
df_train = pd.read_csv(f"{root}/dataframe/df_distance_train.csv")
df_valid = pd.read_csv(f"{root}/dataframe/df_distance_valid.csv")
df_test = pd.read_csv(f"{root}/dataframe/df_distance_test.csv")

loader = LoadDataset(label="label", batch_size=8)
train_loader, valid_loader, test_loader = loader.create_dataloaders(df_train, df_valid, df_test)

Width: 200 Height: 150


# DenseNet

In [4]:
models = ["RES", "DENSE", "MOBILE"]
distances = ["WD", "KUIPER", "KS", "AD", "CVM", "ED"]

model_save_directory = f"{root}fairmodels/{models[1]}/"
if not os.path.exists(model_save_directory):
    os.makedirs(model_save_directory)
    print(f"{model_save_directory} created.")
else:
    print(f"{model_save_directory} aready exist.")
        
df_valid_filepath = f"{root}dataframe/df_valid_{models[1]}_fair.csv"
df_test_filepath = f"{root}dataframe/df_test_{models[1]}_fair.csv"
kde_filepath = f"{root}kde/{models[1]}_{distances[0]}.pkl"

../../dataset/CAI 2025/HAM/fairmodels/DENSE/ aready exist.


In [5]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message="X does not have valid feature names")

In [13]:
# Load model
with open(kde_filepath, "rb") as f:
    kde = pickle.load(f)

num_class = 3
num_epochs = 50
lr = 1e-5
start = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenseNet121(num_class).to(device)
trainer = ModelTrainer(kde, model_save_directory)
best_val_file = trainer.train(model, train_loader, valid_loader, start, num_epochs=num_epochs, lr=lr)

Validation Accuracy: 0.5508 | Loss: 0.9719 | F1: 0.5416
Validation Accuracy: 0.6237 | Loss: 0.8534 | F1: 0.6204
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff:  8
diff: 

RuntimeError: The size of tensor a (8) must match the size of tensor b (48) at non-singleton dimension 0

In [6]:
model = DenseNet121(num_class).to(device)
model.load_state_dict(torch.load(best_val_file))

test_predictions, test_outputs = trainer.evaluate(model, valid_loader)
df_valid["pred"] = test_predictions
df_valid["proba"] = test_outputs
df_valid.to_csv(df_valid_filepath, index=False)

              precision    recall  f1-score   support

           0       0.81      0.85      0.83       219
           1       0.74      0.72      0.73       220
           2       0.79      0.77      0.78       220

    accuracy                           0.78       659
   macro avg       0.78      0.78      0.78       659
weighted avg       0.78      0.78      0.78       659



In [7]:
test_predictions, test_outputs = trainer.evaluate(model, test_loader)
df_test["pred"] = test_predictions
df_test["proba"] = test_outputs
df_test.to_csv(df_test_filepath, index=False)

              precision    recall  f1-score   support

           0       0.75      0.82      0.79       220
           1       0.75      0.71      0.73       219
           2       0.74      0.71      0.72       220

    accuracy                           0.75       659
   macro avg       0.75      0.75      0.75       659
weighted avg       0.75      0.75      0.75       659

