# Visualize representation space of ViT

In [None]:
from argparse import Namespace

import os

import pandas as pd

from transformers import VisionEncoderDecoderModel
from transformers import AutoTokenizer
from transformers import AutoFeatureExtractor

import torch
from torch.utils.data import Dataset, DataLoader

import wandb

SEED = 1

CONFIG = Namespace(
    run_name='CONTRAST-EXP-contrast-loss-full-final-layers-visualize',
    seed=SEED,
    generation_max_length=256,
    generation_num_beams=1,
    min_num_clusters=3,
    max_num_clusters=50)

run = wandb.init(project='pokemon-cards',
                 entity=None,
                 job_type='visualize',
                 name=CONFIG.run_name)

MODEL_ARTIFACT = './artifacts/pokemon-image-captioning-model:v14'
if not os.path.exists(MODEL_ARTIFACT):
    artifact = run.use_artifact('pkthunder/pokemon-cards/pokemon-image-captioning-model:v14', type='model')
    artifact_dir = artifact.download()

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

PRETRAINED_MODEL = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
PRETRAINED_MODEL.to(DEVICE)

FINE_TUNED_MODEL = VisionEncoderDecoderModel.from_pretrained(MODEL_ARTIFACT)
FINE_TUNED_MODEL.to(DEVICE)

# Define image feature extractor and tokenizer
# NOTE: these are not trained, so we can get them directly from HuggingFace
FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
TOKENIZER = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

## Download Data

In [None]:
def download_data(run):
    """
    Download data from wandb
    """
    
    split_data_loc = run.use_artifact('pokemon_cards_split_full:v0')
    table = split_data_loc.get(f"pokemon_table_full_data_split_seed_{SEED}")
    return table

def get_df(table, is_test=False):
    """
    Get dataframe from wandb table
    """

    dataframe = pd.DataFrame(data=table.data, columns=table.columns)

    if is_test:
        test_df = dataframe[dataframe.split == 'test']
        return test_df

    train_val_df = dataframe[dataframe.split != 'test']
    return train_val_df

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch], dim=0),
        'labels': torch.stack([x['labels'] for x in batch], dim=0)
    }

class PokemonCardsDataset(Dataset):

    def __init__(self, images:list, captions: list, config) -> None:

        self.images = []
        for image in images:
            image_ = image.image
            if image_.mode != "RGB":
                image_ = image_.convert(mode="RGB")
            self.images.append(image_)

        self.captions = captions
        self.config = config

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        
        image = self.images[index]
        caption = self.captions[index]

        pixel_values = FEATURE_EXTRACTOR(images=image, return_tensors="pt").pixel_values[0]
        tokenized_caption = TOKENIZER.encode(
            caption, return_tensors='pt', padding='max_length',
            truncation='longest_first',
            max_length=self.config.generation_max_length)[0]

        output = {
            'pixel_values': pixel_values,
            'labels': tokenized_caption
            }

        return output

In [None]:
# Embed validation data
wandb_table = download_data(run)
train_val_df = get_df(wandb_table)
val_df = train_val_df[train_val_df.split == 'valid']

In [None]:
# Pass images through ViT and get contextualized embeddings from final layer
# Take average of embeddings and visualize them.

def get_embeddings(model, dataframe: pd.DataFrame):
    """
    Get embeddings from final layer of ViT
    """

    start_text = "Pokemon Card of type"
    end_text = "with the title"

    # start_text = 'of rarity'
    # end_text = 'from the set'

    embeddings = []
    inst_labels = []
    set_names = dataframe.set_name.values
    card_names = dataframe.name.values

    dataset = PokemonCardsDataset(
        dataframe.image.values,
        dataframe.caption.values,
        CONFIG)

    with torch.no_grad():
        dataloader = DataLoader(dataset)
        for i, inst in enumerate(dataloader):
            enc_output = model.encoder(pixel_values=inst['pixel_values'].to(DEVICE))
            # _embedding = enc_output.pooler_output

            caption = dataframe.caption.values[i]
            start_pos = caption.find(start_text)+len(start_text)
            end_pos = caption.find(end_text, start_pos)
            label = caption[start_pos:end_pos].strip()

            # print(f"Parsed label: {label}")
            # if 'evolved from' in label:
            #     label = label[0:label.find('evolved from')]
            #     label = label.strip()
            # if ' ' in label:
            #     label = label.split(' ')
            #     label = label[1] if len(label[0]) == 1 else label[0]
            # print(f"Final label: {label}")

            inst_labels.append(label)

            _embedding = enc_output.last_hidden_state.cpu()
            _embedding = _embedding.mean(1)
            # _embedding = _embedding.squeeze(0)
            # inst_labels += [dataframe.set_name.values[i]]*_embedding.shape[0]
            embeddings.append(_embedding)

        embeddings = torch.concat(embeddings, axis=0)

    return embeddings, inst_labels, set_names, card_names

