In [1]:
import numpy as np
from sklearn import decomposition, manifold
from sklearn.metrics import classification_report

import tensorflow as tf
from tensorflow.keras import Input, models, Model as TFModel

from jarvis.utils.general import gpus, overload
from tfcaidm import Dataset, JClient, Model
from tfcaidm.models import registry, head

**Autoselect GPU (use only on caidm cluster)**

In [2]:
gpus.autoselect()

[ 2021-12-29 19:26:51 ] CUDA_VISIBLE_DEVICES automatically set to: 1           


In [3]:
!ls ./exp/adni/logs

2021-12-29_00-19-54_PST  2021-12-29_12-32-16_PST  2021-12-29_18-47-50_PST
2021-12-29_00-33-13_PST  2021-12-29_12-41-35_PST
2021-12-29_02-30-40_PST  2021-12-29_16-03-28_PST


## Model

In [4]:
import custom_losses as custom

In [5]:
class Path:
    base = "./exp/adni"
    run = base + "/logs/2021-12-29_16-03-28_PST"
    num = run + "/5"
    model = num + "/ae_0"
    pipeline = num + "/pipeline.yml"

In [6]:
model = models.load_model(Path.model, custom_objects={"ctr_loss": custom.ContrastiveLoss()})

## Visualize

In [7]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = "iframe"

In [8]:
def get_model(model, layer_names=[]):
    return TFModel(inputs=model.inputs, outputs={name: model.get_layer(name).output for name in layer_names})

In [9]:
def plot_embeddings(true, pred, emb):
    fig = make_subplots(
        rows=1, cols=2, subplot_titles=("Model Prediction", "Ground-Truth"),
        specs=[[{"type": "scene"}, {"type": "scene"}]],
    )
    
    colorsIdx = {0: 'purple', 1: 'yellow'}

    fig.add_trace(
        go.Scatter3d(
            x=emb[:, 0],
            y=emb[:, 1],
            z=emb[:, 2],
            mode="markers",
            marker=dict(color=[colorsIdx[i] for i in pred]),
        ),
        row=1,
        col=1,
    )

    fig.add_trace(
        go.Scatter3d(
            x=emb[:, 0],
            y=emb[:, 1],
            z=emb[:, 2],
            mode="markers",
            marker=dict(color=[colorsIdx[i] for i in true]),
        ),
        row=1,
        col=2,
    )

    fig.update_layout(title_text="Lower-Dimensional Embedding Representations")
    fig.show()

### Embedding model

In [10]:
infer_model = get_model(model.layers[2], layer_names=["ctr", "cls"])

### Get embeddings

In [11]:
def get_embeddings(model):
    client = Dataset.from_yaml(Path.pipeline, fold=0)
    gen_train, gen_valid = client.create_generators(test=True)
    
    emb = []
    true = []
    pred = []

    for xs, ys in gen_valid:
        yh = model(xs["dat"])
        emb.append(yh["ctr"])
        pred.append(yh["cls"])
        true.append(xs["lbl"])

    emb = np.array(emb).squeeze()
    true = np.array(true).squeeze()
    pred = np.array(pred).squeeze()
    
    return emb, true, pred

**Lower-dimension visualizations**

In [12]:
emb, true, pred = get_embeddings(infer_model)



In [13]:
emb3d = decomposition.PCA(n_components=3).fit_transform(emb)

In [17]:
plot_embeddings(true.squeeze(), pred.squeeze() > 0.5, emb3d)

In [15]:
# best results with contrastive and 32 embed size, 64, and 96 are also competitive.
# best results with 32 embed size, depth 6, and width scaling of 1.5

```python

# --- Model trained with cross entropy and contrastive loss

class Path:
    base = "./exp/adni"
    run = base + "/logs/2021-12-29_02-30-40_PST"
    num = run + "/8"
    model = num + "/ae_0"
    pipeline = num + "/pipeline.yml"
  
print(classification_report(true, pred > 0.5))
```

```shell
              precision    recall  f1-score   support

           0       0.93      0.88      0.91       494
           1       0.71      0.82      0.76       175

    accuracy                           0.87       669
   macro avg       0.82      0.85      0.84       669
weighted avg       0.88      0.87      0.87       669
```

