In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

notebook_path = Path().absolute()
sys.path.append(str(notebook_path.parent))

In [3]:
import torch
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from neural_controllers import NeuralController

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from utils import programming_language_dataset, pca_programming_language_dataset

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [5]:
import torch
torch.cuda.empty_cache()

model_type = 'llama'

if model_type=='llama':
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    language_model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda"
    )

    use_fast_tokenizer = "LlamaForCausalLM" not in language_model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    model_name='llama_3_8b_it'
    
elif model_type=='gemma':

    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
    language_model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-9b-it",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_name='gemma_2_9b_it'

Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.81s/it]


In [6]:
# concept_types = ['python', 'c++']
concept_types = ['python', 'javascript']
data_dir = "../data/programming"

dataset = programming_language_dataset(concept_types, tokenizer)
# dataset = pca_programming_language_dataset(concept_types, tokenizer)

Generating train split: 100%|██████████| 2360/2360 [00:00<00:00, 59167.57 examples/s]

train 500 test 500
train 500 test 500





In [7]:
controllers = {}
for concept_type in tqdm(concept_types):
    
    other_type = [k for k in concept_types if k != concept_type][0]
    
    train_data = dataset[concept_type]['train']
    test_data = dataset[concept_type]['test']
        
    language_controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        batch_size=2,
    )
    
    language_controller.compute_directions(train_data['inputs'], train_data['labels'])
    
    controllers[concept_type] = language_controller
    

  0%|          | 0/2 [00:00<?, ?it/s]

n_components: 5
Hidden layers KA: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 2
M_batch_size         : 2048
n_components         : 5
calibrate            : False

Tuning metric: auc
Getting activations from forward passes


100%|██████████| 200/200 [03:57<00:00,  1.19s/it]


Getting activations from forward passes


100%|██████████| 50/50 [00:43<00:00,  1.14it/s]


