Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug with weighted DBSCAN #120

Open
nickkeepfer opened this issue Feb 23, 2023 · 4 comments
Open

Bug with weighted DBSCAN #120

nickkeepfer opened this issue Feb 23, 2023 · 4 comments

Comments

@nickkeepfer
Copy link

Problem

I'm using DBSCAN to find clusters in a 3D dataset that varies with time. Every now and again (<5% of the time), DBSCAN fails completely to see a very obvious cluster. It's sometimes possible to make it work by simply circshifting the array, but not always.

There seems to be no clear reason why it fails, it just sometimes does.

Please see the following example (file is included for replication purposes):

using JLD2
using ScikitLearn
using PyCall

# Wrapper for DBSCAN 
DBSCAN = pyimport("sklearn.cluster").DBSCAN

# Load data
f = jldopen("DBSCAN_BUG.jld2")
x, y, z = f["x"], f["y"], f["z"]

# Format data such that each voxel is given as an (x,y,z) coordinate
X = repeat(x',length(y),1,length(z)) .+ 2*maximum(x)
Y = repeat(y,1,length(x),length(z)) .+ 2*maximum(y)
Z = permutedims(repeat(z,1,length(x),length(y)),[3 2 1]) .+ 2*maximum(z)
dat = zeros(length(X[:]),3)
dat[:,1] = X[:]
dat[:,2] = Y[:]
dat[:,3] = Z[:]

# Perform DBSCAN where points are weighted by density array
decomp = DBSCAN(eps=abs(x[1]-x[2]),min_samples=1).fit_predict(dat,sample_weight=f["dens"][:])
dbscan = replace!(reshape(decomp,size(f["dens"])),-1=>0)
dbscan[dbscan.>0] .= 1.0

# Plot DBSCAN results alongside the density 
using CairoMakie
fig = Figure()
ax, hm1 = heatmap(fig[1,1], x, y, f["dens"][:,:,72])
ax, hm2 = heatmap(fig[2,1], x, y, dbscan[:,:,72])
fig

DBSCAN_BUG.jld2.zip
Screenshot 2023-02-23 at 14 28 54

Expected result

There should be a yellow blob in the second heatmap, corresponding to the identified (very obvious) cluster

@cstjean
Copy link
Owner

cstjean commented Feb 23, 2023

Isn't that a problem with the scikit-learn library? ScikitLearn.jl is just an interface to the python scikit-learn. If so, I would encourage you to translate your example to Python and post it there.

@nickkeepfer
Copy link
Author

Hmm, yes probably, I'll see if im able to translate it

@nickkeepfer
Copy link
Author

I'm actually not sure it is an issue with scikit-learn, as it works just fine using it natively in python, (see below):

import numpy as np
import h5py
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
import matplotlib.colors as colors

# Load data
f = h5py.File("DBSCAN_BUG.jld2", "r")
x, y, z = f["x"][:], f["y"][:], f["z"][:]

# Format data such that each voxel is given as an (x,y,z) coordinate
X = np.repeat(x, len(y) * len(z)).reshape(len(x), len(y), len(z), order='F') + 2 * np.max(x)
Y = np.repeat(y, len(x) * len(z)).reshape(len(y), len(x), len(z), order='C') + 2 * np.max(y)
Z = np.repeat(z, len(x) * len(y)).reshape(len(z), len(x), len(y), order='F').transpose((1, 2, 0)) + 2 * np.max(z)
dat = np.vstack((X.ravel('F'), Y.ravel('F'), Z.ravel('F'))).T

# Perform DBSCAN where points are weighted by density array
decomp = DBSCAN(eps=np.abs(x[0] - x[1]), min_samples=1).fit_predict(dat, sample_weight=f["dens"][:].ravel())
dbscan = np.reshape(np.where(decomp != -1, 1, 0), f["dens"].shape)

# Plot DBSCAN results alongside the density
fig, axs = plt.subplots(2, 1)
hm1 = axs[0].imshow(f["dens"][72, :, :], norm=colors.LogNorm())
hm2 = axs[1].imshow(dbscan[72, :, :], cmap='binary')
plt.show()

Figure_1

@cstjean
Copy link
Owner

cstjean commented Feb 23, 2023

I'm not super-familiar with DBScan, I scanned your code and nothing looked obviously wrong. Beware that

dbscan = replace!(reshape(decomp,size(f["dens"])),-1=>0)

reshape is a view, so this line is also mutating decomp. But that shouldn't modify the outcome.

Beyond that, I can't offer advice other than: try to figure out what's different in Python and Julia. Ultimately, it's the same library doing the work, so presumably the inputs (or the plotting) is different.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants