In [None]:
import torch
import cv2
import matplotlib.pyplot as plt
from auxiliary.utils import readJson
import tqdm as notebook_tqdm
import numpy as np
from sklearn.decomposition import PCA

from segment_anything.segment_anything import SamPredictor, sam_model_registry

In [None]:
img_rgb = cv2.imread('data/trainNormal/0.png')
img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2LAB)

In [None]:

#show both images side by side
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img_rgb)
axes[0].set_title('RGB')
axes[0].axis('off')
axes[1].imshow(img_lab)
axes[1].set_title('LAB')
axes[1].axis('off')
plt.show()

In [None]:
# showing only L, A, B channel separately in 3 subplots
img_l = img_lab[:,:,0]
img_a = img_lab[:,:,1]
img_b = img_lab[:,:,2] * 3

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_l )
axes[0].set_title('L channel')
axes[0].axis('off')
axes[1].imshow(img_a)
axes[1].set_title('A channel')
axes[1].axis('off')
axes[2].imshow(img_b)
axes[2].set_title('B channel')
axes[2].axis('off')
plt.show()

In [None]:
# Load SAM
samWeight = "segment_anything/SAMWeights/sam_vit_b_01ec64.pth"
sam = sam_model_registry["vit_b"](checkpoint=samWeight)
predictor = SamPredictor(sam)

In [None]:
# multiply the B channel by 3 to make it more visible
img_lab_enhanced = img_lab[:,:,2] * 3
images = (img_rgb, img_lab)
embeddings = []
for idx, img in enumerate (images):
    print(img.shape)
    #img = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
    with torch.no_grad():
        predictor.set_image(img)
        embedding = predictor.features
        embeddings.append(embedding)

In [None]:
embeddings[0].shape, embeddings[1].shape

In [None]:
pca = PCA(n_components=3)
pca_images = []
for emb in embeddings:
    emb = emb.squeeze(0)
    emb_reshaped = emb.view(emb.shape[0], -1).T.cpu().numpy()
    sam_pca = pca.fit_transform(emb_reshaped)
    sam_pca = sam_pca.reshape(emb.shape[-1], emb.shape[-2], 3)
    pca_images.append(sam_pca)
    

In [None]:
# plot PCA images in 2 subplots
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(pca_images[0])
axes[0].set_title('RGB PCA')
axes[0].axis('off')
axes[1].imshow(pca_images[1])
axes[1].set_title('LAB PCA')
axes[1].axis('off')
plt.show()
