# Visualize latent space of Vision Transformer

In [None]:
from argparse import Namespace

import os

import pandas as pd

from transformers import VisionEncoderDecoderModel
from transformers import AutoTokenizer
from transformers import AutoFeatureExtractor

import torch
from torch.utils.data import Dataset, DataLoader

import wandb

SEED = 1

CONFIG = Namespace(
    predict_with_generate=True,
    include_inputs_for_metrics=False,
    report_to='wandb',
    run_name='fine_tuning_eval',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    learning_rate=1e-3,
    push_to_hub=False,
    load_best_model_at_end=True,
    seed=SEED,
    output_dir='eval-output/',
    optim='adamw_torch',
    generation_max_length=256,
    generation_num_beams=1,
    log_preds=False,
    val_limit=0,
    min_num_clusters=3,
    max_num_clusters=15
    )

run = wandb.init(project='pokemon-cards',
                 entity=None,
                 job_type='visualize',
                 name='visualize_embedings')

MODEL_ARTIFACT = './artifacts/pokemon-image-captioning-model:v10'

if not os.path.exists(MODEL_ARTIFACT):
    artifact = run.use_artifact('pkthunder/model-registry/Pokemon Card Image Captioner Full Dataset Model:v0', type='model')
    artifact_dir = artifact.download()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PRETRAINED_MODEL = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
PRETRAINED_MODEL.to(DEVICE)

FINE_TUNED_MODEL = VisionEncoderDecoderModel.from_pretrained(MODEL_ARTIFACT)
FINE_TUNED_MODEL.to(DEVICE)

# Define image feature extractor and tokenizer
# NOTE: these are not trained, so we can get them directly from HuggingFace
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
TOKENIZER = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

## Download Data

In [None]:
def download_data(run):
    """
    Download data from wandb
    """
    
    split_data_loc = run.use_artifact('pokemon_cards_split_full:v0')
    table = split_data_loc.get(f"pokemon_table_full_data_split_seed_{SEED}")
    return table

def get_df(table, is_test=False):
    """
    Get dataframe from wandb table
    """

    dataframe = pd.DataFrame(data=table.data, columns=table.columns)

    if is_test:
        test_df = dataframe[dataframe.split == 'test']
        return test_df

    train_val_df = dataframe[dataframe.split != 'test']
    return train_val_df

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch], dim=0),
        'labels': torch.stack([x['labels'] for x in batch], dim=0)
    }

class PokemonCardsDataset(Dataset):

    def __init__(self, images:list, captions: list, config) -> None:

        self.images = []
        for image in images:
            image_ = image.image
            if image_.mode != "RGB":
                image_ = image_.convert(mode="RGB")
            self.images.append(image_)

        self.captions = captions
        self.config = config

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        
        image = self.images[index]
        caption = self.captions[index]

        pixel_values = FEATURE_EXTRACTOR(images=image, return_tensors="pt").pixel_values[0]
        tokenized_caption = TOKENIZER.encode(
            caption, return_tensors='pt', padding='max_length',
            truncation='longest_first',
            max_length=self.config.generation_max_length)[0]

        output = {
            'pixel_values': pixel_values,
            'labels': tokenized_caption
            }

        return output

In [None]:
# Use validation dataset

wandb_table = download_data(run)
train_val_df = get_df(wandb_table)
val_df = train_val_df[train_val_df.split == 'valid']

In [None]:
# Pass images through ViT and get contextualized embeddings from final layer
# Take average of embeddings and visualize them.

def get_embeddings(model, dataframe: pd.DataFrame):
    """
    Get embeddings from final layer of ViT
    """

    embeddings = []
    inst_labels = []

    dataset = PokemonCardsDataset(
        dataframe.image.values,
        dataframe.caption.values,
        CONFIG)

    with torch.no_grad():
        dataloader = DataLoader(dataset)
        for i, inst in enumerate(dataloader):
            enc_output = model.encoder(pixel_values=inst['pixel_values'])
            # _embedding = enc_output.pooler_output
            inst_labels.append(dataframe.set_name.values[i])
            _embedding = enc_output.last_hidden_state
            _embedding = _embedding.mean(1)
            # _embedding = _embedding.squeeze(0)
            # inst_labels += [dataframe.set_name.values[i]]*_embedding.shape[0]
            embeddings.append(_embedding)

            # print(enc_output.last_hidden_state.numpy().shape)
        embeddings = torch.concat(embeddings, axis=0)
    return embeddings, inst_labels

In [None]:
from sklearn.preprocessing import normalize

pretrained_embeddings, inst_labels = get_embeddings(PRETRAINED_MODEL, val_df)
finetuned_embeddings, _ = get_embeddings(FINE_TUNED_MODEL, val_df)

# L2 Norm the embeddings
pretrained_embeddings = normalize(pretrained_embeddings)
finetuned_embeddings = normalize(finetuned_embeddings)

