In [None]:
import os
import random
import torch
import numpy as np
import torchvision.transforms as transforms
import pydicom
from scipy.ndimage.filters import median_filter
from lungmask import mask
import SimpleITK as sitk
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader

In [None]:
def transform_to_hu(medical_image, image):
    intercept = medical_image.RescaleIntercept
    slope = medical_image.RescaleSlope
    hu_image = image * slope + intercept

    return hu_image

def get_mask(filename, plot_mask=False, return_val=False):

    input_image = sitk.ReadImage(filename)
    mask_out = mask.apply(input_image)[0]  #default model is U-net(R231)

    if return_val:
        return mask_out

def preprocess_images(img,dicom_image):

    hu_image = transform_to_hu(dicom_image, img)

    # medianl filter for noise reduction 
    # Apply the median filter with a kernel size of 3x3
    filtered_image = median_filter(hu_image, size=(3, 3))



    

    return filtered_image

In [None]:
import os
import torch
import numpy as np
from PIL import Image
import pydicom
from torch.utils.data import Dataset
import tensorflow as tf

class DICOMDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.dcm_files = os.listdir(root_dir)
        self.image_list = []
        self.categories=['A','B','G','E']
        
        for filename in os.listdir(root_dir):
            
            image_name = filename
            category = image_name[0]

            self.image_list.append((image_name,category))
    
    def __len__(self):
        return len(self.dcm_files)
    
    def __getitem__(self, idx):
        dcm_file = self.dcm_files[idx]
        label =dcm_file[0]
        dcm_path = os.path.join(self.root_dir, dcm_file)
        dicom_image= pydicom.dcmread(dcm_path)
        img = np.array(dicom_image.pixel_array)
        
        cleaned_image = preprocess_images(img,dicom_image)

        masked_img=get_mask(dcm_path,plot_mask=True,return_val=True)
        mask_on_orginal = cleaned_image*masked_img
        resized_img = cv2.resize(mask_on_orginal, (224, 224))
        # convert grayscale to RGB
       
        normalized_img = resized_img.astype('float32') / 255
        
        # Apply the transformation only if it is not None
        if self.transform is not None:
           normalized_img = self.transform(normalized_img)
        
        if label=='A' : nu_label=1
        elif label=='B' : nu_label=2
        elif label=='E' : nu_label=3
        elif label=='G'  : nu_label=4
        
        # image=tf.convert_to_tensor(normalized_img)
        return normalized_img,nu_label
    
    def get_labels(self):
        return [item[0] for item in self.image_list]


In [None]:
import torch.nn as nn
import torchvision.models as models

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        # modify the first convolutional layer to accept one input channel
        self.features = models.vgg16(pretrained=True).features
        self.features[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        # remove the last fully connected layer (classifier)
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(*list(models.vgg16(pretrained=True).classifier.children())[:-1])

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
from torchvision.transforms import transforms

dataset = DICOMDataset(root_dir='../IMAGES/DICOM_SUPPORT/', transform=None)


N_WAY = 4  # Number of classes in a task
N_SHOT = 2  # Number of images per class in the support set
N_QUERY = 1  # Number of images per class in the query set
N_EVALUATION_TASKS = 10

# The sampler needs a dataset with a "get_labels" method. Check the code if you have any doubt!
dataset.get_labels = lambda: [
    instance[1] for instance in dataset.image_list]

In [None]:
test_sampler = TaskSampler(
    dataset, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    dataset,
    batch_sampler=test_sampler,
    num_workers=0,
    pin_memory=True,
  
)

In [None]:
for batch in test_loader:
    print(batch)