# Bursty Prompt Tuning

This is a demo using to evaluate our trained BPT models.

## Import packages

In [13]:
import torch
from torch.utils.data import DataLoader

import Dataset
import Models
from collections import OrderedDict
from utils import accuracy, AverageMeter

## Prepare Datasets

take CUB-200 as an example.

In [14]:
train_dataset, val_dataset, num_class = Dataset.__dict__["CUB200"]()
print(len(train_dataset), len(val_dataset), num_class)

5994 5794 200


## Prepare model
take BPT-bilinear as an example.

In [15]:
weight_path = "./run/ablation/length100-width75/ckpt.pth"

checkpoint = torch.load(weight_path, map_location="cpu")['model']

model = Models.__dict__["MAE_bpt_vit_b"](
            drop_path_rate=0.0, 
            global_pool=True,
            num_prompts=100, 
            channels=75,
            num_classes=num_class,
        )
msg = model.load_state_dict(checkpoint, strict=False)
model.cuda()
print(msg)

<All keys matched successfully>


## Eval model

In [16]:
def testmodel(model, test_data):
    val_acc1 = AverageMeter()
    val_acc5 = AverageMeter()
    
    # model to evaluate mode
    model.eval()

    test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False,
                                 num_workers=8, pin_memory=True)

    with torch.no_grad():
        for step, (images, labels) in enumerate(test_dataloader):
            images, labels = images.cuda(), labels.cuda()
            # compute output
            pred = model(images)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(pred, labels, topk=(1, 5))

            val_acc1.update(acc1[0], images.size(0))
            val_acc5.update(acc5[0], images.size(0))
    
    return val_acc1.avg, val_acc5.avg

acc = testmodel(model, val_dataset)
print("Acc: {}".format(acc[0]))

Acc: 0.7785640358924866
