In [2]:
import torch
from utils.CustomDataset import OCTDL
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
import random
import numpy as np

def seeding(num):
    torch.manual_seed(num)
    torch.cuda.manual_seed(num)
    torch.cuda.manual_seed_all(num)
    np.random.seed(num)
    random.seed(num)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True

seeding(2025)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def evaluation(net, loader, f1, f2, f3, DEVICE):
    net.eval()
    net.to(DEVICE)
    accf= f1.to(DEVICE)
    f1scoref = f2.to(DEVICE)
    rmsef = f3.to(DEVICE)
    acc = 0
    f1score = 0
    rmse = 0
    length = len(loader)
    for sample in loader:
        X= torch.stack([s["x"] for s in sample], 0)
        Y= torch.stack([s["y"] for s in sample], 0)
        out = net(X.type(torch.float32).to(DEVICE)) 
        acc += accf(out.type(torch.float32).to(DEVICE), torch.nn.functional.one_hot(Y.type(torch.int64), 7).squeeze().to(DEVICE)).item()
        f1score += f1scoref(out.type(torch.float32).to(DEVICE), torch.nn.functional.one_hot(Y.type(torch.int64), 7).squeeze().to(DEVICE)).item()
        rmse += torch.sqrt(rmsef(out.type(torch.float32).to(DEVICE), torch.nn.functional.one_hot(Y.type(torch.int64), 7).squeeze().type(torch.float32).to(DEVICE))).item()
    return {"acc": acc/length, "f1score":f1score/length, 'rmse': rmse/length}

In [None]:
from torchmetrics.classification import F1Score, Accuracy
from torchmetrics.regression import MeanSquaredError
from Network.Resnet import ResNet

testset = OCTDL("Data/OCTDL/clients/test")
testloader = DataLoader(testset, 16, False, collate_fn= lambda x: x)
net = ResNet().to(DEVICE)
net.load_state_dict(torch.load("Models/CentralOCT/net.pt"))

evaluation(net, testloader, Accuracy("multiclass", num_classes=7, average="macro"), F1Score("multiclass", num_classes=7, average="macro"),
            MeanSquaredError(), DEVICE)

In [None]:
from torchmetrics.classification import F1Score, Accuracy
from torchmetrics.regression import MeanSquaredError
from Network.Resnet import ResNet

testset = OCTDL("Data/OCTDL/clients/test")
testloader = DataLoader(testset, 16, False, collate_fn= lambda x: x)
net = ResNet().to(DEVICE)
net.load_state_dict(torch.load("Models/FedAvgOCT/net.pt"))

evaluation(net, testloader, Accuracy("multiclass", num_classes=7, average="macro"), F1Score("multiclass", num_classes=7, average="macro"),
            MeanSquaredError(), DEVICE)

In [None]:
from torchmetrics.classification import F1Score, Accuracy
from torchmetrics.regression import MeanSquaredError
from Network.Resnet import ResNet

testset = OCTDL("Data/OCTDL/clients/test")
testloader = DataLoader(testset, 16, False, collate_fn= lambda x: x)
net = ResNet().to(DEVICE)
net.load_state_dict(torch.load("Models/FedPIDOCT/net.pt"))

evaluation(net, testloader, Accuracy("multiclass", num_classes=7, average="macro"), F1Score("multiclass", num_classes=7, average="macro"),
            MeanSquaredError(), DEVICE)

In [None]:
from torchmetrics.classification import F1Score, Accuracy
from torchmetrics.regression import MeanSquaredError
from Network.Resnet import ResNet

testset = OCTDL("Data/OCTDL/clients/test")
testloader = DataLoader(testset, 16, False, collate_fn= lambda x: x)
net = ResNet().to(DEVICE)
net.load_state_dict(torch.load("Models/FedLWROCT/net.pt"))

evaluation(net, testloader, Accuracy("multiclass", num_classes=7, average="macro"), F1Score("multiclass", num_classes=7, average="macro"),
            MeanSquaredError(), DEVICE)

