Skip to content

Commit

Permalink
[MINOR] minor fixes in smote
Browse files Browse the repository at this point in the history
replace the rbind/cbind with indexing
rand call is updated with a seed value
  • Loading branch information
Shafaq-Siddiqi committed Oct 5, 2020
1 parent 8d1dfe9 commit 8501fb3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
37 changes: 25 additions & 12 deletions scripts/builtin/smote.dml
Expand Up @@ -46,41 +46,54 @@ return (Matrix[Double] Y) {
print("the number of samples should be an integral multiple of 100. Setting s = 100")
s = 100
}

if(k < 1) {
print("k should not be less than 1. Setting k value to default k = 1.")
k = 1
}

# matrix to keep the index of KNN for each minority sample
knn_index = matrix(0,k,0)
knn_index = matrix(0,k,nrow(X))
# find nearest neighbour
for(i in 1:nrow(X))
{
knn = nn(X, X[i, ], k)
knn_index = cbind(knn_index, knn)
knn_index[, i] = knn
}

# number of synthetic samples from each minority class sample
iter = (s/100)
iter = 0
iterLim = (s/100)
# matrix to store synthetic samples
synthetic_samples = matrix(0, 0, ncol(X))
while(iter > 0)
synthetic_samples = matrix(0, iterLim*ncol(knn_index), ncol(X))

# shuffle the nn indexes
if(k < iterLim)
rand_index = sample(k, iterLim, TRUE)
else
rand_index = sample(k, iterLim)

while(iter < iterLim)
{
# generate a random number
# TODO avoid duplicate random numbers
rand_index = as.integer(as.scalar(Rand(rows=1, cols=1, min=1, max=k)))
# pick the random NN
knn_sample = knn_index[rand_index,]
knn_sample = knn_index[as.scalar(rand_index[iter+1]),]
# generate sample
for(i in 1:ncol(knn_index))
{
index = as.scalar(knn_sample[1,i])
X_diff = X[index,] - X[i, ]
gap = as.scalar(Rand(rows=1, cols=1, min=0, max=1))
gap = as.scalar(Rand(rows=1, cols=1, min=0, max=1, seed = 41))
X_sys = X[i, ] + (gap*X_diff)
synthetic_samples = rbind(synthetic_samples, X_sys)
synthetic_samples[iter*ncol(knn_index)+i,] = X_sys;
}
iter = iter - 1
iter = iter + 1
}

Y = synthetic_samples

if(verbose)
print(nrow(Y)+ " synthesized samples generated.")

}


Expand Down
Expand Up @@ -47,9 +47,14 @@ public void setUp() {
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
}

@Test
public void testSmote0CP() {
runSmoteTest(100, 1, LopProperties.ExecType.CP);
}

@Test
public void testSmote1CP() {
runSmoteTest(300, 3, LopProperties.ExecType.CP);
runSmoteTest(300, 10, LopProperties.ExecType.CP);
}

@Test
Expand Down

0 comments on commit 8501fb3

Please sign in to comment.