In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.transforms.functional import normalize, resize, to_pil_image
from torchvision.transforms import ToPILImage
import torchvision.utils as vutils

# from torchcam.methods import LayerCAM, SmoothGradCAMpp
# from torchcam.utils import overlay_mask

import clip

import argparse
import os
import glob
import matplotlib.pyplot as plt

import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from PIL import Image


from tqdm import tqdm
from itertools import cycle

from models.resnet import CustomResNet
from models.projector import ProjectionHead
from domainnet_data import DomainNetDataset, get_domainnet_loaders, get_data_from_saved_files
from utils_proj import SimpleDINOLoss, compute_accuracy, compute_similarities, plot_grad_flow, plot_confusion_matrix
from prompts.FLM import generate_label_mapping_by_frequency, label_mapping_base
from models.resnet import CustomClassifier, CustomResNet
import umap

to_pil = ToPILImage()


In [2]:
def get_dataset(data_name, domain_name,train_transforms, test_transforms, clip_transform, data_dir='../data'):

    if data_name == 'imagenet':
        train_dataset = dset.ImageFolder(root=f'{data_dir}/imagenet_train_examples', transform=train_transforms)
        val_dataset = dset.ImageFolder(root=f'{data_dir}/imagenet_val_examples', transform=test_transforms)
        class_names = train_dataset.classes

    elif data_name == 'domainnet':
        train_dataset = DomainNetDataset(root_dir=data_dir, domain=domain_name, \
                                        split='train', transform=train_transforms, transform2=clip_transform)
        val_dataset = DomainNetDataset(root_dir=data_dir, domain=domain_name, \
                                        split='test', transform=test_transforms, transform2=clip_transform)
        class_names = train_dataset.class_names

    return train_dataset, val_dataset, class_names

@torch.no_grad()
def get_embeddings(val_loader,classifier,clip_model,clip_text_encodings,projector,device):
    all_clip_embeddings = []

    all_classifier_embeddings = []
    all_proj_embeddings = []
    all_clip_text_embeddings = []
    clip_text_encodings=clip_text_encodings.to(device)

    for i,(images_batch, labels, images_clip_batch) in enumerate(val_loader):
        images_batch = images_batch.to(device)
        images_clip_batch = images_clip_batch.to(device)    
        labels = labels.to(device)
        
        classifier_logits, classifier_embeddings = classifier(images_batch, return_features=True) # (batch_size, embedding_dim)

        clip_image_embeddings = clip_model.encode_image(images_clip_batch) # (batch_size, embedding_dim)
        
        clip_image_embeddings = clip_image_embeddings.type_as(classifier_embeddings)

        if PROJ_CLIP: # this is PLUMBER
            proj_embeddings = projector(clip_image_embeddings) # (batch_size, projection_dim)
        else: # this is LIMBER
            proj_embeddings = projector(classifier_embeddings) # (batch_size, projection_dim)

        all_clip_text_embeddings.append(clip_text_encodings[labels])
        all_clip_embeddings.append(clip_image_embeddings.detach().cpu())
        all_proj_embeddings.append(proj_embeddings.detach().cpu())
        all_classifier_embeddings.append(classifier_embeddings.detach().cpu())
        if i == 200:
            break


    all_clip_embeddings = torch.cat(all_clip_embeddings, dim=0)
    all_proj_embeddings = torch.cat(all_proj_embeddings, dim=0)
    all_classifier_embeddings = torch.cat(all_classifier_embeddings, dim=0)
    all_clip_text_embeddings = torch.cat(all_clip_text_embeddings, dim=0)
    return all_clip_embeddings, all_proj_embeddings, all_classifier_embeddings,all_clip_text_embeddings

def build_classifier(classifier_name, num_classes, pretrained=False, checkpoint_path=None):

    if classifier_name in ['vit_b_16', 'swin_b']:
        classifier = CustomClassifier(classifier_name, use_pretrained=pretrained)
    elif classifier_name in ['resnet18', 'resnet50']:
        classifier = CustomResNet(classifier_name, num_classes=num_classes, use_pretrained=pretrained)

    if checkpoint_path:
        classifier.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])

    train_transform = classifier.train_transform
    test_transform = classifier.test_transform

    return classifier, train_transform, test_transform


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


data_dir = f"/usr/workspace/KDML/DomainNet"
prompt_embeddings_pth = "/usr/workspace/KDML/DomainNet/CLIP_ViT-B-32_text_encodings.pt"
classifier_name= "resnet50"
num_classes = 345

projector_weights_path= '/usr/workspace/KDML/ood_detect/checkpoints/painting_test_projector/best_projector_weights.pth'
#projector_weights_path = "/usr/workspace/KDML/ood_detect/resnet50_domainnet_real/plumber/resnet50domain_{sketch}_lr_0.1_is_mlp_False/projector_weights_final.pth"
checkpoint_path = f"{data_dir}/best_checkpoint.pth"
PROJ_CLIP = True
dataset_name="domainnet"
domain_name="clipart"
domainnet_domains_projector= {"real":'logs/classifier/domainnet/plumber/resnet50domain_real_lr_0.1_is_mlp_False/best_projector_weights.pth',\
                              "sketch": "logs/classifier/domainnet/plumber/resnet50domain_sketch_lr_0.1_is_mlp_False/best_projector_weights.pth",\
                             "painting": "logs/classifier/domainnet/plumber/resnet50domain_painting_lr_0.1_is_mlp_False/best_projector_weights.pth",\
                             "clipart": "logs/classifier/domainnet/plumber/resnet50domain_clipart_lr_0.1_is_mlp_False/best_projector_weights.pth"
}      
# Load class names from a text file
with open(os.path.join(data_dir, 'class_names.txt'), 'r') as f:
    class_names = [line.strip() for line in f.readlines()]