In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.transforms.functional import normalize, resize, to_pil_image
from torchvision.transforms import ToPILImage
import torchvision.utils as vutils

from torchcam.methods import LayerCAM, SmoothGradCAMpp
from torchcam.utils import overlay_mask

import clip

import argparse
import os
import glob
import matplotlib.pyplot as plt
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from PIL import Image


from tqdm import tqdm
from itertools import cycle

from models.resnet import CustomResNet
from models.projector import ProjectionHead, VisualTransformer
from domainnet_data import DomainNetDataset, get_domainnet_loaders, get_data_from_saved_files
from utils import SimpleDINOLoss, compute_accuracy, compute_similarities, plot_grad_flow, plot_confusion_matrix
from prompts.FLM import generate_label_mapping_by_frequency, label_mapping_base


to_pil = ToPILImage()


In [None]:
def load_image(file_path):
    """
    Load an image and convert it to a NumPy array with values in the range [0, 255].

    Args:
        file_path (str): Path to the image file.

    Returns:
        np.ndarray: Image as a NumPy array with values in the range [0, 255].
    """
    # Open the image file
    image = Image.open(file_path)

    # Convert to RGB mode if not already in RGB
    if image.mode != 'RGB':
        image = image.convert('RGB')

    # Convert image to NumPy array
    image_array = np.array(image)

    # Ensure values are in the range [0, 255]
    image_array = np.clip(image_array, 0, 255)

    return image_array


In [None]:

def unnormalize(tensor):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    mean_tensor = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
    std_tensor = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
    tensor.mul_(std_tensor[:, None, None]).add_(mean_tensor[:, None, None])
    return tensor

def save_image(tensor, file_name):

    tensor = tensor.detach().cpu()
    # Ensure it's in the range [0, 1]
    tensor = torch.clamp(tensor, 0, 1)

    # Convert to image and save
    vutils.save_image(tensor, file_name)


In [None]:

# transform = transforms.Compose(
#     [transforms.Resize((224, 224)),
#     # transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                         std=[0.229, 0.224, 0.225])])

CLIP_custom_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224)
])

Resnet_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
])

def gaussian_noise(x, severity=1):
    c = [0., .08, .12, 0.18, 0.26, 0.38][severity - 1]

    x = np.array(x) / 255.
    return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255


# PIL_image = CLIP_custom_transform(Image.open('./data/domainnet_v1.0/real/toothpaste/real_318_000284.jpg'))

im = load_image('./data/domainnet_v1.0/real/toothpaste/real_318_000284.jpg')
PIL_image = CLIP_custom_transform(Image.fromarray(gaussian_noise(im, severity=1).astype(np.uint8)))

# l = torch.from_numpy(np.array([317]))
# valset = torch.utils.data.TensorDataset(image, l)
# val_loader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False)

In [None]:
base_dir = f"logs/classifier/resnet50_domainnet_real"
data_dir = f"data/domainnet_v1.0"
prompt_embeddings_pth = "prompts/CLIP_RN50_text_embeddings.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model, preprocess = clip.load("RN50", device=device)

print(preprocess)
clip_model.eval()

# Load class names from a text file
with open(os.path.join(data_dir, 'class_names.txt'), 'r') as f:
    class_names = [line.strip() for line in f.readlines()]
    
loaders, _ = get_domainnet_loaders("real", batch_size=10, data_dir=data_dir)

train_loader = loaders['train']
val_loader = loaders['test']

text_encodings = torch.load(prompt_embeddings_pth)[0]

In [None]:
# Preprocess the image for clip
image_CLIP = preprocess(PIL_image).unsqueeze(0).to(device)

# Encode the image using CLIP encoder_image
clip_image_features = clip_model.encode_image(image_CLIP)

# Compute similarities between image embeddings and text encodings
orig_similarities_ = compute_similarities(clip_image_features, text_encodings, mode="cosine")
orig_prob_ = F.softmax(orig_similarities_, dim=-1)
orig_predictions_ = torch.argmax(orig_prob_, dim=-1)


