Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
57 changed files
with
4,345 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import os | ||
import PIL | ||
import torch | ||
import numpy as np | ||
import nibabel as nib | ||
import matplotlib.pyplot as plt | ||
|
||
from os import listdir | ||
from os.path import join | ||
from PIL import Image | ||
from utils.transform import itensity_normalize | ||
from torch.utils.data.dataset import Dataset | ||
|
||
|
||
class ISIC2018_dataset(Dataset): | ||
def __init__(self, dataset_folder='/ISIC2018_Task1_npy_all', | ||
folder='folder0', train_type='train', transform=None): | ||
self.transform = transform | ||
self.train_type = train_type | ||
self.folder_file = './Datasets/' + folder | ||
|
||
if self.train_type in ['train', 'validation', 'test']: | ||
# this is for cross validation | ||
with open(join(self.folder_file, self.folder_file.split('/')[-1] + '_' + self.train_type + '.list'), | ||
'r') as f: | ||
self.image_list = f.readlines() | ||
self.image_list = [item.replace('\n', '') for item in self.image_list] | ||
self.folder = [join(dataset_folder, 'image', x) for x in self.image_list] | ||
self.mask = [join(dataset_folder, 'label', x.split('.')[0] + '_segmentation.npy') for x in self.image_list] | ||
# self.folder = sorted([join(dataset_folder, self.train_type, 'image', x) for x in | ||
# listdir(join(dataset_folder, self.train_type, 'image'))]) | ||
# self.mask = sorted([join(dataset_folder, self.train_type, 'label', x) for x in | ||
# listdir(join(dataset_folder, self.train_type, 'label'))]) | ||
else: | ||
print("Choosing type error, You have to choose the loading data type including: train, validation, test") | ||
|
||
assert len(self.folder) == len(self.mask) | ||
|
||
def __getitem__(self, item: int): | ||
image = np.load(self.folder[item]) | ||
label = np.load(self.mask[item]) | ||
|
||
sample = {'image': image, 'label': label} | ||
|
||
if self.transform is not None: | ||
# TODO: transformation to argument datasets | ||
sample = self.transform(sample, self.train_type) | ||
|
||
return sample['image'], sample['label'] | ||
|
||
def __len__(self): | ||
return len(self.folder) | ||
|
||
# a = ISIC2018_dataset() |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
ISIC_0010854.npy |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import torch.nn as nn | ||
|
||
|
||
# # SE block add to U-net | ||
def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1): | ||
"""3x3 convolution with padding""" | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias) | ||
|
||
|
||
class SE_Conv_Block(nn.Module): | ||
expansion = 4 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False): | ||
super(SE_Conv_Block, self).__init__() | ||
self.conv1 = conv3x3(inplanes, planes, stride) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv2 = conv3x3(planes, planes * 2) | ||
self.bn2 = nn.BatchNorm2d(planes * 2) | ||
self.conv3 = conv3x3(planes * 2, planes) | ||
self.bn3 = nn.BatchNorm2d(planes) | ||
self.downsample = downsample | ||
self.stride = stride | ||
self.dropout = drop_out | ||
|
||
if planes <= 16: | ||
self.globalAvgPool = nn.AvgPool2d((224, 300), stride=1) # (224, 300) for ISIC2018 | ||
self.globalMaxPool = nn.MaxPool2d((224, 300), stride=1) | ||
elif planes == 32: | ||
self.globalAvgPool = nn.AvgPool2d((112, 150), stride=1) # (112, 150) for ISIC2018 | ||
self.globalMaxPool = nn.MaxPool2d((112, 150), stride=1) | ||
elif planes == 64: | ||
self.globalAvgPool = nn.AvgPool2d((56, 75), stride=1) # (56, 75) for ISIC2018 | ||
self.globalMaxPool = nn.MaxPool2d((56, 75), stride=1) | ||
elif planes == 128: | ||
self.globalAvgPool = nn.AvgPool2d((28, 37), stride=1) # (28, 37) for ISIC2018 | ||
self.globalMaxPool = nn.MaxPool2d((28, 37), stride=1) | ||
elif planes == 256: | ||
self.globalAvgPool = nn.AvgPool2d((14, 18), stride=1) # (14, 18) for ISIC2018 | ||
self.globalMaxPool = nn.MaxPool2d((14, 18), stride=1) | ||
|
||
self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2)) | ||
self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2) | ||
self.sigmoid = nn.Sigmoid() | ||
|
||
self.downchannel = None | ||
if inplanes != planes: | ||
self.downchannel = nn.Sequential(nn.Conv2d(inplanes, planes * 2, kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * 2),) | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
|
||
if self.downchannel is not None: | ||
residual = self.downchannel(x) | ||
|
||
original_out = out | ||
out1 = out | ||
# For global average pool | ||
out = self.globalAvgPool(out) | ||
out = out.view(out.size(0), -1) | ||
out = self.fc1(out) | ||
out = self.relu(out) | ||
out = self.fc2(out) | ||
out = self.sigmoid(out) | ||
out = out.view(out.size(0), out.size(1), 1, 1) | ||
avg_att = out | ||
out = out * original_out | ||
# For global maximum pool | ||
out1 = self.globalMaxPool(out1) | ||
out1 = out1.view(out1.size(0), -1) | ||
out1 = self.fc1(out1) | ||
out1 = self.relu(out1) | ||
out1 = self.fc2(out1) | ||
out1 = self.sigmoid(out1) | ||
out1 = out1.view(out1.size(0), out1.size(1), 1, 1) | ||
max_att = out1 | ||
out1 = out1 * original_out | ||
|
||
att_weight = avg_att + max_att | ||
out += out1 | ||
out += residual | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
out = self.relu(out) | ||
if self.dropout: | ||
out = nn.Dropout2d(0.5)(out) | ||
|
||
return out, att_weight |
Oops, something went wrong.