In [None]:
import sys
import torch
import torch.nn.functional as F
from torch.utils.data import Subset
from torchvision.utils import save_image
import torchvision.models as tvmodels
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

sys.path.insert(0, "../src")
import exlib
from exlib.explainers import *
from exlib.modules import *
from exlib.datasets.abdomen_organs import *

In [None]:
DATA_DIR = "/home/antonxue/foo/data/abdomen_exlib/"
train_dataset = AbdomenOrgans(data_dir=DATA_DIR, split="train")
test_dataset = AbdomenOrgans(data_dir=DATA_DIR, split="test")

In [None]:
model = AbodmenSegmentationModel(in_channels=3, out_channels=7)
model.eval();

In [None]:
x, olabel, glabel = train_dataset[2]
y = model(x[None,...])

In [None]:
x.shape, y.shape, y.argmax(dim=1).shape, (y.argmax(dim=1) == 0)[0].shape

In [None]:
(y.argmax(dim=1) == 1).int()[0].sum()

In [None]:
plt.imshow((y.argmax(dim=1) == 6).int()[0].numpy(), cmap="gray")

In [None]:
plt.imshow((y.argmax(dim=1) == 2).int()[0].numpy(), cmap="gray")

In [None]:
plt.figure()
fig, ax = plt.subplots(2,2)
ax[0,0].imshow(x.numpy().transpose(1,2,0))
ax[0,1].imshow(40 * y[0].detach().numpy().transpose(1,2,0))
ax[1,0].imshow(30 * olabel.numpy().transpose(1,2,0))
ax[1,1].imshow(40 * glabel.numpy().transpose(1,2,0))

In [None]:
x, olabel, glabel = test_dataset[2]
y = model(x[None,...])

In [None]:
plt.figure()
fig, ax = plt.subplots(2,2)
ax[0,0].imshow(x.numpy().transpose(1,2,0))
ax[0,1].imshow(40 * y[0].detach().numpy().transpose(1,2,0))
ax[1,0].imshow(30 * olabel.numpy().transpose(1,2,0))
ax[1,1].imshow(40 * glabel.numpy().transpose(1,2,0))

In [None]:
cls_model = AbdomenClassificationModel(3, 10)

In [None]:
y = cls_model(x[None,...])

In [None]:
# Try an explainer

In [None]:
explainer = TorchImageSegIntGrad(model)

In [None]:
X = torch.stack([x,x])
intg = explainer(X, labels=[0,2], num_steps=8)

In [None]:
plt.figure()
fig, ax = plt.subplots(1,2)
ax[0].imshow(intg.attributions.detach()[0].numpy().transpose(1,2,0))
ax[1].imshow(intg.attributions.detach()[1].numpy().transpose(1,2,0))

In [None]:
# Use their models. 

In [None]:
import segmentation_models_pytorch as smp

In [None]:
# Load the model
smp_model = smp.Unet(
    encoder_name="resnet34",    # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,              # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                  # model output channels (number of classes in your dataset)
    activation='softmax2d'
)

state_dict_path = "/home/antonxue/foo/real_drs/saved_models/weiqiu_model.pt"
state_dict = torch.load(state_dict_path, map_location="cpu")
smp_model.load_state_dict(state_dict["model"])
smp_model.eval()
;

In [None]:
x, olabel, glabel = test_dataset[0]

x_go = x.clone()
x_go[2:,:,] = 1.0
x_go = x_go *  (glabel == 1)

x_nogo = x.clone()
x_nogo[0,:,:] = 1.0
x_nogo = x_nogo * (glabel == 2)

x_gonogo = (1 - (glabel == 1).int() - (glabel == 2).int()) * x + x_go + x_nogo

In [None]:
smp_explainer = TorchImageSegIntGrad(smp_model)

In [None]:
X = torch.stack([x,x,x])
smp_intg = smp_explainer(X, labels=[0,1,2], num_steps=16, progress_bar=True)

In [None]:
alphas = smp_intg.attributions.cpu().detach()
q = alphas.view(-1).quantile(0.95)
pas = (alphas.abs() > q) * 1.0

In [None]:
plt.figure()
fig, ax = plt.subplots(2,2)
ax[0,0].imshow(x.numpy().transpose(1,2,0))
ax[0,1].imshow(x_gonogo.numpy().transpose(1,2,0))

# ax[0,1].imshow(pas[0].numpy().transpose(1,2,0))
ax[1,0].imshow(pas[1].numpy().transpose(1,2,0))
ax[1,1].imshow(pas[2].numpy().transpose(1,2,0))

In [None]:
xfile = test_dataset.image_filenames[0]
xfile

In [None]:
dump_dir = "/home/antonxue/foo/test/dump/"
save_image(x, dump_dir + xfile)
save_image(x_gonogo, dump_dir + xfile + ".gonogo.png")
save_image(pas[0], dump_dir + xfile + ".intgrad.png")