In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import argparse
from torch.utils.data import random_split
import pandas as pd
import numpy as np
from sklearn.metrics import *
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from shutil import copyfile, make_archive
from IPython.display import FileLink

In [None]:
data_dir = '/kaggle/input/isic-2017-preprocessed-augmented/content/Linear_Exact_Aug'


mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

data_transforms = {
    'Train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    'Test': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['Train', 'Test']}
image_datasets['Valid']=datasets.ImageFolder(os.path.join(data_dir,'Valid'),data_transforms['Train'])
train_size=len(image_datasets['Train']) 
validation_size=len(image_datasets['Valid'])

dataloaders={x: torch.utils.data.DataLoader(image_datasets[x],batch_size=8,shuffle=True,num_workers=10)
             for x in ['Train','Valid']}

dataset_sizes = {'train':train_size,'val':validation_size,'test':len(image_datasets['Test'])}
class_names = image_datasets['Train'].classes
num_classes = len(class_names)

In [None]:
class SelfAttention(nn.Module):
    """ Self attention Layer"""
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.chanel_in = in_dim
        #print(in_dim)
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        m_batchsize, C, width, height = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) # B X (N) X C
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) # B X C x (*W*H)
        energy = torch.bmm(proj_query, proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)
        out = self.gamma * out + x
        return out, attention


In [None]:
class VGG11WithSelfAttention(nn.Module):
    def __init__(self, num_classes):
        super(VGG11WithSelfAttention, self).__init__()
        # Load pre-trained VGG11 model
        self.vgg11 = models.vgg11_bn(pretrained=True)
        
        # Find the last convolutional layer
        last_conv_layer = None
        for layer in reversed(self.vgg11.features):
            if isinstance(layer, nn.Conv2d):
                last_conv_layer = layer
                break
        
        # Get the number of output channels of the last convolutional layer
        num_ftrs = last_conv_layer.out_channels
        
        # Modify the classifier to include Self Attention and final classification layer
        self.self_attention = SelfAttention(num_ftrs)
        self.classifier = nn.Linear(num_ftrs, num_classes)
        
    def forward(self, x):
        features = self.vgg11.features(x)
        # Pass features through self-attention layer
        #print(features.size())
        out, _ = self.self_attention(features)
        # Apply ReLU activation
        out = F.relu(out, inplace=True)
        # Apply adaptive average pooling
        out = F.adaptive_avg_pool2d(out, (1, 1))
        # Flatten the tensor
        out = torch.flatten(out, 1)
        # Pass through the classifier
        out = self.classifier(out)
        return out

