In [1]:
import os
import random
import torch
import numpy as np
import torchvision.transforms as transforms
import pydicom
from scipy.ndimage.filters import median_filter
from lungmask import mask
import SimpleITK as sitk
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from PIL import Image
from torch.utils.data import Dataset
import tensorflow as tf
from torch import nn, optim
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

  from scipy.ndimage.filters import median_filter


In [2]:
def transform_to_hu(medical_image, image):
    intercept = medical_image.RescaleIntercept
    slope = medical_image.RescaleSlope
    hu_image = image * slope + intercept
    return hu_image

def get_mask(filename, plot_mask=False, return_val=False):

    input_image = sitk.ReadImage(filename)
    mask_out = mask.apply(input_image)[0]  #default model is U-net(R231)

    if return_val:
        return mask_out

def preprocess_images(img,dicom_image):

    hu_image = transform_to_hu(dicom_image, img)
    filtered_image = median_filter(hu_image, size=(3, 3))
    return filtered_image

In [3]:

import torch.nn as nn
import torchvision.models as models

class ResNet50Gray(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNet50Gray, self).__init__()
        resnet50 = models.resnet50(pretrained=pretrained)
        # Replace the input convolution layer to accept single channel grayscale images
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.conv1.weight.data = resnet50.conv1.weight.data[:,0,:,:].unsqueeze(1)
        self.bn1 = resnet50.bn1
        self.relu = resnet50.relu
        self.maxpool = resnet50.maxpool
        self.layer1 = resnet50.layer1
        self.layer2 = resnet50.layer2
        self.layer3 = resnet50.layer3
        self.layer4 = resnet50.layer4
        self.avgpool = resnet50.avgpool
        self.fc = resnet50.fc

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [4]:
class DICOMDataset(Dataset):

    def __init__(self, root_dir, transform=None):

        self.img_labels = []
        self.root_dir = root_dir
        self.dcm_files = os.listdir(root_dir)

        for filename in os.listdir(root_dir):
            
            image_name = filename
            category = image_name[0]

            if category =='A' :
               label=1
            elif category =='B':
               label=2
            elif category =='G':  
               label=3 
            elif category =='E':
               label=4 
            else : label=5        
          
            self.img_labels.append((image_name,label))

    def __len__(self):
        # print(len(self.img_labels))
        return len(self.img_labels)

    def __getitem__(self, idx):

        dcm_file = self.dcm_files[idx]
   
        label_ch =dcm_file[0]
        if label_ch =='A' :
            label=1
        elif label_ch =='B':
            label=2
        elif label_ch =='G':  
            label=3 
        elif label_ch =='E':
            label=4  
        else : label=5
        dcm_path = os.path.join(self.root_dir, dcm_file)
        
        dicom_image= pydicom.dcmread(dcm_path)
        image = np.array(dicom_image.pixel_array)
        print(image.shape)

        cleaned_image = preprocess_images(image,dicom_image)
        masked_img=get_mask(dcm_path,plot_mask=True,return_val=True)
    
        mask_on_orginal = cleaned_image * masked_img
        mask_on_orginal = cv2.resize(mask_on_orginal, (224, 224))
       
        image = mask_on_orginal.astype('float32')
        image = np.expand_dims(image, axis=0)
       
        image = torch.from_numpy(image)

        return image, label

In [5]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        
        z_query = self.backbone.forward(query_images)
        
        
        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

    
        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)
   

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        
        return scores

In [6]:
convolutional_network = ResNet50Gray(pretrained=True)
model = PrototypicalNetworks(convolutional_network)
model.eval()


