In [1]:
import numpy as np
import pandas as pd
import torch

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
# from tqdm import tqdm
import os 

import pdb

from trainer import *
from utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def log(report, result, **kwargs):
    file_path = setup["file_path"]
    results = torch.load(file_path)
    results[report] = result
    torch.save(results, file_path)

In [3]:
def choose_kernel(IKM, iter_number, report, **kwargs):
    for kernel in setup[f"kernel {iter_number}"]:
        if kernel == "rbf":
            new_report = report + "-> rbf"
            pass
        elif kernel == "linear":
            new_report = report + "-> linear"
            pass
        if iter_number == 1:
            IKM.set_Ds()
            IKM.make_kernel_matrices(0, kernel, **kwargs)
            cross_validate(IKM, iter_number, new_report, **kwargs)
        else:
            choose_g(IKM, iter_number, new_report, kernel, **kwargs)

In [4]:
def choose_g(IKM, iter_number, report, kernel, **kwargs):
    yhat, preds = kwargs.get("yhat"), kwargs.get("preds")
    for g in setup[f"g {iter_number}"]:
        if g == "identity":
            new_report = report + "-> g identity"
        elif g == "normalize":
            new_report = report + "-> g normalize"
            ## Normalize the predictions
            yhat, preds = normalize_preds(yhat, preds)
        IKM.set_Ds(yhat=yhat, preds=preds)
        IKM.make_kernel_matrices(ind=iter_number, kernel=kernel)
        mix_kernels(IKM, iter_number, new_report, **kwargs)

In [5]:
def mix_kernels(IKM, iter_number, report, **kwargs):
    if iter_number > 1:
        for weights in setup[f"mixing {iter_number}"]:
            new_report = report + f"-> mix ({', '.join(list(map(lambda x: str(x), weights)))})"
            IKM.combine_kernels(weights)
            cross_validate(IKM, iter_number, new_report, **kwargs)

In [6]:
def cross_validate(IKM, iter_number, report, **kwargs):
    avg_diag = IKM.avg_diag_of_kernel()
    for log_reg_ratio in setup[f"log_regs {iter_number}"]:
        reg = avg_diag * (10**log_reg_ratio)
        new_report = report + f"-> reg {reg:.3f}"
        # if new_report not in logs:
        print(new_report)
        yhat, preds, res = IKM.solve(reg)
        print(f"______________________________________")
        log(new_report, res, **kwargs)
        kwargs["yhat"], kwargs["preds"] = yhat, preds
        perform_iteration(IKM, iter_number+1, new_report, **kwargs)

In [7]:
def perform_iteration(IKM, iter_number, report, **kwargs):
    max_iter = setup["max iterations"]
    if iter_number <= max_iter:
        print(f"------- level: {iter_number} -------")
        new_report = report + f" /{iter_number}: "
        choose_kernel(IKM, iter_number, new_report, **kwargs)

In [None]:
dataset_name = "CIFAR2"
max_iter = 2
for labels in [[0, 3], [3, 5]]: 
    directory = f"./results/{dataset_name}/{'_'.join(list(map(lambda x: str(x), labels)))}/"
    if not os.path.exists(directory):
            os.makedirs(directory)
    for ratio in [1.0, 0.75, 0.50, 0.25]: #
        file_path = directory + f"level_{str(max_iter)}_{str(ratio)}"
        if not os.path.exists(file_path):
            torch.save({}, file_path)
        logs = torch.load(file_path)
        setup = {"max iterations": max_iter,
         "file_path": file_path,
         "kernel 1": ["rbf", "linear"],
         "kernel 2": ["rbf", "linear"],
         "g 2": ["identity", "normalize"],
         "mixing 2": [[0.75, 0.25], [0.50, 0.50], [0.25, 0.75], [0.00, 1.00]], # 
         "log_regs 1": [ -1, 0, 0.5, 1, 1.5, 2], # 
         "log_regs 2": [ -1, 0, 0.5, 1, 1.5], # , -0.5
         }
        datasets = load_dataset(dataset_name, ratio=ratio, labels=labels)
        IKM = IterativeKernelModel(dataset_name, datasets=datasets)
        perform_iteration(IKM, 1, "")
        

Train samples: 10000, Test samples: 2000
------- level: 1 -------
 /1: -> rbf-> reg 0.100
iteration took 0.529694 seconds
Training Error is 0.016105
Test Error is 0.425347
Training Accuracy is 0.999800
Test Accuracy is 0.865500
______________________________________
------- level: 2 -------
 /1: -> rbf-> reg 0.100 /2: -> rbf-> g identity-> mix (0.75, 0.25)-> reg 0.100
iteration took 0.446224 seconds
Training Error is 0.000381
Test Error is 0.424929
Training Accuracy is 1.000000
Test Accuracy is 0.863500
______________________________________
 /1: -> rbf-> reg 0.100 /2: -> rbf-> g identity-> mix (0.75, 0.25)-> reg 1.000
iteration took 0.218260 seconds
Training Error is 0.002641
Test Error is 0.425143
Training Accuracy is 0.999900
Test Accuracy is 0.865000
______________________________________
 /1: -> rbf-> reg 0.100 /2: -> rbf-> g identity-> mix (0.75, 0.25)-> reg 3.162
iteration took 0.133069 seconds
Training Error is 0.005147
Test Error is 0.425408
Training Accuracy is 0.999900
Test 

In [28]:
len(logs.keys())

739

In [29]:
level_1s = [x for x in logs.keys() if "/2" not in x]

In [30]:
for key in level_1s:
    print(key, logs[key])

 /1: -> rbf-> reg 0.010 {'reg': 0.01, 'Train error': 0.0003135048707283218, 'Test error': 0.4287952187473711, 'Train accuracy': 1.0, 'Test accuracy': 0.861}
 /1: -> rbf-> reg 0.100 {'reg': 0.1, 'Train error': 2.7505899787207967e-05, 'Test error': 0.42798189417299, 'Train accuracy': 1.0, 'Test accuracy': 0.861}
 /1: -> rbf-> reg 0.316 {'reg': 0.31622776601683794, 'Train error': 5.68641856514919e-05, 'Test error': 0.4282334911964135, 'Train accuracy': 1.0, 'Test accuracy': 0.861}
 /1: -> rbf-> reg 1.000 {'reg': 1.0, 'Train error': 0.00010918876344012461, 'Test error': 0.42865056383744377, 'Train accuracy': 1.0, 'Test accuracy': 0.861}
 /1: -> rbf-> reg 3.162 {'reg': 3.1622776601683795, 'Train error': 0.0002132080156087696, 'Test error': 0.4292700705183397, 'Train accuracy': 1.0, 'Test accuracy': 0.861}
 /1: -> rbf-> reg 10.000 {'reg': 10.0, 'Train error': 0.0005799296455443312, 'Test error': 0.43017116104748127, 'Train accuracy': 1.0, 'Test accuracy': 0.861}
 /1: -> rbf-> reg 31.623 {'re