In [116]:
import argparse
import os
import glob
import json
import random
from tqdm import tqdm as tq
import time
from datetime import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


import cv2
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torch
from torch.utils.data import BatchSampler
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as ddp
from torchinfo import summary

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    RichProgressBar,
    RichModelSummary
)

In [117]:
class CFG:
    def __init__(self):
        return
configs = CFG()
configs.dataset_dir = '/mnt/e/dataset/chexpert/CheXpert-v1.0'
configs.batch_size = 4
configs.image_size = (1024, 1024)
configs.image_add_size = 64

In [138]:
class ImageDataset(torch.utils.data.Dataset): # batch dataset
    def __init__(self, configs, mode='train'):
        # arguments
        self.configs = configs
        self.image_size = configs.image_size
        self.image_add_size = configs.image_add_size
        self.mode = mode
        self.data_dir = f"{configs.dataset_dir}/{mode}.csv"
        
        # configuration
        self.batch_size = configs.batch_size
        self.image_size = configs.image_size
        
        # data processing    
        data = pd.read_csv(self.data_dir).fillna(0.0)
        data['Path'] = data['Path'].str.replace("CheXpert-v1.0", configs.dataset_dir)
        for col_idx in range(5, len(data.columns)):
            data[data.columns[col_idx]] = data[data.columns[col_idx]].astype(str)
            data[data.columns[col_idx]] = data[data.columns[col_idx]].str.replace("-1", "0").astype(float)
            # data[data.columns[col_idx]] = data[data.columns[col_idx]].astype(int)
        
        # store data
        self.X = data['Path'].values
        self.Y = data.values[:, 5:]

        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = (
            self.train_transform(self.resize(cv2.imread(self.X[idx])))
            if (self.mode == 'train')
            else self.resize(cv2.imread(self.X[idx]))
            )
        
        y = (self.Y[idx] 
             if (self.mode in ['train', 'valid', 'val'] )
             else None
            )
        return x, y
    
    def resize(self, x):
        H, W, C = x.shape
        
        if H>W:
            resize_fn = A.Resize(
                height=self.image_size[0]*H//W + self.image_add_size, 
                width=self.image_size[1] + self.image_add_size
            )
        else: # W < H
            resize_fn = A.Resize(
                height=self.image_size[0] + self.image_add_size, 
                width=self.image_size[1]*W//H + self.image_add_size
            )
        
        x = A.Compose([
            resize_fn,
            A.CenterCrop(height=self.image_size[0], width=self.image_size[1]),            
        ])(image=x)['image']
        
        return x     # dtype : numpy array    
    
    def train_transform(self, x):
        x = A.Compose([
            A.Affine(
                scale = 0.95,
                translate_percent=0.05, # moving
                shear = 0.05,           # distortion      
                rotate=None,
                interpolation=1,
            ),
            ToTensorV2()
        ])(image=x)['image']
        return x