PrototypicalNetworks(
  (backbone): ResNet50Gray(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [7]:
%run TaskSampler.ipynb
test_set = DICOMDataset(root_dir='../IMAGES/TEST_SET/', transform=None)


N_WAY = 4  # Number of classes in a task
N_SHOT = 5 # Number of images per class in the support set
N_QUERY = 3 # Number of images per class in the query set
N_EVALUATION_TASKS = 1

# The sampler needs a dataset with a "get_labels" method. Check the code if you have any doubt!
test_set.get_labels = lambda: [
    instance[1] for instance in test_set.img_labels
]

test_sampler = TaskSampler(
    test_set, n_way=N_WAY , n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [8]:
(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels,
    example_class_ids,
) = next(iter(test_loader))

(512, 512)


100%|██████████| 1/1 [00:02<00:00,  2.45s/it]
100%|██████████| 5/5 [00:00<00:00, 2016.69it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 38.46it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 37.90it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 35.78it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 38.34it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 41.76it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 38.23it/s]
100%|██████████| 10/10 [00:00<00:00, 10034.22it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 38.49it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.27it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.53it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.64it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.99it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 22.90it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.40it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.28it/s]
100%|██████████| 4/4 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  9.01it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.27it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 32.09it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.49it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.20it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.12it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.64it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.46it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.07it/s]
100%|██████████| 5/5 [00:00<00:00, 5018.31it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.00it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.84it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.83it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.64it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.35it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.12it/s]
100%|██████████| 4/4 [00:00<00:00, 4014.65it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.67it/s]
100%|██████████| 4/4 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.77it/s]
100%|██████████| 3/3 [00:00<00:00, 3009.55it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.50it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.35it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.27it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.90it/s]
100%|██████████| 3/3 [00:00<00:00, 3011.71it/s]


In [9]:
model.eval()
example_scores = model(
    example_support_images,
    example_support_labels,
    example_query_images,
).detach()


_, example_predicted_labels = torch.max(example_scores.data, 1)


Training a meta-learning algorithm

In [10]:
N_TRAINING_EPISODES = 3

train_set = DICOMDataset(root_dir='../IMAGES/TRAIN_SET/', transform=None)
train_set.get_labels = lambda: [ instance[1] for instance in train_set.img_labels]

train_sampler = TaskSampler(
    train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

In [11]:
def sliding_average(lst, window_size):
    if window_size == 0:
        return 0.0
    return sum(lst[-window_size:]) / min(len(lst), window_size)

In [12]:
from tqdm import tqdm
import matplotlib.pyplot as plt

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def fit(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> float:
    optimizer.zero_grad()
    classification_scores = model(
        support_images, support_labels, query_images
    )

    loss = criterion(classification_scores, query_labels)
    loss.backward()
    optimizer.step()

    return loss.item()

In [13]:
def evaluate_on_one_task(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> [int, int]:
    return (
         torch.max(  
            model(support_images, support_labels, query_images).detach().data,1,)[1]
         )


def pe_evaluate(data_loader: DataLoader):

    model.eval()
    with torch.no_grad():
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in tqdm(enumerate(data_loader), total=len(data_loader)):

            predicted_labels =evaluate_on_one_task(support_images, support_labels, query_images, query_labels)
            actual_labels  = query_labels

            actual_labels_np = actual_labels.cpu().numpy()
            predicted_labels_np = predicted_labels.cpu().numpy()

  
            precision = precision_score(actual_labels_np, predicted_labels_np, average='macro')
            recall = recall_score(actual_labels, predicted_labels, average='macro')
            f1_score_macro = f1_score(actual_labels, predicted_labels, average='macro')
            
            # Calculate accuracy
            accuracy = accuracy_score(actual_labels, predicted_labels)

      
            print("Precision (Macro):", precision)
            print("Recall (Macro):", recall)
            print("F1 Score (Macro):", f1_score_macro)
            print("Accuracy:", accuracy)



In [14]:


log_update_frequency = 11

all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
    
    for episode_index, (
        support_images,
        support_labels,
        query_images,
        query_labels,
        _,
    ) in tqdm_train:
        
        loss_value = fit(support_images, support_labels, query_images, query_labels)
        
        all_loss.append(loss_value)
      
        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))

  0%|          | 0/3 [00:00<?, ?it/s]

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  6.67it/s]
100%|██████████| 4/4 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 39.97it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 39.29it/s]
100%|██████████| 4/4 [00:00<00:00, 2670.26it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.09it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.99it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.92it/s]
100%|██████████| 2/2 [00:00<00:00, 3992.67it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.62it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.36it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.35it/s]
100%|██████████| 3/3 [00:00<00:00, 3000.93it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.36it/s]
100%|██████████| 4/4 [00:00<00:00, 4013.69it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.25it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.63it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.54it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.58it/s]
100%|██████████| 5/5 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.28it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.53it/s]
100%|██████████| 5/5 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.28it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.07it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 31.79it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 32.50it/s]
100%|██████████| 5/5 [00:00<00:00, 4984.91it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.80it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 31.73it/s]
100%|██████████| 2/2 [00:00<00:00, 2016.49it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.16it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.96it/s]
100%|██████████| 10/10 [00:00<00:00, 10036.62it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.33it/s]
100%|██████████| 2/2 [00:00<00:00, 2004.93it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.61it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.71it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.04it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.54it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 31.77it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.27it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.08it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.31it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.36it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.82it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.01it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
 33%|███▎      | 1/3 [00:23<00:47, 23.69s/it, loss=0.34]

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  5.85it/s]
100%|██████████| 5/5 [00:00<00:00, 9972.19it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.60it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.60it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.98it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.87it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.63it/s]
100%|██████████| 3/3 [00:00<00:00, 3009.55it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.19it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.00it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.17it/s]
100%|██████████| 4/4 [00:00<00:00, 4006.02it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  9.52it/s]
100%|██████████| 4/4 [00:00<00:00, 4014.65it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.05it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.88it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.69it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.19it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.09it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.88it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.09it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.35it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.73it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.75it/s]
100%|██████████| 5/5 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.65it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.81it/s]
100%|██████████| 5/5 [00:00<00:00, 5017.11it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.14it/s]
100%|██████████| 2/2 [00:00<00:00, 2003.97it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.72it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.52it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.32it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.19it/s]
100%|██████████| 2/2 [00:00<00:00, 2003.97it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.53it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.05it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.98it/s]
100%|██████████| 2/2 [00:00<00:00, 3979.42it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.89it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  9.90it/s]
100%|██████████| 3/3 [00:00<00:00, 1998.56it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 31.28it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.77it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
 67%|██████▋   | 2/3 [00:46<00:23, 23.42s/it, loss=0.34]

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  3.78it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 37.74it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 39.31it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 41.55it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 36.31it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.37it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.64it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.52it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.99it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.57it/s]
100%|██████████| 3/3 [00:00<00:00, 2985.98it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.89it/s]
100%|██████████| 2/2 [00:00<00:00, 1985.00it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.22it/s]
100%|██████████| 4/4 [00:00<00:00, 3972.82it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  9.32it/s]
100%|██████████| 4/4 [00:00<00:00, 4197.45it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.66it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.48it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.48it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.47it/s]
100%|██████████| 4/4 [00:00<00:00, 4013.69it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.07it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.67it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 22.96it/s]
100%|██████████| 2/2 [00:00<00:00, 2008.29it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.85it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.18it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  9.76it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.12it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.97it/s]
100%|██████████| 10/10 [00:00<00:00, 6658.68it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.04it/s]
100%|██████████| 5/5 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 17.70it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.50it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.63it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.73it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.53it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.32it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.00it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.67it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
100%|██████████| 3/3 [01:10<00:00, 23.40s/it, loss=0.34]


