In [None]:
# 1.1: Activate env (run in terminal)
# conda activate xai_proj


# 1.2: imports
import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision
print('torch', torch.__version__)


# 1.3: constants
PROJECT_ROOT = Path.cwd().resolve()
DATA_DIR = PROJECT_ROOT / 'data' / 'processed'
RESULTS_DIR = PROJECT_ROOT / 'results'
CHECKPOINT = PROJECT_ROOT / 'checkpoints' / 'best_model.pth'

In [None]:
# 2.1 class counts
from collections import Counter
for split in ['train','val','test']:
p = DATA_DIR / split
classes = [d.name for d in p.iterdir() if d.is_dir()]
counts = {c: len(list((p/c).glob('*.*'))) for c in classes}
print(split, counts)


# 2.2 show sample images
from PIL import Image
fig, axes = plt.subplots(2,4, figsize=(12,6))
imgs = []
for i,c in enumerate(sorted((DATA_DIR/'train').iterdir())):
if i>=4: break
sample = next((DATA_DIR/'train'/c.name).glob('*.*'))
img = Image.open(sample).convert('RGB')
axes[0,i].imshow(img)
axes[0,i].axis('off')
axes[0,i].set_title(c.name)


# show more random examples

In [None]:
# 3.1 load training logs if any (TensorBoard logs or manual CSV)
# If you logged to CSV, load and plot
log_csv = PROJECT_ROOT / 'results' / 'training_log.csv'
if log_csv.exists():
df = pd.read_csv(log_csv)
fig, ax = plt.subplots(1,2, figsize=(12,4))
sns.lineplot(data=df, x='epoch', y='train_loss', ax=ax[0], label='train')
sns.lineplot(data=df, x='epoch', y='val_loss', ax=ax[0], label='val')
sns.lineplot(data=df, x='epoch', y='train_acc', ax=ax[1], label='train')
sns.lineplot(data=df, x='epoch', y='val_acc', ax=ax[1], label='val')


# 3.2 model summary (load model and print)
import importlib.util
spec = importlib.util.spec_from_file_location('model_mod', PROJECT_ROOT/'src'/'model.py')
model_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_mod)
model = model_mod.get_resnet18(num_classes=2, pretrained=False)
print(model)

In [None]:
pred_csv = RESULTS_DIR / 'predictions.csv'
if pred_csv.exists():
dfp = pd.read_csv(pred_csv)
display(dfp.head())


from PIL import Image
cm_path = RESULTS_DIR / 'confusion_matrix.png'
if cm_path.exists():
display(Image.open(cm_path))

In [None]:
xai_dir = RESULTS_DIR/'xai'
combined = sorted([p for p in xai_dir.glob('*_combined.png')])[:12]
fig, axes = plt.subplots(4,3, figsize=(12,14))
for ax, p in zip(axes.flatten(), combined):
img = Image.open(p)
ax.imshow(img)
ax.axis('off')
ax.set_title(p.name)
plt.tight_layout()

In [None]:
# uses numpy to open attribution arrays if you saved them, or recomputes on a few images
# Pseudocode: load attributions, compute center mask, compute mean_in/mean_out