In [2]:
# libraries
import os
import time
import random
import numpy as np
import pandas as pd
import subprocess
import cv2
import PIL.Image
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import roc_auc_score
from warmup_scheduler import GradualWarmupScheduler
import albumentations
import timm
from tqdm.notebook import tqdm
import torch.cuda.amp as amp
import warnings

warnings.simplefilter('ignore')
scaler = amp.GradScaler()
device = torch.device('cuda')

In [5]:
DEBUG = False
kernel_type = 'enetb1_5ch_512_lr3e4_bs32_30epo'
enet_type = 'tf_efficientnet_b1_ns'
data_dir = '../xray_image'
num_workers = 2
num_classes = 12
n_ch = 5
image_size = 512
batch_size = 32
init_lr = 3e-4
warmup_epo = 1
# If DEBUG == True, only run 3 epochs per fold
cosine_epo = 29 if not DEBUG else 2
n_epochs = warmup_epo + cosine_epo
loss_weights = [1., 9.]
image_folder = 'train'
mask_folder = '../generated_mask/'

log_dir = 'logs_classificarion'
model_dir = 'models_classification'
os.makedirs(log_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
log_file = os.path.join(log_dir, f'log_{kernel_type}.txt')

In [6]:
train_df = pd.read_csv(os.path.join(data_dir,'train.csv')).drop('PatientID', axis=1)
train_df.shape

(30083, 12)

In [None]:
train_df.query('')

In [76]:
((train_df['ETT - Abnormal'] ==0) & (train_df['ETT - Borderline']==0) & (train_df['ETT - Normal']==0)).astype(int)

0        1
1        0
2        1
3        1
4        1
        ..
30078    0
30079    1
30080    0
30081    1
30082    0
Length: 30083, dtype: int64

In [77]:
1-train_df[train_df.columns[1:4]].sum(axis=1)

0        1
1        0
2        1
3        1
4        1
        ..
30078    0
30079    1
30080    0
30081    1
30082    0
Length: 30083, dtype: int64

In [70]:
train_df.head()

Unnamed: 0,StudyInstanceUID,ETT - Abnormal,ETT - Borderline,ETT - Normal,NGT - Abnormal,NGT - Borderline,NGT - Incompletely Imaged,NGT - Normal,CVC - Abnormal,CVC - Borderline,CVC - Normal,Swan Ganz Catheter Present
0,1.2.826.0.1.3680043.8.498.26697628953273228189...,0,0,0,0,0,0,1,0,0,0,0
1,1.2.826.0.1.3680043.8.498.46302891597398758759...,0,0,1,0,0,1,0,0,0,1,0
2,1.2.826.0.1.3680043.8.498.23819260719748494858...,0,0,0,0,0,0,0,0,1,0,0
3,1.2.826.0.1.3680043.8.498.68286643202323212801...,0,0,0,0,0,0,0,1,0,0,0
4,1.2.826.0.1.3680043.8.498.10050203009225938259...,0,0,0,0,0,0,0,0,0,1,0


In [9]:
all_files = train_df.StudyInstanceUID.unique()
len(all_files)

30083

In [43]:
class RANZCRDatasetCLS(Dataset):

    def __init__(self, df, mode, transform=None):

        self.df = df.reset_index(drop=True)
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        image = cv2.imread(os.path.join(data_dir, image_folder, row.StudyInstanceUID + '.jpg'))[:, :, ::-1]
        mask = cv2.imread(os.path.join(mask_folder, row.StudyInstanceUID + '.png')).astype(np.float32)[:,:,:1]

        res = self.transform(image=image, mask=mask)
        image = res['image'].astype(np.float32).transpose(2, 0, 1) / 255.
        mask = res['mask'].astype(np.float32).transpose(2, 0, 1) / 255.

        image = np.concatenate([image, mask], 0)

        if self.mode == 'test':
            return torch.tensor(image)
        else:
            label = row[[
                'ETT - Abnormal',
                'ETT - Borderline',
                'ETT - Normal',
                'NGT - Abnormal',
                'NGT - Borderline',
                'NGT - Incompletely Imaged',
                'NGT - Normal',
                'CVC - Abnormal',
                'CVC - Borderline',
                'CVC - Normal',
                'Swan Ganz Catheter Present'
            ]].values.astype(float)
            return torch.tensor(image).float(), torch.tensor(label).float()

In [44]:
transforms_train = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
    #albumentations.HorizontalFlip(p=0.5),
    albumentations.RandomBrightness(limit=0.2, p=0.75),
    albumentations.RandomContrast(limit=0.2, p=0.75),

    albumentations.OneOf([
        albumentations.OpticalDistortion(distort_limit=1.),
        albumentations.GridDistortion(num_steps=5, distort_limit=1.),
    ], p=0.75),

    albumentations.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=40, val_shift_limit=0, p=0.75),
    albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.3, rotate_limit=30, border_mode=0, p=0.75),
    #CutoutV2(max_h_size=int(image_size * 0.4), max_w_size=int(image_size * 0.4), num_holes=1, p=0.75),
])
transforms_val = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
])

In [45]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(8, 8))
    for i, (name, image) in enumerate(images.items()):
        if image.shape[0] == 3:
            image = image.numpy()
            image = image.transpose((1,2,0))
        elif image.shape[0] == 1:
            image = image.numpy()
            image = image.reshape(image.shape[1],image.shape[2])
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [53]:
dataset_show = RANZCRDatasetCLS(train_df,'train', transform=transforms_train)

In [61]:
for i in range(5):
    image, mask = dataset_show[i]
    print(mask.shape)
    #visualize(image=image[:3].transpose(0, 1).transpose(1,2),
             #mask= image[3:].sum(0))

torch.Size([11])
torch.Size([11])
torch.Size([11])
torch.Size([11])
torch.Size([11])


In [62]:
class enetv2(nn.Module):
    def __init__(self, enet_type, out_dim):
        super(enetv2, self).__init__()
        self.enet = timm.create_model(enet_type, True)
        self.dropout = nn.Dropout(0.5)
        self.enet.conv_stem.weight = nn.Parameter(self.enet.conv_stem.weight.repeat(1,n_ch//3+1,1,1)[:, :n_ch])
        self.myfc = nn.Linear(self.enet.classifier.in_features, out_dim)
        self.enet.classifier = nn.Identity()

    def extract(self, x):
        return self.enet(x)

    def forward(self, x):
        x = self.extract(x)
        h = self.myfc(self.dropout(x))
        return h

m = enetv2(enet_type, num_classes)
m(torch.rand(2,n_ch,image_size, image_size))

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth" to /home/jupyter/.cache/torch/hub/checkpoints/tf_efficientnet_b1_ns-99dd0c41.pth


tensor([[ 3.5208e-05, -8.1112e-02, -5.9093e-02, -1.1614e-01, -8.9228e-02,
         -1.8643e-03,  1.6376e-02,  1.1959e-01,  1.1565e-01,  1.2328e-02,
          1.1132e-01,  1.6108e-01],
        [-2.1800e-01,  1.1222e-01,  8.1100e-03,  1.1879e-01,  4.6126e-02,
         -6.4138e-02,  3.1382e-02, -1.2634e-02, -3.1873e-02, -1.2966e-01,
          1.0857e-01, -1.2204e-02]], grad_fn=<AddmmBackward>)