In [15]:
pe_evaluate(test_loader) 

  0%|          | 0/1 [00:00<?, ?it/s]

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  5.59it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.75it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.19it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.32it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.66it/s]
100%|██████████| 2/2 [00:00<00:00, 1929.75it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.42it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.50it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.69it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.22it/s]
100%|██████████| 5/5 [00:00<00:00, 5014.71it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.08it/s]
100%|██████████| 5/5 [00:00<00:00, 5021.92it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.36it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.68it/s]
100%|██████████| 2/2 [00:00<00:00, 2011.66it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.07it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.47it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.36it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.48it/s]
100%|██████████| 2/2 [00:00<00:00, 1987.82it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.28it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.13it/s]
100%|██████████| 10/10 [00:00<00:00, 10024.63it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.34it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 32.36it/s]
100%|██████████| 2/2 [00:00<00:00, 2003.01it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 32.24it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.69it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.33it/s]
100%|██████████| 3/3 [00:00<00:00, 1996.65it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.93it/s]
100%|██████████| 3/3 [00:00<00:00, 3012.43it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.46it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.31it/s]
100%|██████████| 4/4 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.34it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.33it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.65it/s]
100%|██████████| 2/2 [00:00<00:00, 1411.75it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.91it/s]
100%|██████████| 4/4 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 17.13it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  9.87it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 32.14it/s]
100%|██████████| 3/3 [00:00<00:00, 2813.08it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.91it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:18<00:00, 18.54s/it]

Precision (Macro): 1.0
Recall (Macro): 1.0
F1 Score (Macro): 1.0
Accuracy: 1.0



