In [1]:
import numpy as np
import pandas as pd
import pathlib
from tqdm.notebook import tqdm

In [2]:
base_path = "/home/jakobs"

In [3]:
endpoints_md = pd.read_csv(f"{base_path}/BiHealth/onnx/endpoints.csv").drop(columns="Unnamed: 0").set_index("endpoint")#[["endpoint", "eligable", "n", "freq", "phecode", "phecode_string", "phecode_category", "sex"]]
endpoints_md

Unnamed: 0_level_0,eligable,n,freq,phecode,phecode_string,phecode_category,sex
endpoint,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
OMOP_4306655,61256,3548,0.057921,4306655.00,All-Cause Death,Death,Both
phecode_002,60945,658,0.010797,2.00,Staphylococcus,ID,Both
phecode_002-1,61010,486,0.007966,2.10,Staphylococcus aureus,ID,Both
phecode_003,60757,1017,0.016739,3.00,Escherichia coli,ID,Both
phecode_004,60584,494,0.008154,4.00,Streptococcus,ID,Both
...,...,...,...,...,...,...,...
phecode_977-52,31669,520,0.016420,977.52,Hormone replacement therapy (postmenopausal),Rx,Female
phecode_977-7,60032,2231,0.037164,977.70,Long term (current) use of insulin or oral hyp...,Rx,Both
phecode_977-71,60936,472,0.007746,977.71,Long term (current) use of insulin,Rx,Both
phecode_977-72,60207,2148,0.035677,977.72,Long term (current) use of oral hypoglycemic d...,Rx,Both


In [4]:
import math
import warnings
from socket import gethostname

import numpy as np
import torch
import torchvision as tv
import torchmetrics

from retinalrisk.models.supervised import (
    ImageTraining
)
from retinalrisk.modules.head import MLPHead

def setup_training():
    
    def get_head(num_head_features, num_endpoints):

        cls = MLPHead

        return cls(
            num_head_features,
            num_endpoints,
            incidence=None,
            dropout=0.5,
            gradient_checkpointing=False,
            num_hidden =512, 
            num_layers =2,
            #loss=None,
        )
    
    base_path = "/home/jakobs"
    x = torch.load(f"{base_path}/BiHealth/ckpts/partition_0_last.ckpt")
    losses = x['hyper_parameters']["losses"]
    label_mapping = x['hyper_parameters']["label_mapping"]
    incidence_mapping = x['hyper_parameters']["incidence_mapping"]

    #encoder = tv.models.__dict__[args.model.encoder](pretrained=args.model.pretrained)
    weights = 'DEFAULT' 
    encoder = tv.models.__dict__["convnext_small"](weights=weights) 
    #print(encoder)

    outshape = 768 

    setattr(encoder.classifier, '2', torch.nn.Identity())

    head = get_head(num_head_features=768, num_endpoints = 1171)

    model = ImageTraining(encoder=encoder, head=head, losses=None, label_mapping=label_mapping, incidence_mapping=None, task="tte")

    return model

  warn(f"Failed to load image Python extension: {e}")


In [5]:
import torch.nn as nn
import PIL
from typing import Union
from random import choice
import torchvision as tv
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode

import torch
import torchvision.transforms as transforms
from skimage.io import imread

crop_ratio = [0.66]
img_size_to_gpu = 420 # 420

class AdaptiveRandomCropTransform(nn.Module):
    def __init__(self,
                 crop_ratio: Union[list,float],
                 out_size: int,
                 interpolation=InterpolationMode.BILINEAR):
        super().__init__()
        self.crop_ratio = crop_ratio
        self.out_size = out_size
        self.interpolation = interpolation

    def forward(self, sample):
        input_size = min(sample.size)
        if isinstance(self.crop_ratio, list):
            crop_ratio = choice(self.crop_ratio)
        else:
            crop_ratio = self.crop_ratio

        crop_size = int(crop_ratio * input_size)
        if crop_size < self.out_size:
            crop_size = tv.transforms.transforms._setup_size(self.out_size,
                                                             error_msg="Please provide only two dimensions (h, w) for size.")
            i, j, h, w = transforms.RandomCrop.get_params(sample, crop_size)
            return TF.crop(sample, i, j, h, w)
        else:
            crop_size = tv.transforms.transforms._setup_size(crop_size,
                                                             error_msg="Please provide only two dimensions (h, w) for size.")
            i, j, h, w = transforms.RandomCrop.get_params(sample, crop_size)
            cropped = TF.crop(sample, i, j, h, w)
        out = TF.resize(cropped, self.out_size, self.interpolation)
                        
        return out
    

