In [1]:
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.models import create_model
import yaml
import numpy as np
import os
from src import cct  # Importing the cct module
import attnmapgradio


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
results_final_path = "./result_final"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
model_names = os.listdir(results_final_path)
c10_models, c100_models = [], []

for m in model_names:
    if 'c100' in m:
        c100_models.append(m)
    else:
        c10_models.append(m)


In [4]:
with open('cifar10_labels.txt') as c10:
    c10_labels = c10.read().split(',')

with open('cifar100_labels.txt') as c100:
    c100_labels = c100.read().split(',')

In [5]:
def load_model(model_name, results_final_path, dataset):
    num_classes = {'CIFAR-10': 10, 'CIFAR-100': 100}
    
    args_file = os.path.join(results_final_path, model_name, "args.yaml")
    with open(args_file, "r") as file:
        args = yaml.safe_load(file)
        model_name1 = args.get("model", None)

    if model_name:
        # Map the model_name to the initialization function
        model_init_function = getattr(cct, model_name1, None)
        if model_init_function:
            model = model_init_function(
                pretrained=False,
                progress=False,
                img_size=args.get("img_size", 32),
                positional_embedding="learnable",
                num_classes=args.get("num_classes", num_classes[dataset]),
            )
    
    checkpoint_path = os.path.join(results_final_path, model_name, "model_best.pth.tar")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model.load_state_dict(checkpoint)
        
    model.to(device)
    model.eval()
    return model

def attnmap(model_name, image):
    return attnmapgradio.gen_maps(model_name, image)


# Function to predict the class of an image
def predict(image, model_name, dataset):
    dataset_labels = {'CIFAR-10': c10_labels, 'CIFAR-100': c100_labels}[dataset]
    
    model = load_model(model_name, results_final_path, dataset)
    inp = transforms.ToTensor()(image).unsqueeze(0).to(device)
    
    # Resize image to 32x32
    inp = torch.nn.functional.interpolate(inp, size=32, mode='bilinear', align_corners=False) #<-- Added this line

    with torch.no_grad():
        output = model(inp)
        output = torch.softmax(output, dim=1).flatten()
        topk_values, topk_indices = torch.topk(output, 5)
        predicted_classes = {dataset_labels[topk_indices[i]]: round(topk_values[i].item(), 2) for i in range(len(topk_indices))}
    
    return predicted_classes

attn_models = ["Bryan_full_cct_6_3x1_32", "Bryan_full_cct_7_3x1_32_c100", 
               "Bryan_full_cct_6_3x1_32_c100_DynEmbed", "Bryan_full_cct_7_3x1_32_DynEmbed_TempScaleAttn"]


def change_models(choice):
    model_choices = {'CIFAR-10': c10_models, 'CIFAR-100': c100_models}
    model_list = model_choices[choice]
    return gr.Dropdown(model_list, value=model_list[0], interactive=True), gr.Image(interactive=True), gr.Button(interactive=True) 


with gr.Blocks() as demo:
    with gr.Row():
        
        with gr.Column():
            dataset_choice = gr.Radio(["CIFAR-10", "CIFAR-100"], label="Select a dataset:")
            model_choice = gr.Dropdown(label="Select a model:", interactive=False)
            image_choice = gr.Image(label="Input Image", interactive=False)
            submit_button = gr.Button(interactive=False)
            dataset_choice.input(fn=change_models, inputs=dataset_choice, outputs=[model_choice, image_choice, submit_button])
            
        with gr.Column():
            pred_labels = gr.Label(num_top_classes=5, label="Prediction")
            submit_button.click(predict, inputs=[image_choice, model_choice, dataset_choice], outputs=pred_labels)
            attn_button = gr.Button("Show Attention Map", interactive=True)
            attention_map = gr.Image(label="Attention Map")
            attn_button.click(attnmap, inputs=[model_choice, image_choice], outputs=attention_map)


print(f"Take note that only the following models support attention maps: {attn_models}")

demo.launch()


Take note that only the following models support attention maps: ['Bryan_full_cct_6_3x1_32', 'Bryan_full_cct_7_3x1_32_c100', 'Bryan_full_cct_6_3x1_32_c100_DynEmbed', 'Bryan_full_cct_7_3x1_32_DynEmbed_TempScaleAttn']
Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
Model mapped successfully: Bryan_full_cct_7_3x1_32_c100_DynEmbed -> cct_DynEmbed_7_3x1_32
Model mapped successfully: Bryan_full_cct_7_3x1_32_c100_SETok_fixed -> cct_SETok_7_3x1_32_c100
Model mapped successfully: Bryan_full_cct_7_3x1_32_c100_DynEmbed_TempScaleAttnFactor -> cct_DynEmbedTempScaleAttnFactor_7_3x1_32
Model mapped successfully: Bryan_full_cct_2_3x2_32_DynEmbed_TempScaleAttn -> cct_DynEmbedTempScaleAttn_2_3x2_32
Model mapped successfully: Bryan_full_cct_7_3x1_32_DynEmbed -> cct_DynEmbed_7_3x1_32
Model mapped successfully: full_cct_7_3x1_32 -> cct_7_3x1_32
Model mapped successfully: Bryan_full_cct_2_3x2_32_c100_DynEmbed -> cct_DynEmbed_2_3x2_32_c100
Model mapped successfully: Bryan_full_cct_6_3x1_32_DynEmbed -> cct_DynEmbed_6_3x1_32
Model mapped successfully: Bryan_full_cct_7_3x1_32_DynEmbed_TempScaleAttnFactor -> cct_DynEmbedTempScaleAttnFactor_7_3x1_32
Model mapped successfully: Bryan_full_cct_6_3x1_32 -> cct_6_3x1_32
Model