In [None]:
# Imports
import torch
import numpy as np
import matplotlib.pyplot as plt
from model.lenet import load_lenet
from dataset.dataset import load_data, create_dataloaders, visualize_5_samples
from functions.optimizer import load_optimizer
from functions.loss import load_loss_fun
from functions.functions import train_model, eval_model, save_checkpoint, load_checkpoint, visualize_5_sample_dynamics
from functions.xai import explain_dataset, visualize_explanation, evaluate_explainations, visualize_k_expl
from functions.xil import xil_loop, compute_simplicity, random_sampling, simplicity_sampling
from utils.utils import enable_reproducibility
from scipy.stats import pearsonr


In [None]:
# Setup reproducibility and device
SEED=123
enable_reproducibility(SEED)
use_cuda = torch.cuda.is_available()
device = 'cuda' if use_cuda else 'cpu'

In [None]:
# Parameters
STEP1 = True
STEP2 = True
STEP3 = True
STEP4 = True
STEP5 = True

In [None]:
def step1(dataset: str = "DecoyMNIST"):
  model = load_lenet(device)
  optim = load_optimizer("SGD", model.parameters(), lr=1.0e-2, weight_decay=0)
  loss = load_loss_fun("CrossEntropy")
  train_set, val_set, test_set = load_data(
    dataset, 
    seed=SEED, 
    reload=True,
    bias_ratio=[1]*10
  )
  
  data = [train_set, val_set, test_set]
  params = {"batch_size":32}
  m_params = [params]*3
  train_loader, val_loader, test_loader = create_dataloaders(data, m_params)

  _, _ = train_model(
    model, 
    train_loader, 
    optim, 
    loss, 
    n_epochs=10, 
    eval_loader=val_loader, 
    device=device
  )
  loss, acc = eval_model(model, test_loader, loss,  device)
  print("="*20,f"Test set Loss:{loss:.2f} | Acc:{acc:.2f}.","="*20)

  all_attr, all_imgs = explain_dataset(train_loader, model, device)
  exp_err = evaluate_explainations(all_attr, torch.from_numpy(train_set.masks))
  print("="*20,f"Train explaination error {exp_err:.2}","="*20)

  return train_set, test_set, all_attr, all_imgs

In [None]:
if STEP1:
  train_set, test_set, all_attr, all_imgs = step1("DecoyMNIST")
  visualize_5_samples(train_set, 0)
  visualize_k_expl(all_attr, all_imgs, train_set, 0)


In [None]:
def step2(dataset: str = "DecoyMNIST"):
  model = load_lenet(device)
  optim = load_optimizer("SGD", model.parameters(), lr=1.0e-2, weight_decay=0)
  rrr_reg = 1 if dataset == "DecoyMNIST" else 1e-2
  loss = load_loss_fun("RRR", reg_rate=rrr_reg)
  train_set, val_set, test_set = load_data(
    dataset, 
    seed=SEED, 
    reload=True,
    bias_ratio=[1]*10
  )
  
  data = [train_set, val_set, test_set]
  params = {"batch_size":32}
  m_params = [params]*3
  train_loader, val_loader, test_loader = create_dataloaders(data, m_params)

  _, _ = train_model(
    model, 
    train_loader, 
    optim, 
    loss, 
    n_epochs=10, 
    eval_loader=val_loader, 
    device=device
  )
  ce_loss = load_loss_fun("CrossEntropy")
  loss, acc = eval_model(model, test_loader, ce_loss,  device)
  print("="*20,f"Test set Loss:{loss:.2f} | Acc:{acc:.2f}.","="*20)

  all_attr, all_imgs = explain_dataset(train_loader, model, device)
  exp_err = evaluate_explainations(all_attr, torch.from_numpy(train_set.masks))
  print("="*20,f"Train explaination error {exp_err:.2}","="*20)

  return train_set, test_set, all_attr, all_imgs

In [None]:
if STEP2:
  train_set, test_set, all_attr, all_imgs = step2("DecoyFashionMNIST")
  visualize_5_samples(train_set, 0)
  visualize_k_expl(all_attr, all_imgs, train_set, 0)