In [None]:
from torchmetrics.classification import F1Score, Accuracy
from torchmetrics.regression import MeanSquaredError
from Network.Resnet import ResNet

testset = OCTDL("Data/OCTDL/clients/test")
testloader = DataLoader(testset, 16, False, collate_fn= lambda x: x)
net = ResNet().to(DEVICE)
net.load_state_dict(torch.load("Models/FedRefOCT/net.pt"))

evaluation(net, testloader, Accuracy("multiclass", num_classes=7, average="macro"), F1Score("multiclass", num_classes=7, average="macro"),
            MeanSquaredError(), DEVICE)

In [None]:
# import pandas as pd
# import matplotlib.pyplot as plt
# fedavgframe = pd.read_csv('Result/FedAvg_loss_OCTDL.csv')
# fedpidframe = pd.read_csv('Result/FedPID_loss_OCTDL.csv')
# fedlwrframe = pd.read_csv('Result/FedLWR_loss_OCTDL.csv')
# fedrefframe = pd.read_csv("Result/FedRef_loss_OCTDL.csv")
# plt.plot(fedavgframe['1'].to_numpy(), color=(0.5,0,1), label= "Fed-Avg", marker= ".")
# plt.plot(fedlwrframe['1'].to_numpy(), color=(0.5,1,0.9), label= "Fed-LWR", marker= ".")
# plt.plot(fedpidframe['1'].to_numpy(), color=(0.5,1,0), label= "Fed-PID", marker= ".")
# plt.plot(fedrefframe['1'].to_numpy(), color=(1,0,0), label= "Fed-Ref", marker= ".")
# plt.legend(fontsize=16)

In [None]:
# import pandas as pd
# import matplotlib.pyplot as plt
# fedavgframe = pd.read_csv('Result/FedAvg_OCTDL.csv')
# fedpidframe = pd.read_csv('Result/FedPID_OCTDL.csv')
# fedlwrframe = pd.read_csv('Result/FedLWR_OCTDL.csv')
# fedrefframe = pd.read_csv("Result/FedRef_OCTDL.csv")
# plt.plot(fedavgframe['accuracy'].to_numpy(), color=(0.5,0,1), label= "Fed-Avg", marker= ".")
# plt.plot(fedlwrframe['accuracy'].to_numpy(), color=(0.5,1,0.9), label= "Fed-LWR", marker= ".")
# plt.plot(fedpidframe['accuracy'].to_numpy(), color=(0.5,1,0), label= "Fed-PID", marker= ".")
# plt.plot(fedrefframe['accuracy'].to_numpy(), color=(1,0,0), label= "Fed-Ref", marker= ".")
# plt.legend(fontsize=16)

In [None]:
# import pandas as pd
# import matplotlib.pyplot as plt
# fedavgframe = pd.read_csv('Result/FedAvg_OCTDL.csv')
# fedpidframe = pd.read_csv('Result/FedPID_OCTDL.csv')
# fedlwrframe = pd.read_csv('Result/FedLWR_OCTDL.csv')
# fedrefframe = pd.read_csv("Result/FedRef_OCTDL.csv")
# plt.plot(fedavgframe['f1score'].to_numpy(), color=(0.5,0,1), label= "Fed-Avg", marker= ".")
# plt.plot(fedlwrframe['f1score'].to_numpy(), color=(0.5,1,0.9), label= "Fed-LWR", marker= ".")
# plt.plot(fedpidframe['f1score'].to_numpy(), color=(0.5,1,0), label= "Fed-PID", marker= ".")
# plt.plot(fedrefframe['f1score'].to_numpy(), color=(1,0,0), label= "Fed-Ref", marker= ".")
# plt.legend(fontsize=16)

In [1]:
from utils.CustomDataset import OCTDL
test=OCTDL("Data/OCTDL/clients/1")
print(test.patients)
test.label_weight

['NO', 'AMD', 'DME', 'ERM', 'RAO', 'RVO', 'VID']


tensor([0.8177, 0.4323, 0.8958, 0.9167, 0.9896, 0.9688, 0.9792])