In [1]:
#| default_exp model_embedding.embedding_creation

# Create Embedding from model
> Get Embedding from Model

In [1]:
#| hide
%load_ext autoreload
%autoreload 2

In [2]:
#| export
from cv_tools.core import *
from cv_tools.imports import *
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from umap import UMAP



In [3]:
#| export
import torch.nn as nn
import torch


In [13]:
#| export
from private_front_easy_pin_detection.pytorch_model_development import UnetManualMaxPoolOnly
from private_front_easy_pin_detection.dataloader_creation import *
from private_front_easy_pin_detection.model_eval.create_mask import *


In [27]:
DATA_PATH = os.getenv('DATA_PATH')
DATA_PATH = Path(DATA_PATH)
root_path = Path(DATA_PATH ,'easy_front_pin_detection/curated_ds_224/selected_trn/synthetic_blurred_shapes')
im_path = Path(root_path, 'images')
msk_path = Path(root_path, 'masks')


MODEL_FN="first_224_no_resize_best_val_0.9347_epoch_87.pth"
MODEL_PATH="/home/hasan/Schreibtisch/projects/data/easy_front_pin_detection/curated_ds_224/models/first_224_no_resize/first_224_no_resize"
MODEL_PATH_FULL = Path(MODEL_PATH, MODEL_FN)


In [30]:
  # If we know the model architecture, initialize it first
model = UnetManualMaxPoolOnly(
    in_channels=1,
    n_classes=1
)
checkpoint = torch.load(MODEL_PATH_FULL)
model.load_state_dict(
    checkpoint['model_state_dict'], 
)
model.to('cuda')
model.eval()

  checkpoint = torch.load(MODEL_PATH_FULL)


UnetManualMaxPoolOnly(
  (encoder): EncoderBlockPtMaxPoolOnly(
    (conv11): convBlockPt(
      (conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_nm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (activation): ReLU()
    )
    (conv12): convBlockPt(
      (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_nm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (activation): ReLU()
    )
    (conv21): convBlockPt(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_nm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.2, inplace=False)
      (activation): ReLU()
    )
    (conv22): convBlockPt(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=

In [9]:
trn_dl, val_dl =  get_dataloader(
    trn_im_path=im_path, 
    trn_msk_path=msk_path,
    val_im_path=im_path,
    val_msk_path=msk_path,
    tst_im_path=None,
    batch_size=32,
    IMAGE_HEIGHT=224,
    IMAGE_WIDTH=224,
    num_workers=0 
    )

  A.ElasticTransform(p=0.1, alpha=3, sigma=50 * 0.05, alpha_affine=None),


In [10]:
#| export
class EmbeddingExtractor(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.encoder = model.encoder
        
    def forward(self, x):
        c1, c2, c3, c4, c5 = self.encoder(x)
		# c5 deepest embedding bottleneck
        return {
            'c1': c1,
            'c2': c2,
            'c3': c3,
            'c4': c4,
            'c5': c5
		}
    

In [11]:
def extract_embeddings(model, dataloader, device='cuda'):
    """
    Extract embeddings from the model for all images in the dataloader
    """
    embedding_extractor = EmbeddingExtractor(model)
    embedding_extractor = embedding_extractor.to(device)
    embedding_extractor.eval()
    
    all_embeddings = {
        'c1': [], 'c2': [], 'c3': [], 'c4': [], 'c5': []
    }
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            embeddings = embedding_extractor(images)
            
            # Store embeddings from each level
            for level in ['c1', 'c2', 'c3', 'c4', 'c5']:
                # Global average pooling to get fixed size embeddings
                emb = torch.mean(embeddings[level], dim=[2, 3])
                all_embeddings[level].append(emb.cpu())
            
            all_labels.extend(labels.numpy())
    
    # Concatenate all embeddings
    for level in all_embeddings:
        all_embeddings[level] = torch.cat(all_embeddings[level], dim=0).numpy()
    
    return all_embeddings, np.array(all_labels)

In [31]:
all_embeddings, all_labels = extract_embeddings(model, trn_dl)

In [12]:
#| export
def visualize_embeddings(embeddings, labels, method='tsne', level='c5'):
    """
    Visualize embeddings using t-SNE or UMAP
    """
    plt.figure(figsize=(10, 8))
    
    # Get embeddings for specified level
    X = embeddings[level]
    
    # Reduce dimensionality
    if method == 'tsne':
        reducer = TSNE(n_components=2, random_state=42)
    else:
        reducer = UMAP(n_components=2, random_state=42)
    
    X_reduced = reducer.fit_transform(X)
    
    # Create scatter plot
    scatter = plt.scatter(X_reduced[:, 0], X_reduced[:, 1], 
                         c=labels, cmap='tab10', alpha=0.6)
    plt.colorbar(scatter)
    plt.title(f'{method.upper()} visualization of {level} embeddings')
    plt.xlabel(f'{method.upper()} 1')
    plt.ylabel(f'{method.upper()} 2')
    plt.show()

In [None]:
# Visualize bottleneck embeddings using t-SNE
visualize_embeddings(embeddings, labels, method='tsne', level='c5')

# Visualize bottleneck embeddings using UMAP
visualize_embeddings(embeddings, labels, method='umap', level='c5')

In [None]:
visualize_embeddings(all_embeddings, all_labels)

In [0]:
#| hide
import nbdev; nbdev.nbdev_export('18_model_embedding.embedding_creation.ipynb')