In [1]:
import os
import shutil
import glob
import torch
from torchinfo import summary
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
import random

In [2]:
os.chdir(r"/mnt/hd1/ani/HIPT/HIPT_4K")
from hipt_4k import HIPT_4K
from hipt_model_utils import get_vit256, get_vit4k, eval_transforms
from hipt_heatmap_utils import *
from attention_visualization_utils import *
light_jet = cmap_map(lambda x: x/2 + 0.5, matplotlib.cm.jet)

In [3]:
# Device + Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# Check GPU properties
for i in range(torch.cuda.device_count()): 
    print(torch.cuda.get_device_properties(i))
    
    # Get memory
    print('Total memory in GB: ', torch.cuda.get_device_properties(i).total_memory/(1.024*1e9))
    print('Memory reserved in GB: ', torch.cuda.memory_reserved(i)/(1.024*1e9))
    print('Memory allocated in GB: ', torch.cuda.memory_allocated(i)/(1.024*1e9))

cuda
_CudaDeviceProperties(name='NVIDIA RTX A5000', major=8, minor=6, total_memory=24251MB, multi_processor_count=64)
Total memory in GB:  24.83328
Memory reserved in GB:  0.0
Memory allocated in GB:  0.0


In [4]:
pretrained_weights256 = './Checkpoints/vit256_small_dino.pth'
pretrained_weights4k = './Checkpoints/vit4k_xs_dino.pth'
device256 = torch.device("cpu")
device4k = torch.device("cpu")

### ViT_256 + ViT_4K loaded independently (used for Attention Heatmaps)
model256 = get_vit256(pretrained_weights=pretrained_weights256, device=device256)
model4k = get_vit4k(pretrained_weights=pretrained_weights4k, device=device4k)

### ViT_256 + ViT_4K loaded into HIPT_4K API
model = HIPT_4K(pretrained_weights256, pretrained_weights4k, device256, device4k)
model.eval()

# Print model summary
summary(model=model,
        col_width=12,
        row_settings=["var_names"]
       )

Take key teacher in provided checkpoint dict
Pretrained weights found at ./Checkpoints/vit256_small_dino.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
# of Patches: 196
Take key teacher in provided checkpoint dict
Pretrained weights found at ./Checkpoints/vit4k_xs_dino.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
Take key teacher in provided checkpoint dict
Pretrained weights found at ./Checkpoints/vit256_small_dino.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.

Layer (type (var_name))                       Param #
HIPT_4K (HIPT_4K)                             --
├─VisionTransformer (model256)                76,032
│    └─PatchEmbed (patch_embed)               --
│    │    └─Conv2d (proj)                     (295,296)
│    └─Dropout (pos_drop)                     --
│    └─ModuleList (blocks)                    --
│    │    └─Block (0)                         (1,774,464)
│    │    └─Block (1)                         (1,774,464)
│    │    └─Block (2)                         (1,774,464)
│    │    └─Block (3)                         (1,774,464)
│    │    └─Block (4)                         (1,774,464)
│    │    └─Block (5)                         (1,774,464)
│    │    └─Block (6)                         (1,774,464)
│    │    └─Block (7)                         (1,774,464)
│    │    └─Block (8)                         (1,774,464)
│    │    └─Block (9)                         (1,774,464)
│    │    └─Block (10)                        (1,774,464)
│  

## Standalone HIPT_4K Model Inference

In [5]:
region = Image.open('./image_demo/image_4k.png')
x = eval_transforms()(region).unsqueeze(dim=0)
print('Input Shape:', x.shape)
print('Output Shape:', model.forward(x).shape)

Input Shape: torch.Size([1, 3, 4096, 4096])
Output Shape: torch.Size([1, 192])


## HIPT_4K Attention Heatmaps

##### Code for producing attention results (for [256 x 256], [4096 x 4096], and hierarchical [4096 x 4096]) can be run (as-is) below. There are several ways these results can be run:

hipt_4k.py Class (Preferred): This class blends inference and heatmap creation in a seamless and more object-oriented manner, and is where I am focusing my future code development around.
Helper Functions in hipt_heatmap_utils.py (Soon-to-be-deprecated): Heatmap creation was originally written as helper functions. May be more useful and easier from research perspective.
Please use whatever is most helpful for your use case :)

## 256 x 256 Demo (Saving Attention Maps Individually)

In [6]:
# patch = Image.open('./image_demo/image_256.png')
# output_dir = './attention_demo/256_output_indiv/'
# os.makedirs(output_dir, exist_ok=True)
# create_patch_heatmaps_indiv(patch=patch, model256=model256, 
#                             output_dir=output_dir, fname='patch',
#                             cmap=light_jet, device256=device256)

## 256 x 256 Demo (Concatenating + Saving Attention Maps)

In [7]:
# patch = Image.open('./image_demo/image_256.png')
# output_dir = './attention_demo/256_output_concat/'
# os.makedirs(output_dir, exist_ok=True)
# create_patch_heatmaps_concat(patch=patch, model256=model256, 
#                             output_dir=output_dir, fname='patch',
#                             cmap=light_jet)

## 4096 x 4096 Demo (Saving Attention Maps Individually)

In [8]:
region = Image.open('./image_demo/image_4k.png')
output_dir = './attention_demo/4k_output_indiv/'
os.makedirs(output_dir, exist_ok=True)
create_hierarchical_heatmaps_indiv(region, model256, model4k, 
                                   output_dir, fname='region', 
                                   scale=2, threshold=0.5, cmap=light_jet, alpha=0.5)

## 4096 x 4096 Demo (Concatenating + Saving Attention Maps)

In [9]:
# region = Image.open('./image_demo/image_4k.png')
# output_dir = './attention_demo/4k_output_concat/'
# os.makedirs(output_dir, exist_ok=True)
# create_hierarchical_heatmaps_concat(region, model256, model4k, 
#                                    output_dir, fname='region', 
#                                    scale=2, cmap=light_jet, alpha=0.5)