In [1]:
import torch
from torch.utils.data import DataLoader

# Load Plug and Play XAI Manager
from pnpxai.utils import set_seed
from pnpxai import Project

from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image


# -----------------------------------------------------------------------------#
# ----------------------------------- setup -----------------------------------#
# -----------------------------------------------------------------------------#

set_seed(seed=0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def input_visualizer(x): return denormalize_image(x, transform.mean, transform.std)
def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())

# -----------------------------------------------------------------------------#
# ----------------------------------- Project -----------------------------------#
# -----------------------------------------------------------------------------#

project = Project('Test Project 1')

# -----------------------------------------------------------------------------#
# ------------------------------------ Model1 -----------------------------------#
# -----------------------------------------------------------------------------#

model, transform = get_torchvision_model("resnet18")
model = model.to(device)
dataset = get_imagenet_dataset(transform, subset_size=25)
loader = DataLoader(dataset, batch_size=10)
def input_extractor(x): return x[0].to(device)
def target_extractor(x): return x[1].to(device)


experiment_resnet = project.create_auto_experiment(
    model,
    loader,
    name='Resnet Experiment',
    input_extractor=input_extractor,
    target_extractor=target_extractor,
    input_visualizer=input_visualizer,
    target_visualizer=target_visualizer,
)

# -----------------------------------------------------------------------------#
# ---------------------------------- Model2 ----------------------------------#
# -----------------------------------------------------------------------------#

model, transform = get_torchvision_model("vit_b_16")
model = model.to(device)
dataset = get_imagenet_dataset(transform, subset_size=25)
loader = DataLoader(dataset, batch_size=10)
def input_extractor(x): return x[0].to(device)
def target_extractor(x): return x[1].to(device)


experiment_vit = project.create_auto_experiment(
    model,
    loader,
    name='ViT Experiment',
    input_extractor=input_extractor,
    target_extractor=target_extractor,
    input_visualizer=input_visualizer,
    target_visualizer=target_visualizer,
)


# Add Project 2 For Testing
project2 = Project('Test Project 2')
model, transform = get_torchvision_model("vit_b_16")
model = model.to(device)

experiment_project2 = project2.create_auto_experiment(
    model,
    loader,
    name='ViT Experiment for Project 2',
    input_extractor=input_extractor,
    target_extractor=target_extractor,
    input_visualizer=input_visualizer,
    target_visualizer=target_visualizer,
)

app = project.get_server().serve_ipynb(host='0.0.0.0', port=5001)

  from .autonotebook import tqdm as notebook_tqdm
 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5001
 * Running on http://192.168.1.58:5001
[33mPress CTRL+C to quit[0m
127.0.0.1 - - [19/Feb/2024 15:06:43] "[36mGET / HTTP/1.1[0m" 304 -
127.0.0.1 - - [19/Feb/2024 15:06:43] "[36mGET /assets/index-b5be62ba.js HTTP/1.1[0m" 304 -
127.0.0.1 - - [19/Feb/2024 15:06:43] "[36mGET /assets/index-97f3c8fb.css HTTP/1.1[0m" 304 -
127.0.0.1 - - [19/Feb/2024 15:06:44] "GET /assets/index-b5be62ba.js HTTP/1.1" 200 -
127.0.0.1 - - [19/Feb/2024 15:06:44] "GET /api/projects/ HTTP/1.1" 200 -
127.0.0.1 - - [19/Feb/2024 15:06:44] "[36mGET /assets/XAI-Top-PnP-f03c9250.svg HTTP/1.1[0m" 304 -
127.0.0.1 - - [19/Feb/2024 15:06:44] "GET /api/projects/Test%20Project%201/models/ HTTP/1.1" 200 -
127.0.0.1 - - [19/Feb/2024 15:06:44] "[36mGET /assets/NotoSansKR-Regular-9db318b6.ttf HTTP/1.1[0m" 304 -
127.0.0.1 - - [19/Feb/2024 15:06:44] "[36mGET /assets/NotoSansKR-Medium-a89e7347.ttf HT

In [None]:
%tb

SystemExit: 1