In [6]:
# Define the transforms to apply
transform = transforms.Compose([
    AdaptiveRandomCropTransform(crop_ratio=crop_ratio,
                                out_size=img_size_to_gpu,
                                interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.CenterCrop(size=384),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
                    
])

#invTrans = transforms.Compose([ 
#    transforms.Normalize(mean = [ 0., 0., 0. ], std = [ 1/0.229, 1/0.224, 1/0.225 ]), 
#    transforms.Normalize(mean = [ -0.485, -0.456, -0.406],  std = [ 1., 1., 1. ])])

In [7]:
from torch.utils.data import Dataset, DataLoader

class EPICImagesDataset(Dataset):
    def __init__(self, data_images, transform):
        self.data_images = data_images
        self.transform = transform

    def __len__(self):
        return len(self.data_images)

    def __getitem__(self, index):
        img_name = self.data_images.iloc[index]["distfilename"]

        try:
            img_np = imread(f"{base_path}/BiHealth/Data/EPICImages/{img_name}")
        except:
            img_np = imread(f"{base_path}/BiHealth/Data/EPICImages_PoorQuality/{img_name}")

        img_pil = PIL.Image.fromarray(img_np)

        try:
            img_tensor = self.transform(img_pil)
        except:
            print(img_name)

        return img_name, img_tensor

def collate_fn(batch):
    img_names, img_tensors = zip(*batch)
    img_names = list(img_names)
    img_tensors = torch.stack(img_tensors)
    return img_names, img_tensors

def predict_batch(model, img_batch):
    loghs = model(img_batch)["head_outputs"]["logits"].detach().cpu().numpy()
    return loghs

corrupted_files = ["0AIULA8E31FVNEXW_epiceye07142.png", 
                   "0AIULA8E315UZ3KA_epiceye03519.png", 
                   "0AIULA8E3354WXMB_epiceye03739.png",
                   "0AIULA8E32JU9I3E_epiceye00148.png",
                   "0AIULA8E3354JOZ0_epiceye06941.png",
                  "0AIULA8E315XYVCL_epiceye05788.png",
                   "0AIULA8E315WMUA2_epiceye05000.png",
                  "0AIULA8E32JRQLLF_epiceye01039.png",
                  "0AIULA8E329S6BCD_epiceye02546.png",
                  "0AIULA8E31RCZB57_epiceye00155.png",
                  "0AIULA8E31REEET3_epiceye05711.png",
                   "0AIULA8E32SJJSL9_epiceye03063.png",
                  "0AIULA8E31FQZ4OO_epiceye06657.png",
                   "0AIULA8E1HIF7IFB_epiceye05427.png",
                  "0AIULA8E32SDCBXM_epiceye00289.png",
                  "0AIULA8E1KEKKV8Y_epiceye05179.png"]

data_images = pd.read_stata(f"{base_path}/BiHealth/Data/StudyData/BiHealth_20230313_Long.dta").query("distfilename!=@corrupted_files")
dataset = EPICImagesDataset(data_images, transform)
#dataset = EPICImagesDataset(data_images, transform, base_path, num_workers=4, cache_size=100)
dataloader = DataLoader(dataset, batch_size=5, shuffle=False, num_workers=32, collate_fn=collate_fn, drop_last=False)
model = setup_training()

In [None]:
partitions = [p for p in range(22)]

metadata = []
for iteration in tqdm(range(10)):
    for partition in tqdm(partitions):
        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
            ckpt = torch.load(f"{base_path}/BiHealth/ckpts/partition_{partition}_last.ckpt")
            model.encoder.load_state_dict({k[8:]:v for k, v in ckpt["state_dict"].items() if "encoder" in k}, strict=True)
            model.head.load_state_dict({k[5:]:v for k, v in ckpt["state_dict"].items() if "head" in k}, strict=True)
            model.eval();
            model.to("cuda")
            # instantiate cktp here
            for img_names, img_batch in dataloader:
                img_batch = img_batch.to("cuda")
                loghs = predict_batch(model, img_batch)
                for img_name, logh in zip(img_names, loghs):
                    metadata.append({"partition": partition, "img_name": img_name, "iteration": iteration, "loghs": logh})
                del img_batch

metadata_df = pd.DataFrame(metadata)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

In [12]:
metadata_df.reset_index(drop=True).to_feather(f"{base_path}/data/predictionstte_test_230321.feather")

In [16]:
metadata_df

Unnamed: 0,partition,img_name,iteration,loghs
0,0,0AIULA8E315X35Q2_epiceye00372.png,0,"[1.344, 3.232, 2.7, 3.467, 2.248, 0.7373, 0.80..."
1,0,0AIULA8E1D5J220T_epiceye01548.png,0,"[1.67, 3.564, 3.152, 3.84, 3.0, 1.064, 2.045, ..."
2,0,0AIULA8E13CCINZ0_epiceye02465.png,0,"[-0.5376, 1.974, 1.755, 2.041, 1.311, -0.4702,..."
3,0,0AIULA8E2K7MOEX9_epiceye03852.png,0,"[0.903, 2.592, 2.076, 3.004, 1.923, 0.542, 0.5..."
4,0,0AIULA8E21BAWQ50_epiceye07263.png,0,"[0.583, 2.611, 2.234, 2.88, 2.107, 0.9736, 0.5..."
...,...,...,...,...
64099,1,0AIULA8E24KW7A44_epiceye02489.png,1,"[2.125, 4.336, 3.863, 4.52, 3.531, 1.7705, 2.7..."
64100,1,0AIULA8E32S7E8OO_epiceye03011.png,1,"[1.937, 4.367, 3.84, 4.83, 3.777, 2.633, 2.23,..."
64101,1,0AIULA8E2C2N6MYL_epiceye02574.png,1,"[1.585, 3.309, 2.93, 3.459, 2.756, 0.9375, 2.0..."
64102,1,0AIULA8E18C97Z2K_epiceye00650.png,1,"[0.9893, 2.69, 2.373, 3.22, 2.086, 0.819, 1.61..."


In [20]:
endpoints_md

Unnamed: 0_level_0,eligable,n,freq,phecode,phecode_string,phecode_category,sex
endpoint,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
OMOP_4306655,61256,3548,0.057921,4306655.00,All-Cause Death,Death,Both
phecode_002,60945,658,0.010797,2.00,Staphylococcus,ID,Both
phecode_002-1,61010,486,0.007966,2.10,Staphylococcus aureus,ID,Both
phecode_003,60757,1017,0.016739,3.00,Escherichia coli,ID,Both
phecode_004,60584,494,0.008154,4.00,Streptococcus,ID,Both
...,...,...,...,...,...,...,...
phecode_977-52,31669,520,0.016420,977.52,Hormone replacement therapy (postmenopausal),Rx,Female
phecode_977-7,60032,2231,0.037164,977.70,Long term (current) use of insulin or oral hyp...,Rx,Both
phecode_977-71,60936,472,0.007746,977.71,Long term (current) use of insulin,Rx,Both
phecode_977-72,60207,2148,0.035677,977.72,Long term (current) use of oral hypoglycemic d...,Rx,Both


In [21]:
wide_df = metadata_df["loghs"].apply(pd.Series)
wide_df.columns = endpoints_md.index
wide_df

predictions = metadata_df.merge(wide_df, how="left", left_index=True, right_index=True).drop("loghs", axis=1)

predictions.to_feather(f"{base_path}/data/predictionsttewide_test_230321.feather")

endpoint,OMOP_4306655,phecode_002,phecode_002-1,phecode_003,phecode_004,phecode_005,phecode_007,phecode_007-1,phecode_008,phecode_009,...,phecode_977,phecode_977-4,phecode_977-41,phecode_977-5,phecode_977-51,phecode_977-52,phecode_977-7,phecode_977-71,phecode_977-72,phecode_979
0,1.343750,3.232422,2.699219,3.466797,2.248047,0.737305,0.809570,0.778809,1.530273,2.080078,...,1.184570,2.320312,2.289062,0.971680,-0.614258,0.819824,0.711914,3.386719,0.425537,2.531250
1,1.669922,3.564453,3.152344,3.839844,3.000000,1.064453,2.044922,1.974609,1.855469,2.541016,...,1.370117,2.869141,2.826172,1.280273,-0.461670,0.809570,0.809570,3.033203,0.664551,2.691406
2,-0.537598,1.973633,1.754883,2.041016,1.310547,-0.470215,-0.130737,-0.117676,1.330078,0.319336,...,0.790527,1.954102,1.971680,1.132812,0.881348,-0.056274,0.271973,2.138672,0.066895,0.527832
3,0.902832,2.591797,2.076172,3.003906,1.922852,0.541992,0.552734,0.520020,1.619141,1.688477,...,0.861816,2.115234,2.052734,1.460938,0.058441,1.100586,0.232910,2.519531,0.048218,1.826172
4,0.583008,2.611328,2.234375,2.880859,2.107422,0.973633,0.504883,0.484619,1.743164,1.376953,...,1.108398,2.203125,2.177734,1.914062,0.997559,1.265625,0.769531,2.966797,0.587891,1.507812
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
64099,2.125000,4.335938,3.863281,4.519531,3.531250,1.770508,2.775391,2.804688,1.763672,3.300781,...,1.792969,3.128906,3.093750,1.624023,-0.510254,1.240234,1.502930,4.140625,1.332031,3.583984
64100,1.936523,4.367188,3.839844,4.828125,3.777344,2.632812,2.230469,2.205078,2.326172,3.603516,...,1.735352,3.513672,3.390625,2.453125,0.060486,2.216797,1.346680,4.257812,1.035156,3.777344
64101,1.584961,3.308594,2.929688,3.458984,2.755859,0.937500,2.044922,1.961914,1.575195,2.275391,...,1.388672,2.652344,2.625000,0.820312,-0.648438,0.307617,1.007812,3.121094,0.949707,2.275391
64102,0.989258,2.689453,2.373047,3.220703,2.085938,0.818848,1.610352,1.484375,1.220703,1.419922,...,1.109375,2.441406,2.425781,2.158203,0.825195,1.704102,0.426270,2.742188,0.348389,1.832031
