In [None]:
import deeplake
from matplotlib import pyplot as plt
import torch
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.metrics import accuracy, specificity
from Network.Unet import Custom2DUnet
from torch.nn.functional import one_hot
import numpy as np
import random

skip = False
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = deeplake.load("hub://activeloop/drive-test")
visualset = deeplake.load("hub://activeloop/drive-train")
testloader = torch.utils.data.DataLoader(dataset, 16, False, collate_fn=lambda x:x)
dataset[0]

In [None]:
diceLoss = smp.losses.DiceLoss(
   mode="binary",          # For multi-class segmentation
   classes=None,               # Compute the loss for all classes
   log_loss=False,             # Do not use log version of Dice loss
   from_logits=True,           # Model outputs are raw logits
   smooth=1e-5,                # A small smoothing factor for stability
   ignore_index=None,          # Don't ignore any classes
   eps=1e-7                    # Epsilon for numerical stability
)
class Hausdolf95(torch.nn.Module):
    def __init__(self):
        super(Hausdolf95, self).__init__()
        self.distancef = torch.nn.PairwiseDistance()
    def distance(self, x, y):
        return self.distancef(x, y)
        
    def forward(self, x, y):
        maximum_d=torch.max(self.distance(x, y))
        return maximum_d
hausdolf95 = Hausdolf95()
jaccardf = smp.losses.JaccardLoss(
   mode="binary",          # For multi-class segmentation
   classes=None,               # Compute the loss for all classes
   log_loss=False,             # Do not use log version of Dice loss
   from_logits=True,           # Model outputs are raw logits
   smooth=1e-5,                # A small smoothing factor for stability          
   eps=1e-7
)
def evaluation(net, loader, diceLoss, HF95, jaccard,DEVICE):
    net.eval()
    net.to(DEVICE)
    dicef= diceLoss.to(DEVICE)
    hf95f = HF95.to(DEVICE)
    jaccardf = jaccard.to(DEVICE)
    dice = 0
    hf95 = 0
    jacc = 0
    acc = 0
    spec = 0
    length = len(loader)
    for sample in loader:
        X= torch.stack([torch.Tensor(s["rgb_images"].numpy()).permute(-1,0,1) for s in sample], 0)
        Y= torch.stack([torch.where(torch.from_numpy(s['masks'].numpy()).squeeze()[...,0], 0.0, 1.0).type(torch.int64) for s in sample], 0)
        out = net(X.type(torch.float32).to(DEVICE))
        jacc += (1-jaccardf(out.squeeze(), Y.squeeze().type(torch.int64).to(DEVICE))).item()
        dice += (1-dicef(out.squeeze(), Y.squeeze().type(torch.int64).to(DEVICE))).item()
        hf95 += hf95f(out.squeeze(), Y.type(torch.float32).to(DEVICE)).item()
        out = (out>0.15).type(torch.int32).to(DEVICE)
        acc += accuracy(out.squeeze()*Y.to(DEVICE), (1-out).squeeze()*Y.to(DEVICE),out.squeeze()*(1-Y).to(DEVICE),(1-out).squeeze()*(1-Y).to(DEVICE)).mean().item()
        spec += specificity(out.squeeze()*Y.to(DEVICE), (1-out).squeeze()*Y.to(DEVICE), out.squeeze()*(1-Y).to(DEVICE), (1-out).squeeze()*(1-Y).to(DEVICE)).mean().item()
    return {"dice": dice/length, "hf95":hf95/length, 'jaccard': jacc/length, 'accuracy': acc/length, "specificity": spec/length}

In [None]:
if not skip:
    net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
    net.load_state_dict(torch.load('./Models/CentralDRIVE/net.pt'))
    net.float()
    net.to(DEVICE)
    print("model is ready")
    print(evaluation(net, testloader, diceLoss, hausdolf95, jaccardf, DEVICE))

In [None]:
if not skip:
    net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
    net.load_state_dict(torch.load('./Models/FedAvgDRIVE/net.pt'))
    net.float()
    net.to(DEVICE)
    print("model is ready")
    print(evaluation(net, testloader, diceLoss, hausdolf95, jaccardf, DEVICE))