In [None]:
def step3(dataset: str = "DecoyMNIST"):
  model = load_lenet(device)
  train_set, val_set, test_set = load_data(
    dataset, 
    seed=SEED, 
    reload=True,
    bias_ratio=[1]*10
  )
  
  data = [train_set, val_set, test_set]
  params = {"batch_size":32}
  m_params = [params]*3
  train_loader, val_loader, test_loader = create_dataloaders(data, m_params)

  query = xil_loop(
    train_set,
    model, 
    random_sampling,
    300,
    val_loader,
    test_loader,
    step_size=100,
    rrr_reg_rate=1 if dataset=="DecoyMNIST" else 1e-2,
    device=device
  )
  print(f"It took {query} iterations.")
  return train_set, test_set, all_attr, all_imgs

In [None]:
if STEP3:
  step3("DecoyMNIST")

In [None]:
def step4(dataset: str = "DecoyMNIST", metric: str= "MP"):
  model = load_lenet(device)
  optim = load_optimizer("SGD", model.parameters(), lr=1.0e-2, weight_decay=0)
  loss = load_loss_fun("CrossEntropy")
  train_set, val_set, test_set = load_data(
    dataset, 
    seed=SEED, 
    reload=True,
    bias_ratio=[0.95]*10
  )
  
  data = [train_set, val_set, test_set]
  params = {"batch_size":32}
  m_params = [params]*3
  train_loader, val_loader, test_loader = create_dataloaders(data, m_params)

  _, dyn = train_model(
    model, 
    train_loader, 
    optim, 
    loss, 
    n_epochs=10, 
    eval_loader=val_loader, 
    device=device
  )
  loss, acc = eval_model(model, test_loader, loss,  device)
  print("="*20,f"Test set Loss:{loss:.2f} | Acc:{acc:.2f}.","="*20)

  all_attr, all_imgs = explain_dataset(train_loader, model, device)
  exp_err = evaluate_explainations(all_attr, torch.from_numpy(train_set.masks))
  print("="*20,f"Train explaination error {exp_err:.2}","="*20)
  
  simplicity = compute_simplicity(dyn, metric=metric)
  scores = []
  is_confounded = []
  class_indeces = range(len(train_set))
  for id in class_indeces:
    index, _,_,mask = train_set[id]
    scores.append(simplicity[index])
    if mask.sum() > 1:
      is_confounded.append(1)
    else:
      is_confounded.append(0)
  
  print(f"All samples Correlation:{pearsonr(scores, is_confounded)}")


  # Class-wise correlation
  for label in range(10):
    class_indeces = np.where(train_set.y == label)[0]
    print("samples len", len(class_indeces))
    scores = []
    is_confounded = []
    for id in class_indeces:
      index, _,_,mask = train_set[id]
      scores.append(simplicity[index])
      if mask.sum() > 1:
        is_confounded.append(1)
      else:
        is_confounded.append(0)
    print(f"{label} Confounded samples:", sum(1 for x in is_confounded if x == 0))

    print(f"{label} Correlation:{pearsonr(scores, is_confounded)}")

  return train_set, test_set, dyn, all_attr, all_imgs

In [None]:
if STEP4:
  train_set, test_set, dyn, all_attr, all_imgs = step4("DecoyFashionMNIST", "EC")

In [None]:
def step5(dataset: str = "DecoyMNIST"):
  model = load_lenet(device)
  train_set, val_set, test_set = load_data(
    dataset, 
    seed=SEED, 
    reload=True,
    bias_ratio=[1]*10
  )
  
  data = [train_set, val_set, test_set]
  params = {"batch_size":32}
  m_params = [params]*3
  train_loader, val_loader, test_loader = create_dataloaders(data, m_params)

  optim = load_optimizer("SGD", model.parameters(), lr=1.0e-2, weight_decay=0)
  loss = load_loss_fun("CrossEntropy")
  _, dynamics = train_model(
    model, 
    train_loader, 
    optim, 
    loss, 
    n_epochs=10, 
    eval_loader=val_loader, 
    device=device
  )

  query = xil_loop(
    train_set,
    model, 
    simplicity_sampling,
    2400,
    val_loader,
    test_loader,
    tr_dynamics=dynamics,
    step_size=100,
    rrr_reg_rate=1 if dataset=="DecoyMNIST" else 1e-2,
    device=device
  )
  print(f"It took {query} iterations.")
  return train_set, test_set, all_attr, all_imgs

In [None]:
if STEP5:
  step5("DecoyMNIST")