In [None]:
class Client:
    def __init__(self, client_id, data_dir):
        self.id = client_id  # Unique id for each client
        self.data_dir = data_dir
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.test_dir = "/kaggle/input/ham10000-data/HAM10000_DATA/test_dir"
        self.num_classes, self.dataloaders, self.dataset_sizes = self.initData()
        self.model = self.build_model()
        self.model_name = f'Client_{self.id}'

    def initData(self):
        '''
            Initialize the train and validation dataloaders
            Split the train directory into train and val dataloaders.
            80% images for train 20% for validation
        '''
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])

        data_transforms = {
            'train_dir': transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'test_dir': transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
        }

        image_datasets = {'train_dir': datasets.ImageFolder(self.data_dir,data_transforms['train_dir'])
                          }

        train_size = int(0.8 * len(image_datasets['train_dir']))  # 80% of images in train_dir will form training set
        validation_size = len(image_datasets['train_dir']) - train_size  # Remaining 20% of images in train_dir will form validation set
        train_dataset, validation_dataset = random_split(image_datasets['train_dir'], [train_size, validation_size])
        dataloaders = {}
        dataloaders['train'] = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=10)
        dataloaders['val'] = torch.utils.data.DataLoader(validation_dataset, batch_size=8, shuffle=True, num_workers=10)

        dataset_sizes = {'train': len(train_dataset), 'val': len(validation_dataset)}
        class_names = image_datasets['train_dir'].classes
        num_classes = len(class_names)
        return num_classes, dataloaders, dataset_sizes

    def build_model(self):
        # Replace VGG11WithSelfAttention with your actual model class
        model = VGG11WithSelfAttention(self.num_classes).to(self.device)
        return model
    
    def calculate_confidences(self, model_paths, validation_loaders):
        confidences = []        
        models=[]
        for path in model_paths:
            model=self.build_model()
            state_dict=torch.load(path)
            model.load_state_dict(state_dict)
            #model.to(self.device)
            model.eval()
            models.append(model)
        loader = validation_loaders
        for model in models: #zip(models, validation_loaders):            
            total_confidence = 0
            with torch.no_grad():
                for inputs, targets in loader:
                    inputs = inputs.to(self.device)
                    targets = targets.to(self.device)
                    outputs = model(inputs)
                    probabilities = torch.nn.functional.softmax(outputs, dim=1)
                    total_confidence += probabilities.max(dim=1)[0].mean().item()
            average_confidence = total_confidence / len(loader)
            confidences.append(average_confidence)

        total_confidence = sum(confidences)
        normalized_confidences = [conf / total_confidence for conf in confidences]
        return normalized_confidences
    
    def confidence_weighted_aggregation(self, model_paths, confidences):
        aggregated_weights = {}        
        models=[]
        for path in model_paths:
            model=self.build_model()
            state_dict=torch.load(path)
            model.load_state_dict(state_dict)
            #model.to(self.device)
            model.eval()
            models.append(model)
        
        confidences_tensor = torch.tensor(confidences, device=self.device)
        aggregated_weights = {}
        if len(models) == 1:
            # If there's only one model, directly use its weights
            aggregated_weights = models[0].state_dict()
        else:
            for key in models[0].state_dict().keys():
                layer_weights = torch.stack([model.state_dict()[key] for model in models], dim=0)
                # Adjust the view of confidences_tensor to match the dimensions of layer_weights
                if layer_weights.dim() == 1:  # Bias terms
                    weighted_avg = torch.sum(layer_weights * confidences_tensor, dim=0) / torch.sum(confidences_tensor)
                elif layer_weights.dim() == 2:  # Fully connected layers
                    weighted_avg = torch.sum(layer_weights * confidences_tensor.view(-1, 1), dim=0) / torch.sum(confidences_tensor)
                elif layer_weights.dim() == 4:  # Convolutional layers
                    weighted_avg = torch.sum(layer_weights * confidences_tensor.view(-1, 1, 1, 1), dim=0) / torch.sum(confidences_tensor)
                else:  # Catch-all case for other dimensions
                    expanded_confidences = confidences_tensor.view([-1] + [1] * (layer_weights.dim() - 1))
                    weighted_avg = torch.sum(layer_weights * expanded_confidences, dim=0) / torch.sum(confidences_tensor)
                aggregated_weights[key] = weighted_avg                

        return aggregated_weights
    
    def set_weights(self, wts):
        self.model.load_state_dict(torch.load(wts))  # Load weights to client model.

    def plot_loss(self, train_loss, val_loss):
        epochs = range(1, len(train_loss) + 1)
        plt.plot(epochs, train_loss, 'b', label='Training loss')
        plt.plot(epochs, val_loss, 'r', label='Validation loss')
        plt.title('Training and Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

    def plot_acc(self, train_acc, val_acc):
        epochs = range(1, len(train_acc) + 1)
        plt.plot(epochs, train_acc, 'b', label='Training accuracy')
        plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
        plt.title('Training and Validation Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.show()

    def aggregate_weights(self, prev_models):
        prev_models = [torch.load(path) for path in prev_models]
        avg_state_dict = copy.deepcopy(prev_models[0])
        for key in avg_state_dict.keys():
            for i in range(1, len(prev_models)):
                avg_state_dict[key] += prev_models[i][key]
            avg_state_dict[key] = torch.div(avg_state_dict[key], len(prev_models))
        self.model.load_state_dict(avg_state_dict)

    def train_model(self, epochs=1, prev_models=None):
        if prev_models:
            #self.aggregate_weights(prev_models)
            confidences=self.calculate_confidences(prev_models,self.dataloaders['val'])            
            new_wts=self.confidence_weighted_aggregation(prev_models,confidences)
             # Debugging step: print shapes of new_wts and current model parameters
#             for key in new_wts.keys():
#                 print(f"Layer: {key}, Aggregated weight shape: {new_wts[key].shape}, Model weight shape: {self.model.state_dict()[key].shape}")
            
            self.model.load_state_dict(new_wts)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(self.model.parameters(), lr=0.0001, momentum=0.99)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        val_loss_gph = []
        train_loss_gph = []
        val_acc_gph = []
        train_acc_gph = []
        since = time.time()

        best_model_wts = copy.deepcopy(self.model.state_dict())
        best_loss = float("inf")
        best_acc = 0
        for epoch in range(epochs):
            print('Epoch {}/{}'.format(epoch + 1, epochs))
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    self.model.train()  # Set model to training mode
                else:
                    self.model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0
                p = phase

                # Wrap data loader with tqdm for progress bar
                data_loader = tqdm(self.dataloaders[p], desc=f'{phase} Epoch {epoch + 1}/{epochs}')

                for inputs, labels in data_loader:
                    inputs = inputs.to(self.device)
                    labels = labels.to(self.device)

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = self.model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                    # Update tqdm progress bar description
                    data_loader.set_postfix({'loss': loss.item()})

                if phase == 'train':
                    scheduler.step()
                    epoch_loss = running_loss / self.dataset_sizes[p]
                    epoch_acc = running_corrects.double() / self.dataset_sizes[p]
                    train_loss_gph.append(epoch_loss)
                    train_acc_gph.append(epoch_acc)
                else:
                    epoch_loss = running_loss / self.dataset_sizes[p]
                    epoch_acc = running_corrects.double() / self.dataset_sizes[p]
                    val_loss_gph.append(epoch_loss)
                    val_acc_gph.append(epoch_acc)

                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(self.model.state_dict())
                    torch.save(self.model.state_dict(), "/kaggle/working/" + "/" + self.model_name + ".pth")
                    print('==>Model Saved')

            print()

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('Best val acc: {:4f}'.format(best_acc))
        for i in range(len(train_acc_gph)):
            train_acc_gph[i]=train_acc_gph[i].cpu() #Convert to numpy tensors to enable plotting

        for i in range(len(val_acc_gph)):
            val_acc_gph[i]=val_acc_gph[i].cpu()
        self.plot_loss(train_loss_gph, val_loss_gph)
        self.plot_acc(train_acc_gph, val_acc_gph)
        self.model.load_state_dict(best_model_wts)

    def metrics(self, labels, predictions, classes, y_true, y_prob):
        print("Classification Report:")
        print(classification_report(labels, predictions, target_names=classes, digits=4))
        matrix = confusion_matrix(labels, predictions)
        print("Confusion matrix:")
        print(matrix)
        print("Heat map:")
        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(matrix, annot=True, xticklabels=classes, yticklabels=classes, cmap=plt.cm.Blues, fmt='.2f')
        plt.ylabel('Actual Classes')
        plt.xlabel('Predicted Classes')
        plt.show(block=False)

        print("Precision: " + str(precision_score(labels, predictions, average='weighted')))
        print("Recall: " + str(recall_score(labels, predictions, average='weighted')))
        print("Accuracy: " + str(accuracy_score(labels, predictions)))
        f1 = f1_score(labels, predictions, average='weighted')
        print("F1 Score: " + str(f1))

        print("Precision: " + str(precision_score(labels, predictions, average='macro')))
        print("Recall: " + str(recall_score(labels, predictions, average='macro')))
        print("Accuracy: " + str(accuracy_score(labels, predictions)))
        f1 = f1_score(labels, predictions, average='macro')
        print("F1 Score: " + str(f1))

        print("Precision: " + str(precision_score(labels, predictions, average='micro')))
        print("Recall: " + str(recall_score(labels, predictions, average='micro')))
        print("Accuracy: " + str(accuracy_score(labels, predictions)))
        f1 = f1_score(labels, predictions, average='micro')
        print("F1 Score: " + str(f1))
        fpr = dict()
        tpr = dict()
        roc_auc = dict()

        for i in range(len(classes)):
            fpr[i], tpr[i], _ = roc_curve(y_true == i, y_prob[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])

        # Plot the ROC curves
        plt.figure()
        for i in range(len(classes)):
            plt.plot(fpr[i], tpr[i], label=f'Class {i} (AUC = {roc_auc[i]:.2f})')

        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve')
        plt.legend(loc='lower right')
        plt.show()
        print("weighted Roc score: " + str(roc_auc_score(y_true, y_prob, multi_class='ovr', average='weighted')))
        print("macro Roc score: " + str(roc_auc_score(y_true, y_prob, multi_class='ovr', average='macro')))
        print("micro Roc score: " + str(roc_auc_score(y_true, y_prob, multi_class='ovr', average='micro')))

        print("\nClasswise Accuracy :{}".format(matrix.diagonal() / matrix.sum(axis=1)))
        print("\nBalanced Accuracy Score: ", balanced_accuracy_score(labels, predictions))

    def get_predictions(self):
        testloader = torch.utils.data.DataLoader(image_datasets['Test'], batch_size=4, shuffle=False, num_workers=4)
        self.model.eval()
        predictions = []
        true_labels = []
        criterion = nn.CrossEntropyLoss()
        test_loss = 0.0
        num_samples = 0

        y_true = []  # For ROC curve
        y_prob = []  # For ROC curve
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(inputs)
                _, predicted = torch.max(outputs, 1)
                prob = torch.nn.functional.softmax(outputs, dim=1)
                # Compute the loss
                loss = criterion(outputs, labels)

                # Accumulate the loss
                test_loss += loss.item() * inputs.size(0)

                # Update the number of samples
                num_samples += inputs.size(0)

                y_true.append(labels.cpu().numpy())  # Convert labels to numpy array and move to CPU
                y_prob.append(prob.cpu().numpy())  # Convert probabilities to numpy array and move to CPU

                predictions.extend(predicted.cpu().numpy())
                true_labels.extend(labels.cpu().numpy())

        y_true = np.concatenate(y_true)  # Concatenate true labels across all batches
        y_prob = np.concatenate(y_prob)  # Concatenate predicted probabilities across all batches
        # Compute the average loss
        avg_test_loss = test_loss / num_samples
        print(f"Test Loss: {avg_test_loss:.4f}")

        self.metrics(np.array(true_labels), np.array(predictions), ['melanoma', 'nevus', 'seborrheic_keartosis'], y_true, y_prob)

    def get_weights(self):
        return self.model.state_dict()

In [None]:
ob1,ob2,ob3,ob4 = Client(1,'/kaggle/input/isic20174clients/kaggle/working/isic2017-clients/partition0'),Client(2,'/kaggle/input/isic20174clients/kaggle/working/isic2017-clients/partition1'),Client(3,'/kaggle/input/isic20174clients/kaggle/working/isic2017-clients/partition2'),Client(4,'/kaggle/input/isic20174clients/kaggle/working/isic2017-clients/partition3')
num_epochs=30

In [None]:
ob1.train_model(epochs=num_epochs)

In [None]:
ob2.train_model(epochs=num_epochs,prev_models=["/kaggle/working/Client_1.pth"])

In [None]:
ob3.train_model(epochs=num_epochs,prev_models=["/kaggle/working/Client_1.pth","/kaggle/working/Client_2.pth"])

In [None]:
ob4.train_model(epochs=num_epochs,prev_models=["/kaggle/working/Client_1.pth","/kaggle/working/Client_2.pth","/kaggle/working/Client_3.pth"])

In [None]:
ob3.set_weights("/kaggle/working/Client_4.pth")
ob2.set_weights("/kaggle/working/Client_4.pth")
ob1.set_weights("/kaggle/working/Client_4.pth")

In [None]:
ob4.get_predictions()

In [None]:
ob3.get_predictions()

In [None]:
ob2.get_predictions()

In [None]:
ob1.get_predictions()