In [None]:
from sklearn.preprocessing import normalize

pretrained_embeddings, inst_labels, set_names, card_names = get_embeddings(PRETRAINED_MODEL, val_df)
finetuned_embeddings, _, _, _ = get_embeddings(FINE_TUNED_MODEL, val_df)

# L2 Norm the embeddings
pretrained_embeddings = normalize(pretrained_embeddings)
finetuned_embeddings = normalize(finetuned_embeddings)

## Cluster embeddings via k-medoids 

Compute best number of cluster via silhouette score + elbow method

In [None]:
from sklearn_extra.cluster import KMedoids

def run_kmedoids(embeddings, num_clusters=8):
    """
    Train KMedoids
    """

    kmedoids = KMedoids(n_clusters=num_clusters, metric='cosine')
    kmedoids.fit(embeddings)
    labels = kmedoids.predict(embeddings)

    return kmedoids, labels

In [None]:
from sklearn.metrics import silhouette_score

pt_silhouette_scores = []
ft_silhouette_scores = []

CONFIG.max_num_clusters = 50
for n_clusters in range(CONFIG.min_num_clusters, CONFIG.max_num_clusters+1):

    _, labels = run_kmedoids(pretrained_embeddings, num_clusters=n_clusters)
    pt_avg_score = silhouette_score(pretrained_embeddings, labels, metric='cosine', random_state=SEED)
    pt_silhouette_scores.append((n_clusters, pt_avg_score))

    _, labels = run_kmedoids(finetuned_embeddings, num_clusters=n_clusters)
    ft_avg_score = silhouette_score(finetuned_embeddings, labels, metric='cosine', random_state=SEED)
    ft_silhouette_scores.append((n_clusters, ft_avg_score))

print(ft_silhouette_scores)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(font_scale=2.5)
plt.set_cmap('tab20')

from sklearn.decomposition import PCA

In [None]:
silhouette_df = pd.DataFrame(columns=['num_clusters', 'silhouette_score'])
silhouette_df.num_clusters = list(range(CONFIG.min_num_clusters, CONFIG.max_num_clusters+1))
silhouette_df.silhouette_score = [x[1] for x in ft_silhouette_scores]

fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(111)

ax.plot(silhouette_df.num_clusters,
        silhouette_df.silhouette_score)

ax.set_xlabel("Number of Clusters")
ax.set_ylabel("Silhouette Score")
ax.set_ylim(0, 0.5)

In [None]:
from sklearn.metrics import silhouette_samples

pokemon_type_score = silhouette_score(finetuned_embeddings, inst_labels, metric='cosine', random_state=SEED)
pokemon_type_score_by_sample = silhouette_samples(finetuned_embeddings, inst_labels, metric='cosine')

print(f"Silhouette score for Pokemon type clustering: {pokemon_type_score}")

## Visualize embeddings from fine-tuned model

In [None]:
kmedoids_obj, labels = run_kmedoids(finetuned_embeddings, num_clusters=15)

# df = pd.DataFrame(
#     finetuned_embeddings,
#     columns=[f"dim_{i}" for i in range(finetuned_embeddings.shape[1])])
# df['cluster'] = labels

# run.log({'finetuned-embeddings': wandb.Table(dataframe=df)})

color_map = plt.get_cmap('tab20')

fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(111)
# ax = fig.add_subplot(111, projection='3d')

