In [1]:
import os
from torch.utils.data.dataset import Dataset
from torchvision import transforms
import pandas as pd
from PIL import Image
import torch
import numpy as np

In [2]:
# def read_data(fn, clinical_root, df, mode):
#     img = Image.open(fn)
    
#     folders = fn.split('/')
#     img_name = folders[-1]
    
    
#     p_number = img_name.split('.')[0]   
#     try:
#         p_number = p_number.split('_')[0]
#     except:
#         pass

#     p_number = int(p_number)
        
#     p_info = df['number'] == p_number
#     label = df[p_info]['EGFR']
#     label = label.values[0]
        
#     if p_number == 38:
#         label = 1
        
#     return img, label, p_number

In [3]:
class ImageDataset_prev(Dataset):
    def __init__(self, root, clinical_root, mode, transform=None):
        super(ImageDataset_prev, self).__init__()
        filenames = list()
        
        if mode == 'train':
            df = pd.read_excel(clinical_root, sheet_name = 'Training')
            ids = df['number'].values
        else:
            df = pd.read_excel(clinical_root, sheet_name = 'Validation')
            ids = df['number'].values           
            
        img_list = os.listdir(root)
        for imgs in img_list:
            if 'png' in imgs:
                im_number = imgs.split('.')[0]
                
                try:
                    im_number = im_number.split('_')[0]
                except:
                    pass
                
                im_number = int(im_number)
                if im_number in ids:
                    filenames.append(os.path.join(root, imgs))

        self.filenames = filenames
        self.root = root
        self.clinical_root = clinical_root
        self.mode = mode
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, index):
        img, label, p_number = read_data(self.filenames[index], self.clinical_root, self.df, self.mode)
        if self.transform:
            img = self.transform(img)
        if label == 0:
            target = torch.tensor([1, 0])
        else:
            target = torch.tensor([0, 1])
        return img, target

In [4]:
def read_data(fn, df):
    img = Image.open(fn)
    
    folders = fn.split('/')
    img_name = folders[-1]
    
    
    p_number = img_name.split('.')[0] 
    
    try:
        p_number = p_number.split('_')[0]
    except:
        pass

    p_number = int(p_number)
        
    p_info = df['number'] == p_number
    label = df[p_info]['EGFR']
    label = label.values[0]
    sex = df[p_info]['Sex']
    sex = sex.values[0]
    age = df[p_info]['Age_norm']
    age = age.values[0]
    smoke = df[p_info]['Smoking']
    smoke = smoke.values[0]
    
        
    if p_number == 38:
        label = 1
        
    return img, label, sex, age, smoke, p_number

In [5]:
class ImageDataset(Dataset):
    def __init__(self, root, imgs_name, clinical_root, mode, transform=None):
        super(ImageDataset, self).__init__()
        
        img_list = os.listdir(root)
        
        filenames = list()
        for imgs in img_list:
            if 'png' in imgs:
                im_number = imgs.split('.')[0]
                
                try:
                    im_number = im_number.split('_')[0]
                except:
                    pass
                
                im_number = int(im_number)
                if im_number in imgs_name:
                    filenames.append(os.path.join(root, imgs))
                    
        if mode == 'test':
            df = pd.read_excel(clinical_root, sheet_name = 'Test')

        else:
            df = pd.read_excel(clinical_root, sheet_name = 'Training')

        
        self.clinical_root = clinical_root
        self.filenames = filenames
        self.root = root
        self.transform = transform
        self.df = df
        
    def __len__(self):
        return len(self.filenames)
    
    
    def __getitem__(self, index):
        img, label, sex, age, smoke, p_number = read_data(self.filenames[index], self.df)
        
        if self.transform:
            img = self.transform(img)
        if label == 0:
            target = torch.tensor([1, 0])
        else:
            target = torch.tensor([0, 1])

        clinical_np = np.array([sex, age, smoke], dtype=np.float32)
        clinical = torch.from_numpy(clinical_np)        

            
        return img, target, clinical, p_number

In [6]:
class ImageDataset_BCE(Dataset):
    def __init__(self, root, imgs_name, clinical_root, mode, transform=None):
        super(ImageDataset_BCE, self).__init__()
        
        img_list = os.listdir(root)
        
        filenames = list()
        for imgs in img_list:
            if 'png' in imgs:
                im_number = imgs.split('.')[0]
                
                try:
                    im_number = im_number.split('_')[0]
                except:
                    pass
                
                im_number = int(im_number)
                if im_number in imgs_name:
                    filenames.append(os.path.join(root, imgs))
                    
        if mode == 'test':
            df = pd.read_excel(clinical_root, sheet_name = 'Test')

        else:
            df = pd.read_excel(clinical_root, sheet_name = 'Training')

        
        self.clinical_root = clinical_root
        self.filenames = filenames
        self.root = root
        self.transform = transform
        self.df = df
        
    def __len__(self):
        return len(self.filenames)
    
    
    def __getitem__(self, index):
        img, label, sex, p_number = read_data(self.filenames[index], self.df)
        
        if self.transform:
            img = self.transform(img)
            
        if label == 0:
            target = np.array([1, 0], dtype=np.float32)
            target = torch.from_numpy(target)
        else:
            target = np.array([0, 1], dtype=np.float32)
            target = torch.from_numpy(target)
            
        if sex == 0:
            target_sex = np.array([1, 0], dtype=np.float32)
            target_sex = torch.from_numpy(target_sex)
        else:
            target_sex = np.array([0, 1], dtype=np.float32)
            target_sex = torch.from_numpy(target_sex)        
            
            
        return img, target, target_sex