In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd "/Users/arasvalizadeh/Desktop/GRIFFIN"

In [None]:
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns 
import numpy as np
import torch.nn as nn
from model.griffin import GRIFFIN, SklearnGRIFFINrapper
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR

In [4]:
# config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_size = 105
batch_size = 16
random_state = 199

lr = 0.01
max_lr = 0.1
epochs = 150

alpha = 1.0
min_alpha = 0.0

learning_params = dict(train_size=train_size,
                       batch_size=batch_size,
                       random_state=random_state,
                       lr=lr,
                       max_lr=max_lr,
                       epochs=epochs,
                       alpha=alpha,
                       min_alpha=min_alpha)

In [None]:
from tests.uci_test import Wine, Haberman, Cryotheraphy, Heart, Thyroid, Autism, Immunotherapy, Iris, Glass, Segmentaition, CaliforniaHousing, Abalone
from tests.class_test import Digits, Segmentation, Digits_UCI, Diabetes, BCW, DNA, Smoke, MNIST, ORL
from torch.utils.data import DataLoader


experiment = Diabetes(session_id=random_state)
train_dataset = experiment.train_dataset(device)
test_dataset = experiment.test_dataset(device)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

learning_params['train_size'] = experiment.train_numpy()[0].shape[0]
regression = experiment.target_type == 'regression'
regression

In [None]:
sns.histplot(experiment.train_numpy()[1])
sns.histplot(experiment.test_numpy()[1])

binary = len(np.unique(experiment.train_numpy()[1])) == 2
binary

In [None]:
from train.early_stop import EarlyStopping

model_params = {
    'in_features': experiment.train_numpy()[0].shape[1],
    'out_features': len(np.unique(experiment.train_numpy()[1])) if not regression else 1,
    'rules': 2,
    'regression': regression,
    'rank': 2,
   'binary':binary,
   'zeta':1,
   'Xi':1,
   'eta':1
}



model = GRIFFIN(**model_params, dtype=torch.float32)
model._init(*experiment.train_numpy())
sk_model = SklearnGRIFFINrapper(model, device=device)

optimizer = optim.Adam(model.parameters(), lr=lr)

steps_per_epoch=len(train_loader)
scheduler = OneCycleLR(optimizer, max_lr=max_lr,
                       steps_per_epoch=steps_per_epoch, epochs=epochs)

# Early stopping    
early_stopping = EarlyStopping(patience=30, delta=-0.00001)

alpha_decaying = np.power(min_alpha / alpha, (steps_per_epoch * epochs/2))

if binary:
    cross = nn.BCEWithLogitsLoss()
    
    def criterion(outputs,batch_y):
        return cross(outputs.squeeze(), batch_y) 

else:

    if regression:
        cross = nn.MSELoss()

        def criterion(outputs,batch_y):
            return cross(outputs.squeeze(), batch_y) 

    else:
        cross = nn.CrossEntropyLoss()

        def criterion(outputs,batch_y):
            return cross(outputs, batch_y.long())    


In [None]:
for epoch in range(epochs):
    
    model.train()
    train_loss = 0
    
    for batch_X, batch_y in train_loader:    
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs,batch_y)

        total_loss = loss

        total_loss.backward()
        optimizer.step()
        # scheduler.step()
        
        train_loss += loss.item() * batch_X.size(0)
        
    
    train_loss /= len(train_loader.dataset)
    model.eval()
    val_loss = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            
            if binary:
                loss = criterion(output, target)
            else:
                loss = criterion(output, target)

            val_loss += loss.item() * data.size(0)

    val_loss /= len(test_loader.dataset)
    print(
        f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

early_stopping.load_best_model(model)    

In [None]:
%matplotlib inline  
experiment.evaluate_model(sk_model)

In [None]:
from sklearn.metrics import (
    classification_report, roc_auc_score, mean_squared_error, r2_score
)
import numpy as np
import pprint

# Helper function to make batched predictions
def predict_in_batches(predict_fn, X, batch_size=1024):
    outputs = []
    for i in range(0, len(X), batch_size):
        batch = X[i:i+batch_size]
        outputs.append(predict_fn(batch))
    return np.concatenate(outputs)

X_train, y_train = experiment.train_numpy()
X_test, y_test = experiment.test_numpy()

result = f'Experiment Result {experiment.__class__.__name__}\n'

if not regression:
    train_preds_proba = predict_in_batches(sk_model.predict_proba, X_train)
    test_preds_proba = predict_in_batches(sk_model.predict_proba, X_test)

    train_preds = predict_in_batches(sk_model.predict,X_train)
    test_preds = predict_in_batches(sk_model.predict,X_test)
    # AUC
    if binary:
        train_auc = roc_auc_score(y_train, train_preds_proba)
        test_auc = roc_auc_score(y_test, test_preds_proba)
    else:
        train_auc = roc_auc_score(y_train, train_preds_proba, multi_class='ovr')
        test_auc = roc_auc_score(y_test, test_preds_proba, multi_class='ovr')

    result += '------ train ------\n'
    result += classification_report(y_train, train_preds, digits=4)
    result += f'AUC\t{train_auc}\n'
    result += '------ test ------\n'
    result += classification_report(y_test, test_preds, digits=4)
    result += f'AUC\t{test_auc}\n'

else :

    train_preds = predict_in_batches(sk_model.predict, X_train)
    test_preds = predict_in_batches(sk_model.predict, X_test)
    train_mse = mean_squared_error(y_train, train_preds)
    test_mse = mean_squared_error(y_test, test_preds)

    train_r2 = r2_score(y_train, train_preds)
    test_r2 = r2_score(y_test, test_preds)

    result += '------ train ------\n'
    result += f'MSE\t{train_mse:.4f}\nR2\t{train_r2:.4f}\n'
    result += '------ test ------\n'
    result += f'MSE\t{test_mse:.4f}\nR2\t{test_r2:.4f}\n'

result += f'{pprint.pformat(model_params)}\n'
result += pprint.pformat(learning_params)

print(result)

In [None]:
with open(f'results/{experiment.__class__.__name__}_rules:{model_params["rules"]}_rank:{model_params["rank"]}-eta:{model_params["eta"]}-zeta:{model_params["zeta"]}-Xi:{model_params["Xi"]}.txt', mode='w') as f:
    f.write(result)