# LLM is a great rule-based feature engineer in few-shot tabular learning
## Overview
This notebook runs training and inference for few-shot tabular learning task over benchmark datasets. Gemini model is used in this tutorial.

## Overall process
* Prepare datasets
* Extract rules for prediction from training samples with the help of LLM
* Parse rules to the program code and convert data into the binary vector
* Train the linear model to predict the likelihood of each class from the binary vector
* Make inference with ensembling

In [30]:
import os
import copy

import importlib
import utils
importlib.reload(utils)
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from torch.optim import Adam
from sklearn.model_selection import StratifiedKFold

## Prepare datasets
1. Set dataset and simulation parameters (e.g., # of queries for ensemble, # of training shots, and the random seed)
2. Get data and split it into train/test dataset, given simulation parameters

In [31]:
_NUM_QUERY = 4 # Number of ensembles
_SHOT = 4 # Number of training shots
_SEED = 0 # Seed for fixing randomness
_DATA = 'bank'
_API_KEY = 'AIzaSyB6jCiC6ssmox_zMSV_ckzE1wvbAWGxueo'

In [32]:
utils.set_seed(_SEED)
df, X_train, X_test, y_train, y_test, target_attr, label_list, is_cat = utils.get_dataset(_DATA, _SHOT, _SEED)
X_all = df.drop(target_attr, axis=1)
X_train.head(2)

Unnamed: 0,age,job,marital,education,default,balance,housing,loan,contact,day,month,duration,campaign,pdays,previous,poutcome
27130,57,management,married,tertiary,no,17118,yes,no,cellular,21,nov,102,1,-1,0,unknown
41120,67,retired,married,unknown,no,696,no,no,telephone,17,aug,119,1,105,2,failure


## Extract rules for prediction from training samples with the help of LLM
To enable the LLM to extract rules based on a more accurate reasoning path, we guided the problem-solving process to mimic how a person might approach a tabular learning task.   

We divided the problem into two sub-tasks for this purpose:   
1. Understand the task description and the features provided by the data, inferring the causal relationships beforehand.   
2. Use the inferred information and few-shot samples to deduce the prediction rules for each class. This two-step reasoning process prevents the model from identifying spurious correlations in irrelevant columns and assists in focusing on more significant features.   

Our prompt comprises three main components as follows:  
* Task description
* Reasoning instruction
* Response instruction

In [33]:
ask_file_name = './templates/ask_llm.txt'
meta_data_name = f"./data/{_DATA}-metadata.json"
templates, feature_desc = utils.get_prompt_for_asking(
    _DATA, X_all, X_train, y_train, label_list, target_attr, ask_file_name, 
    meta_data_name, is_cat, num_query=_NUM_QUERY
)
print(templates[0])

You are an expert. Given the task description and the list of features and data examples, you are extracting conditions for each answer class to solve the task.

Task: Does this client subscribe to a term deposit? Yes or no?


Features:
- age: age (numerical variable)
- job: type of job (categorical variable with categories [management, technician, entrepreneur, blue-collar, unknown, retired, admin., services, self-employed, unemployed, housemaid, student])
- marital: marital status (categorical variable with categories [married, single, divorced])
- education:  (categorical variable with categories [tertiary, secondary, unknown, primary])
- default: has credit in default? (categorical variable with categories [no, yes])
- balance: average yearly balance, in euros (numerical variable)
- housing: has housing loan? (categorical variable with categories [yes, no])
- loan:  (categorical variable with categories [no, yes])
- contact: contact communication type (categorical variable with cat

In [14]:
_DIVIDER = "\n\n---DIVIDER---\n\n"
_VERSION = "\n\n---VERSION---\n\n"

rule_file_name = f'./rules/rule-{_DATA}-{_SHOT}-{_SEED}.out'
if os.path.isfile(rule_file_name) == False:
    results = utils.query_gpt(templates, _API_KEY, max_tokens=1500, temperature=0.5)
    with open(rule_file_name, 'w') as f:
        total_rules = _DIVIDER.join(results)
        f.write(total_rules)
else:
    with open(rule_file_name, 'r') as f:
        total_rules_str = f.read().strip()
        results = total_rules_str.split(_DIVIDER)

print(results[0])

- age: age is not related to whether the client subscribes to a term deposit.
- job: job is not related to whether the client subscribes to a term deposit.
- marital: marital is not related to whether the client subscribes to a term deposit.
- education: education is not related to whether the client subscribes to a term deposit.
- default: default is not related to whether the client subscribes to a term deposit.
- balance: balance is not related to whether the client subscribes to a term deposit.
- housing: housing is not related to whether the client subscribes to a term deposit.
- loan: loan is not related to whether the client subscribes to a term deposit.
- contact: contact is not related to whether the client subscribes to a term deposit.
- day: day is not related to whether the client subscribes to a term deposit.
- month: month is not related to whether the client subscribes to a term deposit.
- duration: duration is not related to whether the client subscribes to a term depos

In [34]:
print(results)

['- age: age is not related to whether the client subscribes to a term deposit.\n- job: job is not related to whether the client subscribes to a term deposit.\n- marital: marital is not related to whether the client subscribes to a term deposit.\n- education: education is not related to whether the client subscribes to a term deposit.\n- default: default is not related to whether the client subscribes to a term deposit.\n- balance: balance is not related to whether the client subscribes to a term deposit.\n- housing: housing is not related to whether the client subscribes to a term deposit.\n- loan: loan is not related to whether the client subscribes to a term deposit.\n- contact: contact is not related to whether the client subscribes to a term deposit.\n- day: day is not related to whether the client subscribes to a term deposit.\n- month: month is not related to whether the client subscribes to a term deposit.\n- duration: duration is not related to whether the client subscribes to

## Parse rules to the program code and convert data into the binary vector

We utilize the rules generated in the previous stage to transform each sample into a binary vector. These vectors are created for each answer class, indicating whether the sample satisfies the rules associated with that class. However, since the rules generated by the LLM are based on natural language, parsing the text into program code is required for automatic data transformation.  

To address the challenges of parsing noisy text, instead of building complex program code, we leverage the LLM itself. We include the function name, input and output descriptions, and inferred rules in the prompt, then input it into the LLM. The generated code is executed using Python’s exec() function along with the provided function name to perform data conversion.

In [35]:
parsed_rules = utils.parse_rules(results, label_list)

saved_file_name = f'./rules/function-{_DATA}-{_SHOT}-{_SEED}.out'    
if os.path.isfile(saved_file_name) == False:
    function_file_name = './templates/ask_for_function.txt'
    fct_strs_all = []
    for parsed_rule in tqdm(parsed_rules):
        fct_templates = utils.get_prompt_for_generating_function(
            parsed_rule, feature_desc, function_file_name
        )
        fct_results = utils.query_gpt(fct_templates, _API_KEY, max_tokens=1500, temperature=0)
        fct_strs = [fct_txt.split('<start>')[1].split('<end>')[0].strip() for fct_txt in fct_results]
        fct_strs_all.append(fct_strs)

    with open(saved_file_name, 'w') as f:
        total_str = _VERSION.join([_DIVIDER.join(x) for x in fct_strs_all])
        f.write(total_str)
else:
    with open(saved_file_name, 'r') as f:
        total_str = f.read().strip()
        fct_strs_all = [x.split(_DIVIDER) for x in total_str.split(_VERSION)]

In [36]:
# Get function names and strings
fct_names = []
fct_strs_final = []
for fct_str_pair in fct_strs_all:
    fct_pair_name = []
    if 'def' not in fct_str_pair[0]:
        continue

    for fct_str in fct_str_pair:
        fct_pair_name.append(fct_str.split('def')[1].split('(')[0].strip())
    fct_names.append(fct_pair_name)
    fct_strs_final.append(fct_str_pair)

In [37]:
print(fct_strs_final)

[["def extracting_features_no(df_input):\n    df_output = pd.DataFrame()\n    df_output['age_range'] = (df_input['age'] >= 20) & (df_input['age'] <= 70)\n    df_output['job_in'] = df_input['job'].isin(['management', 'technician', 'entrepreneur', 'blue-collar', 'unknown', 'retired', 'admin.', 'services', 'self-employed', 'unemployed', 'housemaid', 'student'])\n    df_output['marital_in'] = df_input['marital'].isin(['married', 'single', 'divorced'])\n    df_output['education_in'] = df_input['education'].isin(['tertiary', 'secondary', 'unknown', 'primary'])\n    df_output['default_in'] = df_input['default'].isin(['no', 'yes'])\n    df_output['balance_range'] = (df_input['balance'] >= 0) & (df_input['balance'] <= 20000)\n    df_output['housing_in'] = df_input['housing'].isin(['yes', 'no'])\n    df_output['loan_in'] = df_input['loan'].isin(['no', 'yes'])\n    df_output['contact_in'] = df_input['contact'].isin(['unknown', 'cellular', 'telephone'])\n    df_output['day_range'] = (df_input['day

In [38]:
print(fct_strs_final[0][0])

def extracting_features_no(df_input):
    df_output = pd.DataFrame()
    df_output['age_range'] = (df_input['age'] >= 20) & (df_input['age'] <= 70)
    df_output['job_in'] = df_input['job'].isin(['management', 'technician', 'entrepreneur', 'blue-collar', 'unknown', 'retired', 'admin.', 'services', 'self-employed', 'unemployed', 'housemaid', 'student'])
    df_output['marital_in'] = df_input['marital'].isin(['married', 'single', 'divorced'])
    df_output['education_in'] = df_input['education'].isin(['tertiary', 'secondary', 'unknown', 'primary'])
    df_output['default_in'] = df_input['default'].isin(['no', 'yes'])
    df_output['balance_range'] = (df_input['balance'] >= 0) & (df_input['balance'] <= 20000)
    df_output['housing_in'] = df_input['housing'].isin(['yes', 'no'])
    df_output['loan_in'] = df_input['loan'].isin(['no', 'yes'])
    df_output['contact_in'] = df_input['contact'].isin(['unknown', 'cellular', 'telephone'])
    df_output['day_range'] = (df_input['day'] >= 1) & (df

### Convert to binary vectors

In [39]:
executable_list, X_train_all_dict, X_test_all_dict = utils.convert_to_binary_vectors(fct_strs_final, fct_names, label_list, X_train, X_test)

## Train the linear model to predict the likelihood of each class from the binary vector
When given the rules for each class and a sample, a simple method to measure the class likelihood of the sample is to count how many rules of each class it satisfies (i.e., the sum of the binary vector per class). However, not all rules carry the same importance, necessitating learning their significance from training samples.    
  
We aimed to train this importance using a basic linear model without bias, applied to each class's binary vector.

In [40]:
class simple_model(nn.Module):
    def __init__(self, X):
        super(simple_model, self).__init__()
        self.weights = nn.ParameterList([nn.Parameter(torch.ones(x_each.shape[1] , 1) / x_each.shape[1]) for x_each in X])
        
    def forward(self, x):
        x_total_score = []
        for idx, x_each in enumerate(x):
            x_score = x_each @ torch.clamp(self.weights[idx], min=0)
            x_total_score.append(x_score)
        x_total_score = torch.cat(x_total_score, dim=-1)
        return x_total_score

In [41]:
def train(X_train_now, label_list, shot):
    criterion = nn.CrossEntropyLoss()                
    if shot // len(label_list) == 1:
        model = simple_model(X_train_now)
        opt = Adam(model.parameters(), lr=1e-2)
        for _ in range(200):                    
            opt.zero_grad()
            outputs = model(X_train_now)
            preds = outputs.argmax(dim=1).numpy()
            acc = (np.array(y_train_num) == preds).sum() / len(preds)
            if acc == 1:
                break
            #loss = criterion(outputs, torch.tensor(y_train_num))
            loss = criterion(outputs, torch.tensor(y_train_num, dtype=torch.long))
            loss.backward()
            opt.step()
    else:
        if shot // len(label_list) <= 2:
            n_splits = 2
        else:
            n_splits = 4

        kfold = StratifiedKFold(n_splits=n_splits, shuffle=True)
        model_list = []
        for fold, (train_ids, valid_ids) in enumerate(kfold.split(X_train_now[0], y_train_num)):
            model = simple_model(X_train_now)
            opt = Adam(model.parameters(), lr=1e-2)
            X_train_now_fold = [x_train_now[train_ids] for x_train_now in X_train_now]
            X_valid_now_fold = [x_train_now[valid_ids] for x_train_now in X_train_now]
            y_train_fold = y_train_num[train_ids]
            y_valid_fold = y_train_num[valid_ids]

            max_acc = -1
            for _ in range(200):                    
                opt.zero_grad()
                outputs = model(X_train_now_fold)
                #loss = criterion(outputs, torch.tensor(y_train_fold))
                loss = criterion(outputs, torch.tensor(y_train_fold, dtype=torch.long))
                loss.backward()
                opt.step()

                valid_outputs = model(X_valid_now_fold)
                preds = valid_outputs.argmax(dim=1).numpy()
                acc = (np.array(y_valid_fold) == preds).sum() / len(preds)
                if max_acc < acc:
                    max_acc = acc 
                    final_model = copy.deepcopy(model)
                    if max_acc >= 1:
                        break
            model_list.append(final_model)

        sdict = model_list[0].state_dict()
        for key in sdict:
            sdict[key] = torch.stack([model.state_dict()[key] for model in model_list], dim=0).mean(dim=0)

        model = simple_model(X_train_now)
        model.load_state_dict(sdict)
    return model

In [42]:
test_outputs_all = []
multiclass = True if len(label_list) > 2 else False
y_train_num = np.array([label_list.index(k) for k in y_train])
y_test_num = np.array([label_list.index(k) for k in y_test])

for i in executable_list:
    X_train_now = list(X_train_all_dict[i].values())
    X_test_now = list(X_test_all_dict[i].values())
    
    # Train
    trained_model = train(X_train_now, label_list, _SHOT)

    # Evaluate
    test_outputs = trained_model(X_test_now).detach().cpu()
    test_outputs = F.softmax(test_outputs, dim=1).detach()
    result_auc = utils.evaluate(test_outputs.numpy(), y_test_num, multiclass=multiclass)
    print("AUC:", result_auc)
    test_outputs_all.append(test_outputs)
test_outputs_all = np.stack(test_outputs_all, axis=0)
ensembled_probs = test_outputs_all.mean(0)
result_auc = utils.evaluate(ensembled_probs, y_test_num, multiclass=multiclass)
print("No of shots:",_SHOT)
print("No of Ensembles:",_NUM_QUERY)
print("Ensembled AUC for",_DATA ,"dataset is", result_auc)


AUC: 0.5
AUC: 0.5823375113782576
AUC: 0.5374869349785101
AUC: 0.49436958238095297
AUC: 0.4841491549017357
No of shots: 4
No of Ensembles: 4
Ensembled AUC for bank dataset is 0.5327799761604047
