In [None]:
"""
Perform tsne eval of encoder
"""

from argparse import (
	ArgumentParser
)
from pathlib import Path

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import yaml
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataloader import DataManager
from training_framework import SimCLR
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

ckpt = "/Users/aa56927-admin/Desktop/cifar.ckpt"
dataset = "cifar10.dog_cat"

def extract_embeddings(encoder, dataloader: DataLoader) -> [torch.Tensor, torch.Tensor]:
	"""
	Given a dataloader return the extracted features and labels
	:param encoder:
	:param dataloader:
	:return:
	"""
	features = []
	labels = []
	encoder.eval()
	with torch.no_grad():
		for mini_batch in tqdm(dataloader):
			img, target, _ = mini_batch
			if torch.cuda.is_available():
				img = img.cuda()
				target = target.cuda()
				encoder = encoder.cuda()
			feature = encoder(img).squeeze()
			feature = F.normalize(feature, dim=1)
			features.append(feature)
			labels.append(target)
	extracted_features = torch.cat(features, dim=0).contiguous().cpu().numpy()
	extracted_labels = torch.cat(labels, dim=0).contiguous().cpu().numpy()
	# print(extracted_features.shape)
	# print(extracted_labels.shape)
	return extracted_features, extracted_labels


config = yaml.load(open('config_pretrain.yaml'), Loader=yaml.FullLoader)
torch.set_float32_matmul_precision("high")


# ---- parse config ----
config = config[dataset]
framework_config = config["framework_config"]
data_config = config["data_config"]
training_config = config["training_config"]
pl.seed_everything(1234)
    
# --- Data -----
data_manager = DataManager(
    dataset=dataset,
    data_config=data_config
)
_, _, _, dataloader_test = data_manager.get_data()
print('Loading PreTrained Model from Checkpoint {}'.format(ckpt))
model = SimCLR.load_from_checkpoint(
    ckpt,
    framework_config=framework_config,
    training_config=training_config,
    data_config=data_config,
    val_dataloader=None,
    num_classes=data_manager.num_classes,
    map_location=torch.device('cpu')
)

print("Extracting Embeddings")
feat_te, lbl_te = extract_embeddings(dataloader=dataloader_test, encoder=model.backbone)

print("TSNE Visualization")
tsne = TSNE(n_components=2, verbose=1, random_state=123)
z = tsne.fit_transform(feat_te)
plt.figure(figsize=(8, 6), dpi=300)
ax = sns.scatterplot(x=z[:, 0],
                y=z[:, 1],
                hue=lbl_te,  # .tolist(),
                palette=sns.color_palette("hls", 2))
plt.xlabel('PC-1')
plt.ylabel('PC-2')
plt.grid()

ax.figure.savefig("test.png")

Global seed set to 1234


Files already downloaded and verified
P samples 5000
N samples 5000
Files already downloaded and verified
P samples 5000
N samples 5000
Files already downloaded and verified
P samples 5000
N samples 5000
Files already downloaded and verified
P samples 1000
N samples 1000
Loading PreTrained Model from Checkpoint /Users/aa56927-admin/Desktop/cifar.ckpt
Extracting Embeddings


 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                        | 1/2 [00:14<00:14, 14.63s/it]