In [None]:
if not skip:
    net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
    net.load_state_dict(torch.load('./Models/FedPIDDRIVE/net.pt'))
    net.float()
    net.to(DEVICE)
    print("model is ready")
    print(evaluation(net, testloader, diceLoss, hausdolf95, jaccardf, DEVICE))

In [None]:
if not skip:
    net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
    net.load_state_dict(torch.load('./Models/FedLWRDRIVE/net.pt'))
    net.float()
    net.to(DEVICE)
    print("model is ready")
    print(evaluation(net, testloader, diceLoss, hausdolf95, jaccardf, DEVICE))

In [None]:
if not skip:
    net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
    net.load_state_dict(torch.load('./Models/FedRefDRIVE/net.pt'))
    net.float()
    net.to(DEVICE)
    print("model is ready")
    print(evaluation(net, testloader, diceLoss, hausdolf95, jaccardf, DEVICE))

In [None]:
def minmaxNorm(tensor):
    return (tensor-tensor.min())/(tensor.max()-tensor.min())

def visualization(img, label, pred, threshold=0.7):
    fig, ax =plt.subplots(1,3)
    ax[0].imshow(img.squeeze())
    ax[1].imshow((label.squeeze()).astype("uint8"), cmap="gray")
    print(pred.max())
    ax[2].imshow((minmaxNorm(torch.clip(pred, 0))>threshold).squeeze().cpu().detach().numpy(), cmap="gray")
    ax[0].set_axis_off()
    ax[1].set_axis_off()
    ax[2].set_axis_off()
    fig.tight_layout()
    fig.show()

In [None]:
indx=7

In [None]:
net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
net.load_state_dict(torch.load('./Models/CentralDRIVE/net.pt'))
net.float()
net.to(DEVICE)
print("model is ready")
img = visualset[indx]["rgb_images"].numpy()
label= np.where(visualset[indx]['manual_masks/mask'].numpy(), 0.0,1.0)[...,0].squeeze()
pred = net(torch.Tensor(img).permute(-1,0,1).unsqueeze(0).to(DEVICE))
visualization(img,label, pred, 0.55)

In [None]:
net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
net.load_state_dict(torch.load('./Models/FedAvgDRIVE/net.pt'))
net.float()
net.to(DEVICE)
print("model is ready")
img = visualset[indx]["rgb_images"].numpy()
label= np.where(visualset[indx]['manual_masks/mask'].numpy(), 0.0,1.0)[...,0].squeeze()
pred = net(torch.Tensor(img).permute(-1,0,1).unsqueeze(0).to(DEVICE))
visualization(img,label, pred, 0.15)

In [None]:
net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
net.load_state_dict(torch.load('./Models/FedPIDDRIVE/net.pt'))
net.float()
net.to(DEVICE)
print("model is ready")
img = visualset[indx]["rgb_images"].numpy()
label= np.where(visualset[indx]['manual_masks/mask'].numpy(), 0.0,1.0)[...,0].squeeze()
pred = net(torch.Tensor(img).permute(-1,0,1).unsqueeze(0).to(DEVICE))
visualization(img,label, pred, 1)

In [None]:
net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
net.load_state_dict(torch.load('./Models/FedLWRDRIVE/net.pt'))
net.float()
net.to(DEVICE)
print("model is ready")
img = visualset[indx]["rgb_images"].numpy()
label= np.where(visualset[indx]['manual_masks/mask'].numpy(), 0.0,1.0)[...,0].squeeze()
pred = net(torch.Tensor(img).permute(-1,0,1).unsqueeze(0).to(DEVICE))
visualization(img,label, pred, 0.15)

In [None]:
net = Custom2DUnet(3, 1, True, 4, "cr", num_groups=4)
net.load_state_dict(torch.load('./Models/FedRefDRIVE/net.pt'))
net.float()
net.to(DEVICE)
print("model is ready")
img = visualset[indx]["rgb_images"].numpy()
label= np.where(visualset[indx]['manual_masks/mask'].numpy(), 0.0,1.0)[...,0].squeeze()
pred = net(torch.Tensor(img).permute(-1,0,1).unsqueeze(0).to(DEVICE))
visualization(img,label, pred, 0.15)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
 
