In [2]:
import torch
import torch.nn as nn
# !pip install easyfsl
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

class PrototypicalNetworks_dynamic_query(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks_dynamic_query, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images,
        support_labels,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        num_classes=len(list(support_labels))
        prototypes = []
        for class_label in range(num_classes):
            # print(type(support_images[class_label]),torch.cat(support_images[class_label].shape)
            support_images_eachclass=(support_images[class_label])
            # print("kk", support_images_eachclass.shape)
            class_features = self.backbone(support_images_eachclass.to(device))  # Select features for the current class
            class_prototype = class_features.mean(dim=0)  # Compute the mean along the batch dimension
            prototypes.append(class_prototype)
        prototypes = torch.stack(prototypes)


        distances=[]
        for each_query_class in query_images:

          query_features = self.backbone(each_query_class)

          # Compute the distance between query features and prototypes
          distance = torch.cdist(query_features, prototypes)
          distances.append(-distance)

        # print(distances.shape)

        return distances


In [3]:
!ls '/home/asufian/Desktop/output_olchiki/code/MAPI Mayee'

'MAPI Mayeek'


In [4]:
root_path = '/home/asufian/Desktop/output_olchiki/code/MAPI Mayee'

In [5]:
import os
import random
from collections import defaultdict
from PIL import Image, ImageOps
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class PrototypicalOmniglotDataset(Dataset):
    def __init__(self, root, num_classes=1623, n_shot=5, n_query=10, transform=None):
        self.root = root
        self.num_classes = num_classes
        self.n_shot = n_shot
        self.n_query = n_query
        self.transform = transform
        self.samples_by_label = defaultdict(list)
        self.all_imgs = {}
        self.classes = []

        # Common image file extensions
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp']

        # Organize data into classes and image paths
        for alphabet in os.listdir(self.root):
            alphabet_path = os.path.join(self.root, alphabet)
            if not os.path.isdir(alphabet_path):
                continue

            for char_class in os.listdir(alphabet_path):
                char_class_path = os.path.join(alphabet_path, char_class)

                # Skip if the entry is not a directory
                if not os.path.isdir(char_class_path):
                    continue

                all_images = []
                for img_name in sorted(os.listdir(char_class_path)):
                    img_path = os.path.join(char_class_path, img_name)

                    # Check if it's a file with a recognized image extension
                    if os.path.isfile(img_path) and any(img_name.lower().endswith(ext) for ext in image_extensions):
                        all_images.append(img_path)

                if all_images:
                    char_class_name = f"{alphabet}_{char_class}"
                    self.samples_by_label[char_class_name] = list(range(len(all_images)))
                    self.all_imgs[char_class_name] = all_images
                    self.classes.append(char_class_name)
                    # print(char_class_name)

    # def transform_image(self, raw_img):
    #     img = ImageOps.invert(raw_img)
    #     if self.transform is not None:
    #         img = self.transform(img)
    #     return img
    def transform_image(self, raw_img):
        img = raw_img#ImageOps.invert(raw_img)
        if self.transform is not None:
            # print(self.transform)
            img = self.transform(img)
        # print(img.shape)
        return img

    def __getitem__(self, selected_classes,selected_supports):#
        # selected_classes = random.sample(self.classes, self.num_classes)
        class_indices = [self.samples_by_label[each_cls] for each_cls in selected_classes]
        # print("selected_classes: ",selected_classes)
        support_set = []
        query_set = []
        label_id = 0
        qs=[]
        ko=0
        ghj=[]
        for idx_set in class_indices:
            # Creating support set
            selected_support =selected_supports[ko]# random.sample(idx_set, self.n_shot)
            ko=ko+1
            ghj.append(selected_support)
            # print("selected_support images: ", ghj)
            support_images = [self.transform_image(Image.open(self.all_imgs[selected_classes[label_id]][each]).convert('L')) for each in selected_support]

            # Creating query set
            selected_query = [item for item in idx_set if item not in selected_support]
            query_images = [self.transform_image(Image.open(self.all_imgs[selected_classes[label_id]][each]).convert('L')) for each in selected_query]
            support_set.append((support_images, [label_id for _ in range(self.n_shot)]))
            # print("aaaaaaaaaa",label_id)
            query_set.append((query_images,  [label_id for _ in range(len(selected_query))]))

            label_id += 1
            # qs= [item for item in a for _ in range(self.n_query)]
        return support_set, query_set

    def __len__(self):
        return 10

# Set up the transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.ToTensor(),
])

# Set up the SiameseOmniglotDataset
# Training == evaluation
# root_path = './assamese'
num_classes = 5  # Set the desired number of classes per episode
n_shot = 5
n_query = 10

PrototypicalOmniglotDatasetLoader = PrototypicalOmniglotDataset(root=root_path, num_classes=num_classes, n_shot=n_shot, n_query=n_query, transform=transform)


In [6]:
import torch
import torch.nn as nn

