In [4]:
import sys
sys.path.append('../src')

import os
from pathlib import Path
from PIL import Image
from typing import Optional, Any, List, Tuple, Dict, Union
import matplotlib.pyplot as plt

import torch
from torch import Tensor
import torch.nn as nn
from torchvision import utils as vutils
from torchvision import transforms

import models_mae
from dataset import LOCODataModule, ImageNetTransforms
from models_mae import MaskedAutoencoderViT
from method import MAEMethod
from params import MAEParams
from common import ImageLogCallback, prepare_model, to_rgb_from_tensor, get_patches_containing_bbox_1D


In [48]:
# batch_embedding: (BN, L, D)
batch_embed = torch.randn(24, 196, 768)
print(f"batch_embed.shape: {batch_embed.shape}")
print(f"batch_embed[0]", batch_embed[0])
print("-----\n")
# batch_bbox: (BN, L)
batch_bbox = torch.randint(0, 2, (24, 196))
print(f"batch_bbox.shape: {batch_bbox.shape}")
print(f"batch_bbox[0]", batch_bbox[0])

zero_vector = torch.zeros(768)
print("-----\n")
# convert batch_bbox to indices
batch_indices = [torch.where(bbox == 1)[0] for bbox in batch_bbox]
print(f"length of batch_indices: {len(batch_indices)}")
print(f"batch_indices[0]: {batch_indices[0]}")

print("-----\n")
# mask out the embeddings
for i, indices in enumerate(batch_indices):
    batch_embed[i, indices, :] = 0.
print(f"masked_batch_embed.shape: {batch_embed.shape}")
print(f"masked_batch_embed[0]", batch_embed[0])

batch_embed.shape: torch.Size([24, 196, 768])
batch_embed[0] tensor([[ 0.4467, -0.2411, -1.3649,  ..., -0.6264,  0.8555, -0.9336],
        [-0.3862, -0.9999, -0.8738,  ..., -0.7636, -1.8462, -1.5139],
        [-0.7923,  1.6795, -0.2511,  ..., -0.5125, -0.0260, -0.9999],
        ...,
        [ 1.4510,  1.2476,  1.6363,  ..., -1.1198, -1.5834, -1.0546],
        [ 1.3579, -0.2066, -0.3293,  ..., -0.3328,  0.5646,  0.1232],
        [-1.7694,  0.9908,  0.7270,  ..., -0.6524, -0.2391, -1.9545]])
-----

