In [183]:
import jax
import jax.numpy as jnp
from jax import jit

In [184]:
def compute_fuzzy_similarity(X,weights,sigma=1.0):
    # weights are already sigmoid(w)
    # X shape: (n,m)

    # Weighted difference: (X_i - X_j)^2 * w^2
    # We use broadcasting to get (N,M,M) differences
    diff=X[:,jnp.newaxis,:]-X[jnp.newaxis,:,:]

    weighted_sq_dist=jnp.sum((weights**2)*(diff**2),axis=-1)

    return jnp.exp(-weighted_sq_dist/(2 * sigma**2))

In [185]:
def soft_lower_approximation(R,labels,alpha=10.0):
    #labels is the vector of shape (N,) [1,0,1]
    #R is the similaity matrix found from the prev func

    #Mask[i,j] true if i and j have diff labels
    diff_mask=labels[:,jnp.newaxis]!=labels[jnp.newaxis,:]
    diff_similarities=jnp.where(diff_mask,R,-1e9)
    worst_diff_similarity=jax.nn.logsumexp(alpha*diff_similarities,axis=1)/alpha
    mu=1.0-worst_diff_similarity
    return mu


In [186]:
def calculate_rs_loss(mu):
    gamma=jnp.mean(mu)
    return 1.0-gamma


In [187]:
def total_loss_fn(params,X,y,lambda1=0.1,lambda2=0.07):
    #converting raw weights into sigmoid so that it lies between 0 and 1
    weights=jax.nn.sigmoid(params['w'])

    #1. Fuzzy reln
    R=compute_fuzzy_similarity(X,weights)
    
    #2. Soft layer approx
    mu=soft_lower_approximation(R,y)

    #3. RS loss
    l_rs=calculate_rs_loss(mu)

    #4. L1 regularisation to improve accuracy (Feature selection penalty)
    l_l1=jnp.sum(jnp.abs(weights))

    #Total loss= classifier loss (cross entropy)+lambda1*RS_Loss+lambda2
    return lambda1*l_rs+lambda2*l_l1

In [188]:
key=jax.random.PRNGKey(39)
N,M=300,4
X=jax.random.normal(key,(N,M))
y=(X[:,0]>0).astype(jnp.int32)
initial_raw_value = jnp.log(0.6 / (1 - 0.6))
#Weight gen
params = {
    'w': jnp.full((M,), initial_raw_value)
}


In [189]:
@jit
def update(params,x,y):
    grads=jax.grad(total_loss_fn)(params,X,y)
    #Simple Gradient Descent
    return {k : v-0.1*grads[k] for k,v in params.items()}

In [190]:
#Train
print("----------------------TRAINING START--------------------")

for i in range(500):
    params=update(params,X,y)
    if i%20==0:
        current_weights=jax.nn.sigmoid(params['w'])
        loss=total_loss_fn(params,X,y)
        print(f"Iter {i:3} | Loss {loss:.4f} | Weights: {current_weights}")

final_weights = jax.nn.sigmoid(params['w'])
for i, w in enumerate(final_weights):
    status = "SELECTED" if w > 0.5 else "DISCARDED"
    print(f"Feature {i}: Weight {w:.4f} -> {status}")


----------------------TRAINING START--------------------
Iter   0 | Loss 0.2625 | Weights: [0.59985614 0.5996786  0.5996767  0.5996796 ]
Iter  20 | Loss 0.2613 | Weights: [0.5969839  0.5932225  0.59318274 0.59324396]
Iter  40 | Loss 0.2602 | Weights: [0.59412163 0.58671635 0.5866358  0.58675855]
Iter  60 | Loss 0.2590 | Weights: [0.59126997 0.5801646  0.58004016 0.5802278 ]
Iter  80 | Loss 0.2578 | Weights: [0.5884294  0.57357144 0.57340026 0.5736561 ]
Iter 100 | Loss 0.2567 | Weights: [0.5856004  0.5669416  0.56672066 0.567048  ]
Iter 120 | Loss 0.2555 | Weights: [0.5827833  0.5602798  0.56000614 0.5604084 ]
Iter 140 | Loss 0.2543 | Weights: [0.5799787  0.5535909  0.55326146 0.553742  ]
Iter 160 | Loss 0.2531 | Weights: [0.57718706 0.54687977 0.5464916  0.5470539 ]
Iter 180 | Loss 0.2520 | Weights: [0.5744086  0.54015136 0.5397016  0.540349  ]
Iter 200 | Loss 0.2508 | Weights: [0.5716438  0.5334109  0.53289664 0.5336325 ]
Iter 220 | Loss 0.2496 | Weights: [0.5688931  0.52666336 0.5260