print(f"Original Zero-shot prediction: {class_names[orig_predictions_[0].item()]}")

In [None]:


plt.hist(orig_prob_[0].detach().cpu().numpy(), bins=10)
plt.show()

In [None]:


# Load your trained model from checkpoint
checkpoint = torch.load('logs/classifier/resnet50_domainnet_real/best_checkpoint.pth')

resnet_model = CustomResNet(model_name='resnet50', num_classes=345, use_pretrained=True)
resnet_model.load_state_dict(checkpoint['model_state_dict'])
resnet_model.eval()
print(f"Loaded model from epoch {checkpoint['epoch']}")
resnet_model.to('cuda')

projector = ProjectionHead(input_dim=2048, output_dim=1024).to('cuda')
projector.load_state_dict(torch.load('logs/classifier/resnet50_domainnet_real/projection_default_prompt_feat_sim0.1_distill1_DN_mapping1_scaled_logits/best_projector_weights.pth'))
projector.eval()


In [None]:

resnet_images = Resnet_transform(PIL_image).unsqueeze(0).to('cuda')
# Get the ResNet predictions
resnet_logits, resnet_embeddings = resnet_model(resnet_images, return_features=True)
probs_from_resnet = F.softmax(resnet_logits, dim=-1)
resnet_predictions = torch.argmax(probs_from_resnet, dim=-1)

# Project the resnet embeddings
proj_embeddings = projector(resnet_embeddings)
# Compute the predictions using the projected embeddings
similarities = compute_similarities(proj_embeddings, text_encodings, mode="DN")
probs_from_proj = F.softmax(similarities, dim=-1)
proj_predictions = torch.argmax(probs_from_proj, dim=-1)

print(f"ResNet predictions: {class_names[resnet_predictions[0].item()]}")
print(f"Projected predictions: {class_names[proj_predictions[0].item()]}")


In [None]:
plt.hist(probs_from_resnet[0].detach().cpu().numpy(), bins=10)
plt.show()

In [None]:
plt.hist(probs_from_proj[0].detach().cpu().numpy(), bins=10)
plt.show()

In [None]:
clip_image_features_norm = F.normalize(clip_image_features[0], dim=-1)
proj_embeddings_norm = F.normalize(proj_embeddings[0], dim=-1)

clip_text_features_norm = F.normalize(text_encodings[300], dim=-1)

# plt.plot(clip_image_features_norm.detach().cpu().numpy(), label="CLIP")
plt.plot(proj_embeddings_norm.detach().cpu().numpy(), label="Projected")
plt.plot(clip_text_features_norm.detach().cpu().numpy(), label="Text")
plt.legend()
plt.show()

In [None]:
dataset = DomainNetDataset(root_dir='data/domainnet_v1.0', domain='real', split='train', transform=None)


all_clip_embeddings = []
all_custom_clip_embeddings = []
all_resnet_embeddings = []
all_proj_embeddings = []
for i in range(len(dataset)):
    images, label = dataset[i]

    resnet_images = Resnet_transform(images).unsqueeze(0).to('cuda')
    # Get the ResNet predictions
    resnet_logits, resnet_embeddings = resnet_model(resnet_images, return_features=True)

    # Project the resnet embeddings
    proj_embeddings = projector(resnet_embeddings)
    all_proj_embeddings.append(proj_embeddings.detach().cpu())
    
    # Preprocess the image for clip
    CLIP_images = preprocess(images).unsqueeze(0).to(device)
    clip_image_features = clip_model.encode_image(CLIP_images)
    all_clip_embeddings.append(clip_image_features.detach().cpu())

    custom_CLIP_images = preprocess(CLIP_custom_transform(images)).unsqueeze(0).to(device)
    custom_clip_image_features = clip_model.encode_image(custom_CLIP_images)
    all_custom_clip_embeddings.append(custom_clip_image_features.detach().cpu())


    if i == 100:
        break