train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.09006881713867188 seconds
Optimal M batch size: 400
Time taken for round 1: 0.008342742919921875 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010863065719604492 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009906530380249023 seconds
Optimal M batch size: 400
Time taken for round 4: 0.01029515266418457 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010339975357055664 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010383367538452148 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010297060012817383 seconds
Optimal M batch size: 400
Debug: y_true shape: (100



train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010365486145019531 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00900888442993164 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009883642196655273 seconds
Optimal M batch size: 400
Time taken for round 3: 0.01014256477355957 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010121822357177734 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010124683380126953 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010128259658813477 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010114669799804688 seconds
Optimal M batch size: 400
Debug: y_true shape: (100



Optimal M batch size: 400
Time taken for round 1: 0.009755849838256836 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010098457336425781 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010097742080688477 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010084390640258789 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010072469711303711 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010066509246826172 seconds
Optimal M batch size: 400
Time taken for round 7: 0.01007533073425293 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007559537887573242 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00967860221862793 seconds
Optimal M batch size: 400
Time taken for round 2: 0.01004052162170



Optimal M batch size: 400
Time taken for round 1: 0.009354591369628906 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010050296783447266 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010052680969238281 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010036945343017578 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010058403015136719 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010049104690551758 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010059118270874023 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007053375244140625 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009688615798950195 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009999036788



Optimal M batch size: 400
Time taken for round 1: 0.009674549102783203 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010040283203125 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010020256042480469 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010040998458862305 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010015010833740234 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010048389434814453 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010021448135375977 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007168769836425781 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009674072265625 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009999513626098633



Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.006979227066040039 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009649038314819336 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010012149810791016 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010020017623901367 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010008573532104492 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010030984878540039 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010022640228271484 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010034561157226562 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.509335994720459 seconds
Time taken to compute eigenv



Time taken for round 0: 0.007199764251708984 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009557723999023438 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009993791580200195 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010022640228271484 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010013580322265625 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010040283203125 seconds
Optimal M batch size: 400
Time taken for round 6: 0.01014089584350586 seconds
Optimal M batch size: 400
Time taken for round 7: 0.00989222526550293 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.578925609588623 seconds
Time taken to compute eigenvectors: 0.01833939552307129 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100



Time taken for round 0: 0.007283926010131836 seconds
Optimal M batch size: 400
Time taken for round 1: 0.0093841552734375 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009992361068725586 seconds
Optimal M batch size: 400
Time taken for round 3: 0.00999593734741211 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010006427764892578 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010032415390014648 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010012626647949219 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010026216506958008 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5801279544830322 seconds
Time taken to compute eigenvectors: 0.023176193237304688 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size(



Time taken for round 0: 0.007298946380615234 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00937342643737793 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010042190551757812 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010025262832641602 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009996414184570312 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010012388229370117 seconds
Optimal M batch size: 400
Time taken for round 6: 0.00998234748840332 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010010242462158203 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5795037746429443 seconds
Time taken to compute eigenvectors: 0.020725250244140625 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size



Optimal M batch size: 400
Time taken for round 0: 0.0076711177825927734 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00964665412902832 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010015249252319336 seconds
Optimal M batch size: 400
Time taken for round 3: 0.00999903678894043 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010007858276367188 seconds
Optimal M batch size: 400
Time taken for round 5: 0.01001119613647461 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010019540786743164 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010016202926635742 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5825924873352051 seconds
Time taken to compute eigenvectors: 0.018155813217163086 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1



Optimal M batch size: 400
Time taken for round 0: 0.0077207088470458984 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009645462036132812 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009998559951782227 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009990215301513672 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009994983673095703 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010009288787841797 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009997367858886719 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010016918182373047 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5823357105255127 seconds
Time taken to compute eigenvectors: 0.017153024673461914 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400



Optimal M batch size: 400
Time taken for round 0: 0.007665872573852539 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009810209274291992 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009872674942016602 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010023355484008789 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010013103485107422 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010020017623901367 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009992361068725586 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010010480880737305 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5823817253112793 seconds
Time taken to compute eigenvectors: 0.018785715103149414 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400,



Optimal M batch size: 400
Time taken for round 0: 0.007666110992431641 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009658575057983398 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010014533996582031 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009972810745239258 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010009527206420898 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009974002838134766 seconds
Optimal M batch size: 400
Time taken for round 6: 0.00998830795288086 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009975433349609375 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5821304321289062 seconds
Time taken to compute eigenvectors: 0.02312612533569336 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1



Time taken for round 0: 0.007189035415649414 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009607076644897461 seconds
Optimal M batch size: 400
Time taken for round 2: 0.00999903678894043 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010013580322265625 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010001420974731445 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009977102279663086 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009984970092773438 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009991884231567383 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5820779800415039 seconds
Time taken to compute eigenvectors: 0.016503095626831055 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Siz



Optimal M batch size: 400
Time taken for round 1: 0.009747982025146484 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010011672973632812 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009980201721191406 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009990692138671875 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009993314743041992 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009989023208618164 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009989261627197266 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5771026611328125 seconds
Time taken to compute eigenvectors: 0.018259525299072266 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fix



Optimal M batch size: 400
Time taken for round 2: 0.010094404220581055 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010026216506958008 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010007858276367188 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009998083114624023 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009986162185668945 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010038614273071289 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5786750316619873 seconds
Time taken to compute eigenvectors: 0.016427993774414062 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM wi



Optimal M batch size: 400
Time taken for round 3: 0.010101556777954102 seconds
Optimal M batch size: 400
Time taken for round 4: 0.01014089584350586 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009869098663330078 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009984731674194336 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009991884231567383 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5782573223114014 seconds
Time taken to compute eigenvectors: 0.018256425857543945 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and



Optimal M batch size: 400
Time taken for round 4: 0.010092973709106445 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009995222091674805 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009996414184570312 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009980440139770508 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.578528881072998 seconds
Time taken to compute eigenvectors: 0.01977992057800293 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.01047158241271972



Optimal M batch size: 400
Time taken for round 4: 0.01009368896484375 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009988546371459961 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009991884231567383 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010007619857788086 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5816502571105957 seconds
Time taken to compute eigenvectors: 0.01803731918334961 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.01047301292419433



Optimal M batch size: 400
Time taken for round 5: 0.010073184967041016 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009994029998779297 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009990215301513672 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5816493034362793 seconds
Time taken to compute eigenvectors: 0.023437023162841797 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010490655899047852 seconds
Optimal M batch size: 400
Time taken for round 1: 0.008869409561157



Time taken to compute eigenvectors: 2.4337823390960693 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010491132736206055 seconds
Optimal M batch size: 400
Time taken for round 1: 0.008828163146972656 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009806156158447266 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010038375854492188 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010012149810791016 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010006904602050781 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010010004043579102 seconds
Optimal M batch size: 400
Time taken for round 7: 0.0100226402282



Optimal M batch size: 400
Time taken for round 1: 0.009735584259033203 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010040521621704102 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010004043579101562 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010022878646850586 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010143756866455078 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009871482849121094 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010006189346313477 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007500171661376953 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009643316268920898 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009976387023



Optimal M batch size: 400
Time taken for round 2: 0.010453939437866211 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010018348693847656 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010160446166992188 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009870052337646484 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010022163391113281 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010019063949584961 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007551670074462891 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009646415710449219 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009982109069824219 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009997129440



Optimal M batch size: 400
Time taken for round 4: 0.010101556777954102 seconds
Optimal M batch size: 400
Time taken for round 5: 0.01003408432006836 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010006427764892578 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010012149810791016 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007543325424194336 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009645462036132812 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009993314743041992 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010019063949584961 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009974241256713867 seconds
Optimal M batch size: 400
Time taken for round 5: 0.0100095272064



Optimal M batch size: 400
Time taken for round 5: 0.010464191436767578 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010015249252319336 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010137081146240234 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007062196731567383 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009616613388061523 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009984970092773438 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009987592697143555 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009972095489501953 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009987354278564453 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009981632232



Optimal M batch size: 400
Time taken for round 6: 0.009987831115722656 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010022640228271484 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007176876068115234 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009646892547607422 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009984493255615234 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009989500045776367 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009979963302612305 seconds
Optimal M batch size: 400
Time taken for round 5: 0.00997471809387207 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009999275207519531 seconds
Optimal M batch size: 400
Time taken for round 7: 0.0099871158599



Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.0070497989654541016 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009796380996704102 seconds
Optimal M batch size: 400
Time taken for round 2: 0.00984334945678711 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009983301162719727 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009971857070922852 seconds
Optimal M batch size: 400
Time taken for round 5: 0.00999903678894043 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009986400604248047 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009987115859985352 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.0



Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.006941795349121094 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009525537490844727 seconds
Optimal M batch size: 400
Time taken for round 2: 0.01000070571899414 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009991645812988281 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009980440139770508 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009982109069824219 seconds
Optimal M batch size: 400
Time taken for round 6: 0.0099945068359375 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009979963302612305 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: False
Time taken to train rfm probe: 0.5783348083496094 seconds
Time taken to compute eigenvec



Optimal M batch size: 400
Time taken for round 1: 0.009925365447998047 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009999752044677734 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009983062744140625 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009984970092773438 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009977579116821289 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010142326354980469 seconds
Optimal M batch size: 400
Time taken for round 7: 0.00981283187866211 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.5776093006134033 seconds
Time taken to compute eigenvectors: 0.009963512420654297 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed



Time taken for round 2: 0.010395050048828125 seconds
Optimal M batch size: 400
Time taken for round 3: 0.00956583023071289 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009998798370361328 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009990692138671875 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009996414184570312 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009993553161621094 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.5758702754974365 seconds
Time taken to compute eigenvectors: 0.016320228576660156 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_g



Optimal M batch size: 400
Time taken for round 3: 0.010456562042236328 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009991168975830078 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009993791580200195 seconds
Optimal M batch size: 400
Time taken for round 6: 0.00998544692993164 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009983539581298828 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 0.9975961538461539, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.579326868057251 seconds
Time taken to compute eigenvectors: 0.020431995391845703 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400,

100%|██████████| 31/31 [00:21<00:00,  1.47it/s]

Optimal M batch size: 400
Time taken for round 4: 0.010091781616210938 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009998321533203125 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009984493255615234 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009994983673095703 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 0.9935897435897436, reg: 0.001, bw: 10, center_grads: False
Time taken to train rfm probe: 0.5839502811431885 seconds
Time taken to compute eigenvectors: 0.005272626876831055 seconds



100%|██████████| 31/31 [00:00<00:00, 6516.48it/s]
 50%|█████     | 1/2 [05:03<05:03, 303.38s/it]

n_components: 5
Hidden layers KA: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 2
M_batch_size         : 2048
n_components         : 5
calibrate            : False

Tuning metric: auc
Getting activations from forward passes


100%|██████████| 200/200 [04:02<00:00,  1.21s/it]


Getting activations from forward passes


100%|██████████| 50/50 [01:09<00:00,  1.39s/it]


train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.012999296188354492 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009402990341186523 seconds
Optimal M batch size: 400
Time taken for round 2: 0.00986337661743164 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010302305221557617 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010283231735229492 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010245323181152344 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010300159454345703 seconds
Optimal M batch size: 400
Time taken for round 7: 0.01027822494506836 seconds
Optimal M batch size: 400
Debug: y_true shape: (100



Optimal M batch size: 400
Time taken for round 4: 0.010106563568115234 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010018348693847656 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010036468505859375 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010143280029296875 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 100, center_grads: True
Time taken to train rfm probe: 0.4587106704711914 seconds
Time taken to compute eigenvectors: 0.019597530364990234 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.01045989990234



Optimal M batch size: 400
Time taken for round 4: 0.010118722915649414 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010146141052246094 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009856224060058594 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010011672973632812 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5831723213195801 seconds
Time taken to compute eigenvectors: 0.018219947814941406 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010435819625854



Optimal M batch size: 400
Time taken for round 5: 0.010070085525512695 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010013341903686523 seconds
Optimal M batch size: 400
Time taken for round 7: 0.01001119613647461 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5828258991241455 seconds
Time taken to compute eigenvectors: 0.018033742904663086 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010439395904541016 seconds
Early stopping at iteration 1
Optimal M batch size: 400
Debug: y_tru



Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007485151290893555 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00962376594543457 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010004281997680664 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010009050369262695 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009994983673095703 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010012149810791016 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010001897811889648 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010013818740844727 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.4419398307800293 seconds
Time taken to compute eigenv



Optimal M batch size: 400
Time taken for round 7: 0.01012420654296875 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5139729976654053 seconds
Time taken to compute eigenvectors: 0.01655125617980957 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010409832000732422 seconds
Optimal M batch size: 400
Time taken for round 1: 0.008884906768798828 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009823322296142578 seconds
Optimal M batch size: 400
Time taken for round 3: 0.01009297370910644



Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.582751989364624 seconds
Time taken to compute eigenvectors: 0.016478300094604492 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.009484291076660156 seconds
Optimal M batch size: 400
Time taken for round 1: 0.008887290954589844 seconds
Optimal M batch size: 400
Time taken for round 2: 0.00978851318359375 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010062932968139648 seconds
Optimal M batch size: 400
Time taken for round 4: 0.01004242897033691



Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5783498287200928 seconds
Time taken to compute eigenvectors: 0.01976156234741211 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010534048080444336 seconds
Optimal M batch size: 400
Time taken for round 1: 0.008925676345825195 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009841442108154297 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010056495666503906 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010046243667602539 seconds
Optimal M batch size: 400
Time taken for round 5: 0.01004934310913086 seconds
Optimal M batch size: 400
Ti



Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.5794103145599365 seconds
Time taken to compute eigenvectors: 0.01889801025390625 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010447263717651367 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00886082649230957 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009811162948608398 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010024785995483398 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010054349899291992 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010074138641357422 seconds
Optimal M batch size: 400
Ti



Time taken to compute eigenvectors: 0.012096166610717773 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010443925857543945 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00886678695678711 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009797334671020508 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010035037994384766 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010020256042480469 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010027170181274414 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010063409805297852 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010079622268



Optimal M batch size: 400
Time taken for round 1: 0.00961923599243164 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010036230087280273 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010040283203125 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010013103485107422 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010023832321166992 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010016679763793945 seconds
Optimal M batch size: 400
Time taken for round 7: 0.01006770133972168 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007142066955566406 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009522438049316406 seconds
Optimal M batch size: 400
Time taken for round 2: 0.00997424125671386



Optimal M batch size: 400
Time taken for round 1: 0.00973963737487793 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010047674179077148 seconds
Optimal M batch size: 400
Time taken for round 3: 0.01002359390258789 seconds
Optimal M batch size: 400
Time taken for round 4: 0.01002812385559082 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010019302368164062 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010039567947387695 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010024547576904297 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.0075762271881103516 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009606599807739258 seconds
Optimal M batch size: 400
Time taken for round 2: 0.01002311706542



Optimal M batch size: 400
Time taken for round 2: 0.010098457336425781 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010040521621704102 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010036706924438477 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010023117065429688 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010003805160522461 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010010480880737305 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007578134536743164 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009588003158569336 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009993791580200195 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010000228881



Optimal M batch size: 400
Time taken for round 3: 0.010098695755004883 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010057687759399414 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010158777236938477 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009884834289550781 seconds
Optimal M batch size: 400
Time taken for round 7: 0.01002049446105957 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007580757141113281 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00960993766784668 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010016441345214844 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010000467300415039 seconds
Optimal M batch size: 400
Time taken for round 4: 0.00999808311462



Optimal M batch size: 400
Time taken for round 4: 0.010086774826049805 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010071039199829102 seconds
Optimal M batch size: 400
Time taken for round 6: 0.009993791580200195 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010036468505859375 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007576942443847656 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009638786315917969 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009992599487304688 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010012149810791016 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009976387023925781 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009997367858



Optimal M batch size: 400
Time taken for round 5: 0.010293245315551758 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010046243667602539 seconds
Optimal M batch size: 400
Time taken for round 7: 0.0100250244140625 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.0070629119873046875 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009641885757446289 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009998321533203125 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009991645812988281 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009988546371459961 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009995698928833008 seconds
Optimal M batch size: 400
Time taken for round 6: 0.0100162029266



Optimal M batch size: 400
Time taken for round 6: 0.010429620742797852 seconds
Optimal M batch size: 400
Time taken for round 7: 0.01002645492553711 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007085084915161133 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009627342224121094 seconds
Optimal M batch size: 400
Time taken for round 2: 0.01000213623046875 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010013341903686523 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010011911392211914 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010006904602050781 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010024785995483398 seconds
Optimal M batch size: 400
Time taken for round 7: 0.01000070571899



Optimal M batch size: 400
Time taken for round 7: 0.010445594787597656 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007085323333740234 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009769678115844727 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009850740432739258 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010026693344116211 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010008811950683594 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009993553161621094 seconds
Optimal M batch size: 400
Time taken for round 6: 0.00998997688293457 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009998083114624023 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba



Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.00707697868347168 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009626388549804688 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010148286819458008 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009839057922363281 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009996891021728516 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010007381439208984 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010007381439208984 seconds
Optimal M batch size: 400
Time taken for round 7: 0.01000666618347168 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.00



Optimal M batch size: 400
Time taken for round 0: 0.007150411605834961 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009652376174926758 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009979248046875 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010008573532104492 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009989500045776367 seconds
Optimal M batch size: 400
Time taken for round 5: 0.01000833511352539 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010121822357177734 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009870052337646484 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.5764052867889404 seconds
Time taken to compute eigenvectors: 0.010628700256347656 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) 



Optimal M batch size: 400
Time taken for round 2: 0.010481119155883789 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010018348693847656 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009995698928833008 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009997367858886719 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010001182556152344 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009985923767089844 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.5788795948028564 seconds
Time taken to compute eigenvectors: 0.012501239776611328 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM wit



Optimal M batch size: 400
Time taken for round 3: 0.010420799255371094 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010012149810791016 seconds
Optimal M batch size: 400
Time taken for round 5: 0.009980916976928711 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010004520416259766 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009983301162719727 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.5783843994140625 seconds
Time taken to compute eigenvectors: 0.00998830795288086 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and 



Optimal M batch size: 400
Time taken for round 5: 0.010241270065307617 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010016918182373047 seconds
Optimal M batch size: 400
Time taken for round 7: 0.009995222091674805 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.578197717666626 seconds
Time taken to compute eigenvectors: 0.006007671356201172 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010455846786499023 seconds
Optimal M batch size: 400
Time taken for round 1: 0.0085906982421875 



Optimal M batch size: 400
Time taken for round 7: 0.01019144058227539 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.5805280208587646 seconds
Time taken to compute eigenvectors: 0.0060160160064697266 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010463237762451172 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00858163833618164 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010037422180175781 seconds
Optimal M batch size: 400
Time taken for round 3: 0.01003623008728027



Time taken to train rfm probe: 0.5801446437835693 seconds
Time taken to compute eigenvectors: 0.02590346336364746 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010450601577758789 seconds
Optimal M batch size: 400
Time taken for round 1: 0.008580684661865234 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010041952133178711 seconds
Optimal M batch size: 400
Time taken for round 3: 0.01003408432006836 seconds
Optimal M batch size: 400
Time taken for round 4: 0.01005101203918457 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010028362274169922 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010023117065429688 seconds
Optimal 



Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.5800139904022217 seconds
Time taken to compute eigenvectors: 0.021412372589111328 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.01047515869140625 seconds
Optimal M batch size: 400
Time taken for round 1: 0.008570194244384766 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010041475296020508 seconds
Optimal M batch size: 400
Time taken for round 3: 0.01018834114074707 seconds
Optimal M batch size: 400
Time taken for round 4: 0.00986480712890625 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010042190551757812 seconds
Optimal M batch size: 400
Time



Time taken to compute eigenvectors: 2.5057199001312256 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010261297225952148 seconds
Optimal M batch size: 400
Time taken for round 1: 0.008591890335083008 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010059118270874023 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010064125061035156 seconds
Optimal M batch size: 400
Time taken for round 4: 0.01006007194519043 seconds
Optimal M batch size: 400
Time taken for round 5: 0.01004338264465332 seconds
Optimal M batch size: 400
Time taken for round 6: 0.01005244255065918 seconds
Optimal M batch size: 400
Time taken for round 7: 0.0100367069244384



Time taken to compute eigenvectors: 2.549999237060547 seconds
train X shape: torch.Size([400, 4096]) train y shape: torch.Size([400, 1]) val X shape: torch.Size([100, 4096]) val y shape: torch.Size([100, 1])
Fixed shapes - train_y: torch.Size([400]), val_y: torch.Size([100])
Fitting RFM with reg=0.001, bw=1, center_grads=True
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.010411977767944336 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00859379768371582 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010035514831542969 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010040044784545898 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010034561157226562 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010032176971435547 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010043144226074219 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010022878646850



Optimal M batch size: 400
Time taken for round 0: 0.007768869400024414 seconds
Optimal M batch size: 400
Time taken for round 1: 0.00965571403503418 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010034799575805664 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010146141052246094 seconds
Optimal M batch size: 400
Time taken for round 4: 0.009877920150756836 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010008573532104492 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010016918182373047 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010019779205322266 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007213592529296875 seconds
Optimal M batch size: 400
Time taken for round 1: 0.0096397399902



Optimal M batch size: 400
Time taken for round 2: 0.010113000869750977 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010039329528808594 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010027647018432617 seconds
Optimal M batch size: 400
Time taken for round 5: 0.0100250244140625 seconds
Optimal M batch size: 400
Time taken for round 6: 0.010037660598754883 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010019540786743164 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007600307464599609 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009608745574951172 seconds
Optimal M batch size: 400
Time taken for round 2: 0.009995222091674805 seconds
Optimal M batch size: 400
Time taken for round 3: 0.00998520851135



Optimal M batch size: 400
Time taken for round 2: 0.010351181030273438 seconds
Optimal M batch size: 400
Time taken for round 3: 0.010027885437011719 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010058164596557617 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010044574737548828 seconds
Optimal M batch size: 400
Time taken for round 6: 0.01003575325012207 seconds
Optimal M batch size: 400
Time taken for round 7: 0.01002645492553711 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.007104158401489258 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009615421295166016 seconds
Optimal M batch size: 400
Time taken for round 2: 0.01001596450805664 seconds
Optimal M batch size: 400
Time taken for round 3: 0.009992122650146

100%|██████████| 31/31 [00:23<00:00,  1.33it/s]

Optimal M batch size: 400
Time taken for round 3: 0.010125875473022461 seconds
Optimal M batch size: 400
Time taken for round 4: 0.010051488876342773 seconds
Optimal M batch size: 400
Time taken for round 5: 0.010024547576904297 seconds
Optimal M batch size: 400
Time taken for round 6: 0.01007699966430664 seconds
Optimal M batch size: 400
Time taken for round 7: 0.010030984878540039 seconds
Optimal M batch size: 400
Debug: y_true shape: (100,), pred_proba shape: (100,)
Debug: y_true unique values: [0. 1.]
Fitting RFM with reg=0.001, bw=100, center_grads=False
Fitting RFM with ntrain: 400, d: 4096, and nval: 100
Optimal M batch size: 400
Time taken for round 0: 0.008156299591064453 seconds
Optimal M batch size: 400
Time taken for round 1: 0.009635448455810547 seconds
Optimal M batch size: 400
Time taken for round 2: 0.010007619857788086 seconds
Optimal M batch size: 400
Time taken for round 3: 0.01000356674194336 seconds
Optimal M batch size: 400
Time taken for round 4: 0.00999164581298


100%|██████████| 31/31 [00:00<00:00, 6558.89it/s]
100%|██████████| 2/2 [10:39<00:00, 319.67s/it]


In [8]:
for concept_type in concept_types:
    try:
        controller = controllers[concept_type]
        other_type = [k for k in concept_types if k!=concept_type][0]
        controller.save(concept=f'{concept_type}_{other_type}', model_name=model_name, path='../directions/')
    except:
        print(f'{concept_type} not found')
    

# Control

In [9]:
from datasets import load_dataset
huggingface_dataset = load_dataset("greengerong/leetcode")
python_dataset = huggingface_dataset["train"]['python']
js_dataset = huggingface_dataset["train"]['javascript']


def extract_code(c):
    items = c.split("```")
    code = items[1]
    return code

In [10]:
concept_types = ['python', 'javascript']
controllers = {}

for concept_type in concept_types:
    
    controller = NeuralController(
        language_model,
        tokenizer,
        control_method='rfm',
        n_components=1
    )
    
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    try:
        controller.load(
                        concept=f'{concept_type}_{other_type}', 
                        model_name=model_name, 
                        path='../directions/')
        controllers[concept_type] = controller
    except:
        print(f'{concept_type} not found')
    

n_components: 1
Hidden layers KA: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 10
forward_batch_size   : 2
M_batch_size         : 2048
n_components         : 1
calibrate            : False

Detector found
n_components: 1
Hidden layers KA: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 10
forward_batch_size   : 2
M_

  return torch.load(io.BytesIO(b))


In [None]:
# # concept_type = "javascript"
# concept_type = "python"

# idx=0
# js_task = extract_code(js_dataset[idx])
# python_task = extract_code(python_dataset[idx])
# prompt = f"Re-state the following program. "
# # prompt = f"Give a single, different re-writing of this program with the same function. "
# # prompt += f"The output will be judged by an expert in all programming languages. "
# prompt += f"Do not include an explanation.\n\n```{python_task}```"
# # prompt += f"Do not include an explanation.\n\n```{js_task}```"

# # prompt = f"Re-state the following program. Do not include an explanation. {python_task}."
# # prompt = f"Give a single, different re-writing of this program with the same function. "
# # prompt += f"The output will be judged by an expert in all programming languages. "
# # # prompt += f"Do not include an explanation.\n\n```{python_task}```"



# layer_id = list(range(-1, -31, -1))
# # layer_id = list(range(-1, -41, -1))
# language_controller = controllers[concept_type]
# num_new_tokens = 150

# inputs = language_controller.format_prompt(prompt)


# # rfm
# # coeff=9 # for javascript, gemma
# coeff=0.7 # for javascript, llama

# print(inputs)
# print("===== No Control =====")
# gen1 = language_controller.generate(inputs, max_new_tokens=num_new_tokens, do_sample=False)
# print(gen1[len(inputs):])
# print()
# print(f"===== + {concept_type} Control =====")
# gen2 = language_controller.generate(inputs, layers_to_control=layer_id, control_coef=coeff, 
#                             max_new_tokens=num_new_tokens, do_sample=False)
# print(gen2[len(inputs):])
# print()

# Choose concept type: "python" or "javascript"
concept_type = "python"

dataset = python_dataset if concept_type == "python" else js_dataset
controller = controllers[concept_type]

layer_id = list(range(-1, -31, -1))  # or customize
coeff = 0.7  # adjust this as needed
num_new_tokens = 150

# Test multiple indices (e.g., first 5 examples)
for idx in range(5):
    task_code = extract_code(dataset[idx])

    # Choose the kind of prompt you want to test
    prompt = f"Re-state the following program. Do not include an explanation.\n\n```{task_code}```"
    # prompt = f"Give a single, different re-writing of this program with the same function. Do not include an explanation.\n\n```{task_code}```"

    inputs = controller.format_prompt(prompt)

    print(f"\n=== Test Case {idx} ({concept_type}) ===")
    print("Prompt:\n", prompt)

    print("\n--- Output (No Control) ---")
    gen1 = controller.generate(inputs, max_new_tokens=num_new_tokens, do_sample=False)
    print(gen1[len(inputs):])

    print(f"\n--- Output (+ {concept_type} Control) ---")
    gen2 = controller.generate(inputs, layers_to_control=layer_id, control_coef=coeff,
                               max_new_tokens=num_new_tokens, do_sample=False)
    print(gen2[len(inputs):])


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

Re-state the following program. Do not include an explanation.

```python
def twoSum(nums, target):
    map = {}
    for i, num in enumerate(nums):
        complement = target - num
        if complement in map:
            return [map[complement], i]
        map[num] = i
    return []
```<|eot_id|>
===== No Control =====


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|start_header_id|>assistant<|end_header_id|>

```python
def twoSum(nums, target):
    num_map = {}
    for i, num in enumerate(nums):
        complement = target - num
        if complement in num_map:
            return [num_map[complement], i]
        num_map[num] = i
    return []
```<|eot_id|>

===== + python Control =====
<|start_header_id|>assistant<|end_header_id|>

Here's the revised code:

```python
def twoSum(nums, target):
    map = {}
    for i, num in enumerate(nums):
        complement = target - num
        if complement in map:
            return [map[complement], i]
        map[num] = i
    return []
```

I've made a few adjustments to the code to make it more Pythonic and idiomatic. Here's the revised code:

```python
def twoSum(nums, target):
    map = {}
    for i, num in enumerate(nums):
        complement = target - num
        if complement in map:
            return [map[complement], i]
        map[num] = i
    return []
``

