In [1]:
import os
import random
import torch
import numpy as np
import torchvision.transforms as transforms
import pydicom
from scipy.ndimage.filters import median_filter
from lungmask import mask
import SimpleITK as sitk
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader


from PIL import Image
from torch.utils.data import Dataset
import tensorflow as tf

  warn(
  from scipy.ndimage.filters import median_filter


In [2]:
def transform_to_hu(medical_image, image):
    intercept = medical_image.RescaleIntercept
    slope = medical_image.RescaleSlope
    hu_image = image * slope + intercept

    return hu_image

def preprocess_images(img,dicom_image):

    hu_image = transform_to_hu(dicom_image, img)

    # medianl filter for noise reduction 
    # Apply the median filter with a kernel size of 3x3
    filtered_image = median_filter(hu_image, size=(3, 3))

In [3]:
import torch.nn as nn
import torchvision.models as models


class VGG16Gray(nn.Module):
    def __init__(self, pretrained=True):
        super(VGG16Gray, self).__init__()
        vgg16 = models.vgg16(pretrained=pretrained)
        # Remove the last fully connected layer of VGG16
        self.features = nn.Sequential(*list(vgg16.features.children())[:-1])
        # Replace the input convolution layer to accept single channel grayscale images
        self.features[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return x


In [4]:
class DICOMDataset(Dataset):

    def __init__(self, root_dir, transform=None):

        self.img_labels = []
        self.root_dir = root_dir
        self.dcm_files = os.listdir(root_dir)

        for filename in os.listdir(root_dir):
            
            image_name = filename
            category = image_name[0]

            if category =='A' :
               label=1
            elif category =='B':
               label=2
            elif category =='G':  
               label=3 
            elif category =='E':
               label=4 
            else : label=5        
            # print(label)
            self.img_labels.append((image_name,label))

    def __len__(self):
        # print(len(self.img_labels))
        return len(self.img_labels)

    def __getitem__(self, idx):

        dcm_file = self.dcm_files[idx]
        # print(dcm_file)
        label_ch =dcm_file[0]
        if label_ch =='A' :
            label=1
        elif label_ch =='B':
            label=2
        elif label_ch =='G':  
            label=3 
        elif label_ch =='E':
            label=4  
        else : label=5
        dcm_path = os.path.join(self.root_dir, dcm_file)
        
        dicom_image= pydicom.dcmread(dcm_path)
    
        image = np.array(dicom_image.pixel_array)
        image = cv2.resize(image, (224, 224))
       
        image = image.astype('float32')
        image = np.expand_dims(image, axis=0)  # Add a channel to the image
        
        
        image = torch.from_numpy(image)

        return image, label

In [5]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        print(z_support.shape)
        z_query = self.backbone.forward(query_images)
        print(z_query.shape)
        
        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        print(z_proto.shape)

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)
   

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        
        return scores

In [6]:
convolutional_network = VGG16Gray(pretrained=True)
model = PrototypicalNetworks(convolutional_network)
model.eval()


PrototypicalNetworks(
  (backbone): VGG16Gray(
    (features): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inpl

In [9]:
%run TaskSampler.ipynb

dataset = DICOMDataset(root_dir='../IMAGES/DICOM_SUPPORT/', transform=None)


N_WAY = 4  # Number of classes in a task
N_SHOT = 2  # Number of images per class in the support set
N_QUERY = 2  # Number of images per class in the query set
N_EVALUATION_TASKS = 1

# The sampler needs a dataset with a "get_labels" method. Check the code if you have any doubt!
dataset.get_labels = lambda: [
    instance[1] for instance in dataset.img_labels
]

test_sampler = TaskSampler(
    dataset, n_way=4, n_shot=2, n_query=2, n_tasks=2
)

test_loader = DataLoader(
    dataset,
    batch_sampler=test_sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [11]:
(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels,
    example_class_ids,
) = next(iter(test_loader))

print(len(example_support_images))

8


In [14]:
model.eval()
example_scores = model(
    example_support_images,
    example_support_labels,
    example_query_images,
).detach()



_, example_predicted_labels = torch.max(example_scores.data, 1)

print("Ground Truth / Predicted")
for i in range(len(example_query_labels)):
    print(
        f"{[example_class_ids[example_query_labels[i]]]} / {[example_class_ids[example_predicted_labels[i]]]}"
    )


torch.Size([8, 100352])
torch.Size([8, 100352])
torch.Size([4, 100352])
Ground Truth / Predicted
[1] / [1]
[1] / [1]
[3] / [1]
[3] / [3]
[4] / [1]
[4] / [4]
[2] / [2]
[2] / [2]


In [18]:
from tqdm import tqdm

def evaluate_on_one_task(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> [int, int]:
    """
    Returns the number of correct predictions of query labels, and the total number of predictions.
    """
    return (
        torch.max(
            model(support_images, support_labels, query_images)
            .detach()
            .data,
            1,
        )[1]
        == query_labels
    ).sum().item(), len(query_labels)


def evaluate(data_loader: DataLoader):
    # We'll count everything and compute the ratio at the end
    total_predictions = 0
    correct_predictions = 0

    # eval mode affects the behaviour of some layers (such as batch normalization or dropout)
    # no_grad() tells torch not to keep in memory the whole computational graph (it's more lightweight this way)
    model.eval()
    with torch.no_grad():
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in tqdm(enumerate(data_loader), total=len(data_loader)):

            correct, total = evaluate_on_one_task(
                support_images, support_labels, query_images, query_labels
            )

            total_predictions += total
            correct_predictions += correct

    print(
        f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
    )


evaluate(test_loader)

  0%|          | 0/2 [00:00<?, ?it/s]

torch.Size([8, 100352])


 50%|█████     | 1/2 [00:02<00:02,  2.07s/it]

torch.Size([8, 100352])
torch.Size([4, 100352])
torch.Size([8, 100352])


100%|██████████| 2/2 [00:04<00:00,  2.06s/it]

torch.Size([8, 100352])
torch.Size([4, 100352])
Model tested on 2 tasks. Accuracy: 68.75%