all_clip_embeddings = torch.cat(all_clip_embeddings, dim=0)
all_proj_embeddings = torch.cat(all_proj_embeddings, dim=0)
all_custom_clip_embeddings = torch.cat(all_custom_clip_embeddings, dim=0)

In [None]:
from utils import plot_umap_embeddings

plot_umap_embeddings(all_clip_embeddings, all_proj_embeddings, text_encodings.detach().cpu(), labels=['CLIP image', 'Projected image', 'CLIP Text'])

In [None]:
from torchvision.datasets import CIFAR100


transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
])

CLIP_model, preprocess = clip.load("RN50", device='cuda')
dataset = CIFAR100(root="./data", download=True, transform=None, train=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)



In [None]:
resnet_model = CustomResNet(model_name='resnet50', num_classes=100, use_pretrained=True).to('cuda')

projector_CIFAR100 = ProjectionHead(input_dim=2048, output_dim=1024).to('cuda')
projector_imagenet = ProjectionHead(input_dim=2048, output_dim=1024).to('cuda')

# Load projector weights from checkpoint
projector_CIFAR100.load_state_dict(torch.load('logs/classifier/resnet50_cifar100/projection_default_prompt_DN_mapping1_scaled_logits/best_projector_weights.pth'))
projector_CIFAR100.eval()
projector_imagenet.load_state_dict(torch.load('logs/classifier/imagenet/contrastive_loss/resnet50_imagenet_None/projection_default_prompt_feat_sim0_distill1_cosine_mapping1_scaled_logits/best_projector_weights.pth'))
projector_imagenet.eval()

In [None]:

all_cifar_100_projections = []
all_imagenet_projections = []
all_clip_embeddings = []

for i in range(len(dataset)):
    images, label = dataset[i]

    resnet_images = transform(images).unsqueeze(0).to('cuda')
    # Get the ResNet predictions
    resnet_logits, resnet_embeddings = resnet_model(resnet_images, return_features=True)

    # Project the resnet embeddings
    proj_embeddings = projector_CIFAR100(resnet_embeddings)
    all_cifar_100_projections.append(proj_embeddings.detach().cpu())

    proj_embeddings = projector_imagenet(resnet_embeddings)
    all_imagenet_projections.append(proj_embeddings.detach().cpu())
    
    # Preprocess the image for clip
    CLIP_images = preprocess(images).unsqueeze(0).to(device)
    clip_image_features = clip_model.encode_image(CLIP_images)
    all_clip_embeddings.append(clip_image_features.detach().cpu())

    if i == 300:
        break

all_cifar_100_projections = torch.cat(all_cifar_100_projections, dim=0).detach().cpu()
all_imagenet_projections = torch.cat(all_imagenet_projections, dim=0).detach().cpu()
all_clip_embeddings = torch.cat(all_clip_embeddings, dim=0).detach().cpu()


In [None]:
from utils import plot_umap_embeddings

plot_umap_embeddings(all_cifar_100_projections, all_imagenet_projections, all_clip_embeddings, labels=['CIFAR100 Projection', 'Imagenet Projection', 'CLIP Image Encoder'])

In [None]:
from segment_anything import sam_model_registry, SamPredictor
import torch
import cv2
sam = sam_model_registry["vit_h"](checkpoint="checkpoints/sam_vit_h_4b8939.pth").to('cuda')
predictor = SamPredictor(sam)

sam_transform = predictor.transform
sam_vit = sam.image_encoder

test_image = sam_transform(torch.randn(5, 3, 224, 224)).to('cuda')

print(sam.image_encoder.img_size)
print(sam_vit(test_image).shape)

In [None]:
image = cv2.imread('./data/domainnet_v1.0/real/toothpaste/real_318_000284.jpg')
predictor.set_image(image)
image_embeddings = predictor.get_image_embedding()



In [None]:
print(image_embeddings.shape)

In [None]:
from torchvision.models.feature_extraction import get_graph_node_names


print(predictor.model.image_encoder)

In [None]:
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names


# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())

print(train_nodes)