In [27]:
import os
import random
import torch
import numpy as np
import torchvision.transforms as transforms
import pydicom
from scipy.ndimage.filters import median_filter
from lungmask import mask
import SimpleITK as sitk
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from PIL import Image
from torch.utils.data import Dataset
import tensorflow as tf
from torch import nn, optim
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from CustomTaskSampler import TaskSampler

from preprocessing import transform_to_hu,get_mask,preprocess_images
from DICOMDataset import DICOMDataset
from CNN import CNN
from tqdm import tqdm

  from scipy.ndimage.filters import median_filter


In [28]:
test_set = DICOMDataset(root_dir='C:/Users/Nimesha/Documents/MSC_RESEARCH/IMAGES/TEST_SET', transform=None)
train_set = DICOMDataset(root_dir='C:/Users/Nimesha/Documents/MSC_RESEARCH/IMAGES/TRAIN_SET/', transform=None)

In [29]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        # Extract the features of support and query images
        z_support = self.backbone.forward(support_images)
        
        z_query = self.backbone.forward(query_images)
        
        
        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

    
        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)
   

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        
        return scores


In [30]:
convolutional_network = CNN()
model = PrototypicalNetworks(convolutional_network)
print(model)

PrototypicalNetworks(
  (backbone): CNN(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (fc1): Linear(in_features=100352, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=10, bias=True)
  )
)


In [31]:
N_WAY = 4  # Number of classes in a task
N_SHOT = 4 # Number of images per class in the support set
N_QUERY = 2 # Number of images per class in the query set
N_EVALUATION_TASKS = 1

# The sampler needs a dataset with a "get_labels" method. Check the code if you have any doubt!
test_set.get_labels = lambda: [
    instance[1] for instance in test_set.img_labels
]