pca = PCA(n_components=2)
pca_embedding = pca.fit_transform(finetuned_embeddings)
df = pd.DataFrame(pca_embedding, columns=['pca1', 'pca2'])
df['cluster'] = labels
df['inst_label'] = inst_labels
df['set_name'] = set_names
df['name'] = card_names
df['inst_label_silhouette'] = pokemon_type_score_by_sample

plot_labels = inst_labels

num_labels = len(set(plot_labels))
for i, label in enumerate(set(plot_labels)):
    label_df = df[df.inst_label == label]
    ax.scatter(label_df.pca1, label_df.pca2, label=str(label), c=color_map(i))

ax.set_xlabel("PCA Dimension 1")
ax.set_ylabel("PCA Dimension 2")
# ax.set_zlabel("PCA Dimension 3")
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.7, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1.15, 0.5))
# plt.close(fig)

In [None]:
for label in set(inst_labels):
    num_samples = df[df.inst_label == label].inst_label_silhouette.shape[0]
    pokemon_type_score_ = df[df.inst_label == label].inst_label_silhouette.mean()
    print(label, num_samples, pokemon_type_score_)

## Print k-medoids centroids

In [None]:
medoid_indices = kmedoids_obj.medoid_indices_

for i, index in enumerate(medoid_indices):
    image = val_df.image.values[index].image
    if image.mode != "RGB":
        image = image.convert(mode="RGB")

    image.save(f'centroid-{i}.png')

In [None]:
run.finish()

# Valley of Old Code

In [None]:
# _, labels = run_kmedoids(pretrained_embeddings, num_clusters=pt_num_clusters[0])

# # df = pd.DataFrame(
# #     pretrained_embeddings,
# #     columns=[f"dim_{i}" for i in range(pretrained_embeddings.shape[1])])
# # df['cluster'] = labels

# # run.log({'pretrained-embeddings': wandb.Table(dataframe=df)})

# fig = plt.figure(figsize=(30, 10))
# ax = fig.add_subplot(111)
# # ax = fig.add_subplot(111, projection='3d')

# pca = PCA(n_components=2)
# pca_embedding = pca.fit_transform(pretrained_embeddings)
# df = pd.DataFrame(pca_embedding, columns=['pca1', 'pca2'])
# df['cluster'] = labels
# df['inst_labels'] = inst_labels

# num_labels = len(set(labels))

# for label in set(labels):
#     label_df = df[df.cluster == label]
#     ax.scatter(label_df.pca1, label_df.pca2, label=str(label))

# ax.set_xlabel("PCA Dimension 1")
# ax.set_ylabel("PCA Dimension 2")
# # ax.set_zlabel("PCA Dimension 3")
# box = ax.get_position()
# ax.set_position([box.x0, box.y0, box.width * 0.7, box.height])
# ax.legend(loc='center left', bbox_to_anchor=(1.15, 0.5))
# plt.show()
# # plt.close(fig)

In [None]:
# import torchvision.transforms as T

# to_pil_img = T.ToPILImage()
# to_tensor = T.ToTensor()

# dataset = PokemonCardsDataset(
#     val_df.image.values,
#     val_df.caption.values,
#     CONFIG)
    
# with torch.no_grad():
#     dataloader = DataLoader(dataset)
#     for i, inst in enumerate(dataloader):
#         patches = to_tensor(dataset.images[i]).unfold(1, 128, 128).unfold(2, 128, 128)
#         patches = patches.reshape(patches.shape[0], patches.shape[1]*patches.shape[2], patches.shape[3], patches.shape[4])
#         patches = patches.transpose(0, 1)
#         # unfold = torch.nn.Unfold(kernel_size=(16, 16), stride=16)
#         # print(unfold(inst['pixel_values']).shape)
#         # img = to_pil_img(inst['pixel_values'].squeeze(0))
#         # print(PRETRAINED_MODEL.encoder.embeddings.patch_embeddings(inst['pixel_values']).shape)
#         # patch = PRETRAINED_MODEL.encoder.embeddings.patch_embeddings.projection(inst['pixel_values'])
#         for j in range(patches.shape[0]):
#             patch = to_pil_img(patches[j])
#             patch.save(f'patch-{j}-for-img-{i}.png')
#         raise