```python

# --- Model trained with only cross entropy

class Path:
    base = "./exp/adni"
    run = base + "/logs/2021-12-29_02-30-40_PST"
    num = run + "/9"
    model = num + "/ae_0"
    pipeline = num + "/pipeline.yml"
  
print(classification_report(true, pred > 0.5))
```

```shell

              precision    recall  f1-score   support

           0       0.94      0.85      0.89       494
           1       0.67      0.85      0.75       175

    accuracy                           0.85       669
   macro avg       0.80      0.85      0.82       669
weighted avg       0.87      0.85      0.85       669
```

---

```python

# --- Model trained with cross entropy and focal contrastive loss (gamma = 3)

class Path:
    base = "./exp/adni"
    run = base + "/logs/2021-12-29_12-32-16_PST"
    num = run + "/10"
    model = num + "/ae_0"
    pipeline = num + "/pipeline.yml"

print(classification_report(true, pred > 0.5))
```

```shell
              precision    recall  f1-score   support

           0       0.94      0.88      0.91       494
           1       0.71      0.83      0.77       175

    accuracy                           0.87       669
   macro avg       0.82      0.85      0.84       669
weighted avg       0.88      0.87      0.87       669
```

```python
@NOTE: model eblock says aspp, but only conv was used.

# --- Model trained with cross entropy and focal contrastive loss (gamma = 4)

class Path:
    base = "./exp/adni"
    run = base + "/logs/2021-12-29_12-32-16_PST"
    num = run + "/11"
    model = num + "/ae_0"
    pipeline = num + "/pipeline.yml"

print(classification_report(true, pred > 0.5))
```

```shell
              precision    recall  f1-score   support

           0       0.93      0.87      0.90       494
           1       0.69      0.81      0.75       175

    accuracy                           0.86       669
   macro avg       0.81      0.84      0.82       669
weighted avg       0.87      0.86      0.86       669
```

```python
@NOTE: model eblock says aspp, but only conv was used.
    
# --- Model trained with cross entropy and focal log contrastive loss (gamma = 2)

class Path:
    base = "./exp/adni"
    run = base + "/logs/2021-12-29_12-32-16_PST"
    num = run + "/13"
    model = num + "/ae_0"
    pipeline = num + "/pipeline.yml"

print(classification_report(true, pred > 0.5))
```

```shell
              precision    recall  f1-score   support

           0       0.92      0.91      0.92       494
           1       0.77      0.78      0.77       175

    accuracy                           0.88       669
   macro avg       0.84      0.85      0.85       669
weighted avg       0.88      0.88      0.88       669
```

---

```python

# --- Model trained with cross entropy and focal log contrastive loss (gamma = 2) w/ a batch size of 2!

class Path:
    base = "./exp/adni"
    run = base + "/logs/2021-12-29_16-03-28_PST"
    num = run + "/0"
    model = num + "/ae_0"
    pipeline = num + "/pipeline.yml"

print(classification_report(true, pred > 0.5))
```

```shell
              precision    recall  f1-score   support

           0       0.92      0.91      0.92       494
           1       0.77      0.79      0.78       175

    accuracy                           0.88       669
   macro avg       0.85      0.85      0.85       669
weighted avg       0.88      0.88      0.88       669
```

```python

# --- Model trained with cross entropy loss w/ a batch size of 2!

class Path:
    base = "./exp/adni"
    run = base + "/logs/2021-12-29_16-03-28_PST"
    num = run + "/5"
    model = num + "/ae_0"
    pipeline = num + "/pipeline.yml"

print(classification_report(true, pred > 0.5))
```

```shell
              precision    recall  f1-score   support

           0       0.90      0.92      0.91       494
           1       0.76      0.73      0.74       175

    accuracy                           0.87       669
   macro avg       0.83      0.82      0.83       669
weighted avg       0.87      0.87      0.87       669
```

In [16]:
print(classification_report(true, pred > 0.5))

              precision    recall  f1-score   support

           0       0.90      0.92      0.91       494
           1       0.76      0.73      0.74       175

    accuracy                           0.87       669
   macro avg       0.83      0.82      0.83       669
weighted avg       0.87      0.87      0.87       669