# Find optimal number of clusters using silhouette score

In [None]:
from sklearn_extra.cluster import KMedoids

def run_kmedoids(embeddings, num_clusters=8):
    """
    Train KMedoids
    """

    kmedoids = KMedoids(n_clusters=num_clusters, metric='cosine')
    kmedoids.fit(embeddings)
    labels = kmedoids.predict(embeddings)

    return kmedoids, labels

In [None]:
from sklearn.metrics import silhouette_score

pt_silhouette_scores = []
ft_silhouette_scores = []

for n_clusters in range(CONFIG.min_num_clusters, CONFIG.max_num_clusters+1):

    _, labels = run_kmedoids(pretrained_embeddings, num_clusters=n_clusters)
    pt_avg_score = silhouette_score(pretrained_embeddings, labels, metric='cosine', random_state=SEED)
    pt_silhouette_scores.append((n_clusters, pt_avg_score))

    _, labels = run_kmedoids(finetuned_embeddings, num_clusters=n_clusters)
    ft_avg_score = silhouette_score(finetuned_embeddings, labels, metric='cosine', random_state=SEED)
    ft_silhouette_scores.append((n_clusters, ft_avg_score))

pt_silhouette_scores.sort(key=lambda x: x[1], reverse=True)
ft_silhouette_scores.sort(key=lambda x: x[1], reverse=True)

print(pt_silhouette_scores)
print(ft_silhouette_scores)

pt_num_clusters = pt_silhouette_scores[0]
ft_num_clusters = ft_silhouette_scores[0]

print(f"Pretrained num clusters: {pt_num_clusters}")
print(f"Finetuned num clusters: {ft_num_clusters}")

## Plot Embeddings from Pretrained and Fine-Tuned Models

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.decomposition import PCA

_, labels = run_kmedoids(pretrained_embeddings, num_clusters=pt_num_clusters[0])

fig = plt.figure(figsize=(30, 10))
ax = fig.add_subplot(111)
# ax = fig.add_subplot(111, projection='3d')

pca = PCA(n_components=2)
pca_embedding = pca.fit_transform(pretrained_embeddings)
df = pd.DataFrame(pca_embedding, columns=['pca1', 'pca2'])
df['cluster'] = labels
df['inst_labels'] = inst_labels

num_labels = len(set(labels))

for label in set(labels):
    label_df = df[df.cluster == label]
    ax.scatter(label_df.pca1, label_df.pca2, label=str(label))

ax.set_xlabel("PCA Dimension 1")
ax.set_ylabel("PCA Dimension 2")
# ax.set_zlabel("PCA Dimension 3")
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.7, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1.15, 0.5))
# plt.close(fig)

In [None]:
_, labels = run_kmedoids(finetuned_embeddings, num_clusters=ft_num_clusters[0])

fig = plt.figure(figsize=(30, 10))
ax = fig.add_subplot(111)
# ax = fig.add_subplot(111, projection='3d')

pca = PCA(n_components=2)
pca_embedding = pca.fit_transform(finetuned_embeddings)
df = pd.DataFrame(pca_embedding, columns=['pca1', 'pca2'])
df['cluster'] = labels
df['inst_labels'] = inst_labels

num_labels = len(set(labels))
for label in set(labels):
    label_df = df[df.cluster == label]
    ax.scatter(label_df.pca1, label_df.pca2, label=str(label))

ax.set_xlabel("PCA Dimension 1")
ax.set_ylabel("PCA Dimension 2")
# ax.set_zlabel("PCA Dimension 3")
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.7, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1.15, 0.5))
# plt.close(fig)

In [None]:
run.finish()

In [None]:
# import torchvision.transforms as T

# to_pil_img = T.ToPILImage()
# to_tensor = T.ToTensor()

# dataset = PokemonCardsDataset(
#     val_df.image.values,
#     val_df.caption.values,
#     CONFIG)
    
# with torch.no_grad():
#     dataloader = DataLoader(dataset)
#     for i, inst in enumerate(dataloader):
#         patches = to_tensor(dataset.images[i]).unfold(1, 128, 128).unfold(2, 128, 128)
#         patches = patches.reshape(patches.shape[0], patches.shape[1]*patches.shape[2], patches.shape[3], patches.shape[4])
#         patches = patches.transpose(0, 1)
#         # unfold = torch.nn.Unfold(kernel_size=(16, 16), stride=16)
#         # print(unfold(inst['pixel_values']).shape)
#         # img = to_pil_img(inst['pixel_values'].squeeze(0))
#         # print(PRETRAINED_MODEL.encoder.embeddings.patch_embeddings(inst['pixel_values']).shape)
#         # patch = PRETRAINED_MODEL.encoder.embeddings.patch_embeddings.projection(inst['pixel_values'])
#         for j in range(patches.shape[0]):
#             patch = to_pil_img(patches[j])
#             patch.save(f'patch-{j}-for-img-{i}.png')
#         raise