In [1]:
# external packages
import os
import sys
from argparse import Namespace
from pathlib import Path

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import yaml
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

# own code
sys.path.append('../')
from make_and_evaluate_cams_2D import evaluate_batch
from train_classifier_2d import init_model, load_dataset

This notebook serves to plot examples of Class Activation Maps.

This notebook might need some updates to function with changes in the rest of the code. Either way, Class Activation Map examples are also saved during training, and are accessible in tensorboard. Alternatively, they can also be viewed using make_and_evaluate_cam(_3D).py with the --visualize flag.

# Deeprisk data

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hparams = Namespace()
# choose dataset
hparams.dataset = 'deeprisk'
# reproducability
hparams.trainseed = 42
hparams.splitseed = 84
hparams.train_frac = 0.6
# paths
hparams.data_path = r"\\amc.intra\users\R\rcklein\home\deeprisk\weakly_supervised\data"
hparams.img_path = r"all_niftis_n=657"
hparams.weak_labels_path = r"weak_labels_n=657.xlsx"
hparams.myoseg_path = r"myocard_predictions/version_19"
hparams.seg_labels_dir = r"fibrosis_labels_n=75"
# data augmentation
hparams.image_norm = "global_statistic"
hparams.no_roi_crop = False
hparams.include_no_myo = False
hparams.roi_crop = "fixed" # "fitted", or "fixed"
hparams.center_crop = 224
hparams.input_size = 224
hparams.rotate = 0
hparams.translate = (0, 0)
hparams.scale = (1, 1)
hparams.shear = (0, 0, 0, 0)
hparams.brightness = 0
hparams.contrast = 0
hparams.hflip = 0.0
hparams.vflip = 0.0
hparams.randomaffine_prob=0
hparams.randomcrop=False
hparams.randomcrop_prob = 0.0
hparams.randomerasing_probs = []
print(f'{hparams=}')

pl.seed_everything(hparams.trainseed, workers=True)
# prepare dataloaders
dataset_train, dataset_val, dataset_test = load_dataset(hparams)
print(f"Train data: {len(dataset_train)}, validation data {len(dataset_val)}")

train_loader = DataLoader(dataset_train,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            num_workers=1)
val_loader = DataLoader(dataset_val,
                        batch_size=1,
                        shuffle=False,
                        drop_last=False,
                        num_workers=1)

test_loader = DataLoader(dataset_test,
                        batch_size=1,
                        shuffle=False,
                        drop_last=False,
                        num_workers=1)


Global seed set to 42