class PrototypicalNetworks_dynamic_query(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks_dynamic_query, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images,
        support_labels,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        num_classes=len(list(support_labels))
        prototypes = []
        for class_label in range(num_classes):
            # print(type(support_images[class_label]),torch.cat(support_images[class_label].shape)
            support_images_eachclass=(support_images[class_label])
            # print("kk", support_images_eachclass.shape)
            class_features = self.backbone(support_images_eachclass.to(device))  # Select features for the current class
            class_prototype = class_features.mean(dim=0)  # Compute the mean along the batch dimension
            prototypes.append(class_prototype)
        prototypes = torch.stack(prototypes)


        distances=[]
        for each_query_class in query_images:

          query_features = self.backbone(each_query_class)

          # Compute the distance between query features and prototypes
          distance = torch.cdist(query_features, prototypes)
          distances.append(-distance)

        # print(distances.shape)

        return distances


In [7]:
class PrototypicalNetworks33(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks33, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images,
        support_labels,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        num_classes = len(support_labels)
        prototypes = []
        for class_label in range(num_classes):
            support_images_eachclass = support_images[class_label]
            class_features = torch.stack([self.backbone(each.unsqueeze(0)) for each in support_images_eachclass])
            prototype_bymean = class_features.mean(dim=0)
            # print(prototype_bymean.shape)
            prototypes.append(prototype_bymean)

        query_features = torch.stack([self.backbone(each.unsqueeze(0)) for each in query_images])
        # print(query_features.shape,query_features[0].shape,torch.stack(prototypes).shape,torch.stack(prototypes)[0].shape)
        # distances = -torch.cdist(query_features, torch.stack(prototypes))
        distances = []  # Initialize list to store distances
        prototypes=torch.stack(prototypes)
        for each_query_feature in query_features:
            d = torch.cdist(prototypes, each_query_feature)  # Compute distances
            distances.append(d.squeeze())  # Remove singleton dimension and store
            # print(d.shape)  # Output: [5, 1], shape of each distance matrix

        distances = torch.stack(distances)

        return -distances


In [9]:
import pandas as pd
import os
import csv
def save_in_csv(cls,flnm,conf_matrix, acciuracy,precision_overall, recall_overall, f_betas):
  with open(flnm, 'w', newline='') as csvfile:
      writer = csv.writer(csvfile)
      headers=['Accuracy', 'Precision', 'Recall']
      beta=-5
      values=[]
      values.append(acciuracy)
      values.append(precision_overall)
      values.append(recall_overall)
      for each in f_betas:
              headers.append('F-betas (' +str(beta)+' )')
              beta=beta+1
              values.append(each)

    # Write confusion matrix
      writer.writerow(['used Class:'])
      writer.writerows([cls])
      writer.writerow([])
      writer.writerow([])
      writer.writerow(['Confusion Matrix'])
      writer.writerows(conf_matrix)
      writer.writerow([])
      writer.writerow([])  # Add an empty row for separation

      # Write metrics

      writer.writerow(headers)
      writer.writerow(values)


  # print("Metrics have been written to 'metrics.csv' file.")
def save_in_text_file(msg,cls, flnm, conf_matrix, accuracy, precision_overall, recall_overall, f_betas):
    with open(flnm, 'a') as textfile:
        textfile.write(str("\n\n\n                                 "+msg))
        textfile.write("\n\n\n")
        textfile.write("used Class:\n")
        textfile.write(str(cls) + "\n\n")

        textfile.write("Confusion Matrix:\n")
        for row in conf_matrix:
            textfile.write(' '.join([str(elem) for elem in row]) + "\n")
        textfile.write("\n")

        textfile.write("Metrics:\n")
        textfile.write("Accuracy: {}\n".format(accuracy))
        textfile.write("Precision: {}\n".format(precision_overall))
        textfile.write("Recall: {}\n".format(recall_overall))

        beta = -5
        for f_beta in f_betas:
            textfile.write("F-betas ({}): {}\n".format(beta, f_beta))
            beta += 1

    # print("Metrics have been written to '{}' file.".format(flnm))

def save_in_excel_final(msg, cls, flnm, conf_matrix, accuracy, precision_overall, recall_overall):
    data = {
        "Used Class": [cls],
        "Used support images idx": [msg],
        "Confusion Matrix": [conf_matrix],
        "Accuracy": [accuracy],
        "Precision": [precision_overall],
        "Recall": [recall_overall]
    }
    # for i, f_beta in enumerate(f_betas):
    #     data[f"F-beta ({i - 5})"] = [f_beta]

    df_new = pd.DataFrame(data)
    df_new = df_new.transpose()  # Transpose the DataFrame

    if os.path.exists(flnm):
        df_existing = pd.read_excel(flnm, index_col=0)  # Read existing file
        df_combined = pd.concat([df_existing, df_new], axis=1)  # Concatenate along columns
        df_combined.to_excel(flnm)  # Write to Excel
    else:
        df_new.to_excel(flnm)  # Write new DataFrame to Excel

    # print("Metrics have been written to '{}' file.".format(flnm))


In [10]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch
import torch.nn as nn
import torch.hub

# Define a modified version of DenseNet169 with dropout after each convolutional layer
class DenseNet169WithDropout(nn.Module):
    def __init__(self,modell, pretrained=False, dr=0.4):
        super(DenseNet169WithDropout, self).__init__()
        self.densenet169 = torch.hub.load('pytorch/vision:v0.10.0', modell, pretrained=pretrained)
        self.dropout = nn.Dropout(p=dr)  # Dropout with probability 0.5

        # Modify each dense block to include dropout after each convolutional layer
        for name, module in self.densenet169.features.named_children():
            if isinstance(module, nn.Sequential):
                for sub_name, sub_module in module.named_children():
                    if isinstance(sub_module, nn.Conv2d):
                        setattr(module, sub_name, nn.Sequential(sub_module, self.dropout))

    def forward(self, x):
        return self.densenet169(x)

# Instantiate the modified DenseNet169 model with dropout
# convolutional_network_with_dropout = DenseNet169WithDropout(modell='resnet18',pretrained=False,dr=0.4)

# Example usage:
# output = convolutional_network_with_dropout(input_tensor)

class ResNet18WithDropout(nn.Module):
    def __init__(self, pretrained=False, dr=0.4):
        super(ResNet18WithDropout, self).__init__()
        self.resnet18 = models.resnet18(pretrained=pretrained)
        self.dropout = nn.Dropout(p=dr)

        # Modify each residual block to include dropout after each convolutional layer
        for name, module in self.resnet18.named_children():
            if isinstance(module, nn.Sequential):
                for sub_name, sub_module in module.named_children():
                    if isinstance(sub_module, nn.Conv2d):
                        setattr(module, sub_name, nn.Sequential(sub_module, self.dropout))

    def forward(self, x):
        return self.resnet18(x)

# Instantiate the modified ResNet18 model with dropout
# convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.4)


In [11]:


clss_5=[
     ['MAPI Mayeek_GHOU', 'MAPI Mayeek_JIL', 'MAPI Mayeek_SAM', 'MAPI Mayeek_KOK', 'MAPI Mayeek_MIT'],
      ['MAPI Mayeek_KOK', 'MAPI Mayeek_DIL', 'MAPI Mayeek_PAA', 'MAPI Mayeek_ATIYA', 'MAPI Mayeek_MIT'],
     ['MAPI Mayeek_DHOU', 'MAPI Mayeek_KOK', 'MAPI Mayeek_RAAI', 'MAPI Mayeek_TIL', 'MAPI Mayeek_GHOU'],
      ['MAPI Mayeek_FAM', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_LAI', 'MAPI Mayeek_PAA', 'MAPI Mayeek_KOK'],
     ['MAPI Mayeek_KOK', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_MIT', 'MAPI Mayeek_WAI'],
['MAPI Mayeek_KOK', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_MIT', 'MAPI Mayeek_WAI'],
     ['MAPI Mayeek_NAA', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_BAA', 'MAPI Mayeek_LAI', 'MAPI Mayeek_RAAI'],
     ['MAPI Mayeek_RAAI', 'MAPI Mayeek_FAM', 'MAPI Mayeek_GOK', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_SAM'],
     ['MAPI Mayeek_EEE', 'MAPI Mayeek_KOK', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_WAI'],
     ['MAPI Mayeek_DHOU', 'MAPI Mayeek_RAAI', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_UUN', 'MAPI Mayeek_LAI'],
     ['MAPI Mayeek_HUK', 'MAPI Mayeek_FAM', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_THOU', 'MAPI Mayeek_SAM'],
     ['MAPI Mayeek_MIT', 'MAPI Mayeek_LAI', 'MAPI Mayeek_DIL', 'MAPI Mayeek_ATIYA', 'MAPI Mayeek_BHAM'],
     ['MAPI Mayeek_NGOU', 'MAPI Mayeek_FAM', 'MAPI Mayeek_ATIYA', 'MAPI Mayeek_MIT', 'MAPI Mayeek_JHAM'],
     ['MAPI Mayeek_GOK', 'MAPI Mayeek_ATIYA', 'MAPI Mayeek_LAI', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_KHOU'],
['MAPI Mayeek_HUK', 'MAPI Mayeek_DIL', 'MAPI Mayeek_KOK', 'MAPI Mayeek_PAA', 'MAPI Mayeek_LAI']

     ]

clss_support_imagesss_1_shot=[[[3], [56], [93], [62], [23]],
             [[34], [18], [31], [40], [39]],
             [[54], [61], [9], [50], [93]],
             [[85], [9], [35], [6], [36]],
             [[93], [87], [60], [26], [35]],
             [[2], [39], [84], [38], [30]],
             [[6], [12], [43], [65], [66]],
             [[6], [66], [6], [13], [89]],
             [[50], [70], [13], [61], [7]],
             [[54], [58], [26], [75], [51]],
             [[14], [27], [92], [76], [27]],
             [[47], [34], [84], [45], [82]],
             [[63], [69], [25], [80], [61]],
             [[13], [41], [57], [13], [46]],
             [[58], [92], [48], [92], [35]]
]

clss_support_imagesss_5_shot=[[[3, 21, 2, 72, 87], [56, 22, 58, 3, 24], [93, 76, 48, 81, 68], [62, 82, 2, 87, 69], [23, 73, 83, 92, 55]], [[34, 22, 25, 30, 74], [18, 70, 45, 20, 46], [31, 83, 42, 49, 14], [40, 87, 93, 12, 92], [39, 20, 70, 74, 84]], [[54, 11, 2, 80, 8], [61, 72, 52, 17, 95], [9, 2, 37, 15, 26], [50, 48, 2, 79, 63], [93, 48, 0, 24, 89]], [[85, 19, 68, 55, 56], [9, 42, 35, 93, 43], [35, 29, 91, 2, 68], [6, 56, 30, 50, 25], [36, 42, 19, 21, 41]], [[93, 9, 38, 61, 68], [87, 33, 27, 36, 53], [60, 21, 53, 49, 17], [26, 2, 11, 62, 4], [35, 78, 30, 79, 84]], [[2, 70, 12, 56, 5], [39, 59, 30, 45, 95], [84, 86, 14, 66, 54], [38, 54, 9, 46, 34], [30, 72, 47, 6, 26]], [[6, 1, 61, 70, 8], [12, 33, 77, 55, 51], [43, 56, 94, 74, 45], [65, 27, 13, 11, 6], [66, 93, 36, 68, 27]], [[6, 3, 81, 24, 20], [66, 73, 68, 55, 76], [6, 14, 55, 3, 38], [13, 86, 45, 89, 18], [89, 65, 74, 8, 42]], [[50, 91, 0, 7, 51], [70, 50, 17, 33, 64], [13, 32, 8, 38, 72], [61, 43, 74, 86, 95], [7, 34, 70, 20, 36]], [[54, 61, 75, 15, 50], [58, 7, 51, 93, 56], [26, 89, 91, 57, 24], [75, 49, 33, 88, 50], [51, 92, 53, 54, 52]], [[14, 80, 78, 27, 3], [27, 69, 59, 63, 93], [92, 60, 56, 86, 64], [76, 30, 64, 49, 12], [27, 30, 35, 64, 56]], [[47, 94, 69, 51, 5], [34, 28, 59, 14, 67], [84, 34, 23, 5, 16], [45, 0, 57, 88, 76], [82, 78, 30, 12, 49]], [[63, 89, 62, 51, 88], [69, 75, 83, 49, 45], [25, 71, 45, 26, 72], [80, 92, 22, 41, 62], [61, 25, 40, 14, 94]], [[13, 82, 90, 45, 14], [41, 60, 79, 40, 67], [57, 94, 27, 77, 5], [13, 24, 54, 1, 94], [46, 62, 69, 2, 85]], [[58, 5, 69, 35, 30], [92, 88, 79, 58, 31], [48, 43, 78, 9, 15], [92, 87, 34, 6, 75], [35, 16, 4, 21, 46]]]

clss_support_imagesss_10_shot=[[[3, 21, 2, 72, 87, 60, 81, 24, 28, 0], [56, 22, 58, 3, 24, 6, 89, 25, 59, 39], [93, 76, 48, 81, 68, 39, 40, 17, 58, 89], [62, 82, 2, 87, 69, 65, 4, 35, 9, 59], [23, 73, 83, 92, 55, 14, 82, 38, 6, 70]], [[34, 22, 25, 30, 74, 12, 26, 79, 63, 78], [18, 70, 45, 20, 46, 77, 23, 35, 29, 22], [31, 83, 42, 49, 14, 74, 32, 41, 72, 55], [40, 87, 93, 12, 92, 32, 73, 16, 44, 83], [39, 20, 70, 74, 84, 73, 87, 88, 66, 54]], [[54, 11, 2, 80, 8, 93, 75, 35, 86, 50], [61, 72, 52, 17, 95, 26, 46, 64, 88, 40], [9, 2, 37, 15, 26, 59, 32, 24, 83, 51], [50, 48, 2, 79, 63, 42, 64, 81, 72, 58], [93, 48, 0, 24, 89, 74, 42, 88, 71, 35]], [[85, 19, 68, 55, 56, 14, 93, 49, 90, 9], [9, 42, 35, 93, 43, 89, 4, 92, 84, 33], [35, 29, 91, 2, 68, 51, 52, 39, 59, 85], [6, 56, 30, 50, 25, 38, 85, 84, 93, 47], [36, 42, 19, 21, 41, 72, 62, 86, 25, 18]], [[93, 9, 38, 61, 68, 22, 48, 77, 45, 76], [87, 33, 27, 36, 53, 77, 90, 93, 39, 9], [60, 21, 53, 49, 17, 41, 36, 81, 63, 7], [26, 2, 11, 62, 4, 3, 66, 82, 7, 0], [35, 78, 30, 79, 84, 67, 65, 19, 66, 41]], [[2, 70, 12, 56, 5, 1, 76, 53, 69, 73], [39, 59, 30, 45, 95, 60, 85, 62, 70, 67], [84, 86, 14, 66, 54, 24, 63, 52, 93, 49], [38, 54, 9, 46, 34, 84, 8, 35, 7, 55], [30, 72, 47, 6, 26, 48, 13, 67, 19, 52]], [[6, 1, 61, 70, 8, 79, 46, 62, 85, 95], [12, 33, 77, 55, 51, 0, 19, 42, 85, 69], [43, 56, 94, 74, 45, 95, 29, 6, 0, 4], [65, 27, 13, 11, 6, 53, 16, 5, 4, 47], [66, 93, 36, 68, 27, 86, 60, 70, 58, 57]], [[6, 3, 81, 24, 20, 59, 32, 73, 91, 22], [66, 73, 68, 55, 76, 41, 91, 88, 59, 74], [6, 14, 55, 3, 38, 19, 85, 34, 62, 10], [13, 86, 45, 89, 18, 22, 59, 11, 73, 72], [89, 65, 74, 8, 42, 85, 48, 29, 50, 13]], [[50, 91, 0, 7, 51, 19, 11, 78, 59, 13], [70, 50, 17, 33, 64, 92, 94, 84, 54, 81], [13, 32, 8, 38, 72, 51, 66, 85, 62, 74], [61, 43, 74, 86, 95, 80, 32, 59, 50, 77], [7, 34, 70, 20, 36, 30, 67, 9, 94, 43]], [[54, 61, 75, 15, 50, 24, 37, 84, 27, 52], [58, 7, 51, 93, 56, 38, 35, 82, 46, 50], [26, 89, 91, 57, 24, 7, 86, 12, 3, 44], [75, 49, 33, 88, 50, 82, 8, 12, 81, 56], [51, 92, 53, 54, 52, 26, 27, 74, 22, 72]], [[14, 80, 78, 27, 3, 1, 6, 32, 81, 2], [27, 69, 59, 63, 93, 3, 5, 44, 19, 94], [92, 60, 56, 86, 64, 47, 48, 24, 79, 38], [76, 30, 64, 49, 12, 15, 14, 1, 28, 21], [27, 30, 35, 64, 56, 92, 53, 0, 60, 41]], [[47, 94, 69, 51, 5, 93, 22, 78, 73, 36], [34, 28, 59, 14, 67, 48, 74, 6, 46, 10], [84, 34, 23, 5, 16, 40, 28, 92, 20, 12], [45, 0, 57, 88, 76, 69, 17, 23, 77, 61], [82, 78, 30, 12, 49, 10, 79, 2, 50, 19]], [[63, 89, 62, 51, 88, 28, 34, 64, 16, 78], [69, 75, 83, 49, 45, 19, 23, 72, 3, 0], [25, 71, 45, 26, 72, 43, 27, 41, 81, 94], [80, 92, 22, 41, 62, 87, 5, 33, 12, 49], [61, 25, 40, 14, 94, 85, 46, 67, 87, 30]], [[13, 82, 90, 45, 14, 31, 46, 18, 92, 40], [41, 60, 79, 40, 67, 92, 64, 86, 72, 13], [57, 94, 27, 77, 5, 89, 6, 12, 52, 30], [13, 24, 54, 1, 94, 40, 64, 35, 6, 51], [46, 62, 69, 2, 85, 34, 0, 28, 11, 1]], [[58, 5, 69, 35, 30, 28, 4, 22, 21, 25], [92, 88, 79, 58, 31, 25, 15, 34, 9, 38], [48, 43, 78, 9, 15, 73, 91, 65, 52, 13], [92, 87, 34, 6, 75, 44, 83, 55, 47, 29], [35, 16, 4, 21, 46, 56, 14, 59, 43, 19]]]


In [1]:
# !ls /home/asufian/Desktop/output_olchiki/code/

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

!rm 'model_own_path.pt'
!rm 'model_mu_path.pt'

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/mapi_sce/ressecondary/model_1/model_B_mapi_1-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [27]:
# !ls '/home/asufian/Desktop/output_olchiki'

In [28]:
criterion = nn.CrossEntropyLoss()

In [29]:
root_path

'/home/asufian/Desktop/output_olchiki/code/MAPI Mayee'

In [30]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import numpy as np
import numpy as np
import pandas as pd
# root_path = '/content/assamese'
num_classes = 5
n_shot = 1
n_query = 14
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformt_ugh = transforms.Compose([
    transforms.Resize((64, 64)),
    # transforms.RandomResizedCrop(64),  # Randomly crop the image and resize to 224x224
    # transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),  # Randomly adjust brightness, contrast, saturation, and hue
    # transforms.RandomAffine(degrees=0, translate=(0, 0), scale=(0.9, 1.1), shear=0),  # Random zoom (scaling)
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),  # Convert PIL Image to PyTorch Tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize the image with ImageNet statistics
])

print(f"Generated random values - num_classes: {num_classes}, n_shot: {n_shot}, n_query: {n_query}")

test_loader = PrototypicalOmniglotDataset(root=root_path, num_classes=num_classes, n_shot=n_shot, n_query=20, transform=transformt_ugh)

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    precision=0
    recall=0
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            # msg="\n\n       Way: "+str(len(clss_5[tlv_cls]))+"  Shot: "+str(len(clss_support_imagesss_1_shot[tlv_cls][0]))    +" \n\n"
            # msg=str(clss_5[tlv_cls])
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision=precision+precision_overall
            recall=recall+recall_overall
    print("---------------------->",ttlac/ttla,precision/ttla ,recall/ttla)
# %%%%%%%%%%%%%%%%%%%


pp='/home/asufian/Desktop/output_olchikifaltu'






Generated random values - num_classes: 5, n_shot: 1, n_query: 14


In [31]:
import numpy as np

conf_matrix = np.array([[100, 0, 0, 0, 0],
                        [0, 100, 0, 0, 0],
                        [0, 0, 100, 0, 0],
                        [0, 0, 0, 100, 0],
                        [0, 0, 0, 0, 100]])

def calculate_metrics_get_per(msg, conf_matrix,acciuracy, cls, flnm):
    if conf_matrix.size == 0:
        raise ValueError("Confusion matrix is empty!")

    metrics = []
    for i in range(conf_matrix.shape[0]):
        tp = conf_matrix[i, i]
        fp = np.sum(conf_matrix[:, i]) - tp
        fn = np.sum(conf_matrix[i, :]) - tp

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        metrics.append((precision, recall, f1))

    # Compute overall scores
    precision_overall = np.mean([m[0] for m in metrics])
    recall_overall = np.mean([m[1] for m in metrics])
    f1_overall = np.mean([m[2] for m in metrics])

    # Compute overall accuracy
    total_samples = np.sum(conf_matrix)
    correct_predictions = np.sum(np.diag(conf_matrix))
    accuracy_overall = correct_predictions / total_samples if total_samples > 0 else 0
    # print(accuracy_overall)

    # Save results (Assuming these functions exist)
    save_in_excel_final(msg, cls, flnm + ".xlsx", conf_matrix, accuracy_overall, precision_overall, recall_overall)
    # save_in_text_file(msg, cls, flnm + ".txt", conf_matrix, accuracy_overall, precision_overall, recall_overall, f_betas)

    return precision_overall, recall_overall, f1_overall


# precision, recall, f_beta = calculate_metrics(conf_matrix)
# print(f'Overall Metrics for beta={beta}:')
# print(f'Overall Precision: {precision:.4f}')
# print(f'Overall Recall: {recall:.4f}')
# print(f'Overall F-beta ({f_beta}): ')
# print()


In [32]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

           
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [33]:
def update_model_weights2(model_a, model_b, conv_ratio=0.8, fc_ratio=0.8, bias_ratio=0.8):
    """Updates model_a's weights using a weighted combination of model_a and model_b."""
    for name, param in model_a.named_parameters():
        if name in model_b.state_dict():
            weight_b = model_b.state_dict()[name]
            
            if 'weight' in name:
                # Separate conditions for convolutional and fully connected layers
                if 'conv' in name:  # Convolutional layer weights
                    # print('conv')
                    param.data = conv_ratio * weight_b + (1 - conv_ratio) * param.data
                elif 'fc' in name or 'linear' in name:  # Fully connected layer weights
                    # print('fc')
                    param.data = fc_ratio * weight_b + (1 - fc_ratio) * param.data
            
            elif 'bias' in name:
                param.data = bias_ratio * weight_b + (1 - bias_ratio) * param.data

In [34]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Function to get model accuracy
def get_model_accuracy(b_a_r, c_a_r, f_a_r):
    try:
        print(f"Testing b_a_r = {b_a_r:.4f}, c_a_r = {c_a_r:.4f}, f_a_r = {f_a_r:.4f}")  # Debugging Output
        convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
        convolutional_network_with_dropout.fc = nn.Flatten()

        xyz = convolutional_network_with_dropout
        M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
        M3.load_state_dict(torch.load('model_own_path.pt'))

        # Load Model M2 (Reference)
        convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
        convolutional_network_with_dropout.fc = nn.Flatten()

        xyz2 = convolutional_network_with_dropout
        M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
        M2.load_state_dict(torch.load('model_mu_path.pt'))

        # Apply bias adaptation to M3 using M2
        update_model_weights2(M3, M2, conv_ratio=c_a_r, fc_ratio=f_a_r, bias_ratio=b_a_r)
        
        # Evaluate accuracy
        accuracy = evaluate3(pp+'/_mapi_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)
        
        if accuracy is None:
            raise ValueError("Accuracy computation failed. Returning worst case.")

        return accuracy  # Ensure accuracy is valid
    except Exception as e:
        print(f"Error encountered: {e}. Assigning worst-case accuracy of 0.")
        return 0.0  # Assign lowest accuracy to prevent optimization failures


In [35]:
pp

'/home/asufian/Desktop/output_olchikifaltu'

In [36]:
criterion = nn.CrossEntropyLoss()


In [37]:

# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0, 1.0000, name="b_a_r"),
    Real(0, 1.0000, name="c_a_r"),
    Real(0, 1.0000, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.5785263157894739
Evaluated Accuracy: 0.5785
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.37038596491228076
Evaluated Accuracy: 0.3704
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.4498245614035088
Evaluated Accuracy: 0.4498
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.6471578947368422
Evaluated Accuracy: 0.6472
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.608982456140351
Evaluated Accuracy: 0.6090
Testing b_a_r = 0.0000, c_a_r = 0.0202, f_a_r = 0.0242




----------------------> 0.6512280701754386
Evaluated Accuracy: 0.6512
Testing b_a_r = 0.0000, c_a_r = 1.0000, f_a_r = 0.9437




----------------------> 0.22905263157894737
Evaluated Accuracy: 0.2291
Testing b_a_r = 0.0000, c_a_r = 0.0578, f_a_r = 0.0000




----------------------> 0.6454736842105262
Evaluated Accuracy: 0.6455
Testing b_a_r = 0.0000, c_a_r = 0.0688, f_a_r = 1.0000




----------------------> 0.5990175438596491
Evaluated Accuracy: 0.5990
Testing b_a_r = 1.0000, c_a_r = 0.0127, f_a_r = 0.4027




----------------------> 0.6623157894736842
Evaluated Accuracy: 0.6623
Testing b_a_r = 0.0000, c_a_r = 0.0646, f_a_r = 0.3779




----------------------> 0.6461754385964912
Evaluated Accuracy: 0.6462
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.5463




----------------------> 0.6621754385964914
Evaluated Accuracy: 0.6622
Testing b_a_r = 0.7691, c_a_r = 1.0000, f_a_r = 0.0000




----------------------> 0.2248421052631579
Evaluated Accuracy: 0.2248
Testing b_a_r = 0.9113, c_a_r = 0.0000, f_a_r = 0.2631




----------------------> 0.6548771929824562
Evaluated Accuracy: 0.6549
Testing b_a_r = 0.8864, c_a_r = 0.0000, f_a_r = 0.6079




----------------------> 0.6616140350877192
Evaluated Accuracy: 0.6616
Testing b_a_r = 0.4947, c_a_r = 0.5559, f_a_r = 0.9995




----------------------> 0.3941052631578947
Evaluated Accuracy: 0.3941
Testing b_a_r = 0.9859, c_a_r = 0.0357, f_a_r = 0.5337




----------------------> 0.6614736842105263
Evaluated Accuracy: 0.6615
Testing b_a_r = 0.8492, c_a_r = 0.0032, f_a_r = 0.0024




----------------------> 0.6541754385964913
Evaluated Accuracy: 0.6542
Testing b_a_r = 0.5319, c_a_r = 0.3137, f_a_r = 0.9968




----------------------> 0.42231578947368414
Evaluated Accuracy: 0.4223
Testing b_a_r = 0.6656, c_a_r = 0.7255, f_a_r = 0.0031




----------------------> 0.3367017543859649
Evaluated Accuracy: 0.3367
Testing b_a_r = 0.0246, c_a_r = 0.1737, f_a_r = 0.0015




----------------------> 0.576
Evaluated Accuracy: 0.5760
Testing b_a_r = 0.0687, c_a_r = 0.7619, f_a_r = 0.9980




----------------------> 0.28028070175438596
Evaluated Accuracy: 0.2803
Testing b_a_r = 0.9682, c_a_r = 0.1374, f_a_r = 0.4503




----------------------> 0.6204912280701755
Evaluated Accuracy: 0.6205
Testing b_a_r = 0.9711, c_a_r = 0.0621, f_a_r = 0.1779




----------------------> 0.6457543859649122
Evaluated Accuracy: 0.6458
Testing b_a_r = 0.9981, c_a_r = 0.6193, f_a_r = 0.4876




----------------------> 0.39424561403508773
Evaluated Accuracy: 0.3942
Testing b_a_r = 0.0748, c_a_r = 0.0160, f_a_r = 0.7132




----------------------> 0.6522105263157896
Evaluated Accuracy: 0.6522
Testing b_a_r = 0.9603, c_a_r = 0.2450, f_a_r = 0.4701




----------------------> 0.5389473684210526
Evaluated Accuracy: 0.5389
Testing b_a_r = 0.8759, c_a_r = 0.8543, f_a_r = 0.4162




----------------------> 0.2513684210526316
Evaluated Accuracy: 0.2514
Testing b_a_r = 0.9907, c_a_r = 0.0003, f_a_r = 0.4554




----------------------> 0.6620350877192982
Evaluated Accuracy: 0.6620
Testing b_a_r = 0.1526, c_a_r = 0.0179, f_a_r = 0.5155




----------------------> 0.6620350877192983
Evaluated Accuracy: 0.6620
Testing b_a_r = 0.0734, c_a_r = 0.4332, f_a_r = 0.6359




----------------------> 0.35663157894736847
Evaluated Accuracy: 0.3566
Testing b_a_r = 0.9287, c_a_r = 0.1048, f_a_r = 0.7001




----------------------> 0.6364912280701754
Evaluated Accuracy: 0.6365
Testing b_a_r = 0.9951, c_a_r = 0.0107, f_a_r = 0.5874




----------------------> 0.6628771929824562
Evaluated Accuracy: 0.6629
Testing b_a_r = 0.9112, c_a_r = 0.0344, f_a_r = 0.0013




----------------------> 0.6515087719298247
Evaluated Accuracy: 0.6515
Testing b_a_r = 0.7127, c_a_r = 0.5840, f_a_r = 0.0001




----------------------> 0.4350877192982455
Evaluated Accuracy: 0.4351
Testing b_a_r = 0.8723, c_a_r = 0.1005, f_a_r = 0.0173




----------------------> 0.6301754385964913
Evaluated Accuracy: 0.6302
Testing b_a_r = 0.9989, c_a_r = 0.1528, f_a_r = 0.9907




----------------------> 0.5792280701754386
Evaluated Accuracy: 0.5792
Testing b_a_r = 0.4967, c_a_r = 0.2606, f_a_r = 0.0027




----------------------> 0.512701754385965
Evaluated Accuracy: 0.5127
Testing b_a_r = 0.8886, c_a_r = 0.0293, f_a_r = 0.2931




----------------------> 0.6548771929824563
Evaluated Accuracy: 0.6549
Testing b_a_r = 0.8654, c_a_r = 0.0769, f_a_r = 0.5465




----------------------> 0.6548771929824561
Evaluated Accuracy: 0.6549
Testing b_a_r = 0.0854, c_a_r = 0.0020, f_a_r = 0.4113




----------------------> 0.6607719298245613
Evaluated Accuracy: 0.6608
Testing b_a_r = 0.1607, c_a_r = 0.0458, f_a_r = 0.5777




----------------------> 0.6583859649122805
Evaluated Accuracy: 0.6584
Testing b_a_r = 0.0517, c_a_r = 0.0004, f_a_r = 0.1053




----------------------> 0.6524912280701756
Evaluated Accuracy: 0.6525
Testing b_a_r = 0.9600, c_a_r = 0.0270, f_a_r = 0.4710




----------------------> 0.6607719298245615
Evaluated Accuracy: 0.6608
Testing b_a_r = 0.6982, c_a_r = 0.0004, f_a_r = 0.7781




----------------------> 0.6433684210526316
Evaluated Accuracy: 0.6434
Testing b_a_r = 0.9956, c_a_r = 0.0008, f_a_r = 0.5586




----------------------> 0.6644210526315789
Evaluated Accuracy: 0.6644
Testing b_a_r = 0.4204, c_a_r = 0.8587, f_a_r = 0.0003




----------------------> 0.2488421052631579
Evaluated Accuracy: 0.2488
Testing b_a_r = 0.0965, c_a_r = 0.1158, f_a_r = 0.2027




----------------------> 0.6202105263157895
Evaluated Accuracy: 0.6202
Testing b_a_r = 0.4385, c_a_r = 0.8801, f_a_r = 0.9956




----------------------> 0.24505263157894744
Evaluated Accuracy: 0.2451
Testing b_a_r = 0.9378, c_a_r = 0.0287, f_a_r = 0.6213




----------------------> 0.663438596491228
Evaluated Accuracy: 0.6634

✅ Optimal Values:
   - b_a_r: 0.9956
   - c_a_r: 0.0008
   - f_a_r: 0.5586
📈 Highest Accuracy Achieved: 0.6644


In [39]:
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_mapii/result_ressnet/base/5_w_1_s_f_20', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/secondary/5_w_1_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/amul/5_w_1_s_f", test_loader, M3, criterion)


# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
# evaluate3(pp+'/_mapi_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)



---Accu-----pre----rec---------> 0.6522 ± 0.0673  0.6838 ± 0.0587  0.6522 ± 0.0673




---Accu-----pre----rec---------> 0.6042 ± 0.0866  0.6366 ± 0.0943  0.6042 ± 0.0866
---Accu-----pre----rec---------> 0.6644 ± 0.0714  0.7026 ± 0.0643  0.6644 ± 0.0714


In [26]:
!mkdir /home/asufian/Desktop/output_mapii/result_ressnet/secondary

In [None]:
########################  5   way    5  shot    #########################

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/mapi_sce/ressecondary/model_5/model_B_mapi_5-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [41]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

           
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

            
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [42]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.6562637362637361
Evaluated Accuracy: 0.6563
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.5157509157509158
Evaluated Accuracy: 0.5158
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.5284981684981686
Evaluated Accuracy: 0.5285
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.7717216117216118
Evaluated Accuracy: 0.7717
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7601465201465202
Evaluated Accuracy: 0.7601
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.44336996336996337
Evaluated Accuracy: 0.4434
Testing b_a_r = 0.0805, c_a_r = 0.7144, f_a_r = 0.7016




----------------------> 0.47311355311355313
Evaluated Accuracy: 0.4731
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.7699633699633701
Evaluated Accuracy: 0.7700
Testing b_a_r = 0.0000, c_a_r = 0.0441, f_a_r = 0.0000




----------------------> 0.7598534798534798
Evaluated Accuracy: 0.7599
Testing b_a_r = 0.9995, c_a_r = 0.0000, f_a_r = 0.0823




----------------------> 0.7695238095238095
Evaluated Accuracy: 0.7695
Testing b_a_r = 0.1506, c_a_r = 0.0000, f_a_r = 0.9742




----------------------> 0.7550183150183152
Evaluated Accuracy: 0.7550
Testing b_a_r = 0.9253, c_a_r = 0.0472, f_a_r = 0.1644




----------------------> 0.7626373626373627
Evaluated Accuracy: 0.7626
Testing b_a_r = 1.0000, c_a_r = 0.0606, f_a_r = 1.0000




----------------------> 0.7418315018315019
Evaluated Accuracy: 0.7418
Testing b_a_r = 0.9675, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.7689377289377289
Evaluated Accuracy: 0.7689
Testing b_a_r = 0.5280, c_a_r = 0.0057, f_a_r = 0.0000




----------------------> 0.7696703296703298
Evaluated Accuracy: 0.7697
Testing b_a_r = 0.5546, c_a_r = 0.0008, f_a_r = 0.0389




----------------------> 0.7693772893772894
Evaluated Accuracy: 0.7694
Testing b_a_r = 1.0000, c_a_r = 0.0074, f_a_r = 0.6003




----------------------> 0.7868131868131869
Evaluated Accuracy: 0.7868
Testing b_a_r = 0.3300, c_a_r = 0.4728, f_a_r = 0.9978




----------------------> 0.5343589743589743
Evaluated Accuracy: 0.5344
Testing b_a_r = 0.0735, c_a_r = 0.0309, f_a_r = 0.4697




----------------------> 0.7733333333333334
Evaluated Accuracy: 0.7733
Testing b_a_r = 0.9041, c_a_r = 0.8139, f_a_r = 0.0020




----------------------> 0.3941391941391941
Evaluated Accuracy: 0.3941
Testing b_a_r = 0.6375, c_a_r = 0.9990, f_a_r = 0.9989




----------------------> 0.458021978021978
Evaluated Accuracy: 0.4580
Testing b_a_r = 0.7465, c_a_r = 0.0005, f_a_r = 0.7011




----------------------> 0.7915018315018315
Evaluated Accuracy: 0.7915
Testing b_a_r = 0.9609, c_a_r = 0.0080, f_a_r = 0.7996




----------------------> 0.7790476190476191
Evaluated Accuracy: 0.7790
Testing b_a_r = 0.0969, c_a_r = 0.0005, f_a_r = 0.5807




----------------------> 0.7844688644688645
Evaluated Accuracy: 0.7845
Testing b_a_r = 0.2821, c_a_r = 0.1492, f_a_r = 0.0063




----------------------> 0.6965567765567763
Evaluated Accuracy: 0.6966
Testing b_a_r = 0.9713, c_a_r = 0.0021, f_a_r = 0.4255




----------------------> 0.7771428571428571
Evaluated Accuracy: 0.7771
Testing b_a_r = 0.9981, c_a_r = 0.0980, f_a_r = 0.4525




----------------------> 0.7595604395604397
Evaluated Accuracy: 0.7596
Testing b_a_r = 0.1231, c_a_r = 0.6152, f_a_r = 0.0046




----------------------> 0.43443223443223444
Evaluated Accuracy: 0.4344
Testing b_a_r = 0.8269, c_a_r = 0.3131, f_a_r = 0.9945




----------------------> 0.5399267399267399
Evaluated Accuracy: 0.5399
Testing b_a_r = 0.7934, c_a_r = 0.8305, f_a_r = 0.9986




----------------------> 0.4704761904761906
Evaluated Accuracy: 0.4705
Testing b_a_r = 0.9937, c_a_r = 0.0357, f_a_r = 0.6109




----------------------> 0.77992673992674
Evaluated Accuracy: 0.7799
Testing b_a_r = 0.9908, c_a_r = 0.0003, f_a_r = 0.7087




----------------------> 0.7910622710622711
Evaluated Accuracy: 0.7911
Testing b_a_r = 0.9413, c_a_r = 0.6259, f_a_r = 0.9953




----------------------> 0.48586080586080593
Evaluated Accuracy: 0.4859
Testing b_a_r = 0.1012, c_a_r = 0.9968, f_a_r = 0.0043




----------------------> 0.37069597069597066
Evaluated Accuracy: 0.3707
Testing b_a_r = 0.0337, c_a_r = 0.0016, f_a_r = 0.6737




----------------------> 0.7912087912087913
Evaluated Accuracy: 0.7912
Testing b_a_r = 0.1563, c_a_r = 0.0001, f_a_r = 0.7175




----------------------> 0.7910622710622712
Evaluated Accuracy: 0.7911
Testing b_a_r = 0.1759, c_a_r = 0.0012, f_a_r = 0.2677




----------------------> 0.772893772893773
Evaluated Accuracy: 0.7729
Testing b_a_r = 0.0138, c_a_r = 0.1769, f_a_r = 0.3638




----------------------> 0.669157509157509
Evaluated Accuracy: 0.6692
Testing b_a_r = 0.0206, c_a_r = 0.0014, f_a_r = 0.7663




----------------------> 0.7831501831501833
Evaluated Accuracy: 0.7832
Testing b_a_r = 0.8903, c_a_r = 0.0012, f_a_r = 0.6566




----------------------> 0.7904761904761906
Evaluated Accuracy: 0.7905
Testing b_a_r = 0.9401, c_a_r = 0.5213, f_a_r = 0.5333




----------------------> 0.5006593406593407
Evaluated Accuracy: 0.5007
Testing b_a_r = 0.5842, c_a_r = 0.1392, f_a_r = 0.9979




----------------------> 0.681172161172161
Evaluated Accuracy: 0.6812
Testing b_a_r = 0.1036, c_a_r = 0.0003, f_a_r = 0.5777




----------------------> 0.7844688644688645
Evaluated Accuracy: 0.7845
Testing b_a_r = 0.0647, c_a_r = 0.1029, f_a_r = 0.5959




----------------------> 0.753846153846154
Evaluated Accuracy: 0.7538
Testing b_a_r = 0.0988, c_a_r = 0.0784, f_a_r = 0.3103




----------------------> 0.7558974358974359
Evaluated Accuracy: 0.7559
Testing b_a_r = 0.0352, c_a_r = 0.3601, f_a_r = 0.5816




----------------------> 0.6105494505494504
Evaluated Accuracy: 0.6105
Testing b_a_r = 0.9292, c_a_r = 0.8533, f_a_r = 0.5188




----------------------> 0.4603663003663004
Evaluated Accuracy: 0.4604
Testing b_a_r = 0.9798, c_a_r = 0.0304, f_a_r = 0.3037




----------------------> 0.7702564102564103
Evaluated Accuracy: 0.7703
Testing b_a_r = 0.0607, c_a_r = 0.0229, f_a_r = 0.7064




----------------------> 0.7830036630036631
Evaluated Accuracy: 0.7830
Testing b_a_r = 0.9222, c_a_r = 0.1076, f_a_r = 0.1195




----------------------> 0.7409523809523809
Evaluated Accuracy: 0.7410

✅ Optimal Values:
   - b_a_r: 0.7465
   - c_a_r: 0.0005
   - f_a_r: 0.7011
📈 Highest Accuracy Achieved: 0.7915


In [38]:
# b_a_r: 
#    - c_a_r: 
#    - f_a_r: 

#  pt=20

In [43]:

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_mapii/result_ressnet/base/5_w_5_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/secondary/5_w_5_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/amul/5_w_5_s_f", test_loader, M3, criterion)




# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
# evaluate3(pp+'/_mapi_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)



---Accu-----pre----rec---------> 0.7700 ± 0.0453  0.7766 ± 0.0437  0.7700 ± 0.0453




---Accu-----pre----rec---------> 0.5626 ± 0.0760  0.5757 ± 0.0748  0.5626 ± 0.0760
---Accu-----pre----rec---------> 0.7915 ± 0.0375  0.7980 ± 0.0378  0.7915 ± 0.0375


In [40]:
!ls /home/asufian/Desktop/output_olchiki/code/mapi/model_10

ls: cannot access '/home/asufian/Desktop/output_olchiki/code/mapi/model_10': No such file or directory


In [44]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/mapi_sce/ressecondary/model_10/model_B_mapi_10-shot_res.pth',map_location=torch.device('cpu')))
torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [45]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

           
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

            
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [46]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.7204651162790697
Evaluated Accuracy: 0.7205
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.36635658914728686
Evaluated Accuracy: 0.3664
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.43472868217054267
Evaluated Accuracy: 0.4347
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.8009302325581393
Evaluated Accuracy: 0.8009
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7550387596899225
Evaluated Accuracy: 0.7550
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.31720930232558137
Evaluated Accuracy: 0.3172
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.6826




----------------------> 0.8026356589147287
Evaluated Accuracy: 0.8026
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.7562




----------------------> 0.7981395348837209
Evaluated Accuracy: 0.7981
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.7882170542635658
Evaluated Accuracy: 0.7882
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.4014




----------------------> 0.797674418604651
Evaluated Accuracy: 0.7977
Testing b_a_r = 1.0000, c_a_r = 0.0505, f_a_r = 0.5306




----------------------> 0.8091472868217053
Evaluated Accuracy: 0.8091
Testing b_a_r = 0.7991, c_a_r = 0.0787, f_a_r = 0.0000




----------------------> 0.7869767441860466
Evaluated Accuracy: 0.7870
Testing b_a_r = 0.8057, c_a_r = 0.6538, f_a_r = 1.0000




----------------------> 0.30000000000000004
Evaluated Accuracy: 0.3000
Testing b_a_r = 0.0000, c_a_r = 0.0984, f_a_r = 0.4132




----------------------> 0.7951937984496124
Evaluated Accuracy: 0.7952
Testing b_a_r = 0.0437, c_a_r = 0.0477, f_a_r = 0.2352




----------------------> 0.7965891472868217
Evaluated Accuracy: 0.7966
Testing b_a_r = 0.5582, c_a_r = 0.7756, f_a_r = 0.0030




----------------------> 0.3046511627906977
Evaluated Accuracy: 0.3047
Testing b_a_r = 0.1695, c_a_r = 0.3486, f_a_r = 0.9975




----------------------> 0.43968992248062017
Evaluated Accuracy: 0.4397
Testing b_a_r = 0.4788, c_a_r = 0.1084, f_a_r = 0.9976




----------------------> 0.7615503875968991
Evaluated Accuracy: 0.7616
Testing b_a_r = 0.0492, c_a_r = 0.0541, f_a_r = 0.0905




----------------------> 0.7924031007751936
Evaluated Accuracy: 0.7924
Testing b_a_r = 0.9728, c_a_r = 0.0906, f_a_r = 0.6195




----------------------> 0.8086821705426356
Evaluated Accuracy: 0.8087
Testing b_a_r = 0.2467, c_a_r = 0.9126, f_a_r = 0.9986




----------------------> 0.3103875968992248
Evaluated Accuracy: 0.3104
Testing b_a_r = 0.1309, c_a_r = 0.1568, f_a_r = 0.0053




----------------------> 0.751937984496124
Evaluated Accuracy: 0.7519
Testing b_a_r = 0.0292, c_a_r = 0.0437, f_a_r = 0.5255




----------------------> 0.8120930232558141
Evaluated Accuracy: 0.8121
Testing b_a_r = 0.0102, c_a_r = 0.0544, f_a_r = 0.5954




----------------------> 0.8136434108527132
Evaluated Accuracy: 0.8136
Testing b_a_r = 0.0145, c_a_r = 0.0941, f_a_r = 0.6418




----------------------> 0.8035658914728682
Evaluated Accuracy: 0.8036
Testing b_a_r = 0.0069, c_a_r = 0.0592, f_a_r = 0.0192




----------------------> 0.7920930232558139
Evaluated Accuracy: 0.7921
Testing b_a_r = 0.9458, c_a_r = 0.0228, f_a_r = 0.5856




----------------------> 0.8099224806201551
Evaluated Accuracy: 0.8099
Testing b_a_r = 0.9730, c_a_r = 0.0584, f_a_r = 0.4665




----------------------> 0.804186046511628
Evaluated Accuracy: 0.8042
Testing b_a_r = 0.0595, c_a_r = 0.0048, f_a_r = 0.5402




----------------------> 0.8088372093023256
Evaluated Accuracy: 0.8088
Testing b_a_r = 0.9211, c_a_r = 0.1522, f_a_r = 0.3751




----------------------> 0.7629457364341087
Evaluated Accuracy: 0.7629
Testing b_a_r = 0.1075, c_a_r = 0.0390, f_a_r = 0.6703




----------------------> 0.8083720930232557
Evaluated Accuracy: 0.8084
Testing b_a_r = 0.0055, c_a_r = 0.0055, f_a_r = 0.2272




----------------------> 0.7953488372093023
Evaluated Accuracy: 0.7953
Testing b_a_r = 0.0042, c_a_r = 0.0274, f_a_r = 0.5839




----------------------> 0.8124031007751937
Evaluated Accuracy: 0.8124
Testing b_a_r = 0.1012, c_a_r = 0.9968, f_a_r = 0.0043




----------------------> 0.3116279069767442
Evaluated Accuracy: 0.3116
Testing b_a_r = 0.0147, c_a_r = 0.1213, f_a_r = 0.7251




----------------------> 0.7844961240310078
Evaluated Accuracy: 0.7845
Testing b_a_r = 0.0320, c_a_r = 0.1090, f_a_r = 0.1493




----------------------> 0.7787596899224807
Evaluated Accuracy: 0.7788
Testing b_a_r = 0.6943, c_a_r = 0.6034, f_a_r = 0.0024




----------------------> 0.3238759689922481
Evaluated Accuracy: 0.3239
Testing b_a_r = 0.0551, c_a_r = 0.0320, f_a_r = 0.4757




----------------------> 0.809922480620155
Evaluated Accuracy: 0.8099
Testing b_a_r = 0.0136, c_a_r = 0.0634, f_a_r = 0.5927




----------------------> 0.8113178294573644
Evaluated Accuracy: 0.8113
Testing b_a_r = 0.8194, c_a_r = 0.1810, f_a_r = 0.9991




----------------------> 0.6951937984496123
Evaluated Accuracy: 0.6952
Testing b_a_r = 0.0222, c_a_r = 0.0467, f_a_r = 0.8442




----------------------> 0.7725581395348837
Evaluated Accuracy: 0.7726
Testing b_a_r = 0.5051, c_a_r = 0.4989, f_a_r = 0.9991




----------------------> 0.3043410852713178
Evaluated Accuracy: 0.3043
Testing b_a_r = 0.9908, c_a_r = 0.1195, f_a_r = 0.5469




----------------------> 0.7924031007751938
Evaluated Accuracy: 0.7924
Testing b_a_r = 0.9865, c_a_r = 0.1007, f_a_r = 0.7599




----------------------> 0.787751937984496
Evaluated Accuracy: 0.7878
Testing b_a_r = 0.0365, c_a_r = 0.7787, f_a_r = 0.5808




----------------------> 0.30542635658914724
Evaluated Accuracy: 0.3054
Testing b_a_r = 0.9956, c_a_r = 0.0008, f_a_r = 0.5586




----------------------> 0.8096124031007751
Evaluated Accuracy: 0.8096
Testing b_a_r = 0.9904, c_a_r = 0.0133, f_a_r = 0.1413




----------------------> 0.7903875968992248
Evaluated Accuracy: 0.7904
Testing b_a_r = 0.9632, c_a_r = 0.0288, f_a_r = 0.6375




----------------------> 0.8114728682170541
Evaluated Accuracy: 0.8115
Testing b_a_r = 0.0177, c_a_r = 0.2131, f_a_r = 0.5109




----------------------> 0.7334883720930232
Evaluated Accuracy: 0.7335
Testing b_a_r = 0.0317, c_a_r = 0.5454, f_a_r = 0.5090




----------------------> 0.3049612403100775
Evaluated Accuracy: 0.3050

✅ Optimal Values:
   - b_a_r: 0.0102
   - c_a_r: 0.0544
   - f_a_r: 0.5954
📈 Highest Accuracy Achieved: 0.8136


In [47]:
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_mapii/result_ressnet/base/5_w_10_s_ff', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/secondary/5_w_10_s_ff", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/amul/5_w_5_10_ff", test_loader, M3, criterion)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
# evaluate3(pp+'/_mapi_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)



---Accu-----pre----rec---------> 0.7930 ± 0.0351  0.7971 ± 0.0334  0.7930 ± 0.0351




---Accu-----pre----rec---------> 0.5476 ± 0.0911  0.5502 ± 0.1025  0.5476 ± 0.0911
---Accu-----pre----rec---------> 0.8136 ± 0.0281  0.8183 ± 0.0288  0.8136 ± 0.0281


In [88]:
ls /home/asufian/Desktop/output_olchikilol

ls: cannot access '/home/asufian/Desktop/output_olchikilol': No such file or directory


In [48]:
clss_8=[
    ['MAPI Mayeek_FAM', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_LAI', 'MAPI Mayeek_HUK', 'MAPI Mayeek_KOK', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_THOU', 'MAPI Mayeek_JIL'],
     ['MAPI Mayeek_SAM', 'MAPI Mayeek_HUK', 'MAPI Mayeek_KHOU', 'MAPI Mayeek_NAA', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_THOU', 'MAPI Mayeek_PAA', 'MAPI Mayeek_JHAM'],
    ['MAPI Mayeek_JHAM', 'MAPI Mayeek_NAA', 'MAPI Mayeek_GOK', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_KHOU', 'MAPI Mayeek_RAAI', 'MAPI Mayeek_DIL'],
    ['MAPI Mayeek_CHIN', 'MAPI Mayeek_UUN', 'MAPI Mayeek_YANG', 'MAPI Mayeek_KOK', 'MAPI Mayeek_MIT', 'MAPI Mayeek_LAI', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_NAA'],
    ['MAPI Mayeek_DHOU', 'MAPI Mayeek_PAA', 'MAPI Mayeek_MIT', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_EEE', 'MAPI Mayeek_LAI', 'MAPI Mayeek_UUN', 'MAPI Mayeek_SAM'],
    ['MAPI Mayeek_NAA', 'MAPI Mayeek_UUN', 'MAPI Mayeek_FAM', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_BAA', 'MAPI Mayeek_KOK', 'MAPI Mayeek_RAAI', 'MAPI Mayeek_LAI'],
    ['MAPI Mayeek_CHIN', 'MAPI Mayeek_KOK', 'MAPI Mayeek_HUK', 'MAPI Mayeek_PAA', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_LAI', 'MAPI Mayeek_FAM'],
    ['MAPI Mayeek_NGOU', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_DIL', 'MAPI Mayeek_KHOU', 'MAPI Mayeek_SAM', 'MAPI Mayeek_UUN', 'MAPI Mayeek_PAA', 'MAPI Mayeek_KOK'],
    ['MAPI Mayeek_GOK', 'MAPI Mayeek_SAM', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_CHIN', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_NAA', 'MAPI Mayeek_HUK', 'MAPI Mayeek_THOU'],
    ['MAPI Mayeek_BHAM', 'MAPI Mayeek_SAM', 'MAPI Mayeek_PAA', 'MAPI Mayeek_UUN', 'MAPI Mayeek_FAM', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_RAAI'],
     ['MAPI Mayeek_UUN', 'MAPI Mayeek_FAM', 'MAPI Mayeek_JIL', 'MAPI Mayeek_TIL', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_MIT', 'MAPI Mayeek_GOK', 'MAPI Mayeek_JHAM'],
    ['MAPI Mayeek_KOK', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_UUN', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_MIT', 'MAPI Mayeek_YANG', 'MAPI Mayeek_LAI'],
    ['MAPI Mayeek_FAM', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_LAI', 'MAPI Mayeek_KOK', 'MAPI Mayeek_CHIN', 'MAPI Mayeek_WAI', 'MAPI Mayeek_GOK'],
    ['MAPI Mayeek_UUN', 'MAPI Mayeek_YANG', 'MAPI Mayeek_FAM', 'MAPI Mayeek_NAA', 'MAPI Mayeek_KOK', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_HUK', 'MAPI Mayeek_WAI'],
    ['MAPI Mayeek_JIL', 'MAPI Mayeek_CHIN', 'MAPI Mayeek_TIL', 'MAPI Mayeek_LAI', 'MAPI Mayeek_PAA', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_MIT', 'MAPI Mayeek_YANG'],
 ]

clss_support_imagesss_10_shot=[[[32, 30, 44, 38, 25, 67, 9, 24, 47, 41], [82, 83, 48, 43, 90, 45, 62, 87, 5, 25], [7, 95, 9, 14, 63, 82, 64, 93, 61, 55], [72, 82, 69, 26, 6, 9, 39, 19, 30, 13], [25, 67, 41, 47, 95, 38, 76, 58, 87, 90], [44, 57, 84, 72, 58, 64, 93, 31, 4, 73], [7, 68, 51, 47, 63, 36, 4, 50, 94, 37], [67, 17, 31, 62, 12, 34, 81, 15, 29, 71]], [[50, 94, 95, 27, 35, 80, 62, 55, 82, 88], [74, 33, 78, 25, 11, 79, 52, 17, 12, 15], [33, 56, 67, 11, 31, 74, 88, 3, 24, 78], [71, 37, 27, 11, 65, 74, 3, 86, 80, 78], [67, 29, 87, 60, 71, 43, 40, 74, 61, 37], [16, 83, 67, 4, 50, 72, 77, 89, 69, 15], [20, 87, 71, 53, 90, 41, 63, 18, 44, 1], [14, 77, 74, 42, 27, 68, 35, 93, 81, 15]], [[30, 48, 0, 84, 3, 13, 36, 86, 38, 2], [93, 69, 13, 90, 78, 74, 20, 73, 33, 65], [6, 20, 80, 56, 33, 48, 75, 47, 18, 72], [20, 50, 8, 76, 15, 24, 80, 73, 95, 52], [54, 46, 37, 77, 38, 43, 27, 25, 3, 33], [50, 72, 23, 28, 85, 56, 17, 14, 70, 34], [8, 45, 72, 36, 93, 89, 40, 84, 54, 17], [71, 91, 30, 49, 66, 35, 84, 5, 63, 48]], [[4, 73, 59, 63, 39, 94, 24, 22, 64, 51], [14, 19, 31, 63, 9, 77, 43, 80, 33, 11], [40, 94, 59, 17, 57, 47, 49, 21, 85, 12], [29, 56, 59, 71, 78, 32, 34, 66, 83, 60], [19, 0, 83, 93, 15, 51, 26, 16, 35, 55], [51, 19, 38, 46, 5, 65, 80, 44, 71, 25], [40, 32, 95, 4, 81, 46, 24, 27, 68, 47], [47, 28, 38, 60, 59, 19, 18, 77, 89, 0]], [[53, 40, 2, 54, 68, 81, 23, 86, 94, 83], [23, 49, 47, 43, 7, 84, 62, 86, 6, 35], [12, 35, 13, 79, 86, 75, 58, 32, 88, 31], [2, 42, 85, 0, 7, 27, 93, 73, 25, 18], [78, 72, 46, 12, 1, 70, 18, 62, 14, 39], [9, 88, 77, 7, 84, 52, 32, 94, 53, 42], [56, 89, 26, 69, 17, 59, 34, 91, 58, 11], [73, 11, 22, 27, 48, 69, 92, 7, 26, 13]], [[62, 40, 87, 85, 93, 37, 52, 84, 30, 76], [55, 83, 60, 17, 19, 73, 36, 13, 77, 80], [76, 79, 90, 65, 60, 56, 73, 95, 47, 67], [65, 91, 34, 14, 10, 64, 93, 73, 2, 60], [44, 45, 50, 35, 13, 30, 6, 71, 7, 95], [82, 44, 14, 77, 19, 50, 10, 28, 29, 58], [18, 90, 85, 7, 16, 50, 6, 59, 93, 75], [74, 76, 91, 47, 63, 88, 8, 14, 64, 95]], [[19, 26, 77, 36, 52, 84, 8, 70, 90, 37], [23, 34, 84, 62, 4, 2, 68, 77, 1, 40], [75, 31, 91, 54, 51, 72, 73, 23, 3, 70], [58, 69, 74, 50, 0, 52, 75, 51, 94, 60], [45, 70, 35, 93, 4, 43, 67, 17, 1, 86], [67, 53, 77, 95, 1, 28, 33, 47, 38, 78], [7, 66, 33, 67, 85, 59, 5, 93, 41, 29], [19, 17, 57, 70, 5, 75, 28, 47, 32, 31]], [[66, 32, 35, 20, 34, 42, 65, 33, 38, 29], [49, 52, 60, 28, 16, 89, 7, 67, 94, 21], [73, 25, 19, 9, 30, 81, 7, 21, 14, 82], [50, 80, 95, 71, 1, 49, 90, 28, 57, 52], [20, 60, 10, 47, 5, 44, 46, 21, 83, 17], [45, 73, 41, 89, 90, 50, 10, 91, 81, 64], [81, 6, 71, 14, 8, 33, 25, 55, 27, 72], [73, 57, 22, 65, 88, 68, 0, 23, 46, 42]], [[17, 23, 52, 46, 37, 48, 3, 72, 36, 61], [14, 62, 46, 37, 52, 33, 9, 42, 20, 65], [7, 54, 0, 42, 14, 57, 25, 10, 65, 23], [23, 81, 61, 86, 58, 52, 90, 15, 35, 64], [19, 52, 74, 93, 11, 3, 84, 14, 92, 27], [45, 56, 26, 15, 0, 11, 75, 51, 14, 24], [49, 47, 53, 66, 11, 5, 64, 59, 77, 50], [49, 52, 21, 37, 35, 50, 23, 15, 77, 63]], [[52, 44, 70, 14, 4, 48, 76, 63, 1, 11], [45, 21, 82, 19, 93, 75, 73, 4, 43, 91], [19, 68, 60, 84, 21, 79, 44, 65, 20, 1], [73, 49, 67, 71, 14, 40, 3, 87, 9, 8], [81, 20, 77, 15, 47, 60, 71, 5, 92, 7], [83, 87, 44, 73, 85, 60, 56, 28, 3, 33], [11, 25, 67, 93, 27, 76, 70, 61, 15, 72], [36, 49, 80, 35, 21, 22, 32, 12, 69, 71]], [[89, 16, 7, 23, 52, 38, 8, 47, 42, 18], [36, 64, 9, 79, 38, 26, 94, 66, 58, 12], [6, 40, 78, 44, 29, 43, 28, 34, 60, 24], [62, 84, 32, 7, 11, 94, 39, 91, 3, 12], [16, 63, 55, 32, 70, 87, 18, 85, 4, 86], [7, 81, 91, 44, 93, 73, 18, 38, 53, 12], [44, 60, 32, 50, 0, 46, 74, 35, 79, 83], [83, 82, 88, 23, 19, 33, 72, 41, 45, 13]], [[44, 58, 75, 52, 74, 21, 16, 23, 47, 6], [78, 77, 14, 82, 46, 40, 61, 24, 42, 81], [45, 58, 27, 60, 80, 23, 63, 34, 52, 92], [71, 16, 79, 60, 62, 11, 53, 9, 40, 31], [40, 12, 25, 26, 31, 42, 78, 1, 9, 89], [23, 51, 15, 24, 56, 39, 0, 35, 67, 4], [19, 1, 86, 59, 31, 74, 47, 20, 4, 5], [83, 53, 25, 3, 1, 81, 41, 63, 13, 88]], [[33, 75, 39, 29, 74, 58, 43, 82, 26, 73], [45, 54, 10, 75, 43, 56, 7, 68, 63, 2], [26, 65, 64, 22, 68, 80, 27, 48, 77, 9], [26, 32, 34, 46, 80, 69, 54, 31, 42, 76], [1, 41, 78, 3, 31, 88, 45, 79, 80, 49], [19, 35, 81, 51, 93, 28, 48, 63, 80, 30], [86, 36, 43, 63, 1, 60, 33, 74, 81, 57], [42, 49, 24, 46, 90, 66, 45, 87, 59, 55]], [[4, 67, 64, 24, 56, 73, 55, 43, 54, 92], [30, 2, 73, 4, 1, 34, 9, 11, 95, 29], [5, 10, 88, 42, 86, 67, 80, 25, 53, 38], [59, 66, 52, 77, 73, 35, 89, 26, 45, 72], [56, 83, 24, 73, 28, 61, 50, 93, 64, 13], [82, 73, 11, 45, 31, 41, 80, 58, 1, 5], [17, 91, 25, 2, 42, 56, 8, 40, 15, 33], [57, 84, 95, 4, 64, 83, 26, 66, 15, 30]], [[84, 66, 88, 58, 71, 15, 45, 50, 38, 44], [17, 29, 20, 91, 79, 32, 44, 66, 70, 6], [12, 16, 26, 65, 51, 67, 48, 79, 29, 95], [78, 88, 55, 27, 84, 79, 24, 54, 70, 76], [61, 81, 6, 51, 64, 40, 18, 69, 83, 16], [17, 35, 51, 42, 55, 86, 36, 24, 50, 46], [23, 73, 12, 33, 89, 26, 63, 59, 44, 65], [34, 24, 22, 64, 27, 88, 46, 70, 44, 30]]]

clss_support_imagesss_5_shot=[[[32, 30, 44, 38, 25], [82, 83, 48, 43, 90], [7, 95, 9, 14, 63], [72, 82, 69, 26, 6], [25, 67, 41, 47, 95], [44, 57, 84, 72, 58], [7, 68, 51, 47, 63], [67, 17, 31, 62, 12]], [[50, 94, 95, 27, 35], [74, 33, 78, 25, 11], [33, 56, 67, 11, 31], [71, 37, 27, 11, 65], [67, 29, 87, 60, 71], [16, 83, 67, 4, 50], [20, 87, 71, 53, 90], [14, 77, 74, 42, 27]], [[30, 48, 0, 84, 3], [93, 69, 13, 90, 78], [6, 20, 80, 56, 33], [20, 50, 8, 76, 15], [54, 46, 37, 77, 38], [50, 72, 23, 28, 85], [8, 45, 72, 36, 93], [71, 91, 30, 49, 66]], [[4, 73, 59, 63, 39], [14, 19, 31, 63, 9], [40, 94, 59, 17, 57], [29, 56, 59, 71, 78], [19, 0, 83, 93, 15], [51, 19, 38, 46, 5], [40, 32, 95, 4, 81], [47, 28, 38, 60, 59]], [[53, 40, 2, 54, 68], [23, 49, 47, 43, 7], [12, 35, 13, 79, 86], [2, 42, 85, 0, 7], [78, 72, 46, 12, 1], [9, 88, 77, 7, 84], [56, 89, 26, 69, 17], [73, 11, 22, 27, 48]], [[62, 40, 87, 85, 93], [55, 83, 60, 17, 19], [76, 79, 90, 65, 60], [65, 91, 34, 14, 10], [44, 45, 50, 35, 13], [82, 44, 14, 77, 19], [18, 90, 85, 7, 16], [74, 76, 91, 47, 63]], [[19, 26, 77, 36, 52], [23, 34, 84, 62, 4], [75, 31, 91, 54, 51], [58, 69, 74, 50, 0], [45, 70, 35, 93, 4], [67, 53, 77, 95, 1], [7, 66, 33, 67, 85], [19, 17, 57, 70, 5]], [[66, 32, 35, 20, 34], [49, 52, 60, 28, 16], [73, 25, 19, 9, 30], [50, 80, 95, 71, 1], [20, 60, 10, 47, 5], [45, 73, 41, 89, 90], [81, 6, 71, 14, 8], [73, 57, 22, 65, 88]], [[17, 23, 52, 46, 37], [14, 62, 46, 37, 52], [7, 54, 0, 42, 14], [23, 81, 61, 86, 58], [19, 52, 74, 93, 11], [45, 56, 26, 15, 0], [49, 47, 53, 66, 11], [49, 52, 21, 37, 35]], [[52, 44, 70, 14, 4], [45, 21, 82, 19, 93], [19, 68, 60, 84, 21], [73, 49, 67, 71, 14], [81, 20, 77, 15, 47], [83, 87, 44, 73, 85], [11, 25, 67, 93, 27], [36, 49, 80, 35, 21]], [[89, 16, 7, 23, 52], [36, 64, 9, 79, 38], [6, 40, 78, 44, 29], [62, 84, 32, 7, 11], [16, 63, 55, 32, 70], [7, 81, 91, 44, 93], [44, 60, 32, 50, 0], [83, 82, 88, 23, 19]], [[44, 58, 75, 52, 74], [78, 77, 14, 82, 46], [45, 58, 27, 60, 80], [71, 16, 79, 60, 62], [40, 12, 25, 26, 31], [23, 51, 15, 24, 56], [19, 1, 86, 59, 31], [83, 53, 25, 3, 1]], [[33, 75, 39, 29, 74], [45, 54, 10, 75, 43], [26, 65, 64, 22, 68], [26, 32, 34, 46, 80], [1, 41, 78, 3, 31], [19, 35, 81, 51, 93], [86, 36, 43, 63, 1], [42, 49, 24, 46, 90]], [[4, 67, 64, 24, 56], [30, 2, 73, 4, 1], [5, 10, 88, 42, 86], [59, 66, 52, 77, 73], [56, 83, 24, 73, 28], [82, 73, 11, 45, 31], [17, 91, 25, 2, 42], [57, 84, 95, 4, 64]], [[84, 66, 88, 58, 71], [17, 29, 20, 91, 79], [12, 16, 26, 65, 51], [78, 88, 55, 27, 84], [61, 81, 6, 51, 64], [17, 35, 51, 42, 55], [23, 73, 12, 33, 89], [34, 24, 22, 64, 27]]]


clss_support_imagesss_1_shot=[
    [[32], [82], [7], [72], [25], [44], [7], [67]],
    [[50], [74], [33], [71], [67], [16], [20], [14]],
    [[30], [93], [6], [20], [54], [50], [8], [71]],
     [[4], [14], [40], [29], [19], [51], [40], [47]],
    [[53], [23], [12], [2], [78], [9], [56], [73]],
    [[62], [55], [76], [65], [44], [82], [18], [74]],
    [[19], [23], [75], [58], [45], [67], [7], [19]],
     [[66], [49], [73], [50], [20], [45], [81], [73]],
    [[17], [14], [7], [23], [19], [45], [49], [49]],
    [[52], [45], [19], [73], [81], [83], [11], [36]],
    [[89], [36], [6], [62], [16], [7], [44], [83]],
    [[44], [78], [45], [71], [40], [23], [19], [83]],
    [[33], [45], [26], [26], [1], [19], [86], [42]],
    [[4], [30], [5], [59], [56], [82], [17], [57]],
    [[84], [17], [12], [78], [61], [17], [23], [34]]

]



In [52]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/mapi_sce/ressecondary/model_1/model_B_mapi_1-shot_res.pth',map_location=torch.device('cpu')))
torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')

rm: cannot remove 'model_mu_path.pt': No such file or directory
rm: cannot remove 'model_own_path.pt': No such file or directory


In [53]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

           
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [54]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.4729824561403509
Evaluated Accuracy: 0.4730
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.26929824561403515
Evaluated Accuracy: 0.2693
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.34105263157894744
Evaluated Accuracy: 0.3411
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.5535964912280702
Evaluated Accuracy: 0.5536
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.5080701754385966
Evaluated Accuracy: 0.5081
Testing b_a_r = 0.0000, c_a_r = 0.0035, f_a_r = 0.3872




----------------------> 0.5435964912280702
Evaluated Accuracy: 0.5436
Testing b_a_r = 0.2211, c_a_r = 1.0000, f_a_r = 1.0000




----------------------> 0.14850877192982453
Evaluated Accuracy: 0.1485
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.5348245614035088
Evaluated Accuracy: 0.5348
Testing b_a_r = 0.0000, c_a_r = 0.0729, f_a_r = 1.0000




----------------------> 0.5130701754385965
Evaluated Accuracy: 0.5131
Testing b_a_r = 0.9995, c_a_r = 0.0000, f_a_r = 0.0823




----------------------> 0.5421929824561403
Evaluated Accuracy: 0.5422
Testing b_a_r = 0.1506, c_a_r = 0.0000, f_a_r = 0.9742




----------------------> 0.5081578947368421
Evaluated Accuracy: 0.5082
Testing b_a_r = 1.0000, c_a_r = 0.0801, f_a_r = 0.2542




----------------------> 0.5325438596491229
Evaluated Accuracy: 0.5325
Testing b_a_r = 0.7691, c_a_r = 1.0000, f_a_r = 0.0000




----------------------> 0.1500877192982456
Evaluated Accuracy: 0.1501
Testing b_a_r = 0.0168, c_a_r = 0.0622, f_a_r = 0.0000




----------------------> 0.5322807017543859
Evaluated Accuracy: 0.5323
Testing b_a_r = 0.8357, c_a_r = 0.0461, f_a_r = 0.0021




----------------------> 0.5385087719298246
Evaluated Accuracy: 0.5385
Testing b_a_r = 0.9799, c_a_r = 0.0697, f_a_r = 0.0098




----------------------> 0.533421052631579
Evaluated Accuracy: 0.5334
Testing b_a_r = 1.0000, c_a_r = 0.0220, f_a_r = 0.6264




----------------------> 0.5605263157894737
Evaluated Accuracy: 0.5605
Testing b_a_r = 0.6334, c_a_r = 0.5620, f_a_r = 0.9990




----------------------> 0.306578947368421
Evaluated Accuracy: 0.3066
Testing b_a_r = 0.9940, c_a_r = 0.0447, f_a_r = 0.5401




----------------------> 0.5588596491228071
Evaluated Accuracy: 0.5589
Testing b_a_r = 0.9747, c_a_r = 0.3522, f_a_r = 0.9979




----------------------> 0.30587719298245614
Evaluated Accuracy: 0.3059
Testing b_a_r = 0.3263, c_a_r = 0.7277, f_a_r = 0.0000




----------------------> 0.27289473684210525
Evaluated Accuracy: 0.2729
Testing b_a_r = 0.0687, c_a_r = 0.7619, f_a_r = 0.9980




----------------------> 0.22105263157894733
Evaluated Accuracy: 0.2211
Testing b_a_r = 0.1340, c_a_r = 0.1938, f_a_r = 0.0020




----------------------> 0.44631578947368417
Evaluated Accuracy: 0.4463
Testing b_a_r = 0.9853, c_a_r = 0.0983, f_a_r = 0.5710




----------------------> 0.5464035087719298
Evaluated Accuracy: 0.5464
Testing b_a_r = 0.9935, c_a_r = 0.0026, f_a_r = 0.7491




----------------------> 0.5492105263157896
Evaluated Accuracy: 0.5492
Testing b_a_r = 0.9890, c_a_r = 0.6336, f_a_r = 0.5008




----------------------> 0.30289473684210527
Evaluated Accuracy: 0.3029
Testing b_a_r = 0.8838, c_a_r = 0.0005, f_a_r = 0.5554




----------------------> 0.5569298245614036
Evaluated Accuracy: 0.5569
Testing b_a_r = 0.0370, c_a_r = 0.0475, f_a_r = 0.5776




----------------------> 0.5571929824561404
Evaluated Accuracy: 0.5572
Testing b_a_r = 0.9752, c_a_r = 0.1871, f_a_r = 0.4223




----------------------> 0.4674561403508772
Evaluated Accuracy: 0.4675
Testing b_a_r = 0.9583, c_a_r = 0.8513, f_a_r = 0.4279




----------------------> 0.18570175438596495
Evaluated Accuracy: 0.1857
Testing b_a_r = 0.9530, c_a_r = 0.1640, f_a_r = 0.9984




----------------------> 0.46473684210526317
Evaluated Accuracy: 0.4647
Testing b_a_r = 0.1699, c_a_r = 0.4375, f_a_r = 0.5832




----------------------> 0.297280701754386
Evaluated Accuracy: 0.2973
Testing b_a_r = 0.9402, c_a_r = 0.0992, f_a_r = 0.7832




----------------------> 0.5351754385964912
Evaluated Accuracy: 0.5352
Testing b_a_r = 0.9612, c_a_r = 0.0419, f_a_r = 0.4026




----------------------> 0.5498245614035088
Evaluated Accuracy: 0.5498
Testing b_a_r = 0.7127, c_a_r = 0.5840, f_a_r = 0.0001




----------------------> 0.33903508771929824
Evaluated Accuracy: 0.3390
Testing b_a_r = 0.0339, c_a_r = 0.2887, f_a_r = 0.5786




----------------------> 0.3950877192982455
Evaluated Accuracy: 0.3951
Testing b_a_r = 0.0519, c_a_r = 0.1167, f_a_r = 0.4193




----------------------> 0.5209649122807017
Evaluated Accuracy: 0.5210
Testing b_a_r = 0.0514, c_a_r = 0.0302, f_a_r = 0.1895




----------------------> 0.5376315789473683
Evaluated Accuracy: 0.5376
Testing b_a_r = 0.5424, c_a_r = 0.8691, f_a_r = 0.0003




----------------------> 0.18464912280701753
Evaluated Accuracy: 0.1846
Testing b_a_r = 0.8771, c_a_r = 0.0427, f_a_r = 0.8479




----------------------> 0.5285087719298246
Evaluated Accuracy: 0.5285
Testing b_a_r = 0.0155, c_a_r = 0.0047, f_a_r = 0.5885




----------------------> 0.5560526315789474
Evaluated Accuracy: 0.5561
Testing b_a_r = 0.6121, c_a_r = 0.9993, f_a_r = 0.5275




----------------------> 0.14833333333333334
Evaluated Accuracy: 0.1483
Testing b_a_r = 0.9300, c_a_r = 0.0656, f_a_r = 0.6453




----------------------> 0.5585964912280702
Evaluated Accuracy: 0.5586
Testing b_a_r = 0.9153, c_a_r = 0.1425, f_a_r = 0.1657




----------------------> 0.49236842105263157
Evaluated Accuracy: 0.4924
Testing b_a_r = 0.7933, c_a_r = 0.0024, f_a_r = 0.2267




----------------------> 0.5428070175438596
Evaluated Accuracy: 0.5428
Testing b_a_r = 0.9484, c_a_r = 0.0015, f_a_r = 0.6404




----------------------> 0.5606140350877192
Evaluated Accuracy: 0.5606
Testing b_a_r = 0.9070, c_a_r = 0.0744, f_a_r = 0.4981




----------------------> 0.5508771929824562
Evaluated Accuracy: 0.5509
Testing b_a_r = 0.0091, c_a_r = 0.0877, f_a_r = 0.6657




----------------------> 0.5469298245614035
Evaluated Accuracy: 0.5469
Testing b_a_r = 0.8263, c_a_r = 0.0343, f_a_r = 0.6530




----------------------> 0.5603508771929825
Evaluated Accuracy: 0.5604
Testing b_a_r = 0.0239, c_a_r = 0.5491, f_a_r = 0.3000




----------------------> 0.34438596491228074
Evaluated Accuracy: 0.3444

✅ Optimal Values:
   - b_a_r: 0.9484
   - c_a_r: 0.0015
   - f_a_r: 0.6404
📈 Highest Accuracy Achieved: 0.5606


In [55]:
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_mapii/result_ressnet/base/8_w_1_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/secondary/8_w_1_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/amul/8_w_1_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.5348 ± 0.0390  0.5640 ± 0.0413  0.5348 ± 0.0390




---Accu-----pre----rec---------> 0.4782 ± 0.0603  0.5244 ± 0.0671  0.4782 ± 0.0603
---Accu-----pre----rec---------> 0.5606 ± 0.0388  0.5865 ± 0.0467  0.5606 ± 0.0388


In [56]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/mapi_sce/ressecondary/model_5/model_B_mapi_5-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [57]:
def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

           
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

            
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [58]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.5633699633699634
Evaluated Accuracy: 0.5634
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.426098901098901
Evaluated Accuracy: 0.4261
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.4193223443223443
Evaluated Accuracy: 0.4193
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.6838827838827838
Evaluated Accuracy: 0.6839
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.6538461538461539
Evaluated Accuracy: 0.6538
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.31346153846153846
Evaluated Accuracy: 0.3135
Testing b_a_r = 0.9430, c_a_r = 0.6980, f_a_r = 0.9507




----------------------> 0.34212454212454213
Evaluated Accuracy: 0.3421
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.6737179487179488
Evaluated Accuracy: 0.6737
Testing b_a_r = 0.0000, c_a_r = 0.0594, f_a_r = 0.0000




----------------------> 0.6696886446886448
Evaluated Accuracy: 0.6697
Testing b_a_r = 0.9570, c_a_r = 0.0561, f_a_r = 0.0126




----------------------> 0.6751831501831501
Evaluated Accuracy: 0.6752
Testing b_a_r = 0.0270, c_a_r = 0.0680, f_a_r = 0.9711




----------------------> 0.6239010989010989
Evaluated Accuracy: 0.6239
Testing b_a_r = 1.0000, c_a_r = 0.0046, f_a_r = 0.4483




----------------------> 0.6870879120879121
Evaluated Accuracy: 0.6871
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.6770146520146519
Evaluated Accuracy: 0.6770
Testing b_a_r = 1.0000, c_a_r = 0.0598, f_a_r = 0.4496




----------------------> 0.6839743589743591
Evaluated Accuracy: 0.6840
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.6769




----------------------> 0.6976190476190477
Evaluated Accuracy: 0.6976
Testing b_a_r = 0.7546, c_a_r = 0.7363, f_a_r = 0.0000




----------------------> 0.28608058608058606
Evaluated Accuracy: 0.2861
Testing b_a_r = 0.8735, c_a_r = 0.4286, f_a_r = 1.0000




----------------------> 0.44184981684981683
Evaluated Accuracy: 0.4418
Testing b_a_r = 0.1027, c_a_r = 0.0226, f_a_r = 0.6130




----------------------> 0.692032967032967
Evaluated Accuracy: 0.6920
Testing b_a_r = 0.9514, c_a_r = 0.0187, f_a_r = 0.1875




----------------------> 0.6783882783882783
Evaluated Accuracy: 0.6784
Testing b_a_r = 0.1141, c_a_r = 0.9990, f_a_r = 0.9989




----------------------> 0.3170329670329671
Evaluated Accuracy: 0.3170
Testing b_a_r = 0.9253, c_a_r = 0.1248, f_a_r = 0.2144




----------------------> 0.6248168498168497
Evaluated Accuracy: 0.6248
Testing b_a_r = 0.1054, c_a_r = 0.0000, f_a_r = 0.0081




----------------------> 0.6741758241758242
Evaluated Accuracy: 0.6742
Testing b_a_r = 0.9731, c_a_r = 0.0001, f_a_r = 0.5907




----------------------> 0.6948717948717948
Evaluated Accuracy: 0.6949
Testing b_a_r = 0.9611, c_a_r = 0.0252, f_a_r = 0.7105




----------------------> 0.6937728937728938
Evaluated Accuracy: 0.6938
Testing b_a_r = 0.7342, c_a_r = 0.9968, f_a_r = 0.0033




----------------------> 0.2623626373626374
Evaluated Accuracy: 0.2624
Testing b_a_r = 0.0510, c_a_r = 0.0012, f_a_r = 0.7869




----------------------> 0.6803113553113553
Evaluated Accuracy: 0.6803
Testing b_a_r = 0.0219, c_a_r = 0.5571, f_a_r = 0.5672




----------------------> 0.36959706959706956
Evaluated Accuracy: 0.3696
Testing b_a_r = 0.8358, c_a_r = 0.1511, f_a_r = 0.0048




----------------------> 0.5867216117216117
Evaluated Accuracy: 0.5867
Testing b_a_r = 0.9882, c_a_r = 0.2726, f_a_r = 1.0000




----------------------> 0.4372710622710623
Evaluated Accuracy: 0.4373
Testing b_a_r = 0.8764, c_a_r = 0.0937, f_a_r = 0.5776




----------------------> 0.6756410256410257
Evaluated Accuracy: 0.6756
Testing b_a_r = 0.9795, c_a_r = 0.0410, f_a_r = 0.6012




----------------------> 0.6945970695970696
Evaluated Accuracy: 0.6946
Testing b_a_r = 0.0262, c_a_r = 0.8438, f_a_r = 0.6623




----------------------> 0.32454212454212455
Evaluated Accuracy: 0.3245
Testing b_a_r = 0.2662, c_a_r = 0.5779, f_a_r = 0.0012




----------------------> 0.322985347985348
Evaluated Accuracy: 0.3230
Testing b_a_r = 0.9422, c_a_r = 0.3748, f_a_r = 0.5915




----------------------> 0.4926739926739927
Evaluated Accuracy: 0.4927
Testing b_a_r = 0.9441, c_a_r = 0.2063, f_a_r = 0.4298




----------------------> 0.5402930402930401
Evaluated Accuracy: 0.5403
Testing b_a_r = 0.7680, c_a_r = 0.5510, f_a_r = 0.9986




----------------------> 0.3716117216117217
Evaluated Accuracy: 0.3716
Testing b_a_r = 0.0994, c_a_r = 0.0522, f_a_r = 0.2436




----------------------> 0.6716117216117217
Evaluated Accuracy: 0.6716
Testing b_a_r = 0.8195, c_a_r = 0.8640, f_a_r = 0.9997




----------------------> 0.3232600732600733
Evaluated Accuracy: 0.3233
Testing b_a_r = 0.1049, c_a_r = 0.0005, f_a_r = 0.3424




----------------------> 0.6787545787545786
Evaluated Accuracy: 0.6788
Testing b_a_r = 0.0196, c_a_r = 0.0667, f_a_r = 0.5455




----------------------> 0.6817765567765568
Evaluated Accuracy: 0.6818
Testing b_a_r = 0.2177, c_a_r = 0.8743, f_a_r = 0.0017




----------------------> 0.27097069597069595
Evaluated Accuracy: 0.2710
Testing b_a_r = 0.9361, c_a_r = 0.0999, f_a_r = 0.7515




----------------------> 0.6608058608058608
Evaluated Accuracy: 0.6608
Testing b_a_r = 0.8913, c_a_r = 0.1336, f_a_r = 0.9968




----------------------> 0.5963369963369963
Evaluated Accuracy: 0.5963
Testing b_a_r = 0.0120, c_a_r = 0.0308, f_a_r = 0.4607




----------------------> 0.6856227106227106
Evaluated Accuracy: 0.6856
Testing b_a_r = 0.0163, c_a_r = 0.0007, f_a_r = 0.5384




----------------------> 0.6923992673992673
Evaluated Accuracy: 0.6924
Testing b_a_r = 0.9761, c_a_r = 0.0344, f_a_r = 0.8604




----------------------> 0.6646520146520147
Evaluated Accuracy: 0.6647
Testing b_a_r = 0.9463, c_a_r = 0.0004, f_a_r = 0.6586




----------------------> 0.6974358974358974
Evaluated Accuracy: 0.6974
Testing b_a_r = 0.9441, c_a_r = 0.0217, f_a_r = 0.5360




----------------------> 0.6927655677655677
Evaluated Accuracy: 0.6928
Testing b_a_r = 0.0081, c_a_r = 0.1178, f_a_r = 0.4776




----------------------> 0.6437728937728937
Evaluated Accuracy: 0.6438
Testing b_a_r = 0.9788, c_a_r = 0.6988, f_a_r = 0.4406




----------------------> 0.33452380952380956
Evaluated Accuracy: 0.3345

✅ Optimal Values:
   - b_a_r: 1.0000
   - c_a_r: 0.0000
   - f_a_r: 0.6769
📈 Highest Accuracy Achieved: 0.6976


In [59]:
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_mapii/result_ressnet/base/8_w_5_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/secondary/8_w_5_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/amul/8_w_5_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.6737 ± 0.0556  0.6849 ± 0.0563  0.6737 ± 0.0556




---Accu-----pre----rec---------> 0.5032 ± 0.0551  0.5033 ± 0.0567  0.5032 ± 0.0551
---Accu-----pre----rec---------> 0.6976 ± 0.0490  0.7105 ± 0.0497  0.6976 ± 0.0490


In [60]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/mapi_sce/ressecondary/model_10/model_B_mapi_10-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [61]:
def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

           
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

            
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [62]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=80
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.6320736434108528
Evaluated Accuracy: 0.6321
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.26889534883720934
Evaluated Accuracy: 0.2689
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.3439922480620155
Evaluated Accuracy: 0.3440
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.7171511627906977
Evaluated Accuracy: 0.7172
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.6761627906976744
Evaluated Accuracy: 0.6762
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.20920542635658915
Evaluated Accuracy: 0.2092
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.6807




----------------------> 0.7307170542635659
Evaluated Accuracy: 0.7307
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.7765




----------------------> 0.7161821705426357
Evaluated Accuracy: 0.7162
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.1948




----------------------> 0.7061046511627908
Evaluated Accuracy: 0.7061
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.4942




----------------------> 0.7289728682170543
Evaluated Accuracy: 0.7290
Testing b_a_r = 0.1080, c_a_r = 0.6058, f_a_r = 1.0000




----------------------> 0.21017441860465116
Evaluated Accuracy: 0.2102
Testing b_a_r = 0.0000, c_a_r = 0.0451, f_a_r = 0.4655




----------------------> 0.7223837209302328
Evaluated Accuracy: 0.7224
Testing b_a_r = 0.8923, c_a_r = 0.0712, f_a_r = 0.0058




----------------------> 0.7060077519379846
Evaluated Accuracy: 0.7060
Testing b_a_r = 0.0652, c_a_r = 0.1183, f_a_r = 0.3546




----------------------> 0.6954457364341086
Evaluated Accuracy: 0.6954
Testing b_a_r = 0.1510, c_a_r = 0.7793, f_a_r = 0.0012




----------------------> 0.21346899224806204
Evaluated Accuracy: 0.2135
Testing b_a_r = 0.2203, c_a_r = 0.1103, f_a_r = 0.9931




----------------------> 0.6683139534883722
Evaluated Accuracy: 0.6683
Testing b_a_r = 0.3346, c_a_r = 0.1436, f_a_r = 0.0005




----------------------> 0.6665697674418604
Evaluated Accuracy: 0.6666
Testing b_a_r = 0.0729, c_a_r = 0.0000, f_a_r = 0.0123




----------------------> 0.7078488372093024
Evaluated Accuracy: 0.7078
Testing b_a_r = 0.0492, c_a_r = 0.0541, f_a_r = 0.0905




----------------------> 0.7090116279069767
Evaluated Accuracy: 0.7090
Testing b_a_r = 0.4386, c_a_r = 0.3428, f_a_r = 0.9983




----------------------> 0.3608527131782946
Evaluated Accuracy: 0.3609
Testing b_a_r = 0.7402, c_a_r = 0.8760, f_a_r = 0.9974




----------------------> 0.20804263565891476
Evaluated Accuracy: 0.2080
Testing b_a_r = 0.0544, c_a_r = 0.0441, f_a_r = 0.9940




----------------------> 0.6730620155038758
Evaluated Accuracy: 0.6731
Testing b_a_r = 0.0479, c_a_r = 0.0002, f_a_r = 0.5427




----------------------> 0.7352713178294574
Evaluated Accuracy: 0.7353
Testing b_a_r = 0.0109, c_a_r = 0.1082, f_a_r = 0.6766




----------------------> 0.7080426356589148
Evaluated Accuracy: 0.7080
Testing b_a_r = 0.7342, c_a_r = 0.9968, f_a_r = 0.0033




----------------------> 0.2041666666666667
Evaluated Accuracy: 0.2042
Testing b_a_r = 0.9459, c_a_r = 0.0170, f_a_r = 0.5886




----------------------> 0.7354651162790696
Evaluated Accuracy: 0.7355
Testing b_a_r = 0.0988, c_a_r = 0.0227, f_a_r = 0.6008




----------------------> 0.733720930232558
Evaluated Accuracy: 0.7337
Testing b_a_r = 0.9808, c_a_r = 0.0668, f_a_r = 0.5532




----------------------> 0.7237403100775194
Evaluated Accuracy: 0.7237
Testing b_a_r = 0.1945, c_a_r = 0.9965, f_a_r = 0.9927




----------------------> 0.2085271317829457
Evaluated Accuracy: 0.2085
Testing b_a_r = 0.9154, c_a_r = 0.0015, f_a_r = 0.5983




----------------------> 0.7363372093023256
Evaluated Accuracy: 0.7363
Testing b_a_r = 0.0934, c_a_r = 0.6481, f_a_r = 0.4202




----------------------> 0.21298449612403106
Evaluated Accuracy: 0.2130
Testing b_a_r = 0.9448, c_a_r = 0.0626, f_a_r = 0.2830




----------------------> 0.7085271317829458
Evaluated Accuracy: 0.7085
Testing b_a_r = 0.9899, c_a_r = 0.1937, f_a_r = 0.4146




----------------------> 0.6430232558139534
Evaluated Accuracy: 0.6430
Testing b_a_r = 0.6198, c_a_r = 0.5964, f_a_r = 0.0010




----------------------> 0.22693798449612404
Evaluated Accuracy: 0.2269
Testing b_a_r = 0.8698, c_a_r = 0.1064, f_a_r = 0.1519




----------------------> 0.6966085271317828
Evaluated Accuracy: 0.6966
Testing b_a_r = 0.8552, c_a_r = 0.2076, f_a_r = 0.9980




----------------------> 0.5817829457364342
Evaluated Accuracy: 0.5818
Testing b_a_r = 0.4197, c_a_r = 0.4596, f_a_r = 0.6094




----------------------> 0.24137596899224809
Evaluated Accuracy: 0.2414
Testing b_a_r = 0.0632, c_a_r = 0.0059, f_a_r = 0.3824




----------------------> 0.719767441860465
Evaluated Accuracy: 0.7198
Testing b_a_r = 0.9363, c_a_r = 0.0319, f_a_r = 0.0021




----------------------> 0.7092054263565892
Evaluated Accuracy: 0.7092
Testing b_a_r = 0.9444, c_a_r = 0.1252, f_a_r = 0.5135




----------------------> 0.6999031007751938
Evaluated Accuracy: 0.6999
Testing b_a_r = 0.9912, c_a_r = 0.2323, f_a_r = 0.0081




----------------------> 0.5677325581395348
Evaluated Accuracy: 0.5677
Testing b_a_r = 0.9166, c_a_r = 0.0222, f_a_r = 0.5418




----------------------> 0.7320736434108528
Evaluated Accuracy: 0.7321
Testing b_a_r = 0.6777, c_a_r = 0.7459, f_a_r = 0.9954




----------------------> 0.20872093023255814
Evaluated Accuracy: 0.2087
Testing b_a_r = 0.0189, c_a_r = 0.0011, f_a_r = 0.5610




----------------------> 0.7367248062015505
Evaluated Accuracy: 0.7367
Testing b_a_r = 0.0884, c_a_r = 0.0241, f_a_r = 0.6641




----------------------> 0.7308139534883722
Evaluated Accuracy: 0.7308
Testing b_a_r = 0.9585, c_a_r = 0.8587, f_a_r = 0.4424




----------------------> 0.20872093023255814
Evaluated Accuracy: 0.2087
Testing b_a_r = 0.0975, c_a_r = 0.2701, f_a_r = 0.6229




----------------------> 0.4967054263565892
Evaluated Accuracy: 0.4967
Testing b_a_r = 0.6874, c_a_r = 0.1670, f_a_r = 0.1771




----------------------> 0.6440891472868217
Evaluated Accuracy: 0.6441
Testing b_a_r = 0.9981, c_a_r = 0.0912, f_a_r = 0.8552




----------------------> 0.6758720930232557
Evaluated Accuracy: 0.6759
Testing b_a_r = 0.0243, c_a_r = 0.0516, f_a_r = 0.6315




----------------------> 0.7248062015503877
Evaluated Accuracy: 0.7248
Testing b_a_r = 0.8629, c_a_r = 0.0066, f_a_r = 0.6215




----------------------> 0.7350775193798449
Evaluated Accuracy: 0.7351
Testing b_a_r = 0.8080, c_a_r = 0.0001, f_a_r = 0.6218




----------------------> 0.7376937984496125
Evaluated Accuracy: 0.7377
Testing b_a_r = 0.5135, c_a_r = 0.0120, f_a_r = 0.5693




----------------------> 0.7343992248062015
Evaluated Accuracy: 0.7344
Testing b_a_r = 0.0378, c_a_r = 0.1726, f_a_r = 0.5979




----------------------> 0.6705426356589147
Evaluated Accuracy: 0.6705
Testing b_a_r = 0.8438, c_a_r = 0.0776, f_a_r = 0.4370




----------------------> 0.7170542635658914
Evaluated Accuracy: 0.7171
Testing b_a_r = 0.0102, c_a_r = 0.0008, f_a_r = 0.1679




----------------------> 0.7092054263565893
Evaluated Accuracy: 0.7092
Testing b_a_r = 0.0172, c_a_r = 0.0730, f_a_r = 0.0004




----------------------> 0.7066860465116279
Evaluated Accuracy: 0.7067
Testing b_a_r = 0.0602, c_a_r = 0.0009, f_a_r = 0.8879




----------------------> 0.686046511627907
Evaluated Accuracy: 0.6860
Testing b_a_r = 0.9948, c_a_r = 0.0015, f_a_r = 0.6680




----------------------> 0.7350775193798449
Evaluated Accuracy: 0.7351
Testing b_a_r = 0.0758, c_a_r = 0.0222, f_a_r = 0.2677




----------------------> 0.7125000000000001
Evaluated Accuracy: 0.7125
Testing b_a_r = 0.0910, c_a_r = 0.0013, f_a_r = 0.5772




----------------------> 0.7364341085271316
Evaluated Accuracy: 0.7364
Testing b_a_r = 0.9609, c_a_r = 0.0285, f_a_r = 0.3901




----------------------> 0.7189922480620156
Evaluated Accuracy: 0.7190
Testing b_a_r = 0.0802, c_a_r = 0.7426, f_a_r = 0.6837




----------------------> 0.20833333333333334
Evaluated Accuracy: 0.2083
Testing b_a_r = 0.9269, c_a_r = 0.0246, f_a_r = 0.7042




----------------------> 0.7271317829457364
Evaluated Accuracy: 0.7271
Testing b_a_r = 0.8687, c_a_r = 0.4722, f_a_r = 0.9966




----------------------> 0.2336240310077519
Evaluated Accuracy: 0.2336
Testing b_a_r = 0.0067, c_a_r = 0.0849, f_a_r = 0.5406




----------------------> 0.7207364341085272
Evaluated Accuracy: 0.7207
Testing b_a_r = 0.0047, c_a_r = 0.0133, f_a_r = 0.4981




----------------------> 0.7287790697674417
Evaluated Accuracy: 0.7288
Testing b_a_r = 0.9722, c_a_r = 0.0770, f_a_r = 0.6475




----------------------> 0.7213178294573644
Evaluated Accuracy: 0.7213
Testing b_a_r = 0.0957, c_a_r = 0.8955, f_a_r = 0.0060




----------------------> 0.20901162790697672
Evaluated Accuracy: 0.2090
Testing b_a_r = 0.0234, c_a_r = 0.0010, f_a_r = 0.6246




----------------------> 0.7375000000000002
Evaluated Accuracy: 0.7375
Testing b_a_r = 0.9148, c_a_r = 0.0221, f_a_r = 0.1124




----------------------> 0.7080426356589147
Evaluated Accuracy: 0.7080
Testing b_a_r = 0.9876, c_a_r = 0.0006, f_a_r = 0.5291




----------------------> 0.7318798449612404
Evaluated Accuracy: 0.7319
Testing b_a_r = 0.8868, c_a_r = 0.0432, f_a_r = 0.5592




----------------------> 0.7315891472868217
Evaluated Accuracy: 0.7316
Testing b_a_r = 0.9408, c_a_r = 0.0024, f_a_r = 0.6184




----------------------> 0.7367248062015503
Evaluated Accuracy: 0.7367
Testing b_a_r = 0.7406, c_a_r = 0.0015, f_a_r = 0.3075




----------------------> 0.7127906976744185
Evaluated Accuracy: 0.7128
Testing b_a_r = 0.7743, c_a_r = 0.0001, f_a_r = 0.5984




----------------------> 0.736046511627907
Evaluated Accuracy: 0.7360
Testing b_a_r = 0.9739, c_a_r = 0.0281, f_a_r = 0.6210




----------------------> 0.7332364341085271
Evaluated Accuracy: 0.7332
Testing b_a_r = 0.3857, c_a_r = 0.0002, f_a_r = 0.6652




----------------------> 0.7335271317829457
Evaluated Accuracy: 0.7335
Testing b_a_r = 0.0466, c_a_r = 0.0248, f_a_r = 0.7781




----------------------> 0.7093023255813954
Evaluated Accuracy: 0.7093
Testing b_a_r = 0.0072, c_a_r = 0.0653, f_a_r = 0.1990




----------------------> 0.7085271317829458
Evaluated Accuracy: 0.7085

✅ Optimal Values:
   - b_a_r: 0.8080
   - c_a_r: 0.0001
   - f_a_r: 0.6218
📈 Highest Accuracy Achieved: 0.7377


In [63]:
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_mapii/result_ressnet/base/8_w_10_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/secondary/8_w_10_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/amul/8_w_10_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.7081 ± 0.0452  0.7141 ± 0.0447  0.7081 ± 0.0452




---Accu-----pre----rec---------> 0.4578 ± 0.0421  0.4678 ± 0.0463  0.4578 ± 0.0421
---Accu-----pre----rec---------> 0.7377 ± 0.0318  0.7442 ± 0.0288  0.7377 ± 0.0318


In [64]:
###############################################   12 way 1 shot

In [65]:
# clss_20=
# clss_1_shot_20_w=
# clss_5_shot_20_w=
# clss_10_shot_20_w=

In [66]:
clss_support_imagesss_10_shot=[[[21, 14, 23, 29, 67, 54, 10, 91, 61, 52], [40, 71, 6, 67, 65, 37, 0, 74, 11, 83], [89, 58, 33, 13, 23, 88, 77, 62, 18, 34], [6, 24, 74, 14, 37, 71, 72, 75, 91, 38], [10, 35, 15, 91, 71, 74, 86, 12, 80, 23], [93, 24, 81, 78, 83, 49, 15, 85, 10, 25], [32, 73, 51, 14, 42, 87, 74, 77, 70, 61], [36, 0, 34, 82, 39, 31, 1, 52, 27, 50], [89, 55, 56, 32, 80, 58, 53, 76, 63, 2], [22, 14, 80, 1, 6, 69, 50, 49, 72, 10], [37, 9, 46, 70, 12, 84, 23, 25, 75, 36], [7, 81, 29, 40, 46, 8, 37, 57, 52, 2]], [[66, 39, 90, 45, 52, 93, 58, 20, 12, 95], [63, 61, 91, 49, 5, 53, 56, 15, 57, 62], [93, 72, 86, 79, 27, 18, 8, 23, 22, 54], [67, 3, 58, 84, 72, 30, 1, 32, 94, 43], [15, 86, 95, 33, 54, 94, 28, 12, 93, 40], [94, 91, 48, 21, 86, 35, 76, 66, 32, 70], [46, 24, 71, 61, 64, 88, 63, 39, 70, 44], [34, 48, 62, 63, 64, 82, 57, 15, 47, 65], [58, 60, 72, 49, 52, 73, 71, 79, 33, 44], [39, 64, 28, 15, 20, 78, 50, 38, 59, 86], [26, 78, 67, 34, 92, 40, 48, 90, 53, 8], [39, 42, 64, 40, 32, 46, 95, 7, 85, 44]], [[64, 80, 7, 65, 40, 5, 13, 49, 58, 19], [71, 66, 67, 84, 64, 19, 5, 82, 33, 4], [38, 82, 13, 91, 14, 34, 67, 37, 58, 85], [26, 64, 89, 84, 95, 41, 66, 2, 77, 8], [17, 83, 77, 37, 39, 25, 86, 92, 79, 22], [52, 2, 60, 37, 73, 17, 89, 8, 4, 64], [71, 9, 59, 87, 49, 31, 4, 79, 22, 0], [21, 6, 79, 86, 74, 34, 19, 82, 62, 33], [45, 36, 26, 57, 22, 93, 3, 27, 2, 89], [4, 71, 51, 25, 84, 24, 10, 5, 85, 27], [29, 61, 79, 82, 18, 40, 70, 26, 32, 68], [56, 63, 64, 66, 70, 45, 87, 48, 88, 18]], [[64, 80, 50, 62, 0, 8, 58, 34, 91, 52], [39, 8, 57, 52, 44, 91, 80, 69, 60, 81], [36, 27, 19, 54, 11, 95, 71, 80, 9, 51], [63, 70, 2, 59, 10, 89, 43, 37, 27, 21], [26, 50, 86, 12, 47, 30, 52, 32, 7, 67], [17, 29, 94, 6, 50, 28, 43, 25, 56, 70], [31, 39, 56, 32, 43, 76, 71, 19, 49, 27], [2, 3, 64, 92, 57, 46, 6, 67, 78, 10], [29, 84, 78, 36, 59, 14, 41, 42, 0, 90], [63, 2, 13, 39, 16, 1, 14, 31, 77, 44], [9, 14, 6, 60, 95, 61, 19, 40, 20, 39], [7, 3, 68, 59, 11, 5, 61, 21, 1, 14]], [[48, 20, 92, 33, 18, 91, 35, 52, 41, 86], [59, 8, 20, 88, 73, 72, 94, 11, 28, 66], [62, 28, 50, 93, 79, 95, 88, 10, 14, 80], [38, 94, 71, 34, 27, 18, 64, 79, 15, 11], [74, 94, 13, 10, 12, 24, 70, 48, 41, 50], [7, 10, 83, 95, 54, 71, 13, 40, 4, 60], [0, 47, 66, 61, 65, 60, 24, 43, 83, 63], [65, 51, 45, 63, 75, 10, 83, 42, 27, 58], [70, 11, 61, 8, 0, 6, 25, 31, 24, 62], [72, 68, 43, 4, 88, 22, 92, 41, 27, 91], [56, 92, 14, 34, 9, 21, 68, 76, 49, 4], [68, 49, 17, 32, 5, 90, 26, 13, 60, 69]], [[55, 74, 21, 10, 8, 88, 31, 53, 92, 59], [40, 63, 76, 34, 13, 31, 43, 48, 66, 18], [74, 3, 8, 34, 60, 47, 78, 50, 27, 77], [29, 20, 16, 53, 52, 77, 14, 47, 2, 94], [45, 88, 59, 54, 77, 55, 90, 84, 32, 65], [6, 87, 81, 36, 95, 88, 73, 21, 57, 67], [58, 43, 30, 94, 15, 49, 62, 53, 31, 87], [78, 6, 85, 70, 33, 77, 24, 89, 29, 1], [41, 69, 86, 23, 17, 43, 94, 32, 56, 8], [95, 80, 52, 35, 43, 22, 29, 67, 82, 63], [41, 73, 27, 28, 78, 19, 23, 69, 13, 68], [2, 12, 80, 42, 86, 28, 14, 70, 39, 91]], [[76, 3, 44, 35, 67, 32, 63, 27, 26, 68], [75, 63, 0, 71, 23, 45, 36, 81, 69, 72], [44, 19, 88, 73, 83, 90, 33, 45, 93, 11], [3, 27, 4, 48, 74, 50, 78, 80, 18, 95], [35, 78, 45, 70, 43, 65, 0, 11, 28, 37], [10, 8, 60, 84, 12, 36, 88, 41, 31, 46], [26, 54, 82, 20, 41, 51, 48, 92, 31, 49], [39, 82, 95, 38, 12, 15, 63, 83, 21, 90], [61, 48, 54, 83, 46, 23, 44, 22, 38, 32], [27, 92, 58, 35, 76, 8, 89, 23, 83, 91], [84, 7, 5, 83, 18, 3, 76, 21, 44, 20], [70, 59, 37, 4, 56, 80, 84, 76, 50, 38]], [[3, 7, 56, 39, 64, 92, 12, 21, 55, 20], [34, 81, 92, 86, 30, 93, 82, 29, 57, 63], [52, 36, 19, 14, 10, 1, 25, 17, 29, 34], [14, 62, 69, 59, 3, 28, 36, 79, 67, 7], [89, 60, 16, 44, 86, 56, 39, 12, 20, 64], [45, 73, 50, 80, 60, 41, 93, 86, 24, 49], [89, 70, 24, 16, 26, 23, 35, 12, 41, 72], [45, 0, 34, 59, 51, 6, 12, 66, 25, 35], [7, 44, 73, 38, 78, 89, 11, 87, 4, 80], [0, 33, 63, 12, 56, 23, 60, 21, 47, 14], [37, 66, 42, 26, 7, 63, 52, 50, 13, 19], [10, 49, 85, 47, 78, 27, 3, 48, 53, 15]], [[36, 69, 29, 48, 68, 42, 18, 49, 78, 50], [1, 92, 9, 75, 87, 88, 25, 50, 32, 39], [72, 2, 17, 41, 15, 16, 26, 0, 14, 51], [29, 90, 48, 92, 79, 40, 52, 88, 35, 7], [73, 6, 79, 14, 29, 33, 91, 76, 10, 36], [56, 64, 85, 67, 9, 31, 93, 60, 58, 5], [27, 23, 88, 28, 26, 92, 1, 84, 70, 21], [26, 48, 40, 54, 95, 83, 60, 62, 33, 44], [37, 24, 26, 25, 69, 58, 80, 18, 78, 71], [25, 68, 88, 1, 87, 53, 36, 35, 45, 30], [32, 78, 30, 9, 80, 62, 77, 28, 13, 89], [63, 1, 26, 30, 3, 64, 67, 70, 29, 76]], [[65, 53, 94, 51, 89, 32, 26, 38, 62, 61], [95, 84, 72, 74, 79, 30, 47, 17, 15, 64], [90, 67, 6, 14, 81, 34, 50, 26, 86, 17], [70, 59, 21, 3, 77, 62, 63, 68, 34, 90], [3, 84, 46, 4, 5, 2, 12, 89, 74, 45], [92, 74, 95, 54, 60, 86, 83, 43, 59, 78], [2, 36, 29, 67, 22, 94, 78, 16, 30, 74], [63, 53, 89, 23, 86, 35, 57, 49, 66, 21], [14, 54, 48, 58, 76, 25, 63, 30, 24, 78], [52, 32, 1, 95, 54, 57, 16, 3, 89, 14], [11, 67, 66, 33, 2, 72, 10, 4, 79, 5], [71, 1, 42, 75, 60, 84, 16, 88, 4, 19]], [[32, 70, 12, 11, 93, 36, 68, 69, 23, 78], [44, 83, 87, 53, 94, 31, 66, 95, 65, 55], [32, 46, 89, 41, 62, 81, 52, 92, 1, 47], [25, 7, 24, 85, 81, 93, 58, 72, 34, 67], [64, 74, 12, 41, 11, 52, 48, 26, 93, 34], [33, 58, 64, 56, 38, 37, 61, 14, 3, 50], [22, 93, 23, 70, 64, 63, 5, 40, 33, 9], [86, 21, 93, 1, 26, 32, 18, 81, 52, 10], [63, 89, 24, 45, 1, 51, 36, 6, 95, 57], [83, 13, 40, 31, 92, 88, 82, 51, 95, 52], [86, 62, 33, 31, 72, 18, 74, 12, 0, 68], [11, 20, 70, 34, 64, 25, 30, 9, 46, 18]], [[43, 76, 1, 61, 59, 74, 49, 91, 87, 94], [73, 26, 81, 83, 47, 74, 15, 45, 12, 16], [30, 34, 26, 47, 28, 94, 57, 4, 35, 68], [89, 13, 40, 58, 73, 12, 72, 0, 85, 94], [17, 2, 89, 75, 41, 95, 59, 53, 84, 86], [81, 57, 20, 72, 6, 14, 85, 64, 41, 37], [49, 30, 6, 48, 58, 31, 5, 20, 67, 85], [51, 74, 16, 14, 34, 84, 79, 75, 77, 30], [23, 51, 33, 83, 79, 81, 53, 61, 49, 84], [6, 66, 51, 95, 4, 59, 85, 71, 69, 2], [40, 4, 26, 87, 56, 17, 71, 57, 13, 85], [57, 41, 15, 67, 83, 87, 44, 93, 12, 30]], [[4, 21, 66, 88, 71, 57, 18, 16, 73, 41], [69, 60, 38, 32, 92, 53, 54, 47, 2, 27], [40, 17, 59, 39, 74, 18, 38, 23, 95, 2], [21, 29, 76, 74, 9, 68, 32, 35, 69, 51], [70, 60, 13, 49, 94, 14, 24, 8, 73, 67], [86, 32, 54, 13, 58, 87, 27, 92, 24, 83], [21, 41, 24, 56, 88, 27, 30, 11, 71, 67], [15, 12, 86, 91, 82, 38, 58, 76, 2, 33], [45, 81, 95, 38, 43, 79, 27, 61, 64, 87], [66, 1, 90, 77, 85, 15, 50, 53, 87, 21], [92, 39, 7, 50, 89, 36, 81, 68, 52, 23], [5, 87, 49, 55, 57, 66, 18, 37, 79, 50]], [[22, 32, 91, 72, 81, 69, 8, 34, 15, 20], [89, 31, 91, 27, 40, 90, 72, 19, 4, 9], [5, 60, 51, 34, 8, 56, 88, 52, 78, 69], [86, 5, 9, 50, 18, 45, 93, 28, 89, 30], [26, 35, 19, 90, 58, 10, 79, 52, 92, 45], [17, 91, 77, 33, 5, 40, 64, 89, 24, 60], [93, 77, 36, 60, 13, 74, 75, 71, 45, 14], [48, 42, 78, 58, 29, 36, 65, 74, 4, 85], [33, 81, 90, 38, 88, 66, 34, 49, 26, 93], [30, 12, 36, 7, 15, 6, 0, 13, 25, 43], [8, 86, 52, 10, 32, 4, 69, 6, 79, 65], [37, 2, 41, 15, 45, 36, 76, 32, 3, 22]], [[50, 81, 36, 1, 39, 63, 90, 72, 58, 77], [0, 30, 38, 83, 13, 14, 59, 90, 23, 22], [84, 60, 45, 23, 0, 48, 32, 18, 24, 29], [92, 19, 29, 89, 25, 75, 53, 24, 45, 66], [90, 92, 53, 85, 0, 23, 70, 88, 51, 39], [75, 14, 70, 24, 10, 2, 66, 28, 18, 68], [67, 52, 46, 45, 60, 34, 59, 91, 21, 77], [29, 2, 94, 65, 75, 87, 33, 25, 55, 6], [60, 34, 75, 74, 18, 20, 31, 10, 22, 83], [18, 77, 61, 24, 53, 60, 54, 28, 39, 21], [55, 87, 43, 12, 83, 57, 10, 23, 90, 93], [33, 32, 13, 52, 16, 54, 84, 91, 81, 38]]]



clss_support_imagesss_5_shot=[[[21, 14, 23, 29, 67], [40, 71, 6, 67, 65], [89, 58, 33, 13, 23], [6, 24, 74, 14, 37], [10, 35, 15, 91, 71], [93, 24, 81, 78, 83], [32, 73, 51, 14, 42], [36, 0, 34, 82, 39], [89, 55, 56, 32, 80], [22, 14, 80, 1, 6], [37, 9, 46, 70, 12], [7, 81, 29, 40, 46]], [[66, 39, 90, 45, 52], [63, 61, 91, 49, 5], [93, 72, 86, 79, 27], [67, 3, 58, 84, 72], [15, 86, 95, 33, 54], [94, 91, 48, 21, 86], [46, 24, 71, 61, 64], [34, 48, 62, 63, 64], [58, 60, 72, 49, 52], [39, 64, 28, 15, 20], [26, 78, 67, 34, 92], [39, 42, 64, 40, 32]], [[64, 80, 7, 65, 40], [71, 66, 67, 84, 64], [38, 82, 13, 91, 14], [26, 64, 89, 84, 95], [17, 83, 77, 37, 39], [52, 2, 60, 37, 73], [71, 9, 59, 87, 49], [21, 6, 79, 86, 74], [45, 36, 26, 57, 22], [4, 71, 51, 25, 84], [29, 61, 79, 82, 18], [56, 63, 64, 66, 70]], [[64, 80, 50, 62, 0], [39, 8, 57, 52, 44], [36, 27, 19, 54, 11], [63, 70, 2, 59, 10], [26, 50, 86, 12, 47], [17, 29, 94, 6, 50], [31, 39, 56, 32, 43], [2, 3, 64, 92, 57], [29, 84, 78, 36, 59], [63, 2, 13, 39, 16], [9, 14, 6, 60, 95], [7, 3, 68, 59, 11]], [[48, 20, 92, 33, 18], [59, 8, 20, 88, 73], [62, 28, 50, 93, 79], [38, 94, 71, 34, 27], [74, 94, 13, 10, 12], [7, 10, 83, 95, 54], [0, 47, 66, 61, 65], [65, 51, 45, 63, 75], [70, 11, 61, 8, 0], [72, 68, 43, 4, 88], [56, 92, 14, 34, 9], [68, 49, 17, 32, 5]], [[55, 74, 21, 10, 8], [40, 63, 76, 34, 13], [74, 3, 8, 34, 60], [29, 20, 16, 53, 52], [45, 88, 59, 54, 77], [6, 87, 81, 36, 95], [58, 43, 30, 94, 15], [78, 6, 85, 70, 33], [41, 69, 86, 23, 17], [95, 80, 52, 35, 43], [41, 73, 27, 28, 78], [2, 12, 80, 42, 86]], [[76, 3, 44, 35, 67], [75, 63, 0, 71, 23], [44, 19, 88, 73, 83], [3, 27, 4, 48, 74], [35, 78, 45, 70, 43], [10, 8, 60, 84, 12], [26, 54, 82, 20, 41], [39, 82, 95, 38, 12], [61, 48, 54, 83, 46], [27, 92, 58, 35, 76], [84, 7, 5, 83, 18], [70, 59, 37, 4, 56]], [[3, 7, 56, 39, 64], [34, 81, 92, 86, 30], [52, 36, 19, 14, 10], [14, 62, 69, 59, 3], [89, 60, 16, 44, 86], [45, 73, 50, 80, 60], [89, 70, 24, 16, 26], [45, 0, 34, 59, 51], [7, 44, 73, 38, 78], [0, 33, 63, 12, 56], [37, 66, 42, 26, 7], [10, 49, 85, 47, 78]], [[36, 69, 29, 48, 68], [1, 92, 9, 75, 87], [72, 2, 17, 41, 15], [29, 90, 48, 92, 79], [73, 6, 79, 14, 29], [56, 64, 85, 67, 9], [27, 23, 88, 28, 26], [26, 48, 40, 54, 95], [37, 24, 26, 25, 69], [25, 68, 88, 1, 87], [32, 78, 30, 9, 80], [63, 1, 26, 30, 3]], [[65, 53, 94, 51, 89], [95, 84, 72, 74, 79], [90, 67, 6, 14, 81], [70, 59, 21, 3, 77], [3, 84, 46, 4, 5], [92, 74, 95, 54, 60], [2, 36, 29, 67, 22], [63, 53, 89, 23, 86], [14, 54, 48, 58, 76], [52, 32, 1, 95, 54], [11, 67, 66, 33, 2], [71, 1, 42, 75, 60]], [[32, 70, 12, 11, 93], [44, 83, 87, 53, 94], [32, 46, 89, 41, 62], [25, 7, 24, 85, 81], [64, 74, 12, 41, 11], [33, 58, 64, 56, 38], [22, 93, 23, 70, 64], [86, 21, 93, 1, 26], [63, 89, 24, 45, 1], [83, 13, 40, 31, 92], [86, 62, 33, 31, 72], [11, 20, 70, 34, 64]], [[43, 76, 1, 61, 59], [73, 26, 81, 83, 47], [30, 34, 26, 47, 28], [89, 13, 40, 58, 73], [17, 2, 89, 75, 41], [81, 57, 20, 72, 6], [49, 30, 6, 48, 58], [51, 74, 16, 14, 34], [23, 51, 33, 83, 79], [6, 66, 51, 95, 4], [40, 4, 26, 87, 56], [57, 41, 15, 67, 83]], [[4, 21, 66, 88, 71], [69, 60, 38, 32, 92], [40, 17, 59, 39, 74], [21, 29, 76, 74, 9], [70, 60, 13, 49, 94], [86, 32, 54, 13, 58], [21, 41, 24, 56, 88], [15, 12, 86, 91, 82], [45, 81, 95, 38, 43], [66, 1, 90, 77, 85], [92, 39, 7, 50, 89], [5, 87, 49, 55, 57]], [[22, 32, 91, 72, 81], [89, 31, 91, 27, 40], [5, 60, 51, 34, 8], [86, 5, 9, 50, 18], [26, 35, 19, 90, 58], [17, 91, 77, 33, 5], [93, 77, 36, 60, 13], [48, 42, 78, 58, 29], [33, 81, 90, 38, 88], [30, 12, 36, 7, 15], [8, 86, 52, 10, 32], [37, 2, 41, 15, 45]], [[50, 81, 36, 1, 39], [0, 30, 38, 83, 13], [84, 60, 45, 23, 0], [92, 19, 29, 89, 25], [90, 92, 53, 85, 0], [75, 14, 70, 24, 10], [67, 52, 46, 45, 60], [29, 2, 94, 65, 75], [60, 34, 75, 74, 18], [18, 77, 61, 24, 53], [55, 87, 43, 12, 83], [33, 32, 13, 52, 16]]]


clss_12=[
   ['MAPI Mayeek_GOK', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_JIL', 'MAPI Mayeek_UUN', 'MAPI Mayeek_BAA', 'MAPI Mayeek_FAM', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_KOK', 'MAPI Mayeek_LAI', 'MAPI Mayeek_DIL', 'MAPI Mayeek_SAM', 'MAPI Mayeek_ATIYA'],
   ['MAPI Mayeek_JIL', 'MAPI Mayeek_EEE', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_MIT', 'MAPI Mayeek_ATIYA', 'MAPI Mayeek_BAA', 'MAPI Mayeek_HUK', 'MAPI Mayeek_THOU', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_UUN', 'MAPI Mayeek_PAA', 'MAPI Mayeek_JHAM'],
     ['MAPI Mayeek_PAA', 'MAPI Mayeek_KOK', 'MAPI Mayeek_BAA', 'MAPI Mayeek_DIL', 'MAPI Mayeek_CHIN', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_NAA', 'MAPI Mayeek_RAAI', 'MAPI Mayeek_WAI', 'MAPI Mayeek_UUN', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_LAI'],
   ['MAPI Mayeek_KHOU', 'MAPI Mayeek_FAM', 'MAPI Mayeek_SAM', 'MAPI Mayeek_JIL', 'MAPI Mayeek_EEE', 'MAPI Mayeek_LAI', 'MAPI Mayeek_HUK', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_GOK', 'MAPI Mayeek_TIL', 'MAPI Mayeek_KOK'],
    ['MAPI Mayeek_CHIN', 'MAPI Mayeek_YANG', 'MAPI Mayeek_KOK', 'MAPI Mayeek_PAA', 'MAPI Mayeek_NAA', 'MAPI Mayeek_FAM', 'MAPI Mayeek_RAAI', 'MAPI Mayeek_THOU', 'MAPI Mayeek_EEE', 'MAPI Mayeek_TIL', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_WAI'],
    ['MAPI Mayeek_NAA', 'MAPI Mayeek_EEE', 'MAPI Mayeek_KOK', 'MAPI Mayeek_BAA', 'MAPI Mayeek_WAI', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_PAA', 'MAPI Mayeek_MIT', 'MAPI Mayeek_THOU', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_NGOU'],
   ['MAPI Mayeek_NGOU', 'MAPI Mayeek_LAI', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_JIL', 'MAPI Mayeek_SAM', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_EEE', 'MAPI Mayeek_ATIYA', 'MAPI Mayeek_RAAI', 'MAPI Mayeek_MIT', 'MAPI Mayeek_KOK', 'MAPI Mayeek_FAM'],
    ['MAPI Mayeek_SAM', 'MAPI Mayeek_LAI', 'MAPI Mayeek_CHIN', 'MAPI Mayeek_THOU', 'MAPI Mayeek_PAA', 'MAPI Mayeek_KOK', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_NAA', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_MIT', 'MAPI Mayeek_HUK', 'MAPI Mayeek_JHAM'],
   ['MAPI Mayeek_TIL', 'MAPI Mayeek_WAI', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_HUK', 'MAPI Mayeek_DIL', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_FAM', 'MAPI Mayeek_LAI', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_GOK'],
    ['MAPI Mayeek_BAA', 'MAPI Mayeek_GOK', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_LAI', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_TIL', 'MAPI Mayeek_THOU', 'MAPI Mayeek_KOK', 'MAPI Mayeek_FAM', 'MAPI Mayeek_RAAI', 'MAPI Mayeek_SAM'],
   ['MAPI Mayeek_ATIYA', 'MAPI Mayeek_UUN', 'MAPI Mayeek_PAA', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_NAA', 'MAPI Mayeek_FAM', 'MAPI Mayeek_LAI', 'MAPI Mayeek_KOK', 'MAPI Mayeek_GOK', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_SAM'],
    ['MAPI Mayeek_BHAM', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_BAA', 'MAPI Mayeek_EEE', 'MAPI Mayeek_CHIN', 'MAPI Mayeek_LAI', 'MAPI Mayeek_FAM', 'MAPI Mayeek_NAA', 'MAPI Mayeek_GOK', 'MAPI Mayeek_THOU', 'MAPI Mayeek_SAM', 'MAPI Mayeek_ATIYA'],
    ['MAPI Mayeek_BHAM', 'MAPI Mayeek_LAI', 'MAPI Mayeek_PAA', 'MAPI Mayeek_NAA', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_SAM', 'MAPI Mayeek_DIL', 'MAPI Mayeek_CHIN', 'MAPI Mayeek_JHAM', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_EEE', 'MAPI Mayeek_GOK'],
   ['MAPI Mayeek_FAM', 'MAPI Mayeek_BHAM', 'MAPI Mayeek_UUN', 'MAPI Mayeek_WAI', 'MAPI Mayeek_PAA', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_KOK', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_GOK', 'MAPI Mayeek_EEE', 'MAPI Mayeek_YANG', 'MAPI Mayeek_JHAM'],
    ['MAPI Mayeek_PAA', 'MAPI Mayeek_BAA', 'MAPI Mayeek_SAM', 'MAPI Mayeek_HUK', 'MAPI Mayeek_NGOU', 'MAPI Mayeek_FAM', 'MAPI Mayeek_LAI', 'MAPI Mayeek_GHOU', 'MAPI Mayeek_DHOU', 'MAPI Mayeek_GOK', 'MAPI Mayeek_NAA', 'MAPI Mayeek_EEE'],
   ]
clss_support_imagesss_1_shot=[
 [[21], [40], [89], [6], [10], [93], [32], [36], [89], [22], [37], [7]],
 [[66], [63], [93], [67], [15], [94], [46], [34], [58], [39], [26], [39]],
 [[64], [71], [38], [26], [17], [52], [71], [21], [45], [4], [29], [56]],
 [[64], [39], [36], [63], [26], [17], [31], [2], [29], [63], [9], [7]],
 [[48], [59], [62], [38], [74], [7], [0], [65], [70], [72], [56], [68]],
 [[55], [40], [74], [29], [45], [6], [58], [78], [41], [95], [41], [2]],
 [[76], [75], [44], [3], [35], [10], [26], [39], [61], [27], [84], [70]],
 [[3], [34], [52], [14], [89], [45], [89], [45], [7], [0], [37], [10]],
 [[36], [1], [72], [29], [73], [56], [27], [26], [37], [25], [32], [63]],
 [[65], [95], [90], [70], [3], [92], [2], [63], [14], [52], [11], [71]],
 [[32], [44], [32], [25], [64], [33], [22], [86], [63], [83], [86], [11]],
  [[43], [73], [30], [89], [17], [81], [49], [51], [23], [6], [40], [57]],
 [[4], [69], [40], [21], [70], [86], [21], [15], [45], [66], [92], [5]],
 [[22], [89], [5], [86], [26], [17], [93], [48], [33], [30], [8], [37]],
 [[50], [0], [84], [92], [90], [75], [67], [29], [60], [18], [55], [33]]
]











In [67]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18_35.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/mapi_sce/ressecondary/model_1/model_B_mapi_1-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')




In [68]:

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

           
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [69]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.39421052631578946
Evaluated Accuracy: 0.3942
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.21918128654970762
Evaluated Accuracy: 0.2192
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.25228070175438594
Evaluated Accuracy: 0.2523
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.45099415204678367
Evaluated Accuracy: 0.4510
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.40923976608187135
Evaluated Accuracy: 0.4092
Testing b_a_r = 0.0311, c_a_r = 0.0000, f_a_r = 0.5378




----------------------> 0.4547368421052631
Evaluated Accuracy: 0.4547
Testing b_a_r = 0.2287, c_a_r = 1.0000, f_a_r = 1.0000




----------------------> 0.10853801169590643
Evaluated Accuracy: 0.1085
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.6865




----------------------> 0.453859649122807
Evaluated Accuracy: 0.4539
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.44070175438596493
Evaluated Accuracy: 0.4407
Testing b_a_r = 0.8306, c_a_r = 1.0000, f_a_r = 0.0000




----------------------> 0.10976608187134504
Evaluated Accuracy: 0.1098
Testing b_a_r = 0.0000, c_a_r = 0.0725, f_a_r = 0.0000




----------------------> 0.43251461988304096
Evaluated Accuracy: 0.4325
Testing b_a_r = 0.0121, c_a_r = 0.0689, f_a_r = 0.9289




----------------------> 0.41350877192982455
Evaluated Accuracy: 0.4135
Testing b_a_r = 1.0000, c_a_r = 0.0422, f_a_r = 0.4297




----------------------> 0.45052631578947366
Evaluated Accuracy: 0.4505
Testing b_a_r = 0.9250, c_a_r = 0.5358, f_a_r = 1.0000




----------------------> 0.23280701754385966
Evaluated Accuracy: 0.2328
Testing b_a_r = 0.0655, c_a_r = 0.0000, f_a_r = 0.2200




----------------------> 0.43824561403508777
Evaluated Accuracy: 0.4382
Testing b_a_r = 0.1142, c_a_r = 0.0831, f_a_r = 0.5475




----------------------> 0.44760233918128656
Evaluated Accuracy: 0.4476
Testing b_a_r = 0.6185, c_a_r = 0.0361, f_a_r = 0.6056




----------------------> 0.45766081871345027
Evaluated Accuracy: 0.4577
Testing b_a_r = 0.7125, c_a_r = 0.7413, f_a_r = 0.0005




----------------------> 0.1916374269005848
Evaluated Accuracy: 0.1916
Testing b_a_r = 0.5319, c_a_r = 0.3137, f_a_r = 0.9968




----------------------> 0.2426900584795322
Evaluated Accuracy: 0.2427
Testing b_a_r = 0.6139, c_a_r = 0.7555, f_a_r = 0.9972




----------------------> 0.1561988304093567
Evaluated Accuracy: 0.1562
Testing b_a_r = 0.1496, c_a_r = 0.0004, f_a_r = 0.0633




----------------------> 0.4371345029239766
Evaluated Accuracy: 0.4371
Testing b_a_r = 0.9833, c_a_r = 0.0868, f_a_r = 0.0208




----------------------> 0.4301754385964912
Evaluated Accuracy: 0.4302
Testing b_a_r = 0.9731, c_a_r = 0.0002, f_a_r = 0.5907




----------------------> 0.45748538011695905
Evaluated Accuracy: 0.4575
Testing b_a_r = 0.6131, c_a_r = 0.1618, f_a_r = 0.3920




----------------------> 0.40146198830409363
Evaluated Accuracy: 0.4015
Testing b_a_r = 0.9981, c_a_r = 0.6193, f_a_r = 0.4876




----------------------> 0.23497076023391816
Evaluated Accuracy: 0.2350
Testing b_a_r = 0.3784, c_a_r = 0.1765, f_a_r = 0.0017




----------------------> 0.3761988304093567
Evaluated Accuracy: 0.3762
Testing b_a_r = 0.8623, c_a_r = 0.8642, f_a_r = 0.4550




----------------------> 0.13403508771929826
Evaluated Accuracy: 0.1340
Testing b_a_r = 0.9172, c_a_r = 0.4251, f_a_r = 0.6014




----------------------> 0.21994152046783627
Evaluated Accuracy: 0.2199
Testing b_a_r = 0.8892, c_a_r = 0.1062, f_a_r = 0.7132




----------------------> 0.44269005847953213
Evaluated Accuracy: 0.4427
Testing b_a_r = 0.1213, c_a_r = 0.0791, f_a_r = 0.2338




----------------------> 0.43362573099415197
Evaluated Accuracy: 0.4336
Testing b_a_r = 0.9530, c_a_r = 0.1640, f_a_r = 0.9984




----------------------> 0.386140350877193
Evaluated Accuracy: 0.3861
Testing b_a_r = 0.5067, c_a_r = 0.5918, f_a_r = 0.0042




----------------------> 0.25695906432748544
Evaluated Accuracy: 0.2570
Testing b_a_r = 0.9192, c_a_r = 0.0356, f_a_r = 0.0530




----------------------> 0.44157894736842107
Evaluated Accuracy: 0.4416
Testing b_a_r = 0.8416, c_a_r = 0.0009, f_a_r = 0.4622




----------------------> 0.4536842105263158
Evaluated Accuracy: 0.4537
Testing b_a_r = 0.0319, c_a_r = 0.0204, f_a_r = 0.6819




----------------------> 0.453859649122807
Evaluated Accuracy: 0.4539
Testing b_a_r = 0.8886, c_a_r = 0.2565, f_a_r = 0.4904




----------------------> 0.33456140350877195
Evaluated Accuracy: 0.3346
Testing b_a_r = 0.4409, c_a_r = 0.9999, f_a_r = 0.4896




----------------------> 0.10888888888888888
Evaluated Accuracy: 0.1089
Testing b_a_r = 0.9509, c_a_r = 0.0977, f_a_r = 0.4202




----------------------> 0.4360818713450293
Evaluated Accuracy: 0.4361
Testing b_a_r = 0.2664, c_a_r = 0.6360, f_a_r = 0.9933




----------------------> 0.20502923976608184
Evaluated Accuracy: 0.2050
Testing b_a_r = 0.9303, c_a_r = 0.0176, f_a_r = 0.5820




----------------------> 0.45947368421052637
Evaluated Accuracy: 0.4595
Testing b_a_r = 0.2177, c_a_r = 0.8743, f_a_r = 0.0017




----------------------> 0.1339766081871345
Evaluated Accuracy: 0.1340
Testing b_a_r = 0.6609, c_a_r = 0.8845, f_a_r = 0.9963




----------------------> 0.12520467836257312
Evaluated Accuracy: 0.1252
Testing b_a_r = 0.8963, c_a_r = 0.0173, f_a_r = 0.8103




----------------------> 0.43497076023391806
Evaluated Accuracy: 0.4350
Testing b_a_r = 0.9974, c_a_r = 0.0417, f_a_r = 0.5816




----------------------> 0.45888888888888885
Evaluated Accuracy: 0.4589
Testing b_a_r = 0.9780, c_a_r = 0.1388, f_a_r = 0.6101




----------------------> 0.4315204678362573
Evaluated Accuracy: 0.4315
Testing b_a_r = 0.9828, c_a_r = 0.0414, f_a_r = 0.2885




----------------------> 0.44391812865497066
Evaluated Accuracy: 0.4439
Testing b_a_r = 0.9971, c_a_r = 0.0689, f_a_r = 0.6235




----------------------> 0.4543274853801169
Evaluated Accuracy: 0.4543
Testing b_a_r = 0.9205, c_a_r = 0.0030, f_a_r = 0.3712




----------------------> 0.44766081871345026
Evaluated Accuracy: 0.4477
Testing b_a_r = 0.9261, c_a_r = 0.1445, f_a_r = 0.1831




----------------------> 0.40649122807017546
Evaluated Accuracy: 0.4065
Testing b_a_r = 0.0675, c_a_r = 0.7259, f_a_r = 0.6179




----------------------> 0.16929824561403503
Evaluated Accuracy: 0.1693

✅ Optimal Values:
   - b_a_r: 0.9303
   - c_a_r: 0.0176
   - f_a_r: 0.5820
📈 Highest Accuracy Achieved: 0.4595


In [70]:
2+3

5

In [71]:
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_mapii/result_ressnet/base/12_w_1_s_f_20', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/secondary/12_w_1_s_f_20", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/amul/12_w_1_s_f_20", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.4365 ± 0.0519  0.4690 ± 0.0405  0.4365 ± 0.0519




---Accu-----pre----rec---------> 0.3792 ± 0.0572  0.4317 ± 0.0817  0.3792 ± 0.0572
---Accu-----pre----rec---------> 0.4595 ± 0.0541  0.4978 ± 0.0523  0.4595 ± 0.0541


In [72]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18_35.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/mapi_sce/ressecondary/model_5/model_B_mapi_5-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [73]:

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

           
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

            
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [74]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.46978021978021983
Evaluated Accuracy: 0.4698
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.3518315018315018
Evaluated Accuracy: 0.3518
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.3420024420024419
Evaluated Accuracy: 0.3420
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.6016483516483517
Evaluated Accuracy: 0.6016
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.572039072039072
Evaluated Accuracy: 0.5720
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.24633699633699635
Evaluated Accuracy: 0.2463
Testing b_a_r = 0.9430, c_a_r = 0.6975, f_a_r = 0.9507




----------------------> 0.27429792429792427
Evaluated Accuracy: 0.2743
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.576007326007326
Evaluated Accuracy: 0.5760
Testing b_a_r = 0.0000, c_a_r = 0.0838, f_a_r = 0.0000




----------------------> 0.5678876678876678
Evaluated Accuracy: 0.5679
Testing b_a_r = 0.9724, c_a_r = 0.0458, f_a_r = 0.9447




----------------------> 0.5642857142857143
Evaluated Accuracy: 0.5643
Testing b_a_r = 0.1506, c_a_r = 0.0000, f_a_r = 0.9742




----------------------> 0.5659340659340659
Evaluated Accuracy: 0.5659
Testing b_a_r = 0.9615, c_a_r = 0.0577, f_a_r = 0.0312




----------------------> 0.5841269841269842
Evaluated Accuracy: 0.5841
Testing b_a_r = 0.0494, c_a_r = 0.0589, f_a_r = 0.0320




----------------------> 0.5790598290598291
Evaluated Accuracy: 0.5791
Testing b_a_r = 0.9608, c_a_r = 0.0008, f_a_r = 0.0633




----------------------> 0.5762515262515262
Evaluated Accuracy: 0.5763
Testing b_a_r = 0.0151, c_a_r = 0.0492, f_a_r = 0.9565




----------------------> 0.5466422466422466
Evaluated Accuracy: 0.5466
Testing b_a_r = 1.0000, c_a_r = 0.0145, f_a_r = 0.6097




----------------------> 0.6092185592185593
Evaluated Accuracy: 0.6092
Testing b_a_r = 1.0000, c_a_r = 0.0627, f_a_r = 0.4964




----------------------> 0.5992063492063492
Evaluated Accuracy: 0.5992
Testing b_a_r = 0.8168, c_a_r = 0.7339, f_a_r = 0.0000




----------------------> 0.23253968253968257
Evaluated Accuracy: 0.2325
Testing b_a_r = 0.9546, c_a_r = 0.4428, f_a_r = 1.0000




----------------------> 0.35836385836385837
Evaluated Accuracy: 0.3584
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.5069




----------------------> 0.5973137973137973
Evaluated Accuracy: 0.5973
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.7393




----------------------> 0.6095238095238096
Evaluated Accuracy: 0.6095
Testing b_a_r = 0.0090, c_a_r = 0.9953, f_a_r = 0.9983




----------------------> 0.25293040293040286
Evaluated Accuracy: 0.2529
Testing b_a_r = 0.9730, c_a_r = 0.0449, f_a_r = 0.0645




----------------------> 0.583089133089133
Evaluated Accuracy: 0.5831
Testing b_a_r = 0.0162, c_a_r = 0.0186, f_a_r = 0.6475




----------------------> 0.607020757020757
Evaluated Accuracy: 0.6070
Testing b_a_r = 0.9553, c_a_r = 0.1190, f_a_r = 0.3129




----------------------> 0.551037851037851
Evaluated Accuracy: 0.5510
Testing b_a_r = 0.9426, c_a_r = 0.0240, f_a_r = 0.7217




----------------------> 0.6067765567765568
Evaluated Accuracy: 0.6068
Testing b_a_r = 0.0109, c_a_r = 0.0291, f_a_r = 0.3865




----------------------> 0.586080586080586
Evaluated Accuracy: 0.5861
Testing b_a_r = 0.8445, c_a_r = 0.9969, f_a_r = 0.0047




----------------------> 0.2032356532356533
Evaluated Accuracy: 0.2032
Testing b_a_r = 0.1281, c_a_r = 0.5569, f_a_r = 0.5264




----------------------> 0.2992673992673992
Evaluated Accuracy: 0.2993
Testing b_a_r = 0.0175, c_a_r = 0.0683, f_a_r = 0.5796




----------------------> 0.5976190476190476
Evaluated Accuracy: 0.5976
Testing b_a_r = 0.9937, c_a_r = 0.0357, f_a_r = 0.6109




----------------------> 0.6054334554334554
Evaluated Accuracy: 0.6054
Testing b_a_r = 0.9908, c_a_r = 0.0003, f_a_r = 0.7087




----------------------> 0.6117216117216118
Evaluated Accuracy: 0.6117
Testing b_a_r = 0.0801, c_a_r = 0.0004, f_a_r = 0.6051




----------------------> 0.6105006105006106
Evaluated Accuracy: 0.6105
Testing b_a_r = 0.1069, c_a_r = 0.2957, f_a_r = 0.9966




----------------------> 0.34926739926739925
Evaluated Accuracy: 0.3493
Testing b_a_r = 0.6385, c_a_r = 0.1810, f_a_r = 0.0040




----------------------> 0.44865689865689873
Evaluated Accuracy: 0.4487
Testing b_a_r = 0.9214, c_a_r = 0.1002, f_a_r = 0.6162




----------------------> 0.5923076923076923
Evaluated Accuracy: 0.5923
Testing b_a_r = 0.9753, c_a_r = 0.8454, f_a_r = 0.9903




----------------------> 0.2634310134310134
Evaluated Accuracy: 0.2634
Testing b_a_r = 0.4405, c_a_r = 0.5734, f_a_r = 0.0028




----------------------> 0.26288156288156284
Evaluated Accuracy: 0.2629
Testing b_a_r = 0.8321, c_a_r = 0.1378, f_a_r = 0.9910




----------------------> 0.503113553113553
Evaluated Accuracy: 0.5031
Testing b_a_r = 0.9647, c_a_r = 0.3855, f_a_r = 0.5651




----------------------> 0.4234432234432235
Evaluated Accuracy: 0.4234
Testing b_a_r = 0.0418, c_a_r = 0.8385, f_a_r = 0.4476




----------------------> 0.25567765567765566
Evaluated Accuracy: 0.2557
Testing b_a_r = 0.0105, c_a_r = 0.0008, f_a_r = 0.2332




----------------------> 0.5785714285714286
Evaluated Accuracy: 0.5786
Testing b_a_r = 0.0106, c_a_r = 0.0023, f_a_r = 0.7291




----------------------> 0.6075091575091576
Evaluated Accuracy: 0.6075
Testing b_a_r = 0.4295, c_a_r = 0.5613, f_a_r = 0.9997




----------------------> 0.2993284493284493
Evaluated Accuracy: 0.2993
Testing b_a_r = 0.9239, c_a_r = 0.0711, f_a_r = 0.2667




----------------------> 0.5832112332112332
Evaluated Accuracy: 0.5832
Testing b_a_r = 0.9819, c_a_r = 0.0892, f_a_r = 0.7924




----------------------> 0.5762515262515263
Evaluated Accuracy: 0.5763
Testing b_a_r = 0.9463, c_a_r = 0.0004, f_a_r = 0.6586




----------------------> 0.6129426129426129
Evaluated Accuracy: 0.6129
Testing b_a_r = 0.9401, c_a_r = 0.2371, f_a_r = 0.4316




----------------------> 0.41623931623931637
Evaluated Accuracy: 0.4162
Testing b_a_r = 0.0540, c_a_r = 0.8819, f_a_r = 0.0116




----------------------> 0.2177655677655678
Evaluated Accuracy: 0.2178
Testing b_a_r = 0.0540, c_a_r = 0.1000, f_a_r = 0.4545




----------------------> 0.5741758241758241
Evaluated Accuracy: 0.5742

✅ Optimal Values:
   - b_a_r: 0.9463
   - c_a_r: 0.0004
   - f_a_r: 0.6586
📈 Highest Accuracy Achieved: 0.6129


In [75]:
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_mapii/result_ressnet/base/12_w_5_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/secondary/12_w_5_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/amul/12_w_5_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.5760 ± 0.0509  0.5864 ± 0.0505  0.5760 ± 0.0509




---Accu-----pre----rec---------> 0.4195 ± 0.0390  0.4205 ± 0.0392  0.4195 ± 0.0390
---Accu-----pre----rec---------> 0.6129 ± 0.0528  0.6278 ± 0.0492  0.6129 ± 0.0528


In [76]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18_35.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/mapi_sce/ressecondary/model_10/model_B_mapi_10-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [77]:

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

           
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

            
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [78]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.5271317829457365
Evaluated Accuracy: 0.5271
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.20484496124031004
Evaluated Accuracy: 0.2048
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.2718346253229974
Evaluated Accuracy: 0.2718
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.6260335917312662
Evaluated Accuracy: 0.6260
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.5808139534883722
Evaluated Accuracy: 0.5808
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.15729974160206722
Evaluated Accuracy: 0.1573
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.6549




----------------------> 0.6393410852713178
Evaluated Accuracy: 0.6393
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.7532




----------------------> 0.6220284237726098
Evaluated Accuracy: 0.6220
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.6106589147286823
Evaluated Accuracy: 0.6107
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.4056




----------------------> 0.6268087855297159
Evaluated Accuracy: 0.6268
Testing b_a_r = 0.0611, c_a_r = 0.0548, f_a_r = 0.4558




----------------------> 0.6344961240310077
Evaluated Accuracy: 0.6345
Testing b_a_r = 0.1539, c_a_r = 0.6651, f_a_r = 1.0000




----------------------> 0.15574935400516798
Evaluated Accuracy: 0.1557
Testing b_a_r = 1.0000, c_a_r = 0.0282, f_a_r = 0.4693




----------------------> 0.6349483204134369
Evaluated Accuracy: 0.6349
Testing b_a_r = 0.1569, c_a_r = 0.0824, f_a_r = 0.0000




----------------------> 0.6124031007751937
Evaluated Accuracy: 0.6124
Testing b_a_r = 0.2035, c_a_r = 0.0970, f_a_r = 0.4011




----------------------> 0.6263565891472866
Evaluated Accuracy: 0.6264
Testing b_a_r = 0.0076, c_a_r = 0.0525, f_a_r = 0.2878




----------------------> 0.6208010335917314
Evaluated Accuracy: 0.6208
Testing b_a_r = 0.3956, c_a_r = 0.7753, f_a_r = 0.0057




----------------------> 0.16149870801033595
Evaluated Accuracy: 0.1615
Testing b_a_r = 1.0000, c_a_r = 0.0401, f_a_r = 0.0000




----------------------> 0.6147932816537468
Evaluated Accuracy: 0.6148
Testing b_a_r = 0.1751, c_a_r = 0.0388, f_a_r = 0.5893




----------------------> 0.6434754521963825
Evaluated Accuracy: 0.6435
Testing b_a_r = 0.1492, c_a_r = 0.3714, f_a_r = 0.9969




----------------------> 0.24257105943152454
Evaluated Accuracy: 0.2426
Testing b_a_r = 0.6375, c_a_r = 0.9990, f_a_r = 0.9989




----------------------> 0.15717054263565894
Evaluated Accuracy: 0.1572
Testing b_a_r = 0.6264, c_a_r = 0.1068, f_a_r = 0.9985




----------------------> 0.572545219638243
Evaluated Accuracy: 0.5725
Testing b_a_r = 0.0081, c_a_r = 0.0970, f_a_r = 0.6010




----------------------> 0.6374677002583979
Evaluated Accuracy: 0.6375
Testing b_a_r = 0.0102, c_a_r = 0.0544, f_a_r = 0.5954




----------------------> 0.6441860465116277
Evaluated Accuracy: 0.6442
Testing b_a_r = 0.7342, c_a_r = 0.9968, f_a_r = 0.0033




----------------------> 0.14709302325581397
Evaluated Accuracy: 0.1471
Testing b_a_r = 0.0464, c_a_r = 0.0427, f_a_r = 0.8073




----------------------> 0.60671834625323
Evaluated Accuracy: 0.6067
Testing b_a_r = 0.5813, c_a_r = 0.1657, f_a_r = 0.0042




----------------------> 0.5501291989664082
Evaluated Accuracy: 0.5501
Testing b_a_r = 0.9808, c_a_r = 0.0668, f_a_r = 0.5532




----------------------> 0.6386950904392763
Evaluated Accuracy: 0.6387
Testing b_a_r = 0.5852, c_a_r = 0.0002, f_a_r = 0.5439




----------------------> 0.6400516795865633
Evaluated Accuracy: 0.6401
Testing b_a_r = 0.0112, c_a_r = 0.1669, f_a_r = 0.4095




----------------------> 0.5753875968992248
Evaluated Accuracy: 0.5754
Testing b_a_r = 0.0862, c_a_r = 0.1126, f_a_r = 0.1815




----------------------> 0.6062661498708011
Evaluated Accuracy: 0.6063
Testing b_a_r = 0.3891, c_a_r = 0.0008, f_a_r = 0.1844




----------------------> 0.6129198966408268
Evaluated Accuracy: 0.6129
Testing b_a_r = 0.0654, c_a_r = 0.6101, f_a_r = 0.4416




----------------------> 0.16188630490956074
Evaluated Accuracy: 0.1619
Testing b_a_r = 0.6681, c_a_r = 0.8311, f_a_r = 0.7533




----------------------> 0.15904392764857883
Evaluated Accuracy: 0.1590
Testing b_a_r = 0.0378, c_a_r = 0.1030, f_a_r = 0.7432




----------------------> 0.6190568475452197
Evaluated Accuracy: 0.6191
Testing b_a_r = 0.0872, c_a_r = 0.6104, f_a_r = 0.0081




----------------------> 0.17609819121447026
Evaluated Accuracy: 0.1761
Testing b_a_r = 0.9992, c_a_r = 0.0240, f_a_r = 0.6151




----------------------> 0.6439922480620156
Evaluated Accuracy: 0.6440
Testing b_a_r = 0.9784, c_a_r = 0.1189, f_a_r = 0.5725




----------------------> 0.625581395348837
Evaluated Accuracy: 0.6256
Testing b_a_r = 0.0115, c_a_r = 0.0168, f_a_r = 0.5925




----------------------> 0.6443152454780362
Evaluated Accuracy: 0.6443
Testing b_a_r = 0.2099, c_a_r = 0.2001, f_a_r = 0.9962




----------------------> 0.484560723514212
Evaluated Accuracy: 0.4846
Testing b_a_r = 0.1020, c_a_r = 0.5163, f_a_r = 0.9966




----------------------> 0.166795865633075
Evaluated Accuracy: 0.1668
Testing b_a_r = 0.1892, c_a_r = 0.0527, f_a_r = 0.1323




----------------------> 0.6162790697674418
Evaluated Accuracy: 0.6163
Testing b_a_r = 0.0038, c_a_r = 0.3817, f_a_r = 0.5534




----------------------> 0.2361111111111111
Evaluated Accuracy: 0.2361
Testing b_a_r = 0.0073, c_a_r = 0.0357, f_a_r = 0.5389




----------------------> 0.6419250645994833
Evaluated Accuracy: 0.6419
Testing b_a_r = 0.8834, c_a_r = 0.0757, f_a_r = 0.6257




----------------------> 0.6386304909560725
Evaluated Accuracy: 0.6386
Testing b_a_r = 0.9956, c_a_r = 0.0008, f_a_r = 0.5586




----------------------> 0.6427002583979327
Evaluated Accuracy: 0.6427
Testing b_a_r = 0.0396, c_a_r = 0.0803, f_a_r = 0.5135




----------------------> 0.6381136950904392
Evaluated Accuracy: 0.6381
Testing b_a_r = 0.9632, c_a_r = 0.0288, f_a_r = 0.6375




----------------------> 0.640891472868217
Evaluated Accuracy: 0.6409
Testing b_a_r = 0.9988, c_a_r = 0.0020, f_a_r = 0.6173




----------------------> 0.6443798449612402
Evaluated Accuracy: 0.6444
Testing b_a_r = 0.9844, c_a_r = 0.0778, f_a_r = 0.2008




----------------------> 0.6161498708010336
Evaluated Accuracy: 0.6161

✅ Optimal Values:
   - b_a_r: 0.9988
   - c_a_r: 0.0020
   - f_a_r: 0.6173
📈 Highest Accuracy Achieved: 0.6444


In [79]:
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_mapii/result_ressnet/base/12_w_10_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout =  ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()

xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/secondary/12_w_10_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_mapii/result_ressnet/amul/12_w_10_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.6107 ± 0.0480  0.6174 ± 0.0490  0.6107 ± 0.0480




---Accu-----pre----rec---------> 0.3665 ± 0.0401  0.3719 ± 0.0368  0.3665 ± 0.0401
---Accu-----pre----rec---------> 0.6444 ± 0.0512  0.6542 ± 0.0516  0.6444 ± 0.0512
