In [None]:
from easydict import EasyDict

model_config = EasyDict({
        "backbone": "./hf_model/models--openai--clip-vit-base-patch32",
        "save_path": './ckpt/CLIP4SBSR_v3/13_lora_ams_tcl',
        "lr_model": 1.0e-5,
        "loss_type": "ams+tcl",
        "classifier": {
            "alph": 12,
            "feat_dim": 512,
            "num_classes": 67
        },
        "lora": {
            "use_lora": False,
            "lora_rank": 32
        },
        "prompt": {
            "use_prompt": False,
            "shared_prompt": False,
            "num_prompts": 3,
            "prompt_dim": 768, # dim_embedding
            "lr_prompt": 1.0e-5
        }
    })



test_sketch_datadir = '/lizhikai/workspace/clip4sbsr/data/SHREC14_ZS2/14_sketch_test_picture'
test_view_datadir = '/lizhikai/workspace/clip4sbsr/data/SHREC14_ZS2/14_view_render_test_img'

num_workers = 6
batch_size = 32

In [None]:
import sys
sys.path.append('.')

from model.clip_model import Clip4SbsrModel
from dataset.clip_dataset import Clip4SbsrDataset

import lightning as L
from lightning.pytorch.loggers import WandbLogger

from torchvision import transforms
from torch.utils.data import DataLoader, random_split

import yaml
import argparse

import wandb

import os
import torch
import numpy as np
import random

setup_seed(config.setting.seed)

sketch_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.TrivialAugmentWide(),
        transforms.ToTensor(),
        transforms.Normalize([0.48145466, 0.4578275, 0.40821073],
                             [0.26862954, 0.26130258, 0.27577711])])  # Imagenet standards

view_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
    transforms.Normalize([0.48145466, 0.4578275, 0.40821073],
                         [0.26862954, 0.26130258, 0.27577711])])

clip_model = Clip4SbsrModel(model_config)

trainer = L.Trainer(max_epochs=0,
                    logger = None,
                    accumulate_grad_batches = 1)

test_dataset = Clip4SbsrDataset(test_sketch_datadir, sketch_transform, test_view_datadir, view_transform)
test_dataloader = DataLoader(test_dataset, 
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers,
                            #  sampler=sampler
                             )
# clip_model.load_checkpoint()
trainer.test(model=clip_model, 
            dataloaders=test_dataloader)