test_sampler = TaskSampler(
    test_set, n_way=N_WAY , n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [32]:
(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels,
    example_class_ids,
) = next(iter(test_loader))

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  7.71it/s]
100%|██████████| 3/3 [00:00<00:00, 3012.43it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.94it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 42.65it/s]
100%|██████████| 4/4 [00:00<00:00, 4012.73it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.94it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 39.33it/s]
100%|██████████| 3/3 [00:00<00:00, 3009.55it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.02it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.91it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.88it/s]
100%|██████████| 2/2 [00:00<00:00, 2008.29it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.77it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.77it/s]
100%|██████████| 5/5 [00:00<00:00, 5015.91it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.74it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.02it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.61it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.88it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.86it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.26it/s]
100%|██████████| 5/5 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.25it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.88it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.26it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.44it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.45it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.06it/s]
100%|██████████| 2/2 [00:00<00:00, 2009.73it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 23.55it/s]
100%|██████████| 3/3 [00:00<00:00, 3008.83it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.72it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.64it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.88it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.27it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


In [33]:
model.eval()

print(type(example_support_images))
print(len(example_support_images))

print(type(example_support_images[0]))

example_scores = model(
    example_support_images,
    example_support_labels,
    example_query_images,
).detach()


_, example_predicted_labels = torch.max(example_scores.data, 1)

print(example_query_labels)
print(example_predicted_labels)

<class 'torch.Tensor'>
16
<class 'torch.Tensor'>
tensor([3, 3, 2, 2, 1, 1, 0, 0])
tensor([3, 0, 2, 2, 0, 0, 0, 3])


In [34]:
def evaluate_on_one_task(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> [int, int]:
    return (
         torch.max(  
            model(support_images, support_labels, query_images).detach().data,1,)[1]
         )


def evaluate(data_loader: DataLoader):

    model.eval()
    with torch.no_grad():
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in tqdm(enumerate(data_loader), total=len(data_loader)):

            predicted_labels =evaluate_on_one_task(support_images, support_labels, query_images, query_labels)
            actual_labels  = query_labels

            actual_labels_np = actual_labels.cpu().numpy()
            predicted_labels_np = predicted_labels.cpu().numpy()

  
            precision = precision_score(actual_labels_np, predicted_labels_np, average='macro')
            recall = recall_score(actual_labels, predicted_labels, average='macro')
            f1_score_macro = f1_score(actual_labels, predicted_labels, average='macro')
            
            # Calculate accuracy
            accuracy = accuracy_score(actual_labels, predicted_labels)

      
            print("Precision (Macro):", precision)
            print("Recall (Macro):", recall)
            print("F1 Score (Macro):", f1_score_macro)
            print("Accuracy:", accuracy)

            print(actual_labels_np)
            print(predicted_labels_np)

Training a meta-learning algorithm


In [35]:
N_TRAINING_EPISODES = 6


train_set.get_labels = lambda: [ instance[1] for instance in train_set.img_labels]

train_sampler = TaskSampler(
    train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

In [36]:
def sliding_average(lst, window_size):
    if window_size == 0:
        return 0.0
    return sum(lst[-window_size:]) / min(len(lst), window_size)

In [37]:
from tqdm import tqdm
import matplotlib.pyplot as plt

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def fit(
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
) -> float:
    optimizer.zero_grad()
    classification_scores = model(
        support_images, support_labels, query_images
    )

    loss = criterion(classification_scores, query_labels)
    loss.backward()
    optimizer.step()

    return loss.item()

In [38]:
log_update_frequency = 10

all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
    
    for episode_index, (
        support_images,
        support_labels,
        query_images,
        query_labels,
        _,
    ) in tqdm_train:
        
        loss_value = fit(support_images, support_labels, query_images, query_labels)
        
        all_loss.append(loss_value)
      
        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))

  0%|          | 0/6 [00:00<?, ?it/s]

(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.65it/s]
100%|██████████| 3/3 [00:00<00:00, 3005.95it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.10it/s]
100%|██████████| 2/2 [00:00<00:00, 2004.93it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.50it/s]
100%|██████████| 4/4 [00:00<00:00, 4014.65it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.94it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.27it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.39it/s]
100%|██████████| 4/4 [00:00<00:00, 4021.38it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.01it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.99it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.08it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.46it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.72it/s]
100%|██████████| 5/5 [00:00<00:00, 5018.31it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 10.90it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.32it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.46it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.86it/s]
100%|██████████| 2/2 [00:00<00:00, 2003.49it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.37it/s]
100%|██████████| 2/2 [00:00<00:00, 2003.49it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.08it/s]
100%|██████████| 5/5 [00:00<00:00, 5014.71it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.72it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.05it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.68it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.27it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.46it/s]
100%|██████████| 3/3 [00:00<00:00, 3009.55it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 24.46it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.03it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.07it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.09it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  6.02it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
 17%|█▋        | 1/6 [00:13<01:05, 13.14s/it, loss=3.29]

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  5.73it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.94it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.11it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 43.62it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.08it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.05it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.36it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.65it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.20it/s]
100%|██████████| 10/10 [00:00<00:00, 10019.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.83it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.17it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.36it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.08it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.32it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.76it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.80it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.86it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.86it/s]
100%|██████████| 5/5 [00:00<00:00, 5017.11it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.75it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.39it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.39it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  5.31it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 39.34it/s]
100%|██████████| 3/3 [00:00<00:00, 3008.83it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 41.80it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 44.58it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.38it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.99it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.74it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.48it/s]
100%|██████████| 3/3 [00:00<00:00, 3011.71it/s]
 33%|███▎      | 2/6 [00:25<00:51, 12.95s/it, loss=3.29]

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  4.12it/s]
100%|██████████| 2/2 [00:00<00:00, 2008.29it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 41.80it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 42.49it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.27it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 39.34it/s]
100%|██████████| 3/3 [00:00<00:00, 3009.55it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 43.62it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.29it/s]
100%|██████████| 3/3 [00:00<00:00, 3008.83it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.86it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.05it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.80it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.05it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.39it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.36it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  5.90it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.10it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 43.62it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.07it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.88it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 43.62it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.32it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.85it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.23it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.05it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.75it/s]
100%|██████████| 5/5 [00:00<00:00, 5024.32it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.75it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.78it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.36it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  5.88it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.80it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 43.11it/s]
100%|██████████| 10/10 [00:00<00:00, 10034.22it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 42.67it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
 50%|█████     | 3/6 [00:39<00:39, 13.23s/it, loss=3.29]

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  6.09it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 41.78it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.94it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.99it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 39.33it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 37.85it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.26it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.28it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.36it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.84it/s]
100%|██████████| 10/10 [00:00<00:00, 10022.23it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.43it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.64it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.71it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.36it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.39it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.26it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.48it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  7.42it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.94it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.94it/s]
100%|██████████| 5/5 [00:00<00:00, 5019.51it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.49it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.94it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.27it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.10it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.23it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  6.94it/s]
100%|██████████| 3/3 [00:00<00:00, 3008.83it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 42.68it/s]
100%|██████████| 3/3 [00:00<00:00, 3009.55it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 42.68it/s]
100%|██████████| 4/4 [00:00<00:00, 4015.61it/s]
 67%|██████▋   | 4/6 [00:51<00:25, 12.90s/it, loss=3.29]

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  7.32it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  6.70it/s]
100%|██████████| 4/4 [00:00<00:00, 4013.69it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  8.32it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 32.27it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 33.43it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 31.85it/s]
100%|██████████| 3/3 [00:00<00:00, 3009.55it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 31.85it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.74it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.51it/s]
100%|██████████| 5/5 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.53it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.26it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 29.94it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.40it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.87it/s]
100%|██████████| 4/4 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 30.87it/s]
100%|██████████| 4/4 [00:00<00:00, 4010.81it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.27it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.46it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
100%|██████████| 3/3 [00:00<00:00, 3008.11it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.86it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.86it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.80it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.26it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.38it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.25it/s]
100%|██████████| 10/10 [00:00<00:00, 10010.27it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.37it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.88it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  7.51it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]
 83%|████████▎ | 5/6 [01:04<00:12, 12.83s/it, loss=3.29]

(512, 512)


100%|██████████| 1/1 [00:00<00:00, 42.67it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.42it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.83it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 15.08it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.33it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.49it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.71it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.48it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 28.26it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  8.09it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 33.44it/s]
100%|██████████| 4/4 [00:00<00:00, 4011.77it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 34.00it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 32.89it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 32.36it/s]
100%|██████████| 4/4 [00:00<00:00, 4013.69it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.11it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.56it/s]
100%|██████████| 3/3 [00:00<00:00, 3010.27it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.75it/s]
100%|██████████| 4/4 [00:00<00:00, 4013.69it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.48it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.10it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.75it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.07it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.21it/s]
100%|██████████| 2/2 [00:00<00:00, 1331.95it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.37it/s]
100%|██████████| 10/10 [00:00<00:00, 10034.22it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.54it/s]
100%|██████████| 2/2 [00:00<00:00, 2005.88it/s]
100%|██████████| 6/6 [01:16<00:00, 12.81s/it, loss=3.29]


In [39]:
evaluate(test_loader)

  0%|          | 0/1 [00:00<?, ?it/s]

(512, 512)


100%|██████████| 1/1 [00:00<00:00,  5.52it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.94it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 41.80it/s]
100%|██████████| 3/3 [00:00<00:00, 3012.43it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 41.71it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 43.63it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.07it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.24it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.35it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.40it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.39it/s]
100%|██████████| 2/2 [00:00<00:00, 2008.29it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 25.39it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 11.80it/s]
100%|██████████| 10/10 [00:00<00:00, 10051.05it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.57it/s]
100%|██████████| 4/4 [00:00<00:00, 4013.69it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 14.01it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 13.19it/s]
100%|██████████| 2/2 [00:00<00:00, 2006.84it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 26.03it/s]
100%|██████████| 4/4 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.85it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.80it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.43it/s]
100%|██████████| 2/2 [00:00<00:00, 2007.80it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 27.48it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 12.23it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00,  6.44it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 43.62it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.94it/s]
100%|██████████| 3/3 [00:00<?, ?it/s]


(512, 512)


100%|██████████| 1/1 [00:00<00:00, 40.94it/s]
100%|██████████| 2/2 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:11<00:00, 11.74s/it]

Precision (Macro): 0.9166666666666666
Recall (Macro): 0.875
F1 Score (Macro): 0.8666666666666666
Accuracy: 0.875
[0 0 2 2 3 3 1 1]
[0 0 2 2 1 3 1 1]





In [40]:
torch.save(model.state_dict(), 'model.pth')