In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os 
os.chdir('/home/mahdi/Aria/T7/LitAnfis')

In [None]:
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns 
import numpy as np
import torch.nn as nn
from model.litanfis import LitAnfis, SklearnLitAnfisWrapper
from model.unfis import UNFIS
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR

In [None]:
# config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_size = 105
batch_size = 32
random_state = 199

lr = 0.005
max_lr = 0.01
epochs = 100

alpha = 1.0
min_alpha = 0.0

learning_params = dict(train_size=train_size,
                       batch_size=batch_size,
                       random_state=random_state,
                       lr=lr,
                       max_lr=max_lr,
                       epochs=epochs,
                       alpha=alpha,
                       min_alpha=min_alpha)

In [None]:
from tests.uci_test import Wine, Haberman, Cryotheraphy, Heart, Thyroid, Autism, Immunotherapy, Iris, Glass, Segmentaition
from tests.class_test import Digits, Segmentation, Digits_UCI, Diabetes, BCW, DNA, Smoke
from torch.utils.data import DataLoader


experiment = Wine(session_id=random_state)
train_dataset = experiment.train_dataset(device)
test_dataset = experiment.test_dataset(device)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

learning_params['train_size'] = experiment.train_numpy()[0].shape[0]

In [None]:
sns.histplot(experiment.train_numpy()[1])
sns.histplot(experiment.test_numpy()[1])

binary = len(np.unique(experiment.train_numpy()[1])) == 2
binary

In [None]:
from train.early_stop import EarlyStopping

model_params = {
    'in_features': experiment.train_numpy()[0].shape[1],
    'out_features': len(np.unique(experiment.train_numpy()[1])),
    'rules': 6,
    'drop_out_p': 0.3,
    'binary' : binary     
}

model = LitAnfis(**model_params, dtype=torch.float32)
sk_model = SklearnLitAnfisWrapper(model, device=device)


optimizer = optim.Adam(model.parameters(), lr=lr)

steps_per_epoch=len(train_loader)
scheduler = OneCycleLR(optimizer, max_lr=max_lr,
                       steps_per_epoch=steps_per_epoch, epochs=epochs)

# Early stopping    
early_stopping = EarlyStopping(patience=10, delta=-0.00001)

alpha_decaying = np.power(min_alpha / alpha, (steps_per_epoch * epochs/2))

cos = nn.L1Loss()

if binary:
    cross = nn.BCEWithLogitsLoss()
    
    def criterion(batch_X, batch_y, outputs, reconstructed, alpha):
        return cross(outputs.squeeze(), batch_y.squeeze()) + cos(reconstructed, batch_X) * alpha 

else:
    cross = nn.CrossEntropyLoss()

    def criterion(batch_X, batch_y, outputs, reconstructed, alpha):
        return cross(outputs, batch_y.long()) + cos(reconstructed, batch_X) * alpha 

In [None]:
for epoch in range(epochs):
    
    model.train()
    train_loss = 0
    
    for batch_X, batch_y in train_loader:    

        optimizer.zero_grad()
        
        outputs, reconstructed, entropy = model(batch_X)

        loss = criterion(batch_X, batch_y, outputs, reconstructed, alpha) 

        total_loss = loss

        total_loss.backward()
        optimizer.step()
        scheduler.step()
        
        alpha *= alpha_decaying
        
        train_loss += loss.item() * batch_X.size(0)

    train_loss /= len(train_loader.dataset)
        
    model.eval()
    val_loss = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            output, _, _ = model(data)
            
            if binary:
                loss = cross(output.squeeze(), target.squeeze())
            else:
                loss = cross(output, target.long())

            val_loss += loss.item() * data.size(0)

    val_loss /= len(test_loader.dataset)

    print(
        f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    early_stopping(val_loss, model)
    # if early_stopping.early_stop:
    #     print("Early stopping")
    #     break
    
early_stopping.load_best_model(model)

In [None]:
%matplotlib inline  
experiment.evaluate_model(sk_model)

In [None]:
from sklearn.metrics import classification_report, roc_auc_score
import pprint 

if binary:
    train_auc = roc_auc_score(experiment.train_numpy()[1], sk_model.predict_proba(
        experiment.train_numpy()[0]))
    test_auc = roc_auc_score(experiment.test_numpy()[1], sk_model.predict_proba(
        experiment.test_numpy()[0]))
else:
    train_auc = roc_auc_score(experiment.train_numpy()[1], sk_model.predict_proba(
        experiment.train_numpy()[0]), multi_class='ovr')
    test_auc = roc_auc_score(experiment.test_numpy()[1], sk_model.predict_proba(
        experiment.test_numpy()[0]), multi_class='ovr')
result = ""

result += f'Experiment Result {experiment.__class__.__name__}\n'
result += '------ train ------\n'
result += classification_report(experiment.train_numpy()[1], sk_model.predict(experiment.train_numpy()[0]), digits=4)
result += f'AUC\t{train_auc}\n'
result += '------ test ------\n'
result += classification_report(experiment.test_numpy()[1], sk_model.predict(experiment.test_numpy()[0]), digits=4)
result += f'AUC\t{test_auc}\n'

result += f'{pprint.pformat(model_params)}\n'
result += pprint.pformat(learning_params)

print(result)

In [None]:
with open(f'results/{experiment.__class__.__name__}_{model_params["rules"]}.txt', mode='w') as f:
    f.write(result)

In [None]:
from sklearn.metrics import roc_curve, auc, RocCurveDisplay

plt.figure(figsize=(8, 8), dpi=200)


fpr, tpr, thresholds = roc_curve(experiment.test_numpy()[1], sk_model.predict_proba(
        experiment.test_numpy()[0]))
roc_auc = auc(fpr, tpr)
display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
                                  estimator_name='example estimator')
display.plot()

# Zoom in: upper-left corner
plt.xlim([0.0, 0.025])  # false positive rate
plt.ylim([0.99, 1.0])  # true positive rate

# Optional: show grid or labels
plt.grid(True)