In [None]:
#### comment after the first run
!git clone https://github.com/averysi224/abci.git

%cd abci
!pip install -e .
!pip install archetypes

In [None]:
import sys, os
sys.path.append(os.getcwd())

%cd puq

In [None]:
import os
import logging
import random

import torch
from types import SimpleNamespace
from torchvision.transforms import transforms as T
import matplotlib.pyplot as plt
import pandas as pd

from puq.core import DAPUQUncertaintyRegion
from puq.arch_sample import hvf_json
from puq.data.data import (
    DiffusionSamplesDataset,
    GroundTruthsDataset,
    DiffusionSamplesDataLoader,
    GroundTruthsDataLoader,
)
from puq.plotting.visual import plot_archetype_matrices
from puq.utils import misc


In [12]:
args = SimpleNamespace(
    method="da_puq",                           
    data="/data5/wenwens/UW_subgroups/moderate",
    test_ratio=0.2,                             
    seed=42,
    gpu=0,                                      
    batch=4,
    num_workers=0,
    no_cache=False,

    alpha=0.25,                                 
    beta=0.14,                                  
    q=0.9,                                      
    delta=0.1,                                  

    num_reconstruction_lambdas=17,
    num_coverage_lambdas=100,
    num_pcs_lambdas=20,
    max_coverage_lambda=800.0,

    archetypes=True,
)


In [13]:
if args.gpu is not None and torch.cuda.is_available():
    device = torch.device(f"cuda:{args.gpu}")
else:
    device = torch.device("cpu")

print("Using device:", device)


Using device: cuda:0


In [14]:
def plot_vf(sample):
    """
    Convert alg outputs to plotting format (9x9 VF grid).
    sample: 81-dim tensor (flattened 9x9), defined on hvf_json coordinates.
    """
    coordinates = [(entry["x"], entry["y"]) for entry in hvf_json]
    vf = [1.0 for _ in range(81)]
    vf[34] = 0.0
    vf[43] = 0.0
    for j in range(len(coordinates)):
        idx = coordinates[j][0] * 9 + coordinates[j][1]
        vf[idx] = sample[idx].item()
    return torch.tensor(vf)

In [15]:
misc.setup_logging(os.path.join(os.getcwd(), "log.txt"))
logging.info(args)
torch.manual_seed(args.seed)

if not os.path.exists("results"):
    os.makedirs("results")

puq = DAPUQUncertaintyRegion(args)

cal_samples_dataset = DiffusionSamplesDataset(
    opt=args,
    calibration=True,
    transform=T.Compose([
        T.Grayscale(num_output_channels=1),
        T.ToTensor(),
    ]),
)

cal_ground_truths_dataset = GroundTruthsDataset(
    opt=args,
    samples_dataset=cal_samples_dataset,
    transform=T.Compose([
        T.Grayscale(num_output_channels=1),
        T.ToTensor(),
    ]),
)

puq.calibration(cal_samples_dataset, cal_ground_truths_dataset)

test_samples_dataset = DiffusionSamplesDataset(
    opt=args,
    calibration=False,
    transform=T.Compose([
        T.Grayscale(num_output_channels=1),
        T.ToTensor(),
    ]),
)

test_ground_truths_dataset = GroundTruthsDataset(
    opt=args,
    samples_dataset=test_samples_dataset,
    transform=T.Compose([
        T.Grayscale(num_output_channels=1),
        T.ToTensor(),
    ]),
)

results = puq.eval(test_samples_dataset, test_ground_truths_dataset)
print("Eval results:", results)

namespace(alpha=0.25, archetypes=True, batch=4, beta=0.14, data='/data5/wenwens/UW_subgroups/moderate', delta=0.1, gpu=0, max_coverage_lambda=800.0, method='da_puq', no_cache=False, num_coverage_lambdas=100, num_pcs_lambdas=20, num_reconstruction_lambdas=17, num_workers=0, q=0.9, seed=42, test_ratio=0.2)
namespace(alpha=0.25, archetypes=True, batch=4, beta=0.14, data='/data5/wenwens/UW_subgroups/moderate', delta=0.1, gpu=0, max_coverage_lambda=800.0, method='da_puq', no_cache=False, num_coverage_lambdas=100, num_pcs_lambdas=20, num_reconstruction_lambdas=17, num_workers=0, q=0.9, seed=42, test_ratio=0.2)


Applying approximation phase...
Applying approximation phase...
100%|██████████| 129/129 [00:37<00:00,  3.46it/s]
Applying calibration phase...
Applying calibration phase...
Successfully calibrated: lambda1=0.125, lambda2=243.1212158203125
Successfully calibrated: lambda1=0.125, lambda2=243.1212158203125
Applying evaluation...
Applying evaluation...
100%|██████████| 33/33 [00:09<00:00,  3.33it/s]
{'coverage_risk': 0.15552585165609012, 'reconstruction_risk': 0.08954099825385844, 'interval_size': 0.02144144650435809, 'dimension': 1.5454545454545454, 'max_dimension': 200.0, 'uncertainty_volume': 4.330704385169921e-10}
{'coverage_risk': 0.15552585165609012, 'reconstruction_risk': 0.08954099825385844, 'interval_size': 0.02144144650435809, 'dimension': 1.5454545454545454, 'max_dimension': 200.0, 'uncertainty_volume': 4.330704385169921e-10}


