In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
import numpy as np
import pandas as pd
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

  from tqdm.autonotebook import tqdm, trange


In [2]:
# Define a simple MLP model with LASSO-style regularization
class LassoMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LassoMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.beta = nn.Parameter(torch.ones(input_dim))  # Magnitude parameter β for each input feature

    def forward(self, x):
        # Apply the magnitude parameter β to the input
        x = x * self.beta
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
# Step 1: Define function to train or retrain the neural network
def train_network(model, X_train, y_train, lambda_, num_epochs=100, learning_rate=0.001):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    for epoch in tqdm(range(100)):
        model.train()
        optimizer.zero_grad()
        predictions = model(X_train)
        loss = criterion(predictions, y_train)

        # Add L1 regularization for β to implement LASSO
        lasso_penalty = lambda_ * torch.sum(torch.abs(model.beta))
        loss += lasso_penalty

        loss.backward()
        optimizer.step()

    return model

In [19]:
# Step 3: Use K-fold cross-validation to find the optimal lambda
def select_lambda(X, y, lambdas, k_folds=5):
    kf = KFold(n_splits=k_folds)
    best_lambda = None
    best_score = float('inf')

    for lambda_ in lambdas:
        print(f"Testing lambda = {lambda_}")
        scores = []

        for train_idx, val_idx in kf.split(X):
            model = LassoMLP(input_dim=X.shape[1], hidden_dim=64, output_dim=X.shape[1])
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]

            print(f"Training with {len(X_train)} samples, validating with {len(X_val)} samples")

            # Train the model and calculate the validation loss
            trained_model = train_network(model, X_train, y_train, lambda_)
            model.eval()
            with torch.no_grad():
                predictions = trained_model(X_val)
                val_loss = mean_squared_error(y_val.numpy(), predictions.numpy())
                scores.append(val_loss)

        # Take the average score over all folds
        avg_score = np.mean(scores)
        if avg_score < best_score:
            best_score = avg_score
            best_lambda = lambda_

    return best_lambda

In [None]:
# Step 4-5: Iteratively prune features with zero β
def lasso_mlp_algorithm(X, y, output_dim, hidden_dim=64, num_epochs=100, learning_rate=0.01, lambdas=[0.01, 0.1, 1.0]):
    input_dim = X.shape[1]
    output_dim = input_dim
    termination = False

    c = 0
    while not termination:
        print(f"Iteration {c + 1}:")
        # Initialize and train the network
        model = LassoMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
        
        # Select the optimal lambda
        best_lambda = select_lambda(X, y, lambdas)
        print(f"Optimal lambda: {best_lambda}")

        # Train the model with the optimal lambda
        trained_model = train_network(model, X, y, best_lambda, num_epochs=num_epochs, learning_rate=learning_rate)

        # Get the estimated β values
        beta_estimates = model.beta.detach().cpu().numpy()
        print("β values:", beta_estimates)

        # Step 5: Check for termination (if no non-zero β values)
        non_zero_beta = beta_estimates != 0
        if np.sum(non_zero_beta) == input_dim:  # If all β are non-zero, terminate
            termination = True
        else:
            # Remove features with β = 0 and create a new dataset
            X = X[:, non_zero_beta]
            input_dim = X.shape[1]  # Update input dimension
            c += 1

    return trained_model, beta_estimates

In [6]:
df = pd.read_csv('https://raw.githubusercontent.com/tiagoft/NLP/main/wiki_movie_plots_drama_comedy.csv')

x, y = df["Plot"][0:1000], df["Genre"][0:1000]

In [7]:
print("Loading Sentence Transformer model...")
model = SentenceTransformer("all-MiniLM-L6-v2")

x_embeddings = model.encode(x.tolist(), convert_to_tensor=True)
y_embeddings = model.encode(y.tolist(), convert_to_tensor=True)

Loading Sentence Transformer model...


In [9]:
print(x_embeddings.shape, y_embeddings.shape)

torch.Size([1000, 384]) torch.Size([1000, 384])


In [20]:
print("Training the LASSO-MLP algorithm...")
model, final_beta = lasso_mlp_algorithm(x_embeddings, y_embeddings, output_dim=x_embeddings.shape[1])

print("Final β values:", final_beta)

Training the LASSO-MLP algorithm...
Iteration 1:
Testing lambda = 0.01
Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 117.18it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 119.65it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 112.43it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 111.89it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 104.11it/s]


Testing lambda = 0.1
Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 105.73it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 114.68it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 111.75it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 110.52it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 104.34it/s]


Testing lambda = 1.0
Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 116.61it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 108.59it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 106.98it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 115.64it/s]


Training with 800 samples, validating with 200 samples


100%|██████████| 100/100 [00:00<00:00, 118.50it/s]


Optimal lambda: 0.1


100%|██████████| 100/100 [00:00<00:00, 111.87it/s]

β values: [ 4.90993261e-06  1.44354999e-06  5.64567745e-06  7.51949847e-06
  1.30375847e-05  1.51228160e-05  6.31716102e-06  6.08898699e-06
  2.51270831e-06  2.07126141e-06  4.02797014e-06  3.18419188e-06
  1.19116157e-06  5.02448529e-06  4.33996320e-06  7.47852027e-07
  1.61482021e-05  4.98443842e-06  9.14465636e-06  4.76185232e-06
  4.14904207e-06  9.41380858e-06  1.62236392e-06  1.93342566e-06
  4.39304858e-06  7.39376992e-06  4.24124300e-06  5.54695725e-06
  4.31202352e-06 -9.22009349e-07  3.86126339e-06  1.91852450e-06
  8.00937414e-07  9.06642526e-06  4.31854278e-06  9.00030136e-06
  5.42309135e-06  4.82238829e-06  8.98633152e-06  2.17650086e-06
  3.95905226e-06  4.10526991e-06  7.15255737e-07  5.41377813e-06
  5.09340316e-06  2.24355608e-06  1.35786831e-05  1.06627122e-05
  3.20188701e-06  4.51318920e-06  6.73532486e-06  1.40350312e-06
  8.76002014e-06  1.45155936e-05  2.61701643e-06  1.85705721e-06
  4.16301191e-06  2.37580389e-06  9.61218029e-06  4.85777855e-06
  1.30385160e-0


