In [28]:
import pandas as pd
import seaborn as sns 
import matplotlib.pyplot as plt
import torch
import subprocess
import numpy as np
from torchmetrics import PrecisionRecallCurve, F1Score, ConfusionMatrix, Precision, Recall
from typing import List, Any
from torcheval.metrics.functional import binary_auprc, binary_auroc
from collections import defaultdict
import torch.nn as nn

In [29]:
class MLP(nn.Module):
    def __init__(
            self,
            input_size=512,
            num_classes=18,
            activation='relu',
            hidden_sizes=[1024, 2048, 1024, 256, 128],
            dropout=0.1
        ):
        super().__init__()
        
        # Pick activation
        if activation == "relu":
            activation_cls = nn.ReLU
        elif activation == "leaky_relu":
            activation_cls = nn.LeakyReLU
        elif activation == "gelu":
            activation_cls = nn.GELU
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        layers = []
        in_dim = input_size
        for h in hidden_sizes:
            layers.append(nn.Linear(in_dim, h))
            layers.append(nn.BatchNorm1d(h))  # helps stabilize
            layers.append(activation_cls())
            layers.append(nn.Dropout(dropout))
            in_dim = h

        # Final classification layer
        layers.append(nn.Linear(in_dim, num_classes))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)
    
model = MLP(
    input_size=512,
    activation='gelu',
    dropout=0.1,
    num_classes=2,
)

In [30]:
# state_dict = torch.load('/home/free4ky/projects/chest-diseases/model_binary_ctrate_mosmed_base.pth')
# state_dict = torch.load('/home/free4ky/projects/chest-diseases/model_binary_ctrate_mosmed2.pth')
state_dict = torch.load('/home/free4ky/projects/chest-diseases/model_binary_ctrate_mosmed_test1.pth')

model.load_state_dict(state_dict)

<All keys matched successfully>

In [31]:
model.eval()
model.to('cuda')

MLP(
  (layers): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GELU(approximate='none')
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=1024, out_features=2048, bias=True)
    (5): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): GELU(approximate='none')
    (7): Dropout(p=0.1, inplace=False)
    (8): Linear(in_features=2048, out_features=1024, bias=True)
    (9): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): GELU(approximate='none')
    (11): Dropout(p=0.1, inplace=False)
    (12): Linear(in_features=1024, out_features=256, bias=True)
    (13): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): GELU(approximate='none')
    (15): Dropout(p=0.1, inplace=False)
    (16): Linear(in_features=256, out_features=128, bias

In [32]:
from  torch.utils.data import TensorDataset, DataLoader

mosmed_ds = TensorDataset(
    torch.cat([
        torch.load('/home/free4ky/projects/chest-diseases/test_embs/norma_anon_lipro.pt').unsqueeze(0),
        torch.load('/home/free4ky/projects/chest-diseases/test_embs/pneumonia_anon_lipro.pt').unsqueeze(0),
        torch.load('/home/free4ky/projects/chest-diseases/test_embs/pneumotorax_anon_lipro.pt').unsqueeze(0)
    ],
    dim=0)
)
mosmed_dl = DataLoader(
    mosmed_ds,
    batch_size=1,
    shuffle=False
)


In [33]:
mosmed_ds.tensors[0].shape

torch.Size([3, 512])

In [34]:
patalogy_probs = []
with torch.no_grad():
    for emb in mosmed_dl:
        x = torch.nn.functional.normalize(emb[0], dim=-1)
        x = x.to('cuda')
        logits = model(x)
        probs = nn.functional.softmax(logits)
        patalogy_probs.append(probs[0][-1].cpu().item())
patalogy_probs
# patalogy_probs = []
# with torch.no_grad():
#     for emb in mosmed_dl:
#         x = torch.nn.functional.normalize(emb[0], dim=-1)
#         x = x.to('cuda')
#         logits = model(x)
#         # probs = nn.functional.softmax(logits)
#         probs = torch.sigmoid(logits) 
#         patalogy_probs.append(probs.cpu())
# patalogy_probs

  probs = nn.functional.softmax(logits)


[0.8047426342964172, 0.9958354234695435, 0.575995922088623]

In [22]:
s = pd.Series(patalogy_probs)
s.quantile([0.1,.25,.75,.9])

0.10    0.731108
0.25    0.967450
0.75    0.999410
0.90    0.999857
dtype: float64

In [23]:
s[s<0.5]

19      0.110620
26      0.045772
30      0.331867
37      0.069041
46      0.490368
          ...   
1848    0.439804
1871    0.448998
1873    0.307171
1878    0.457991
1903    0.267192
Length: 119, dtype: float64