Eval results: {'coverage_risk': 0.15552585165609012, 'reconstruction_risk': 0.08954099825385844, 'interval_size': 0.02144144650435809, 'dimension': 1.5454545454545454, 'max_dimension': 200.0, 'uncertainty_volume': 4.330704385169921e-10}


In [16]:
# visualization

dl = DiffusionSamplesDataLoader(
    test_samples_dataset,
    batch_size=args.batch,
    num_workers=args.num_workers,
)

gl = GroundTruthsDataLoader(
    test_ground_truths_dataset,
    batch_size=args.batch,
    num_workers=args.num_workers,
)

dl_iter = iter(dl)
gl_iter = iter(gl)

all_diff = 0.0
cnt = 0

# 读取 archetype 矩阵
archetypes = pd.read_csv("./plotting/at17_matrix.csv").values

num_batches = len(test_ground_truths_dataset) // args.batch
print("Num test batches:", num_batches)

for i_batch in range(num_batches):
    image_shape = [1, 9, 9]

    batch_samples = next(dl_iter).flatten(2)           # [B, C*H*W]
    batch_gt = next(gl_iter).flatten(2).squeeze()      # [B, C*H*W] or [B, H*W]
    batch_gt = [plot_vf(i.cpu()) for i in batch_gt]

    batch_mu, batch_pcs, batch_svs, batch_lower, batch_upper, batch_indices = puq.inference(batch_samples)

    batch_mu = [plot_vf(i.cpu()) for i in batch_mu]
    batch_pcs = [i.cpu() for i in batch_pcs]
    batch_svs = [i.cpu() for i in batch_svs]
    batch_lower = [i.cpu() for i in batch_lower]
    batch_upper = [i.cpu() for i in batch_upper]
    batch_indices = [i.cpu() for i in batch_indices]

    fig, axs = plt.subplots(args.batch, 8, figsize=(20, 10))

    for i in range(args.batch):
        axs[i, 0].imshow(
            batch_mu[i].view(image_shape).transpose(0, 1).transpose(1, 2),
            cmap="gray",
        )
        axs[i, 0].axis("off")

        lower_image = (batch_mu[i] + batch_pcs[i] @ batch_lower[i]).clamp_(0, 1)
        axs[i, 1].imshow(
            lower_image.view(image_shape).transpose(0, 1).transpose(1, 2),
            cmap="gray",
        )
        axs[i, 1].axis("off")

        upper_image = (batch_mu[i] + batch_pcs[i] @ batch_upper[i]).clamp_(0, 1)
        axs[i, 2].imshow(
            upper_image.view(image_shape).transpose(0, 1).transpose(1, 2),
            cmap="gray",
        )
        axs[i, 2].axis("off")

        # ground truth
        axs[i, 3].imshow(
            batch_gt[i].view(image_shape).transpose(0, 1).transpose(1, 2),
            cmap="gray",
        )
        axs[i, 3].axis("off")

        # interval
        diff = torch.abs(upper_image - lower_image).sum() / 52.0
        all_diff += diff
        cnt += 1

        # archetype 0 does not count as visual loss
        selected_pcs = batch_indices[i]
        selected_pcs = selected_pcs[selected_pcs != 0]

        # top-4 high-contributing archetypes
        for axis_i in range(4):
            if batch_pcs[i].shape[1] > axis_i and axis_i < len(selected_pcs):
                plot_archetype_matrices(
                    axs[i, axis_i + 4],
                    selected_pcs[axis_i],
                    archetypes,
                )
            axs[i, axis_i + 4].axis("off")

    cols = [
        "average\nprediction",
        "lower\nbound",
        "upper\nbound",
        "ground\ntruth",
        "",
        "",
        "main uncertainty components\n",
        "",
    ]
    for ax, col in zip(axs[0], cols):
        ax.set_title(col, fontsize=20, pad=20)

    dis = [0.93, 0.8, 0.6, 0.4, 0.2]
    texts = ["patient", "   1   ", "   2   ", "   3   ", "   4   "]
    for row_idx, text in enumerate(texts):
        fig.text(
            x=0.05,
            y=dis[row_idx],
            s=text,
            va="center",
            ha="left",
            fontsize=20,
        )

    stage = args.data.split("/")[-1]
    fig.suptitle("", fontsize=20, y=0.95)
    plt.savefig(f"results/{stage}_{args.alpha}_bounds_{i_batch}.png")
    plt.close()

interval_size = (all_diff / cnt).item()
print("interval size:", interval_size)


Num test batches: 32
interval size: 0.09185771644115448