fedavgframe = pd.read_csv('Result/fedavg/FedAvg_drive.csv')
fedpidframe = pd.read_csv('Result/fedpid/FedPID_drive.csv')
fedlwrframe = pd.read_csv('Result/fedlwr/FedLWR_drive.csv')
fedrefframe = pd.read_csv("Result/fedref/FedRef_drive.csv")

length = min([len(fedavgframe), len(fedpidframe), len(fedlwrframe), len(fedrefframe)])
plt.plot(fedavgframe['loss'].to_numpy()[:length], color=(0.5,0,1), label= "Fed-Avg", marker= "", linestyle="-")
plt.plot(fedlwrframe['loss'].to_numpy()[:length], color=(0.5,1,0.9), label= "Fed-LWR", marker= "", linestyle="--")
plt.plot(fedpidframe['loss'].to_numpy()[:length], color=(0.5,1,0), label= "Fed-PID", marker= "", linestyle="-")
plt.plot(fedrefframe['loss'].to_numpy()[:length], color=(1,0,0), label= "Fed-Ref", marker= "", linestyle="--")
plt.grid(True, axis="y", alpha=0.5, linestyle="--")
plt.legend(fontsize=16)
# plt.ylim(top=0.84)
plt.xlabel("round".upper())
# plt.ylabel("Loss")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
 
fedavgframe = pd.read_csv('Result/fedavg/FedAvg_drive.csv')
fedpidframe = pd.read_csv('Result/fedpid/FedPID_drive.csv')
fedlwrframe = pd.read_csv('Result/fedlwr/FedLWR_drive.csv')
fedrefframe = pd.read_csv("Result/fedref/FedRef_drive.csv")
length = min([len(fedavgframe), len(fedpidframe), len(fedlwrframe), len(fedrefframe)])
plt.plot(fedavgframe['mDice'].to_numpy()[:length], color=(0.5,0,1), label= "Fed-Avg", marker= "", linestyle="--")
plt.plot(fedlwrframe['mDice'].to_numpy()[:length], color=(0.5,1,0.9), label= "Fed-LWR", marker= "", linestyle="-.")
plt.plot(fedpidframe['mDice'].to_numpy()[:length], color=(0.5,1,0), label= "Fed-PID", marker= "", linestyle="-.")
plt.plot(fedrefframe['mDice'].to_numpy()[:length], color=(1,0,0), label= "Fed-Ref", marker= "", linestyle="--")
plt.grid(True, axis="y", alpha=0.5, linestyle="--")
plt.legend(fontsize=16)
plt.xlabel("round".upper())
# plt.ylabel("DiceScore")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
 
fedavgframe = pd.read_csv('Result/fedavg/FedAvg_drive.csv')
fedpidframe = pd.read_csv('Result/fedpid/FedPID_drive.csv')
fedlwrframe = pd.read_csv('Result/fedlwr/FedLWR_drive.csv')
fedrefframe = pd.read_csv("Result/fedref/FedRef_drive.csv")
length = min([len(fedavgframe), len(fedpidframe), len(fedlwrframe), len(fedrefframe)])
plt.plot(fedavgframe['mHF95'].to_numpy()[:length], color=(0.5,0,1), label= "Fed-Avg", marker= ".", linestyle="--")
plt.plot(fedlwrframe['mHF95'].to_numpy()[:length], color=(0.5,1,0.9), label= "Fed-LWR", marker= ".", linestyle="--")
plt.plot(fedpidframe['mHF95'].to_numpy()[:length], color=(0.5,1,0), label= "Fed-PID", marker= ".", linestyle="--")
plt.plot(fedrefframe['mHF95'].to_numpy()[:length], color=(1,0,0), label= "Fed-Ref", marker= "v", linestyle="--")
plt.grid(True, axis="y", alpha=0.5, linestyle="--")
plt.legend(fontsize=16)
plt.xlabel("round".upper())
# plt.ylabel("Housdorff Distance")