This notebook is based on unsupervised analysis of classical feature pipeline proposed by the paper from [Cambridge](https://arxiv.org/pdf/2006.05919.pdf)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings
import argparse
import os
from os.path import join, dirname
from typing import List
import multiprocessing as mp
import wandb
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from cac.config import Config
from cac.models import factory as model_factory
from cac.utils.logger import set_logger
from training.utils import seed_everything

warnings.simplefilter('ignore')

### Define the config

In [None]:
VERSION = 'experiments/unsupervised/cambridge.yml'
NUM_WORKERS = 4

In [None]:
seed_everything()
config = Config(VERSION)

In [None]:
set_logger(join(config.log_dir, 'unsupervised.log'))
config.num_workers = NUM_WORKERS

### Define the model

In [None]:
model = model_factory.create(config.model['name'], **{'config': config})

In [None]:
data = model.fit(batch_size=8, return_predictions=True, debug=False)

In [None]:
X, Z, Y = data['input'], data['latent'], data['labels']

In [None]:
X.shape

In [None]:
Y[0]

### Plotting and Analysis

In [None]:
def scatter2d(x1, x2, row_values : dict, label: str, legend: bool = True, title=None):

    labels = np.array([row_value[label] for row_value in row_values])

    unique_labels = np.unique(labels)

    colors = cm.plasma(np.linspace(0, 1, len(unique_labels)))

    f, ax = plt.subplots(1, figsize=(10, 10))

    for label, color in zip(unique_labels, colors):
        indices = np.where(labels == label)
        num = len(indices[0])
        ax.scatter(x1[indices], x2[indices], label='{} : {}'.format(label, num), color=color)

    ax.set_ylabel('Component 2')
    ax.set_xlabel('Component 1')
    
    if title is not None:
        ax.set_title('title')

    ax.grid()

    if legend:
        ax.legend(loc='best')

In [None]:
scatter2d(Z[:, 0], Z[:, 1], Y, label='dataset-name')

In [None]:
scatter2d(Z[:, 0], Z[:, 1], Y, label='enroll_patient_gender')

In [None]:
scatter2d(Z[:, 0], Z[:, 1], Y, label='enroll_facility')