hparams=Namespace(dataset='deeprisk', trainseed=42, splitseed=84, train_frac=0.6, data_path='\\\\amc.intra\\users\\R\\rcklein\\home\\deeprisk\\weakly_supervised\\data', img_path='all_niftis_n=657', weak_labels_path='weak_labels_n=657.xlsx', myoseg_path='myocard_predictions/version_19', seg_labels_dir='fibrosis_labels_n=75', image_norm='global_statistic', no_roi_crop=False, include_no_myo=False, roi_crop='fixed', center_crop=224, input_size=224, rotate=0, translate=(0, 0), scale=(1, 1), shear=(0, 0, 0, 0), brightness=0, contrast=0, hflip=0.0, vflip=0.0, randomaffine_prob=0, randomcrop=False, randomcrop_prob=0.0, randomerasing_probs=[])
train transforms:
Compose(
    ApplyAllResize(size=(224, 224), interpolations=['bilinear', 'bilinear', 'bilinear'], max_size=None, antialias=None)
    ApplyFirstColorJitter(brightness=None, contrast=None, saturation=None, hue=None)
    ApplyFirstNormalize(mean=[0.57], std=[0.06])
    RandomApply(
    p=0
    ApplyAllRandomAffine(degrees=[0.0, 0.0], scale=

431it [00:36, 11.90it/s]


len(PIXEL_LABEL_IDS)=75
Validation patients: 21
VAL_IDS=['DRAUMC1094', 'DRAUMC0532', 'DRAUMC1164', 'DRAUMC0235', 'DRAUMC0804', 'DRAUMC0072', 'DRAUMC0847', 'DRAUMC1082', 'DRAUMC0583', 'DRAUMC0481', 'DRAUMC0768', 'DRAUMC1224', 'DRAUMC0923', 'DRAUMC0667', 'DRAUMC1037', 'DRAUMC0431', 'DRAUMC0868', 'DRAUMC0527', 'DRAUMC1122', 'DRAUMC1084', 'DRAUMC0743']
len(label_df)=18


18it [00:01, 11.35it/s]


len(PIXEL_LABEL_IDS)=75
Test patients: 21
TEST_IDS=['DRAUMC0075', 'DRAUMC0172', 'DRAUMC0338', 'DRAUMC0380', 'DRAUMC0411', 'DRAUMC0435', 'DRAUMC0507', 'DRAUMC0567', 'DRAUMC0634', 'DRAUMC0642', 'DRAUMC0673', 'DRAUMC1017', 'DRAUMC1042', 'DRAUMC1166', 'DRAUMC1199', 'DRAUMC0051', 'DRAUMC0805', 'DRAUMC0891', 'DRAUMC1049', 'DRAUMC0008', 'DRAUMC0949']
len(label_df)=20


20it [00:01, 11.05it/s]

Train data: 4191, positive: 2347
Validation data: 173, positive: 76
Test data: 198, positive: 129
Train data: 4191, validation data 173





# Load models

In [3]:
# load a trained model
#resnet18 = LightningResNet(no_cam=False, highres=True, model='resnet18', num_classes=1, in_chans=1)
#resnet18.load_state_dict(torch.load(r"..\..\tb_logs\resnet18\version_9\checkpoints\epoch=249-step=37749.ckpt", map_location=torch.device('cpu'))['state_dict'])
#resnet18.eval()
def load_model(checkpoint_path, model="drnd"):
    # path names
    MODEL_DIR = Path(checkpoint_path).parent.parent
    setattr(hparams, "model", model)
    setattr(hparams, "singlestage", False)
    setattr(hparams, "load_checkpoint", checkpoint_path)
    
    
    # add hparams from saved model
    with open(MODEL_DIR.joinpath("hparams.yaml"), 'r') as stream:
        parsed_yaml = yaml.safe_load(stream)
        for k, v in parsed_yaml.items():
            setattr(hparams, k, v)

    # load model
    model = init_model(hparams)
    model.eval()
    return model

drnd_d_7 = load_model(r"\\amc.intra\users\R\rcklein\home\deeprisk\weakly_supervised\tb_logs\drnd_masked_avg_dilation=7\version_0\checkpoints\epoch=119-step=31320.ckpt")
drnd_d_3 = load_model(r"\\amc.intra\users\R\rcklein\home\deeprisk\weakly_supervised\tb_logs\drnd_masked_avg_dilation=3\version_0\checkpoints\epoch=119-step=31320.ckpt")
drnd_d_3_new = load_model(r"\\amc.intra\users\R\rcklein\home\deeprisk\weakly_supervised\tb_logs\drnd_masked_avg_dilation=3\version_0\checkpoints\epoch=119-step=31320.ckpt")

ConstructorError: could not determine a constructor for the tag 'tag:yaml.org,2002:python/tuple'
  in "\\amc.intra\users\R\rcklein\home\deeprisk\weakly_supervised\tb_logs\drnd_masked_avg_dilation=7\version_0\hparams.yaml", line 14, column 7

In [None]:
import os
import skimage
from skimage.filters import threshold_multiotsu
def plot_cams_gen(dataloader, net, wrong_only=False):
    for batch in dataloader:
        inputs = batch["img"]
        labels = batch["label"]
        myo_seg = batch["myo_seg"]
        batch_path = batch["img_path"]
        batch_slice = batch["slice_idx"]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward
        outputs = net(inputs, myo_seg=myo_seg)
        preds = (torch.sigmoid(outputs) > 0.5)[:, 0]
        
        cam = net.make_cam(inputs, myo_seg=myo_seg, select_values="pos_only", upsampling="conv_transpose").detach()
        #cam = torchvision.transforms.Resize(inputs.shape[2], interpolation=transforms.InterpolationMode.NEAREST)(cam)
        
        for i in range(len(inputs)):
            if labels[i].item() != int(preds[i].item()) or wrong_only == False:
                # create the histogram
                histogram, bin_edges = np.histogram(cam[i][0], bins=20)
                print(f'{histogram=}')
                # configure and draw the histogram figure
                plt.figure()
                plt.title("CAM Histogram")
                plt.xlabel("Value")
                plt.ylabel("pixel count")
                plt.xlim([0.0, 1.0])  # <- named arguments do not work here

                plt.plot(bin_edges[0:-1], histogram)  # <- or here
                plt.show()
                
                
                print("label: ", labels[i].item(), ", predicted: ", int(preds[i].item()), ', score:', torch.sigmoid(outputs[i]).detach().numpy())
                fig, axs = plt.subplots(1, 2, figsize=(16, 8))
                print(f"slice {batch_slice[i]}, file {batch_path[i]}")
                print(f"{cam[i].max()}")
                axs[0].imshow(inputs[i][0], cmap="gray")
                axs[1].imshow(inputs[i][0], cmap='gray')
                axs[1].imshow(cam[i][0], cmap="coolwarm", alpha=0.5, vmin=0.0, vmax=1)
                fig.colorbar(cm.ScalarMappable(cmap="coolwarm"))
                plt.show()
                yield
        
from make_and_evaluate_cams_2D import evaluate_batch



def plot_cams_crf_gen(dataloader, net, crf_params, iters, threshold, otsu_cam_threshold=0.0):
    for batch in dataloader:
        if batch["fibrosis_seg_label"].amax() == 1:
            print(f"{batch['img_path']} {batch['slice_idx']}")
            image = batch["img"].detach().numpy()[0][0]
            gt = batch["fibrosis_seg_label"].numpy()[0][0]
            
            cam = net.make_cam(batch["img"], myo_seg=batch["myo_seg"], upsampling="conv_transpose", select_values="pos_only")
            cam = cam[0][0].detach().numpy()
            hist, bin_centers = skimage.exposure.histogram(image[cam > otsu_cam_threshold], nbins=256)
            thresholds = threshold_multiotsu(hist=(hist, bin_centers), classes=3)
            # Using the threshold values, we generate the three regions.
            regions = np.digitize(image, bins=thresholds)
            regions[cam <= otsu_cam_threshold] = 3

            fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(10, 3.5))

            # Plotting the original image.
            ax[0].imshow(image, cmap='gray')
            ax[0].set_title('Original')
            ax[0].axis('off')

            # Plotting the histogram and the two thresholds obtained from
            # multi-Otsu.
            ax[1].hist(image[cam > otsu_cam_threshold].ravel(), bins=255)
            ax[1].set_title('Histogram')
            for thresh in thresholds:
                ax[1].axvline(thresh, color='r')

            # Plotting the Multi Otsu result.
            ax[2].imshow(regions, cmap='jet')
            ax[2].set_title('Multi-Otsu result')
            ax[2].axis('off')
            
            ax[3].hist(image[gt > 0.5].ravel(), bins=255)
            ax[3].set_title('ground truth fibrosis histogram')
            

            plt.subplots_adjust()

            plt.show()
            metrics, pseudo_metrics, pseudo = evaluate_batch(batch, net, confidence_weighting=True,
                                                               threshold=threshold, iters=iters, crf_params=crf_params,
                                                               visualize=True, otsu_cam_threshold=0.0, otsu_mask=True,
                                                               compute_no_label=False, merge_cam=True)
            yield    

In [None]:
selected_model = drnd_d_3
w1, alpha, beta, w2, gamma, iters = 10, 3, 3, 1, 3, 5
crf_params = w1, alpha, beta, w2, gamma, iters
threshold = 0.3

train_cam_examples = plot_cams_crf_gen(train_loader, selected_model, crf_params, iters, threshold)
val_cam_examples = plot_cams_crf_gen(val_loader, selected_model, crf_params, iters, threshold)

In [None]:
next(train_cam_examples)

In [None]:
next(val_cam_examples)

# Make nice plot

In [None]:
from make_and_evaluate_cams_2D import dice_score, cam_to_prob, otsu_mask_cam
from preprocessing import denormalize_transform
from densecrf import densecrf

#import seaborn as sns
#sns.set()

def plot_grid(dataloader, model, crf_params, iters, threshold, n_imgs=16):
    DENORM = denormalize_transform()
    # collect images
    collected_imgs, collected_gt, collected_pseudo = [], [], []
    for batch in dataloader:        
        img = batch["img"]
        weak_label = batch["label"]
        myo_seg = batch['myo_seg']
        fibrosis_seg_label = batch["fibrosis_seg_label"]
        # compute dice if label available
        if fibrosis_seg_label.amax() == 1:
            cam = model.make_cam(img, select_values="pos_only",
                                    upsampling="conv_transpose",
                                    myo_seg=myo_seg).detach()
            pred = model.output2pred(model(img, myo_seg=myo_seg))
            
            cam = cam * pred

            pseudo = otsu_mask_cam(cam, img, 0)
            
            if iters > 0:
                # denseCRF
                denorm_img = DENORM(img)
                img_int8 = (255*img).type(torch.uint8)
                probs = cam_to_prob(pseudo, cam_threshold=threshold, binarize=True)
                pseudo = densecrf(img_int8[0], probs[0], crf_params)
                pseudo_dice = dice_score(pseudo, fibrosis_seg_label, 0.5)
            else:
                pseudo_dice = None
                
            collected_imgs.append(img)
            collected_gt.append((fibrosis_seg_label > 0.0))
            collected_pseudo.append(pseudo)
            
            if len(collected_imgs) >= n_imgs:
                # start plotting
                nrows = int(np.sqrt(n_imgs))
                fig, axs = plt.subplots(1, 3, figsize=(14, 5))
                fig.suptitle("Weakly supervised fibrosis segmentation")


                imgs = torch.cat(collected_imgs, dim=0)
                grid_img = torchvision.utils.make_grid(imgs, nrow=nrows, normalize=True, padding=2).detach().cpu()
                axs[0].imshow(grid_img[0,:,:], cmap="gray")
                axs[0].set_title("Reference image")

                gt = torch.cat(collected_gt, dim=0)
                print(gt.shape)
                grid_gt = torchvision.utils.make_grid(gt, nrow=nrows, padding=2).cpu().detach()
                grid_gt = np.ma.masked_where(grid_gt <= 0, grid_gt)
                axs[1].imshow(grid_img[0,:,:], cmap="gray")
                axs[1].imshow(grid_gt[0,:,:], cmap="Reds", alpha=0.5, vmin=0, vmax=1)
                axs[1].set_title("Ground truth")


                pseudo = torch.cat(collected_pseudo, dim=0)
                print(pseudo.shape)
                grid_pseudo = torchvision.utils.make_grid(pseudo, nrow=nrows, padding=2).cpu().detach()
                grid_pseudo = np.ma.masked_where(grid_pseudo <= 0, grid_pseudo)
                axs[2].imshow(grid_img[0,:,:], cmap="gray")
                axs[2].imshow(grid_pseudo[0,:,:], cmap="Reds", alpha=0.5, vmin=0, vmax=1)
                axs[2].set_title("Pseudo label")

                axs[0].set_axis_off()
                axs[1].set_axis_off()
                axs[2].set_axis_off()

                # reset containers
                collected_imgs, collected_gt, collected_pseudo = [], [], []
                yield

    
    
    
    
    
    
    
    
pseudo_grid_gen = plot_grid(test_loader, selected_model, crf_params, iters, threshold, n_imgs=1)
    

In [None]:
next(pseudo_grid_gen)

In [None]:
import torch
import math
from torchvision.transforms import transforms
import cv2
irange = range

def plot_cams_grid_gen(dataloader, net, wrong_only=False):
    for batch in dataloader:
        inputs = batch[-2]
        labels = batch[-1]
        batch_paths = batch[0]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward
        outputs = net(inputs)
        #_, preds = torch.max(outputs, 1)
        preds = (torch.sigmoid(outputs) > 0.5)[:, 0].int()
        
        cam = net.make_cam(inputs, select_values="sigmoid", upsampling="conv_transpose").detach()
        print(cam.shape, cam.max(), cam.min())
        #cam = torchvision.transforms.Resize(inputs.shape[2], interpolation=transforms.InterpolationMode.NEAREST)(cam)
        
        
        labeled_grid = make_grid_with_labels(inputs, labels, preds, nrow=4, normalize=True)
        plt.figure(figsize=(10,10))
        plt.imshow(labeled_grid.permute(1, 2, 0)[:,:,0], cmap="gray")
        plt.show()
        yield
            



def make_grid_with_labels(tensor, labels, predictions, nrow=8, limit=20, padding=2,
                          normalize=False, range=None, scale_each=False, pad_value=0):
    """Make a grid of images.

    Args:
        tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
            or a list of images all of the same size.
        labels (list):  ( [labels_1,labels_2,labels_3,...labels_n]) where labels is Bx1 vector of some labels
        limit ( int, optional): Limits number of images and labels to make grid of
        nrow (int, optional): Number of images displayed in each row of the grid.
            The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
        padding (int, optional): amount of padding. Default: ``2``.
        normalize (bool, optional): If True, shift the image to the range (0, 1),
            by the min and max values specified by :attr:`range`. Default: ``False``.
        range (tuple, optional): tuple (min, max) where min and max are numbers,
            then these numbers are used to normalize the image. By default, min and max
            are computed from the tensor.
        scale_each (bool, optional): If ``True``, scale each image in the batch of
            images separately rather than the (min, max) over all images. Default: ``False``.
        pad_value (float, optional): Value for the padded pixels. Default: ``0``.

    Example:
        See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_

    """
    # Opencv configs
    if limit is not None:
        tensor = tensor[:limit, ::]
        labels = labels[:limit]
        predictions = predictions[:limit]

    
    font = 1
    fontScale = 1
    color = (255, 0, 0)
    thickness = 1

    if not (torch.is_tensor(tensor) or
            (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))

    # if list of tensors, convert to a 4D mini-batch Tensor
    if isinstance(tensor, list):
        tensor = torch.stack(tensor, dim=0)

    if tensor.dim() == 2:  # single image H x W
        tensor = tensor.unsqueeze(0)
    if tensor.dim() == 3:  # single image
        if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
            tensor = torch.cat((tensor, tensor, tensor), 0)
        tensor = tensor.unsqueeze(0)

    if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
        tensor = torch.cat((tensor, tensor, tensor), 1)

    if normalize is True:
        tensor = tensor.clone()  # avoid modifying tensor in-place
        if range is not None:
            assert isinstance(range, tuple), \
                "range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, min, max):
            img.clamp_(min=min, max=max)
            img.add_(-min).div_(max - min + 1e-5)

        def norm_range(t, range):
            if range is not None:
                norm_ip(t, range[0], range[1])
            else:
                norm_ip(t, float(t.min()), float(t.max()))

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range(t, range)
        else:
            norm_range(tensor, range)

    if tensor.size(0) == 1:
        return tensor.squeeze(0)

    # make the mini-batch of images into a grid
    nmaps = tensor.size(0)
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
    num_channels = tensor.size(1)
    grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
    k = 0
    for y in irange(ymaps):
        for x in irange(xmaps):
            if k >= nmaps:
                break
            working_tensor = tensor[k]
            if labels is not None:
                org = (0, int(tensor[k].shape[1] * 0.1))
                working_image = cv2.UMat(
                    np.asarray(np.transpose(working_tensor.numpy(), (1, 2, 0)) * 255).astype('uint8'))
                image = cv2.putText(working_image, f"label: {str(labels[k].item())}      pred: {str(predictions[k].item())}",
                                    org, font, fontScale, color, thickness, cv2.LINE_AA)
                working_tensor = transforms.ToTensor()(image.get())
            grid.narrow(1, y * height + padding, height - padding) \
                .narrow(2, x * width + padding, width - padding) \
                .copy_(working_tensor)
            k = k + 1
    return grid

selected_model = drnd24_maxpool
grid_examples = plot_cams_grid_gen(train_loader, selected_model, wrong_only=False)

In [None]:
next(grid_examples)

# CIFAR

In [None]:
pl.seed_everything(42, workers=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# transforms
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ColorJitter(brightness=0.0, contrast=0.0),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.RandomAffine(90, scale=(0.8, 1.2), shear=(-5, 5, -5, 5), translate=(0, 0))    
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
print(train_transforms)
print(val_transforms)


# path names
DATA_DIR = Path(r"\\amc.intra\users\R\rcklein\home\deeprisk\weakly_supervised\data")
CIFAR_DIR = DATA_DIR.joinpath('cifar10_data')            
assert DATA_DIR.exists()

cifar_train = torchvision.datasets.CIFAR10(root=CIFAR_DIR, train=True,
                                download=True, transform=train_transforms)

cifar_val = torchvision.datasets.CIFAR10(root=CIFAR_DIR, train=False,
                                    download=True, transform=val_transforms)


cifar_train_dataloader = DataLoader(cifar_train, batch_size=8, shuffle=False)
cifar_val_dataloader = DataLoader(cifar_val, batch_size=8, shuffle=False)

In [None]:
def plot_cams_gen_cifar(dataloader, net, wrong_only=False):
    classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    for batch in dataloader:
        inputs = batch[-2]
        labels = batch[-1]
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward
        outputs = net(inputs)
        _, preds = torch.max(outputs, 1)
        #preds = (torch.sigmoid(outputs) > 0.5)[:, 0]
        
        cam = net.make_cam(inputs).detach()
        print(cam.shape)
        cam = torchvision.transforms.Resize(inputs.shape[2], interpolation=transforms.InterpolationMode.NEAREST)(cam)
        
        for i in range(len(inputs)):
            if labels[i].item() != int(preds[i].item()) or wrong_only == False:
                print("label: ", f"{classes[labels[i].item()]} ({labels[i].item()})", ", predicted: ", f"{classes[int(preds[i].item())]} ({int(preds[i].item())})", ', scores:', F.softmax(outputs[i], dim=0).detach().numpy())
                fig, axs = plt.subplots(1, 2, figsize=(20, 10))
                # get back image
                img = (inputs[i].permute(1,2,0)) / 2 + 0.5
                axs[0].imshow(img)
                axs[1].imshow(img)
                axs[1].imshow(cam[i][labels[i].item()], cmap="jet", alpha=0.5)
                plt.show()
                yield
selected_model = resnet18_cifar
train_cam_examples_cifar = plot_cams_gen_cifar(cifar_train_dataloader, selected_model, wrong_only=False)
val_cam_examples_cifar = plot_cams_gen_cifar(cifar_val_dataloader, selected_model, wrong_only=False)

In [None]:
next(train_cam_examples_cifar)

In [None]:
next(val_cam_examples_cifar)

In [None]:
print(resnet18_7x7)

In [None]:
print(convnext_narrow_shallow)