batch_bbox.shape: torch.Size([24, 196])
batch_bbox[0] tensor([1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0,
        0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0,
        0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1,
        0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 

In [61]:
# self attention mask (BN, H, L, L)

# expand head dimension
attn_mask = batch_bbox.unsqueeze(1).repeat(1, 8, 1)
print(f"attn_mask.shape: {attn_mask.shape}")
print("-----\n")
# expand L dimension
attn_mask = attn_mask.unsqueeze(-2).repeat(1, 1, 196, 1)
print(f"attn_mask.shape: {attn_mask.shape}")
print(f"attn_mask[0, 0]", attn_mask[0, 0])
print("-----\n")
# multiply by -1e9 (Greatly negative value)
attn_mask = attn_mask * -1e9
print(f"attn_mask.shape: {attn_mask.shape}")
print(f"attn_mask[0, 0]", attn_mask[0, 0])

attn_mask.shape: torch.Size([24, 8, 196])
-----

attn_mask.shape: torch.Size([24, 8, 196, 196])
attn_mask[0, 0] tensor([[1, 1, 0,  ..., 1, 0, 1],
        [1, 1, 0,  ..., 1, 0, 1],
        [1, 1, 0,  ..., 1, 0, 1],
        ...,
        [1, 1, 0,  ..., 1, 0, 1],
        [1, 1, 0,  ..., 1, 0, 1],
        [1, 1, 0,  ..., 1, 0, 1]])
-----

attn_mask.shape: torch.Size([24, 8, 196, 196])
attn_mask[0, 0] tensor([[-1.0000e+09, -1.0000e+09, -0.0000e+00,  ..., -1.0000e+09,
         -0.0000e+00, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -0.0000e+00,  ..., -1.0000e+09,
         -0.0000e+00, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -0.0000e+00,  ..., -1.0000e+09,
         -0.0000e+00, -1.0000e+09],
        ...,
        [-1.0000e+09, -1.0000e+09, -0.0000e+00,  ..., -1.0000e+09,
         -0.0000e+00, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -0.0000e+00,  ..., -1.0000e+09,
         -0.0000e+00, -1.0000e+09],
        [-1.0000e+09, -1.0000e+09, -0.0000e+00,  ..., -1.0000e+09,
     

In [None]:

params = MAEParams()
ckpt_file = "/home/dl/takamagahara/hutodama/MAE/src/wandb/run-20230807_101801-kuia1oij/ckpt/epoch=499-step=2500.ckpt"
img_transforms = ImageNetTransforms(input_resolution=params.resolution)
datamodule = LOCODataModule(
    data_root=params.data_root,
    category=params.category,
    input_resolution=params.resolution,
    img_transforms=img_transforms,
    batch_size=params.batch_size,
    num_workers=params.num_workers,
)

model = getattr(models_mae, params.arch)()

method = MAEMethod(
    model=model,
    datamodule=datamodule,
    params=params,
)

method.load_state_dict(torch.load(ckpt_file)['state_dict'])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

method.eval()
method.to(device)

In [None]:
# load dataset
datamodule.setup(stage='test')
train_dataset = datamodule.train_dataset
test_dataset = datamodule.test_dataset
valid_dataset = datamodule.valid_dataset
train_dataloader = datamodule.train_dataloader()
test_dataloader = datamodule.test_dataloader()
valid_dataloader = datamodule.val_dataloader()

train_files = datamodule.train_dataset.files
test_files = datamodule.test_dataset.files

labels = test_dataloader.dataset.labels
abnormal_indices = [idx for idx, label in enumerate(labels) if label == 1]
normal_indices = [idx for idx, label in enumerate(labels) if label == 0]

# indices of the logical anomalies and structural anomalies
log_indices = [idx for idx, file in enumerate(test_dataloader.dataset.files) if 'logical' in str(file)]
str_indices = [idx for idx, file in enumerate(test_dataloader.dataset.files) if 'structural' in str(file)]

# sample anomal images
num_samples = 5 * 3
perm = torch.randperm(len(abnormal_indices))[:num_samples]
indices = [abnormal_indices[idx] for idx in perm]
sample_files = [test_files[idx] for idx in indices]

fig, axes = plt.subplots(num_samples // 5, 5, figsize=(20, 8))
for i, ax in enumerate(axes.flatten()):
    img = Image.open(sample_files[i])
    ax.imshow(img)

In [None]:
# util functions

def inference(model: MaskedAutoencoderViT, inputs: Tensor):
    """Inference function.

    Args:
        model (MaskedAutoencoderViT): model
        inputs (Tensor): inputs batch (B, C, H, W)
    """
    loss, pred, mask = model(inputs)
    return loss, pred, mask

def get_batched_inputs(dataset, indices: List):
    assert len(dataset[0]) == 2, "dataset must return (image, label) tuple."
    selected_imgs = [dataset[idx][0] for idx in indices]
    selected_imgs = torch.stack(selected_imgs, dim=0)
    return selected_imgs

def make_grid_results(model: MaskedAutoencoderViT, inputs: Tensor, params: Any, mask_ratio: float = 0.75, \
    mask_indices = None):
    """reformat results to a grid."""
    
    if params.gpus > 0:
        batch = inputs.to(device)
    _, pred, mask = model(batch, mask_ratio=mask_ratio, mask_indices=mask_indices)  # pred: (B, L, p*p*3), mask: (B, L)
    
    patch_size = model.patch_embed.patch_size[0]
    mask = mask.detach()  # -> (B, L)
    print(mask.shape)
    mask = mask.unsqueeze(-1).repeat(1, 1, patch_size**2 * 3)  # -> (B, L, p*p*3)
    mask = model.unpatchify(mask)  # -> (B, 3, H, W)
    
    pred = model.unpatchify(pred)  # -> (B, 3, H, W)
    
    # mask input images.
    im_masked = batch * (1 - mask) # -> (B, 3, H, W)
    im_paste = im_masked + pred * mask # -> (B, 3, H, W)
    
    # convert tensor to rgb format.
    images = to_rgb_from_tensor(batch.cpu()) # -> (B, 3, H, W)
    im_masked = to_rgb_from_tensor(im_masked.cpu())  # -> (B, 3, H, W)
    im_paste = to_rgb_from_tensor(im_paste.cpu())  # -> (B, 3, H, W)
    
    # combine images in a way so we can display all outputs in one grid.
    out = torch.cat([images.unsqueeze(1), im_masked.unsqueeze(1), im_paste.unsqueeze(1)], dim=1)  # -> (B, 3, 3, H, W)
    out = out.view(-1, *out.shape[2:])  # -> (3*B, 3, H, W)
    
    images = vutils.make_grid(
            out.cpu(),
            nrows=out.shape[0] // 3,
            ncols=3,
            normalize=False, 
    )
    # images: (3, H, W)
    return images

In [None]:
inputs = get_batched_inputs(test_dataset, indices)
grid_images = make_grid_results(method.mae, inputs, params, 0.75)

In [None]:
fig, ax = plt.subplots(figsize=(20, 20))
ax.imshow(grid_images.permute(1, 2, 0))

## test

In [None]:
batch = next(iter(train_dataloader))

In [None]:
# Testing the modified function
bbox = (50, 50, 70, 70)
patch_size = (16, 16)
img_size = (224, 224)
mask_indices = get_patches_containing_bbox_1D(bbox, patch_size, img_size)
print(mask_indices)

In [None]:
grid_images = make_grid_results(method.mae, inputs, params, 0.75, mask_indices=mask_indices)
fig, ax = plt.subplots(figsize=(20, 20))
ax.imshow(grid_images.permute(1, 2, 0))

In [None]:
import numpy as np
from ipywidgets import interactive, IntSlider, HBox, VBox
from IPython.display import display

# Display the image with initial bbox
def display_bbox(x, y, width, height):
    fig, ax = plt.subplots(1, figsize=(10, 6))
    ax.imshow(img)
    rect = patches.Rectangle((x, y), width, height, linewidth=1, edgecolor='r', facecolor='none')
    ax.add_patch(rect)
    plt.show()

# Create sliders for bbox coordinates and size
x_slider = IntSlider(min=0, max=img.shape[1], step=1, value=0, description='x')
y_slider = IntSlider(min=0, max=img.shape[0], step=1, value=0, description='y')
width_slider = IntSlider(min=0, max=img.shape[1], step=1, value=50, description='width')
height_slider = IntSlider(min=0, max=img.shape[0], step=1, value=50, description='height')

# Combine the sliders with the display function
interactive_plot = interactive(display_bbox, x=x_slider, y=y_slider, width=width_slider, height=height_slider)
output = interactive_plot.children[-1]
output.layout.height = '350px'
display(interactive_plot)