In [2]:
import torch
import torch.nn as nn
# !pip install easyfsl
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

class PrototypicalNetworks_dynamic_query(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks_dynamic_query, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images,
        support_labels,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        num_classes=len(list(support_labels))
        prototypes = []
        for class_label in range(num_classes):
            # print(type(support_images[class_label]),torch.cat(support_images[class_label].shape)
            support_images_eachclass=(support_images[class_label])
            # print("kk", support_images_eachclass.shape)
            class_features = self.backbone(support_images_eachclass.to(device))  # Select features for the current class
            class_prototype = class_features.mean(dim=0)  # Compute the mean along the batch dimension
            prototypes.append(class_prototype)
        prototypes = torch.stack(prototypes)


        distances=[]
        for each_query_class in query_images:

          query_features = self.backbone(each_query_class)

          # Compute the distance between query features and prototypes
          distance = torch.cdist(query_features, prototypes)
          distances.append(-distance)

        # print(distances.shape)

        return distances


In [1]:
root_path = '/home/asufian/Desktop/output_olchiki/olchiki'

In [3]:
import os
import random
from collections import defaultdict
from PIL import Image, ImageOps
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class PrototypicalOmniglotDataset(Dataset):
    def __init__(self, root, num_classes=1623, n_shot=5, n_query=10, transform=None):
        self.root = root
        self.num_classes = num_classes
        self.n_shot = n_shot
        self.n_query = n_query
        self.transform = transform
        self.samples_by_label = defaultdict(list)
        self.all_imgs = {}
        self.classes = []

        # Common image file extensions
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp']

        # Organize data into classes and image paths
        for alphabet in os.listdir(self.root):
            alphabet_path = os.path.join(self.root, alphabet)
            if not os.path.isdir(alphabet_path):
                continue

            for char_class in os.listdir(alphabet_path):
                char_class_path = os.path.join(alphabet_path, char_class)

                # Skip if the entry is not a directory
                if not os.path.isdir(char_class_path):
                    continue

                all_images = []
                for img_name in sorted(os.listdir(char_class_path)):
                    img_path = os.path.join(char_class_path, img_name)

                    # Check if it's a file with a recognized image extension
                    if os.path.isfile(img_path) and any(img_name.lower().endswith(ext) for ext in image_extensions):
                        all_images.append(img_path)

                if all_images:
                    char_class_name = f"{alphabet}_{char_class}"
                    self.samples_by_label[char_class_name] = list(range(len(all_images)))
                    self.all_imgs[char_class_name] = all_images
                    self.classes.append(char_class_name)
                    # print(char_class_name)

    def transform_image(self, raw_img):
        img = raw_img#ImageOps.invert(raw_img)
        if self.transform is not None:
            # print(self.transform)
            img = self.transform(img)
        # print(img.shape)
        return img

    def __getitem__(self, selected_classes,selected_supports):#
        # selected_classes = random.sample(self.classes, self.num_classes)
        class_indices = [self.samples_by_label[each_cls] for each_cls in selected_classes]
        # print("selected_classes: ",selected_classes)
        support_set = []
        query_set = []
        label_id = 0
        qs=[]
        ko=0
        ghj=[]
        for idx_set in class_indices:
            # Creating support set
            selected_support =selected_supports[ko]# random.sample(idx_set, self.n_shot)
            ko=ko+1
            ghj.append(selected_support)
            # print("selected_support images: ", ghj)
            support_images = [self.transform_image(Image.open(self.all_imgs[selected_classes[label_id]][each]).convert('L')) for each in selected_support]

            # Creating query set
            selected_query = [item for item in idx_set if item not in selected_support]
            query_images = [self.transform_image(Image.open(self.all_imgs[selected_classes[label_id]][each]).convert('L')) for each in selected_query]
            support_set.append((support_images, [label_id for _ in range(self.n_shot)]))
            # print("aaaaaaaaaa",label_id)
            query_set.append((query_images,  [label_id for _ in range(len(selected_query))]))

            label_id += 1
            # qs= [item for item in a for _ in range(self.n_query)]
        return support_set, query_set

    def __len__(self):
        return 10

# Set up the transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.ToTensor(),
])

# Set up the SiameseOmniglotDataset
# Training == evaluation
# root_path = './assamese'
num_classes = 5  # Set the desired number of classes per episode
n_shot = 5
n_query = 10

PrototypicalOmniglotDatasetLoader = PrototypicalOmniglotDataset(root=root_path, num_classes=num_classes, n_shot=n_shot, n_query=n_query, transform=transform)


In [4]:
import torch
import torch.nn as nn

class PrototypicalNetworks_dynamic_query(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks_dynamic_query, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images,
        support_labels,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        num_classes=len(list(support_labels))
        prototypes = []
        for class_label in range(num_classes):
            # print(type(support_images[class_label]),torch.cat(support_images[class_label].shape)
            support_images_eachclass=(support_images[class_label])
            # print("kk", support_images_eachclass.shape)
            class_features = self.backbone(support_images_eachclass.to(device))  # Select features for the current class
            class_prototype = class_features.mean(dim=0)  # Compute the mean along the batch dimension
            prototypes.append(class_prototype)
        prototypes = torch.stack(prototypes)


        distances=[]
        for each_query_class in query_images:

          query_features = self.backbone(each_query_class)

          # Compute the distance between query features and prototypes
          distance = torch.cdist(query_features, prototypes)
          distances.append(-distance)

        # print(distances.shape)

        return distances


In [5]:
class PrototypicalNetworks33(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks33, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images,
        support_labels,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        num_classes = len(support_labels)
        prototypes = []
        for class_label in range(num_classes):
            support_images_eachclass = support_images[class_label]
            class_features = torch.stack([self.backbone(each.unsqueeze(0)) for each in support_images_eachclass])
            prototype_bymean = class_features.mean(dim=0)
            # print(prototype_bymean.shape)
            prototypes.append(prototype_bymean)

        query_features = torch.stack([self.backbone(each.unsqueeze(0)) for each in query_images])
        # print(query_features.shape,query_features[0].shape,torch.stack(prototypes).shape,torch.stack(prototypes)[0].shape)
        # distances = -torch.cdist(query_features, torch.stack(prototypes))
        distances = []  # Initialize list to store distances
        prototypes=torch.stack(prototypes)
        for each_query_feature in query_features:
            d = torch.cdist(prototypes, each_query_feature)  # Compute distances
            distances.append(d.squeeze())  # Remove singleton dimension and store
            # print(d.shape)  # Output: [5, 1], shape of each distance matrix

        distances = torch.stack(distances)

        return -distances


In [45]:
import pandas as pd
import os
import csv
def save_in_csv(cls,flnm,conf_matrix, acciuracy,precision_overall, recall_overall, f_betas):
  with open(flnm, 'w', newline='') as csvfile:
      writer = csv.writer(csvfile)
      headers=['Accuracy', 'Precision', 'Recall']
      beta=-5
      values=[]
      values.append(acciuracy)
      values.append(precision_overall)
      values.append(recall_overall)
      for each in f_betas:
              headers.append('F-betas (' +str(beta)+' )')
              beta=beta+1
              values.append(each)

    # Write confusion matrix
      writer.writerow(['used Class:'])
      writer.writerows([cls])
      writer.writerow([])
      writer.writerow([])
      writer.writerow(['Confusion Matrix'])
      writer.writerows(conf_matrix)
      writer.writerow([])
      writer.writerow([])  # Add an empty row for separation

      # Write metrics

      writer.writerow(headers)
      writer.writerow(values)


  # print("Metrics have been written to 'metrics.csv' file.")
def save_in_text_file(msg,cls, flnm, conf_matrix, accuracy, precision_overall, recall_overall, f_betas):
    with open(flnm, 'a') as textfile:
        textfile.write(str("\n\n\n                                 "+msg))
        textfile.write("\n\n\n")
        textfile.write("used Class:\n")
        textfile.write(str(cls) + "\n\n")

        textfile.write("Confusion Matrix:\n")
        for row in conf_matrix:
            textfile.write(' '.join([str(elem) for elem in row]) + "\n")
        textfile.write("\n")

        textfile.write("Metrics:\n")
        textfile.write("Accuracy: {}\n".format(accuracy))
        textfile.write("Precision: {}\n".format(precision_overall))
        textfile.write("Recall: {}\n".format(recall_overall))

        beta = -5
        for f_beta in f_betas:
            textfile.write("F-betas ({}): {}\n".format(beta, f_beta))
            beta += 1

    # print("Metrics have been written to '{}' file.".format(flnm))

def save_in_excel_final(msg, cls, flnm, conf_matrix, accuracy, precision_overall, recall_overall):
    data = {
        "Used Class": [cls],
        "Used support images idx": [msg],
        "Confusion Matrix": [conf_matrix],
        "Accuracy": [accuracy],
        "Precision": [precision_overall],
        "Recall": [recall_overall]
    }
    # for i, f_beta in enumerate(f_betas):
    #     data[f"F-beta ({i - 5})"] = [f_beta]

    df_new = pd.DataFrame(data)
    df_new = df_new.transpose()  # Transpose the DataFrame

    if os.path.exists(flnm):
        df_existing = pd.read_excel(flnm, index_col=0)  # Read existing file
        df_combined = pd.concat([df_existing, df_new], axis=1)  # Concatenate along columns
        df_combined.to_excel(flnm)  # Write to Excel
    else:
        df_new.to_excel(flnm)  # Write new DataFrame to Excel

    # print("Metrics have been written to '{}' file.".format(flnm))


In [8]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch
import torch.nn as nn
import torch.hub

# Define a modified version of DenseNet169 with dropout after each convolutional layer
class DenseNet169WithDropout(nn.Module):
    def __init__(self,modell, pretrained=False, dr=0.4):
        super(DenseNet169WithDropout, self).__init__()
        self.densenet169 = torch.hub.load('pytorch/vision:v0.10.0', modell, pretrained=pretrained)
        self.dropout = nn.Dropout(p=dr)  # Dropout with probability 0.5

        # Modify each dense block to include dropout after each convolutional layer
        for name, module in self.densenet169.features.named_children():
            if isinstance(module, nn.Sequential):
                for sub_name, sub_module in module.named_children():
                    if isinstance(sub_module, nn.Conv2d):
                        setattr(module, sub_name, nn.Sequential(sub_module, self.dropout))

    def forward(self, x):
        return self.densenet169(x)

# Instantiate the modified DenseNet169 model with dropout
# convolutional_network_with_dropout = DenseNet169WithDropout(modell='resnet18',pretrained=False,dr=0.4)

# Example usage:
# output = convolutional_network_with_dropout(input_tensor)

class ResNet18WithDropout(nn.Module):
    def __init__(self, pretrained=False, dr=0.4):
        super(ResNet18WithDropout, self).__init__()
        self.resnet18 = models.resnet18(pretrained=pretrained)
        self.dropout = nn.Dropout(p=dr)

        # Modify each residual block to include dropout after each convolutional layer
        for name, module in self.resnet18.named_children():
            if isinstance(module, nn.Sequential):
                for sub_name, sub_module in module.named_children():
                    if isinstance(sub_module, nn.Conv2d):
                        setattr(module, sub_name, nn.Sequential(sub_module, self.dropout))

    def forward(self, x):
        return self.resnet18(x)

# Instantiate the modified ResNet18 model with dropout
# convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.4)


In [102]:
clss_5=[
  ['Olchiki_2_80_c28', 'Olchiki_2_80_c1', 'Olchiki_2_80_c25', 'Olchiki_2_80_c17', 'Olchiki_2_80_c21'],
    ['Olchiki_2_80_c17', 'Olchiki_2_80_c5', 'Olchiki_2_80_c20', 'Olchiki_2_80_c16', 'Olchiki_2_80_c6'],
    ['Olchiki_2_80_c7', 'Olchiki_2_80_c30', 'Olchiki_2_80_c17', 'Olchiki_2_80_c15', 'Olchiki_2_80_c21'],
    ['Olchiki_2_80_c11', 'Olchiki_2_80_c26', 'Olchiki_2_80_c7', 'Olchiki_2_80_c25', 'Olchiki_2_80_c4'],
    ['Olchiki_2_80_c26', 'Olchiki_2_80_c9', 'Olchiki_2_80_c18', 'Olchiki_2_80_c1', 'Olchiki_2_80_c22'],
    ['Olchiki_2_80_c30', 'Olchiki_2_80_c29', 'Olchiki_2_80_c11', 'Olchiki_2_80_c6', 'Olchiki_2_80_c19'],
    ['Olchiki_2_80_c26', 'Olchiki_2_80_c3', 'Olchiki_2_80_c15', 'Olchiki_2_80_c5', 'Olchiki_2_80_c7'],
    ['Olchiki_2_80_c10', 'Olchiki_2_80_c26', 'Olchiki_2_80_c28', 'Olchiki_2_80_c29', 'Olchiki_2_80_c25'],
    ['Olchiki_2_80_c18', 'Olchiki_2_80_c11', 'Olchiki_2_80_c1', 'Olchiki_2_80_c24', 'Olchiki_2_80_c29'],
    ['Olchiki_2_80_c13', 'Olchiki_2_80_c14', 'Olchiki_2_80_c16', 'Olchiki_2_80_c24', 'Olchiki_2_80_c12'],
     ['Olchiki_2_80_c16', 'Olchiki_2_80_c22', 'Olchiki_2_80_c20', 'Olchiki_2_80_c26', 'Olchiki_2_80_c12'],
    ['Olchiki_2_80_c3', 'Olchiki_2_80_c22', 'Olchiki_2_80_c18', 'Olchiki_2_80_c26', 'Olchiki_2_80_c19']

     ]

clss_support_imagesss_1_shot=[
 [[24], [37], [78], [65], [37]],
    [[15], [4], [39], [60], [23]],
    [[1], [48], [29], [76], [27]],
    [[71], [58], [35], [24], [3]],
    [[23], [40], [58], [35], [56]],
    [[66], [63], [23], [72], [64]],
    [[21], [45], [37], [8], [63]],
     [[9], [19], [67], [51], [20]],
    [[55], [34], [7], [10], [9]],
    [[47], [25], [48], [10], [54]],
    [[13], [33], [72], [51], [20]],
    [[59], [5], [52], [70], [1]]







]

clss_support_imagesss_5_shot=[
    [[24, 17, 73, 60, 18], [37, 38, 24, 4, 19], [78, 45, 50, 6, 14], [65, 62, 41, 33, 55], [37, 27, 61, 65, 10]],
    [[15, 30, 21, 54, 33], [4, 62, 38, 57, 34], [39, 8, 75, 37, 10], [60, 67, 14, 10, 53], [23, 42, 44, 1, 5]],
    [[1, 63, 32, 67, 5], [48, 63, 25, 13, 22], [29, 17, 12, 30, 75], [76, 49, 24, 18, 30], [27, 56, 4, 41, 66]],
    [[71, 78, 55, 22, 61], [58, 70, 0, 54, 17], [35, 31, 26, 16, 30], [24, 35, 49, 32, 30], [3, 50, 39, 18, 22]],
    [[23, 2, 59, 68, 74], [40, 14, 4, 3, 45], [58, 51, 31, 45, 75], [35, 46, 61, 55, 79], [56, 15, 64, 33, 14]],
    [[66, 61, 29, 8, 39], [63, 77, 69, 36, 72], [23, 56, 28, 52, 31], [72, 20, 41, 73, 49], [64, 48, 53, 13, 8]],
    [[21, 17, 74, 56, 30], [45, 37, 63, 30, 8], [37, 42, 5, 52, 50], [8, 46, 52, 36, 20], [63, 26, 43, 8, 17]],
    [[9, 47, 25, 35, 46], [19, 27, 35, 34, 40], [67, 4, 16, 61, 60], [51, 53, 4, 32, 20], [20, 40, 25, 29, 39]],
    [[55, 44, 15, 8, 62], [34, 9, 3, 16, 46], [7, 76, 52, 3, 16], [10, 11, 34, 37, 77], [9, 41, 67, 6, 36]],
     [[47, 49, 72, 36, 9], [25, 41, 71, 40, 57], [48, 24, 27, 41, 52], [10, 34, 11, 59, 7], [54, 55, 11, 64, 16]],
    [[13, 35, 18, 11, 42], [33, 50, 28, 27, 53], [72, 58, 64, 33, 73], [51, 13, 31, 66, 58], [20, 54, 45, 27, 0]],
     [[59, 73, 74, 0, 47], [5, 20, 27, 3, 71], [52, 0, 2, 38, 60], [70, 1, 30, 64, 6], [1, 34, 8, 56, 49]],
]

clss_support_imagesss_10_shot=[
    [[24, 17, 73, 60, 18, 68, 23, 39, 41, 20], [37, 38, 24, 4, 19, 3, 17, 35, 66, 56], [78, 45, 50, 6, 14, 54, 32, 25, 13, 55], [65, 62, 41, 33, 55, 25, 51, 74, 48, 35], [37, 27, 61, 65, 10, 63, 70, 24, 57, 8]],
    [[15, 30, 21, 54, 33, 28, 48, 36, 9, 37], [4, 62, 38, 57, 34, 59, 32, 9, 3, 68], [39, 8, 75, 37, 10, 27, 79, 26, 13, 6], [60, 67, 14, 10, 53, 71, 33, 42, 1, 0], [23, 42, 44, 1, 5, 32, 46, 36, 13, 34]],
    [[1, 63, 32, 67, 5, 74, 43, 51, 68, 8], [48, 63, 25, 13, 22, 4, 1, 68, 62, 38], [29, 17, 12, 30, 75, 31, 67, 11, 68, 37], [76, 49, 24, 18, 30, 6, 38, 26, 15, 7], [27, 56, 4, 41, 66, 28, 73, 50, 29, 26]],
     [[71, 78, 55, 22, 61, 69, 79, 33, 24, 13], [58, 70, 0, 54, 17, 62, 47, 7, 60, 78], [35, 31, 26, 16, 30, 2, 44, 14, 10, 12], [24, 35, 49, 32, 30, 48, 14, 72, 29, 68], [3, 50, 39, 18, 22, 63, 25, 77, 47, 26]],
     [[23, 2, 59, 68, 74, 36, 41, 45, 72, 27], [40, 14, 4, 3, 45, 53, 21, 77, 13, 12], [58, 51, 31, 45, 75, 9, 44, 40, 61, 19], [35, 46, 61, 55, 79, 12, 76, 0, 47, 33], [56, 15, 64, 33, 14, 2, 26, 25, 8, 20]],
    [[66, 61, 29, 8, 39, 53, 59, 10, 25, 24], [63, 77, 69, 36, 72, 15, 60, 58, 16, 49], [23, 56, 28, 52, 31, 2, 6, 21, 7, 68], [72, 20, 41, 73, 49, 35, 21, 34, 29, 50], [64, 48, 53, 13, 8, 54, 11, 9, 29, 65]],
    [[21, 17, 74, 56, 30, 23, 64, 38, 67, 1], [45, 37, 63, 30, 8, 12, 70, 4, 1, 75], [37, 42, 5, 52, 50, 76, 43, 46, 22, 61], [8, 46, 52, 36, 20, 5, 1, 31, 45, 67], [63, 26, 43, 8, 17, 22, 16, 54, 70, 32]],
     [[9, 47, 25, 35, 46, 7, 2, 27, 49, 69], [19, 27, 35, 34, 40, 53, 16, 7, 64, 45], [67, 4, 16, 61, 60, 79, 12, 70, 63, 42], [51, 53, 4, 32, 20, 1, 68, 15, 55, 10], [20, 40, 25, 29, 39, 43, 3, 23, 38, 0]],
    [[55, 44, 15, 8, 62, 77, 58, 71, 69, 30], [34, 9, 3, 16, 46, 2, 8, 76, 23, 32], [7, 76, 52, 3, 16, 28, 19, 18, 41, 26], [10, 11, 34, 37, 77, 48, 9, 23, 4, 71], [9, 41, 67, 6, 36, 56, 49, 10, 53, 45]],
    [[47, 49, 72, 36, 9, 61, 56, 40, 30, 68], [25, 41, 71, 40, 57, 44, 37, 45, 56, 3], [48, 24, 27, 41, 52, 58, 32, 74, 9, 12], [10, 34, 11, 59, 7, 30, 24, 68, 69, 20], [54, 55, 11, 64, 16, 51, 33, 22, 39, 35]],
     [[13, 35, 18, 11, 42, 77, 59, 66, 71, 74], [33, 50, 28, 27, 53, 70, 0, 22, 2, 52], [72, 58, 64, 33, 73, 54, 78, 53, 49, 22], [51, 13, 31, 66, 58, 4, 18, 44, 32, 52], [20, 54, 45, 27, 0, 39, 77, 9, 67, 29]],
    [[59, 73, 74, 0, 47, 42, 39, 19, 48, 75], [5, 20, 27, 3, 71, 10, 19, 22, 11, 25], [52, 0, 2, 38, 60, 6, 25, 37, 46, 18], [70, 1, 30, 64, 6, 29, 50, 8, 34, 38], [1, 34, 8, 56, 49, 53, 36, 5, 15, 32]],

]



In [49]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")




convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/olchiki/ressecondary/model_1/model_B_olchiki_1-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')

In [25]:
# !ls '/home/asufian/Desktop/output_olchiki/'

In [50]:
criterion = nn.CrossEntropyLoss()

In [51]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import numpy as np
import numpy as np
import pandas as pd
# root_path = '/content/assamese'
num_classes = 5
n_shot = 1
n_query = 14
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformt_ugh = transforms.Compose([
    transforms.Resize((64, 64)),
    # transforms.RandomResizedCrop(64),  # Randomly crop the image and resize to 224x224
    # transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),  # Randomly adjust brightness, contrast, saturation, and hue
    # transforms.RandomAffine(degrees=0, translate=(0, 0), scale=(0.9, 1.1), shear=0),  # Random zoom (scaling)
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),  # Convert PIL Image to PyTorch Tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize the image with ImageNet statistics
])

print(f"Generated random values - num_classes: {num_classes}, n_shot: {n_shot}, n_query: {n_query}")

test_loader = PrototypicalOmniglotDataset(root=root_path, num_classes=num_classes, n_shot=n_shot, n_query=20, transform=transformt_ugh)

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    precision=0
    recall=0
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            # msg="\n\n       Way: "+str(len(clss_5[tlv_cls]))+"  Shot: "+str(len(clss_support_imagesss_1_shot[tlv_cls][0]))    +" \n\n"
            # msg=str(clss_5[tlv_cls])
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision=precision+precision_overall
            recall=recall+recall_overall
    print("---------------------->",ttlac/ttla,precision/ttla ,recall/ttla)
# %%%%%%%%%%%%%%%%%%%


pp='/home/asufian/Desktop/output_olchiki/faltu'






Generated random values - num_classes: 5, n_shot: 1, n_query: 14


In [52]:
import numpy as np

conf_matrix = np.array([[100, 0, 0, 0, 0],
                        [0, 100, 0, 0, 0],
                        [0, 0, 100, 0, 0],
                        [0, 0, 0, 100, 0],
                        [0, 0, 0, 0, 100]])

def calculate_metrics_get_per(msg, conf_matrix,acciuracy, cls, flnm):
    if conf_matrix.size == 0:
        raise ValueError("Confusion matrix is empty!")

    metrics = []
    for i in range(conf_matrix.shape[0]):
        tp = conf_matrix[i, i]
        fp = np.sum(conf_matrix[:, i]) - tp
        fn = np.sum(conf_matrix[i, :]) - tp

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        metrics.append((precision, recall, f1))

    # Compute overall scores
    precision_overall = np.mean([m[0] for m in metrics])
    recall_overall = np.mean([m[1] for m in metrics])
    f1_overall = np.mean([m[2] for m in metrics])

    # Compute overall accuracy
    total_samples = np.sum(conf_matrix)
    correct_predictions = np.sum(np.diag(conf_matrix))
    accuracy_overall = correct_predictions / total_samples if total_samples > 0 else 0
    # print(accuracy_overall)

    # Save results (Assuming these functions exist)
    save_in_excel_final(msg, cls, flnm + ".xlsx", conf_matrix, accuracy_overall, precision_overall, recall_overall)
    # save_in_text_file(msg, cls, flnm + ".txt", conf_matrix, accuracy_overall, precision_overall, recall_overall, f_betas)

    return precision_overall, recall_overall, f1_overall


# precision, recall, f_beta = calculate_metrics(conf_matrix)
# print(f'Overall Metrics for beta={beta}:')
# print(f'Overall Precision: {precision:.4f}')
# print(f'Overall Recall: {recall:.4f}')
# print(f'Overall F-beta ({f_beta}): ')
# print()


In [53]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

           
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [54]:
def update_model_weights2(model_a, model_b, conv_ratio=0.8, fc_ratio=0.8, bias_ratio=0.8):
    """Updates model_a's weights using a weighted combination of model_a and model_b."""
    for name, param in model_a.named_parameters():
        if name in model_b.state_dict():
            weight_b = model_b.state_dict()[name]
            
            if 'weight' in name:
                # Separate conditions for convolutional and fully connected layers
                if 'conv' in name:  # Convolutional layer weights
                    # print('conv')
                    param.data = conv_ratio * weight_b + (1 - conv_ratio) * param.data
                elif 'fc' in name or 'linear' in name:  # Fully connected layer weights
                    # print('fc')
                    param.data = fc_ratio * weight_b + (1 - fc_ratio) * param.data
            
            elif 'bias' in name:
                param.data = bias_ratio * weight_b + (1 - bias_ratio) * param.data

In [55]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Function to get model accuracy
def get_model_accuracy(b_a_r, c_a_r, f_a_r):
    try:
        print(f"Testing b_a_r = {b_a_r:.4f}, c_a_r = {c_a_r:.4f}, f_a_r = {f_a_r:.4f}")  # Debugging Output

        convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
        convolutional_network_with_dropout.fc = nn.Flatten()
        xyz = convolutional_network_with_dropout
        M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
        M3.load_state_dict(torch.load('model_own_path.pt'))

        # Load Model M2 (Reference)
        convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
        convolutional_network_with_dropout.fc = nn.Flatten()
        xyz2 = convolutional_network_with_dropout
        M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
        M2.load_state_dict(torch.load('model_mu_path.pt'))

        # Apply bias adaptation to M3 using M2
        update_model_weights2(M3, M2, conv_ratio=c_a_r, fc_ratio=f_a_r, bias_ratio=b_a_r)
        
        # Evaluate accuracy
        accuracy = evaluate3(pp+'/_olchiki_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)
        
        if accuracy is None:
            raise ValueError("Accuracy computation failed. Returning worst case.")

        return accuracy  # Ensure accuracy is valid
    except Exception as e:
        print(f"Error encountered: {e}. Assigning worst-case accuracy of 0.")
        return 0.0  # Assign lowest accuracy to prevent optimization failures


In [31]:

# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.709282700421941
Evaluated Accuracy: 0.7093
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.6369198312236287
Evaluated Accuracy: 0.6369
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.6913502109704642
Evaluated Accuracy: 0.6914
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.7493670886075949
Evaluated Accuracy: 0.7494
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7261603375527427
Evaluated Accuracy: 0.7262
Testing b_a_r = 0.2729, c_a_r = 0.0992, f_a_r = 0.0805




----------------------> 0.7147679324894515
Evaluated Accuracy: 0.7148
Testing b_a_r = 0.2213, c_a_r = 1.0000, f_a_r = 0.9428




----------------------> 0.4316455696202531
Evaluated Accuracy: 0.4316
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.6845991561181436
Evaluated Accuracy: 0.6846
Testing b_a_r = 0.5807, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.6837552742616034
Evaluated Accuracy: 0.6838
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.5379




----------------------> 0.7158227848101264
Evaluated Accuracy: 0.7158
Testing b_a_r = 0.0000, c_a_r = 0.0834, f_a_r = 1.0000




----------------------> 0.75
Evaluated Accuracy: 0.7500
Testing b_a_r = 0.0376, c_a_r = 0.3473, f_a_r = 1.0000




----------------------> 0.6390295358649789
Evaluated Accuracy: 0.6390
Testing b_a_r = 0.0263, c_a_r = 0.7089, f_a_r = 0.0000




----------------------> 0.4525316455696202
Evaluated Accuracy: 0.4525
Testing b_a_r = 0.9104, c_a_r = 0.2182, f_a_r = 0.0014




----------------------> 0.7248945147679325
Evaluated Accuracy: 0.7249
Testing b_a_r = 0.0740, c_a_r = 0.0958, f_a_r = 0.7043




----------------------> 0.7533755274261603
Evaluated Accuracy: 0.7534
Testing b_a_r = 0.4247, c_a_r = 0.5644, f_a_r = 0.9947




----------------------> 0.49978902953586496
Evaluated Accuracy: 0.4998
Testing b_a_r = 0.4329, c_a_r = 0.9905, f_a_r = 0.0003




----------------------> 0.4088607594936709
Evaluated Accuracy: 0.4089
Testing b_a_r = 0.0018, c_a_r = 0.1504, f_a_r = 0.3743




----------------------> 0.7512658227848101
Evaluated Accuracy: 0.7513
Testing b_a_r = 0.0000, c_a_r = 0.0800, f_a_r = 1.0000




----------------------> 0.7487341772151899
Evaluated Accuracy: 0.7487
Testing b_a_r = 0.6865, c_a_r = 0.0993, f_a_r = 0.4827




----------------------> 0.7527426160337553
Evaluated Accuracy: 0.7527
Testing b_a_r = 0.0058, c_a_r = 0.2149, f_a_r = 0.3343




----------------------> 0.7445147679324894
Evaluated Accuracy: 0.7445
Testing b_a_r = 0.7525, c_a_r = 0.1733, f_a_r = 0.2161




----------------------> 0.7341772151898733
Evaluated Accuracy: 0.7342
Testing b_a_r = 1.0000, c_a_r = 0.0685, f_a_r = 0.5174




----------------------> 0.7409282700421941
Evaluated Accuracy: 0.7409
Testing b_a_r = 0.0000, c_a_r = 0.1029, f_a_r = 0.5818




----------------------> 0.7571729957805907
Evaluated Accuracy: 0.7572
Testing b_a_r = 0.0112, c_a_r = 0.0567, f_a_r = 0.8664




----------------------> 0.7550632911392405
Evaluated Accuracy: 0.7551
Testing b_a_r = 0.9689, c_a_r = 0.0750, f_a_r = 0.8703




----------------------> 0.7611814345991562
Evaluated Accuracy: 0.7612
Testing b_a_r = 0.9827, c_a_r = 0.1054, f_a_r = 0.7541




----------------------> 0.7611814345991562
Evaluated Accuracy: 0.7612
Testing b_a_r = 0.9169, c_a_r = 0.1354, f_a_r = 0.5615




----------------------> 0.7626582278481013
Evaluated Accuracy: 0.7627
Testing b_a_r = 0.9780, c_a_r = 0.3268, f_a_r = 0.5227




----------------------> 0.6523206751054853
Evaluated Accuracy: 0.6523
Testing b_a_r = 0.9296, c_a_r = 0.1780, f_a_r = 0.4858




----------------------> 0.740506329113924
Evaluated Accuracy: 0.7405
Testing b_a_r = 0.0092, c_a_r = 0.0971, f_a_r = 0.7504




----------------------> 0.7533755274261603
Evaluated Accuracy: 0.7534
Testing b_a_r = 0.1064, c_a_r = 0.7952, f_a_r = 0.9940




----------------------> 0.4432489451476793
Evaluated Accuracy: 0.4432
Testing b_a_r = 0.0018, c_a_r = 0.2523, f_a_r = 0.0065




----------------------> 0.6795358649789028
Evaluated Accuracy: 0.6795
Testing b_a_r = 0.7420, c_a_r = 0.1205, f_a_r = 0.9966




----------------------> 0.7453586497890295
Evaluated Accuracy: 0.7454
Testing b_a_r = 0.9813, c_a_r = 0.3830, f_a_r = 0.0000




----------------------> 0.6959915611814348
Evaluated Accuracy: 0.6960
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 1.0000




----------------------> 0.7238396624472573
Evaluated Accuracy: 0.7238
Testing b_a_r = 0.0000, c_a_r = 0.1642, f_a_r = 0.5552




----------------------> 0.7470464135021097
Evaluated Accuracy: 0.7470
Testing b_a_r = 0.7511, c_a_r = 0.0621, f_a_r = 0.9957




----------------------> 0.7556962025316456
Evaluated Accuracy: 0.7557
Testing b_a_r = 0.9643, c_a_r = 0.1087, f_a_r = 0.6230




----------------------> 0.759493670886076
Evaluated Accuracy: 0.7595
Testing b_a_r = 0.1151, c_a_r = 0.4326, f_a_r = 1.0000




----------------------> 0.6388185654008439
Evaluated Accuracy: 0.6388
Testing b_a_r = 0.5576, c_a_r = 0.5575, f_a_r = 0.0002




----------------------> 0.5050632911392404
Evaluated Accuracy: 0.5051
Testing b_a_r = 0.8976, c_a_r = 0.8592, f_a_r = 0.0014




----------------------> 0.43797468354430374
Evaluated Accuracy: 0.4380
Testing b_a_r = 0.1330, c_a_r = 0.2536, f_a_r = 0.9830




----------------------> 0.6200421940928269
Evaluated Accuracy: 0.6200
Testing b_a_r = 0.9848, c_a_r = 0.0300, f_a_r = 0.9727




----------------------> 0.7474683544303797
Evaluated Accuracy: 0.7475
Testing b_a_r = 0.0143, c_a_r = 0.1279, f_a_r = 0.7010




----------------------> 0.7546413502109705
Evaluated Accuracy: 0.7546
Testing b_a_r = 0.0745, c_a_r = 0.6770, f_a_r = 0.9955




----------------------> 0.45105485232067505
Evaluated Accuracy: 0.4511
Testing b_a_r = 0.0572, c_a_r = 0.3867, f_a_r = 0.5196




----------------------> 0.7067510548523207
Evaluated Accuracy: 0.7068
Testing b_a_r = 0.9587, c_a_r = 0.1266, f_a_r = 0.3702




----------------------> 0.7506329113924051
Evaluated Accuracy: 0.7506
Testing b_a_r = 0.0916, c_a_r = 0.0444, f_a_r = 0.0036




----------------------> 0.6991561181434599
Evaluated Accuracy: 0.6992
Testing b_a_r = 0.1036, c_a_r = 0.9006, f_a_r = 0.9947




----------------------> 0.4198312236286921
Evaluated Accuracy: 0.4198

✅ Optimal Values:
   - b_a_r: 0.9169
   - c_a_r: 0.1354
   - f_a_r: 0.5615
📈 Highest Accuracy Achieved: 0.7627


In [56]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_olchiki/resnet_resluts/base/5_w_1_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/secondary/5_w_1_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/amul/5_w_1_s_f", test_loader, M3, criterion)


# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
# evaluate3(pp+'/_olchiki_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)

---Accu-----pre----rec---------> 0.6846 ± 0.0834  0.7269 ± 0.0951  0.6846 ± 0.0834




---Accu-----pre----rec---------> 0.4696 ± 0.0886  0.5225 ± 0.0912  0.4696 ± 0.0886
---Accu-----pre----rec---------> 0.7627 ± 0.0978  0.8044 ± 0.0788  0.7627 ± 0.0978


In [35]:
!mkdir '/home/asufian/Desktop/output_olchiki/resnet_resluts/base'

In [21]:
!ls /home/asufian/Desktop/output_olchiki/code/olchiki/model_5/model_B_olchiki_1-shot_30.pth

model_B_olchiki_1-shot_20.pth  model_B_olchiki_1-shot_80.pth
model_B_olchiki_1-shot_30.pth  model_B_olchiki_1-shot_90.pth
model_B_olchiki_1-shot_50.pth


In [None]:
########################  5   way    5  shot    #########################

In [103]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/olchiki/ressecondary/model_5/model_B_olchiki_5-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [104]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

           
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

            
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [105]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.8915555555555555
Evaluated Accuracy: 0.8916
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.6746666666666666
Evaluated Accuracy: 0.6747
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.7553333333333333
Evaluated Accuracy: 0.7553
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.8220000000000001
Evaluated Accuracy: 0.8220
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7737777777777778
Evaluated Accuracy: 0.7738
Testing b_a_r = 0.2849, c_a_r = 0.9362, f_a_r = 0.8249




----------------------> 0.39888888888888885
Evaluated Accuracy: 0.3989
Testing b_a_r = 0.0000, c_a_r = 0.1520, f_a_r = 0.0000




----------------------> 0.8295555555555554
Evaluated Accuracy: 0.8296
Testing b_a_r = 0.0000, c_a_r = 0.1871, f_a_r = 1.0000




----------------------> 0.8804444444444445
Evaluated Accuracy: 0.8804
Testing b_a_r = 1.0000, c_a_r = 0.1743, f_a_r = 1.0000




----------------------> 0.8857777777777777
Evaluated Accuracy: 0.8858
Testing b_a_r = 1.0000, c_a_r = 0.2652, f_a_r = 1.0000




----------------------> 0.7804444444444445
Evaluated Accuracy: 0.7804
Testing b_a_r = 0.0000, c_a_r = 0.1521, f_a_r = 0.7508




----------------------> 0.9024444444444445
Evaluated Accuracy: 0.9024
Testing b_a_r = 0.0391, c_a_r = 0.6947, f_a_r = 0.0000




----------------------> 0.5213333333333333
Evaluated Accuracy: 0.5213
Testing b_a_r = 0.0949, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.7597777777777778
Evaluated Accuracy: 0.7598
Testing b_a_r = 0.5598, c_a_r = 0.5713, f_a_r = 0.9978




----------------------> 0.48266666666666663
Evaluated Accuracy: 0.4827
Testing b_a_r = 0.0682, c_a_r = 0.1295, f_a_r = 0.9830




----------------------> 0.8877777777777777
Evaluated Accuracy: 0.8878
Testing b_a_r = 0.0707, c_a_r = 0.1666, f_a_r = 0.4830




----------------------> 0.8817777777777779
Evaluated Accuracy: 0.8818
Testing b_a_r = 0.4329, c_a_r = 0.9905, f_a_r = 0.0003




----------------------> 0.42333333333333334
Evaluated Accuracy: 0.4233
Testing b_a_r = 0.0669, c_a_r = 0.2332, f_a_r = 0.0068




----------------------> 0.8011111111111111
Evaluated Accuracy: 0.8011
Testing b_a_r = 0.9704, c_a_r = 0.1249, f_a_r = 0.5790




----------------------> 0.8702222222222221
Evaluated Accuracy: 0.8702
Testing b_a_r = 0.0380, c_a_r = 0.1504, f_a_r = 0.8636




----------------------> 0.8993333333333333
Evaluated Accuracy: 0.8993
Testing b_a_r = 0.0576, c_a_r = 0.2224, f_a_r = 0.5273




----------------------> 0.8611111111111112
Evaluated Accuracy: 0.8611
Testing b_a_r = 0.0316, c_a_r = 0.1690, f_a_r = 0.7041




----------------------> 0.9015555555555554
Evaluated Accuracy: 0.9016
Testing b_a_r = 0.0032, c_a_r = 0.0957, f_a_r = 0.3067




----------------------> 0.8077777777777779
Evaluated Accuracy: 0.8078
Testing b_a_r = 0.6267, c_a_r = 0.7597, f_a_r = 0.9937




----------------------> 0.43977777777777766
Evaluated Accuracy: 0.4398
Testing b_a_r = 0.0202, c_a_r = 0.3942, f_a_r = 0.6623




----------------------> 0.7928888888888888
Evaluated Accuracy: 0.7929
Testing b_a_r = 0.4951, c_a_r = 0.4127, f_a_r = 0.9965




----------------------> 0.7357777777777778
Evaluated Accuracy: 0.7358
Testing b_a_r = 0.0702, c_a_r = 0.8360, f_a_r = 0.0027




----------------------> 0.49222222222222217
Evaluated Accuracy: 0.4922
Testing b_a_r = 0.2489, c_a_r = 0.0837, f_a_r = 0.9916




----------------------> 0.8466666666666667
Evaluated Accuracy: 0.8467
Testing b_a_r = 0.9907, c_a_r = 0.0003, f_a_r = 0.4554




----------------------> 0.7800000000000001
Evaluated Accuracy: 0.7800
Testing b_a_r = 0.1755, c_a_r = 0.3092, f_a_r = 0.6061




----------------------> 0.7497777777777777
Evaluated Accuracy: 0.7498
Testing b_a_r = 0.9723, c_a_r = 0.1554, f_a_r = 0.7745




----------------------> 0.8935555555555555
Evaluated Accuracy: 0.8936
Testing b_a_r = 0.0099, c_a_r = 0.1489, f_a_r = 0.9788




----------------------> 0.8973333333333332
Evaluated Accuracy: 0.8973
Testing b_a_r = 0.8877, c_a_r = 0.5632, f_a_r = 0.0014




----------------------> 0.5482222222222222
Evaluated Accuracy: 0.5482
Testing b_a_r = 0.9836, c_a_r = 0.1947, f_a_r = 0.3364




----------------------> 0.846
Evaluated Accuracy: 0.8460
Testing b_a_r = 0.0568, c_a_r = 0.1257, f_a_r = 0.7582




----------------------> 0.8895555555555555
Evaluated Accuracy: 0.8896
Testing b_a_r = 0.8867, c_a_r = 0.4632, f_a_r = 0.6217




----------------------> 0.6113333333333334
Evaluated Accuracy: 0.6113
Testing b_a_r = 0.5786, c_a_r = 0.3535, f_a_r = 0.9981




----------------------> 0.7831111111111112
Evaluated Accuracy: 0.7831
Testing b_a_r = 0.0189, c_a_r = 0.1686, f_a_r = 0.8311




----------------------> 0.8968888888888888
Evaluated Accuracy: 0.8969
Testing b_a_r = 0.9414, c_a_r = 0.1351, f_a_r = 0.9579




----------------------> 0.8831111111111113
Evaluated Accuracy: 0.8831
Testing b_a_r = 0.0052, c_a_r = 0.1477, f_a_r = 0.7102




----------------------> 0.900888888888889
Evaluated Accuracy: 0.9009
Testing b_a_r = 0.3607, c_a_r = 0.3856, f_a_r = 0.0337




----------------------> 0.7826666666666667
Evaluated Accuracy: 0.7827
Testing b_a_r = 0.9629, c_a_r = 0.2089, f_a_r = 0.9346




----------------------> 0.8624444444444445
Evaluated Accuracy: 0.8624
Testing b_a_r = 0.5581, c_a_r = 0.9981, f_a_r = 0.9764




----------------------> 0.3908888888888889
Evaluated Accuracy: 0.3909
Testing b_a_r = 0.0200, c_a_r = 0.1938, f_a_r = 0.6450




----------------------> 0.8815555555555555
Evaluated Accuracy: 0.8816
Testing b_a_r = 0.0134, c_a_r = 0.1476, f_a_r = 0.5540




----------------------> 0.8871111111111111
Evaluated Accuracy: 0.8871
Testing b_a_r = 0.0038, c_a_r = 0.1600, f_a_r = 0.9287




----------------------> 0.8980000000000002
Evaluated Accuracy: 0.8980
Testing b_a_r = 0.9128, c_a_r = 0.0559, f_a_r = 0.0033




----------------------> 0.7793333333333333
Evaluated Accuracy: 0.7793
Testing b_a_r = 0.8765, c_a_r = 0.6686, f_a_r = 0.9963




----------------------> 0.44511111111111107
Evaluated Accuracy: 0.4451
Testing b_a_r = 0.0069, c_a_r = 0.1618, f_a_r = 0.7130




----------------------> 0.9017777777777777
Evaluated Accuracy: 0.9018
Testing b_a_r = 0.0415, c_a_r = 0.1441, f_a_r = 0.8108




----------------------> 0.8984444444444445
Evaluated Accuracy: 0.8984

✅ Optimal Values:
   - b_a_r: 0.0000
   - c_a_r: 0.1521
   - f_a_r: 0.7508
📈 Highest Accuracy Achieved: 0.9024


In [None]:
# b_a_r: 
#    - c_a_r: 
#    - f_a_r: 

#  pt=20

In [106]:

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_olchiki/resnet_resluts/base/5_w_5_s_f_ff', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/secondary/5_w_5_s_f_ff", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/amul/5_w_5_s_f_ff", test_loader, M3, criterion)




# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
# evaluate3(pp+'/_olchiki_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)



---Accu-----pre----rec---------> 0.7591 ± 0.0982  0.7797 ± 0.0954  0.7591 ± 0.0982




---Accu-----pre----rec---------> 0.6169 ± 0.1020  0.6402 ± 0.1033  0.6169 ± 0.1020
---Accu-----pre----rec---------> 0.9024 ± 0.0558  0.9085 ± 0.0513  0.9024 ± 0.0558


In [33]:
!ls /home/asufian/Desktop/output_olchiki/code/olchiki/model_10

model_B_olchiki_1-shot_20.pth  model_B_olchiki_1-shot_70.pth
model_B_olchiki_1-shot_40.pth  model_B_olchiki_1-shot_90.pth
model_B_olchiki_1-shot_50.pth


In [71]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/olchiki/ressecondary/model_10/model_B_olchiki_10-shot_res.pth',map_location=torch.device('cpu')))


torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')

In [72]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

           
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

            
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [73]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.9226190476190478
Evaluated Accuracy: 0.9226
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.6683333333333334
Evaluated Accuracy: 0.6683
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.7997619047619047
Evaluated Accuracy: 0.7998
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.9140476190476191
Evaluated Accuracy: 0.9140
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.8707142857142857
Evaluated Accuracy: 0.8707
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.4033333333333334
Evaluated Accuracy: 0.4033
Testing b_a_r = 0.0000, c_a_r = 0.1267, f_a_r = 0.0000




----------------------> 0.8688095238095238
Evaluated Accuracy: 0.8688
Testing b_a_r = 0.0000, c_a_r = 0.1547, f_a_r = 1.0000




----------------------> 0.9216666666666667
Evaluated Accuracy: 0.9217
Testing b_a_r = 0.0000, c_a_r = 0.1296, f_a_r = 0.7093




----------------------> 0.9314285714285714
Evaluated Accuracy: 0.9314
Testing b_a_r = 0.1898, c_a_r = 0.3402, f_a_r = 1.0000




----------------------> 0.755952380952381
Evaluated Accuracy: 0.7560
Testing b_a_r = 0.1925, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.809047619047619
Evaluated Accuracy: 0.8090
Testing b_a_r = 1.0000, c_a_r = 0.1202, f_a_r = 0.8362




----------------------> 0.9276190476190475
Evaluated Accuracy: 0.9276
Testing b_a_r = 0.9068, c_a_r = 0.1426, f_a_r = 0.4617




----------------------> 0.91
Evaluated Accuracy: 0.9100
Testing b_a_r = 0.0560, c_a_r = 0.1050, f_a_r = 1.0000




----------------------> 0.9321428571428573
Evaluated Accuracy: 0.9321
Testing b_a_r = 0.7112, c_a_r = 0.1202, f_a_r = 0.9946




----------------------> 0.9250000000000002
Evaluated Accuracy: 0.9250
Testing b_a_r = 0.0187, c_a_r = 0.1152, f_a_r = 0.8502




----------------------> 0.9319047619047621
Evaluated Accuracy: 0.9319
Testing b_a_r = 0.0290, c_a_r = 0.1331, f_a_r = 0.7392




----------------------> 0.9316666666666666
Evaluated Accuracy: 0.9317
Testing b_a_r = 0.2356, c_a_r = 0.7142, f_a_r = 0.9959




----------------------> 0.46309523809523806
Evaluated Accuracy: 0.4631
Testing b_a_r = 0.0090, c_a_r = 0.2089, f_a_r = 0.4416




----------------------> 0.9126190476190477
Evaluated Accuracy: 0.9126
Testing b_a_r = 0.4374, c_a_r = 0.7738, f_a_r = 0.0021




----------------------> 0.49523809523809526
Evaluated Accuracy: 0.4952
Testing b_a_r = 0.0903, c_a_r = 0.0758, f_a_r = 0.9951




----------------------> 0.9335714285714284
Evaluated Accuracy: 0.9336
Testing b_a_r = 0.5260, c_a_r = 0.2268, f_a_r = 0.0036




----------------------> 0.8754761904761906
Evaluated Accuracy: 0.8755
Testing b_a_r = 0.0991, c_a_r = 0.2232, f_a_r = 0.9552




----------------------> 0.8971428571428574
Evaluated Accuracy: 0.8971
Testing b_a_r = 0.0537, c_a_r = 0.0917, f_a_r = 0.5867




----------------------> 0.9221428571428573
Evaluated Accuracy: 0.9221
Testing b_a_r = 0.8187, c_a_r = 0.5421, f_a_r = 0.9992




----------------------> 0.64
Evaluated Accuracy: 0.6400
Testing b_a_r = 0.5503, c_a_r = 0.6149, f_a_r = 0.0004




----------------------> 0.6852380952380952
Evaluated Accuracy: 0.6852
Testing b_a_r = 0.9774, c_a_r = 0.2527, f_a_r = 0.5390




----------------------> 0.8890476190476191
Evaluated Accuracy: 0.8890
Testing b_a_r = 0.1453, c_a_r = 0.0908, f_a_r = 0.8341




----------------------> 0.9330952380952381
Evaluated Accuracy: 0.9331
Testing b_a_r = 0.7326, c_a_r = 0.8981, f_a_r = 0.9970




----------------------> 0.3928571428571428
Evaluated Accuracy: 0.3929
Testing b_a_r = 0.2564, c_a_r = 0.0456, f_a_r = 0.9987




----------------------> 0.907857142857143
Evaluated Accuracy: 0.9079
Testing b_a_r = 0.9785, c_a_r = 0.0870, f_a_r = 0.9764




----------------------> 0.9283333333333332
Evaluated Accuracy: 0.9283
Testing b_a_r = 0.0125, c_a_r = 0.0899, f_a_r = 0.8741




----------------------> 0.9352380952380953
Evaluated Accuracy: 0.9352
Testing b_a_r = 0.0180, c_a_r = 0.1585, f_a_r = 0.7272




----------------------> 0.9240476190476189
Evaluated Accuracy: 0.9240
Testing b_a_r = 0.0006, c_a_r = 0.1012, f_a_r = 0.9045




----------------------> 0.9345238095238094
Evaluated Accuracy: 0.9345
Testing b_a_r = 0.9946, c_a_r = 0.0888, f_a_r = 0.8290




----------------------> 0.9290476190476191
Evaluated Accuracy: 0.9290
Testing b_a_r = 0.2527, c_a_r = 0.0650, f_a_r = 0.0147




----------------------> 0.844047619047619
Evaluated Accuracy: 0.8440
Testing b_a_r = 0.1945, c_a_r = 0.0033, f_a_r = 0.6007




----------------------> 0.8557142857142858
Evaluated Accuracy: 0.8557
Testing b_a_r = 0.0446, c_a_r = 0.0881, f_a_r = 0.9712




----------------------> 0.9345238095238096
Evaluated Accuracy: 0.9345
Testing b_a_r = 0.9333, c_a_r = 0.1847, f_a_r = 0.9966




----------------------> 0.9166666666666666
Evaluated Accuracy: 0.9167
Testing b_a_r = 0.9319, c_a_r = 0.1846, f_a_r = 0.3268




----------------------> 0.9030952380952381
Evaluated Accuracy: 0.9031
Testing b_a_r = 0.0706, c_a_r = 0.2769, f_a_r = 0.0061




----------------------> 0.860952380952381
Evaluated Accuracy: 0.8610
Testing b_a_r = 0.2778, c_a_r = 0.9150, f_a_r = 0.0167




----------------------> 0.4423809523809523
Evaluated Accuracy: 0.4424
Testing b_a_r = 0.0118, c_a_r = 0.1040, f_a_r = 0.7799




----------------------> 0.9364285714285715
Evaluated Accuracy: 0.9364
Testing b_a_r = 0.0020, c_a_r = 0.0728, f_a_r = 0.8171




----------------------> 0.9319047619047618
Evaluated Accuracy: 0.9319
Testing b_a_r = 0.5542, c_a_r = 0.4446, f_a_r = 0.9982




----------------------> 0.6404761904761905
Evaluated Accuracy: 0.6405
Testing b_a_r = 0.0222, c_a_r = 0.1825, f_a_r = 0.6951




----------------------> 0.9171428571428573
Evaluated Accuracy: 0.9171
Testing b_a_r = 0.9805, c_a_r = 0.1054, f_a_r = 0.7276




----------------------> 0.9321428571428573
Evaluated Accuracy: 0.9321
Testing b_a_r = 0.0067, c_a_r = 0.0978, f_a_r = 0.7058




----------------------> 0.9330952380952383
Evaluated Accuracy: 0.9331
Testing b_a_r = 0.9910, c_a_r = 0.0809, f_a_r = 0.9840




----------------------> 0.9288095238095239
Evaluated Accuracy: 0.9288
Testing b_a_r = 0.0145, c_a_r = 0.1022, f_a_r = 0.9654




----------------------> 0.9340476190476191
Evaluated Accuracy: 0.9340

✅ Optimal Values:
   - b_a_r: 0.0118
   - c_a_r: 0.1040
   - f_a_r: 0.7799
📈 Highest Accuracy Achieved: 0.9364


In [74]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_olchiki/resnet_resluts/base/5_w_10_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/secondary/5_w_10_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/amul/5_w_5_10_f", test_loader, M3, criterion)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
# evaluate3(pp+'/_olchiki_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)



---Accu-----pre----rec---------> 0.8098 ± 0.0727  0.8214 ± 0.0711  0.8098 ± 0.0727




---Accu-----pre----rec---------> 0.6790 ± 0.0583  0.6879 ± 0.0631  0.6790 ± 0.0583
---Accu-----pre----rec---------> 0.9364 ± 0.0456  0.9386 ± 0.0436  0.9364 ± 0.0456


In [50]:
ls /home/asufian/Desktop/output_olchiki/lol

5_w_10_s.txt  5_w_10_s.xlsx


In [76]:
clss_8=[
    ['Olchiki_2_80_c10', 'Olchiki_2_80_c25', 'Olchiki_2_80_c27', 'Olchiki_2_80_c13', 'Olchiki_2_80_c8', 'Olchiki_2_80_c11', 'Olchiki_2_80_c15', 'Olchiki_2_80_c5'],
    ['Olchiki_2_80_c1', 'Olchiki_2_80_c27', 'Olchiki_2_80_c8', 'Olchiki_2_80_c4', 'Olchiki_2_80_c18', 'Olchiki_2_80_c2', 'Olchiki_2_80_c29', 'Olchiki_2_80_c26'],
    ['Olchiki_2_80_c23', 'Olchiki_2_80_c11', 'Olchiki_2_80_c25', 'Olchiki_2_80_c29', 'Olchiki_2_80_c21', 'Olchiki_2_80_c27', 'Olchiki_2_80_c3', 'Olchiki_2_80_c7'],
    ['Olchiki_2_80_c23', 'Olchiki_2_80_c4', 'Olchiki_2_80_c26', 'Olchiki_2_80_c2', 'Olchiki_2_80_c17', 'Olchiki_2_80_c30', 'Olchiki_2_80_c18', 'Olchiki_2_80_c9'],
    ['Olchiki_2_80_c6', 'Olchiki_2_80_c12', 'Olchiki_2_80_c24', 'Olchiki_2_80_c18', 'Olchiki_2_80_c22', 'Olchiki_2_80_c10', 'Olchiki_2_80_c26', 'Olchiki_2_80_c16'],
     ['Olchiki_2_80_c6', 'Olchiki_2_80_c26', 'Olchiki_2_80_c8', 'Olchiki_2_80_c12', 'Olchiki_2_80_c14', 'Olchiki_2_80_c16', 'Olchiki_2_80_c25', 'Olchiki_2_80_c4'],
    ['Olchiki_2_80_c24', 'Olchiki_2_80_c26', 'Olchiki_2_80_c21', 'Olchiki_2_80_c7', 'Olchiki_2_80_c12', 'Olchiki_2_80_c25', 'Olchiki_2_80_c2', 'Olchiki_2_80_c5'],
    ['Olchiki_2_80_c5', 'Olchiki_2_80_c18', 'Olchiki_2_80_c30', 'Olchiki_2_80_c7', 'Olchiki_2_80_c9', 'Olchiki_2_80_c16', 'Olchiki_2_80_c6', 'Olchiki_2_80_c24'],
    ['Olchiki_2_80_c27', 'Olchiki_2_80_c5', 'Olchiki_2_80_c3', 'Olchiki_2_80_c12', 'Olchiki_2_80_c30', 'Olchiki_2_80_c6', 'Olchiki_2_80_c18', 'Olchiki_2_80_c10'],
    ['Olchiki_2_80_c26', 'Olchiki_2_80_c20', 'Olchiki_2_80_c1', 'Olchiki_2_80_c6', 'Olchiki_2_80_c13', 'Olchiki_2_80_c23', 'Olchiki_2_80_c30', 'Olchiki_2_80_c5'],
    ['Olchiki_2_80_c12', 'Olchiki_2_80_c8', 'Olchiki_2_80_c11', 'Olchiki_2_80_c14', 'Olchiki_2_80_c25', 'Olchiki_2_80_c4', 'Olchiki_2_80_c2', 'Olchiki_2_80_c29'],
    ['Olchiki_2_80_c17', 'Olchiki_2_80_c14', 'Olchiki_2_80_c20', 'Olchiki_2_80_c29', 'Olchiki_2_80_c8', 'Olchiki_2_80_c11', 'Olchiki_2_80_c2', 'Olchiki_2_80_c28'],
    
  

     ]

clss_support_imagesss_10_shot=[
    [[6, 16, 45, 19, 27, 44, 23, 60, 42, 13], [60, 32, 27, 28, 65, 66, 44, 3, 21, 34], [8, 74, 70, 6, 14, 23, 69, 75, 39, 21], [42, 58, 71, 17, 35, 1, 75, 40, 56, 44], [10, 51, 55, 15, 75, 48, 6, 56, 63, 19], [11, 47, 24, 10, 15, 37, 63, 48, 33, 22], [72, 58, 3, 32, 41, 51, 15, 46, 0, 52], [61, 6, 76, 62, 45, 51, 55, 72, 53, 40]],
    [[3, 23, 33, 18, 2, 25, 50, 6, 36, 68], [55, 79, 3, 0, 15, 53, 4, 46, 2, 44], [11, 9, 62, 15, 36, 37, 24, 38, 13, 49], [34, 26, 31, 70, 44, 17, 5, 12, 67, 15], [23, 72, 65, 58, 5, 42, 11, 22, 15, 69], [40, 65, 22, 56, 7, 69, 66, 57, 54, 48], [13, 40, 20, 75, 71, 45, 12, 32, 56, 79], [22, 55, 51, 69, 64, 18, 60, 31, 30, 11]],
    [[62, 36, 4, 1, 3, 32, 31, 33, 59, 7], [25, 62, 6, 45, 53, 7, 71, 72, 61, 23], [21, 68, 67, 64, 36, 13, 33, 2, 25, 24], [59, 76, 28, 11, 10, 39, 51, 7, 22, 25], [11, 29, 59, 64, 40, 63, 27, 37, 22, 9], [42, 51, 50, 26, 56, 79, 8, 7, 15, 0], [28, 43, 56, 19, 48, 20, 50, 6, 73, 10], [74, 70, 67, 65, 72, 47, 62, 26, 49, 10]],
    [[22, 58, 20, 1, 55, 40, 6, 18, 62, 33], [9, 28, 77, 31, 4, 37, 17, 29, 64, 40], [30, 7, 9, 20, 74, 79, 43, 2, 24, 55], [18, 13, 20, 19, 5, 4, 0, 39, 25, 34], [11, 28, 38, 72, 9, 47, 6, 62, 43, 64], [23, 26, 65, 54, 24, 27, 15, 48, 41, 28], [20, 55, 53, 3, 24, 8, 13, 50, 68, 35], [32, 33, 24, 51, 79, 39, 57, 37, 42, 60]],
    [[38, 30, 53, 3, 0, 68, 18, 47, 28, 32], [55, 22, 0, 37, 75, 7, 64, 12, 34, 62], [6, 25, 44, 67, 77, 8, 70, 76, 3, 53], [70, 23, 59, 30, 37, 39, 12, 34, 53, 56], [77, 25, 60, 33, 43, 52, 75, 26, 36, 61], [29, 76, 1, 56, 55, 0, 3, 19, 18, 40], [40, 41, 45, 19, 6, 27, 18, 15, 43, 56], [72, 52, 24, 42, 66, 60, 79, 10, 11, 73]],
     [[37, 66, 29, 79, 19, 72, 63, 49, 20, 4], [51, 47, 75, 16, 17, 61, 3, 78, 77, 24], [1, 27, 34, 16, 30, 48, 75, 12, 10, 33], [24, 66, 65, 71, 8, 12, 56, 58, 16, 30], [26, 23, 74, 35, 40, 8, 38, 37, 42, 66], [30, 21, 33, 0, 28, 20, 19, 79, 12, 45], [22, 5, 32, 77, 40, 67, 46, 7, 42, 56], [14, 70, 63, 6, 23, 5, 30, 72, 26, 41]],
    [[29, 43, 70, 46, 11, 58, 37, 48, 21, 23], [18, 42, 64, 14, 38, 6, 16, 21, 29, 67], [15, 72, 56, 69, 78, 39, 29, 49, 27, 36], [67, 54, 16, 26, 58, 17, 68, 42, 19, 76], [60, 75, 74, 7, 2, 47, 34, 25, 44, 63], [28, 62, 59, 68, 65, 16, 72, 48, 42, 10], [44, 55, 58, 42, 78, 24, 9, 32, 16, 73], [17, 34, 19, 32, 69, 75, 46, 45, 39, 18]],
    [[58, 54, 29, 40, 11, 43, 30, 44, 70, 3], [75, 21, 6, 36, 48, 70, 64, 30, 32, 59], [1, 70, 41, 69, 67, 19, 16, 77, 45, 53], [19, 5, 11, 16, 34, 14, 70, 71, 69, 33], [3, 14, 75, 28, 1, 5, 15, 72, 40, 60], [16, 22, 38, 77, 39, 7, 79, 9, 58, 46], [53, 36, 62, 75, 37, 45, 76, 74, 58, 70], [74, 33, 62, 8, 73, 75, 38, 0, 20, 12]],
    [[79, 67, 1, 32, 11, 7, 59, 26, 50, 57], [11, 24, 16, 74, 39, 78, 41, 6, 36, 27], [27, 41, 58, 71, 39, 42, 25, 13, 24, 73], [47, 67, 42, 35, 3, 24, 53, 43, 39, 16], [15, 14, 2, 23, 53, 18, 25, 48, 4, 66], [66, 36, 68, 14, 63, 13, 48, 15, 56, 20], [13, 72, 12, 4, 76, 29, 45, 53, 59, 70], [73, 52, 24, 26, 42, 5, 0, 63, 2, 45]],
     [[42, 35, 43, 7, 20, 30, 17, 2, 76, 45], [6, 24, 17, 30, 68, 12, 78, 7, 50, 55], [12, 30, 1, 74, 69, 16, 15, 17, 34, 9], [64, 37, 66, 77, 45, 49, 78, 75, 31, 57], [21, 49, 36, 12, 52, 10, 60, 50, 31, 77], [28, 7, 72, 49, 1, 67, 10, 0, 38, 32], [69, 59, 58, 22, 15, 14, 60, 7, 21, 35], [45, 1, 62, 46, 35, 12, 66, 17, 64, 27]],
    [[14, 68, 70, 55, 67, 58, 79, 62, 2, 31], [9, 1, 10, 26, 25, 49, 35, 6, 0, 33], [56, 62, 75, 31, 25, 71, 40, 12, 39, 6], [72, 75, 40, 34, 8, 63, 79, 20, 73, 12], [36, 70, 50, 37, 69, 66, 54, 75, 8, 55], [18, 27, 8, 65, 33, 57, 28, 56, 0, 60], [51, 41, 4, 10, 57, 21, 11, 3, 44, 79], [13, 52, 5, 71, 9, 68, 28, 46, 63, 12]],
     [[25, 79, 77, 21, 73, 38, 52, 57, 75, 31], [61, 76, 63, 53, 27, 49, 0, 38, 21, 31], [77, 53, 4, 0, 14, 29, 34, 30, 39, 3], [36, 50, 8, 30, 21, 55, 54, 71, 4, 12], [16, 4, 44, 36, 66, 30, 79, 58, 7, 24], [45, 42, 56, 59, 57, 22, 79, 69, 36, 70], [27, 77, 31, 32, 61, 55, 9, 10, 44, 50], [71, 25, 45, 11, 16, 72, 79, 47, 32, 63]]
]
clss_support_imagesss_5_shot=[
    [[6, 16, 45, 19, 27], [60, 32, 27, 28, 65], [8, 74, 70, 6, 14], [42, 58, 71, 17, 35], [10, 51, 55, 15, 75], [11, 47, 24, 10, 15], [72, 58, 3, 32, 41], [61, 6, 76, 62, 45]],
    [[3, 23, 33, 18, 2], [55, 79, 3, 0, 15], [11, 9, 62, 15, 36], [34, 26, 31, 70, 44], [23, 72, 65, 58, 5], [40, 65, 22, 56, 7], [13, 40, 20, 75, 71], [22, 55, 51, 69, 64]],
    [[62, 36, 4, 1, 3], [25, 62, 6, 45, 53], [21, 68, 67, 64, 36], [59, 76, 28, 11, 10], [11, 29, 59, 64, 40], [42, 51, 50, 26, 56], [28, 43, 56, 19, 48], [74, 70, 67, 65, 72]],
    [[22, 58, 20, 1, 55], [9, 28, 77, 31, 4], [30, 7, 9, 20, 74], [18, 13, 20, 19, 5], [11, 28, 38, 72, 9], [23, 26, 65, 54, 24], [20, 55, 53, 3, 24], [32, 33, 24, 51, 79]],
    [[38, 30, 53, 3, 0], [55, 22, 0, 37, 75], [6, 25, 44, 67, 77], [70, 23, 59, 30, 37], [77, 25, 60, 33, 43], [29, 76, 1, 56, 55], [40, 41, 45, 19, 6], [72, 52, 24, 42, 66]],
    [[37, 66, 29, 79, 19], [51, 47, 75, 16, 17], [1, 27, 34, 16, 30], [24, 66, 65, 71, 8], [26, 23, 74, 35, 40], [30, 21, 33, 0, 28], [22, 5, 32, 77, 40], [14, 70, 63, 6, 23]],
    [[29, 43, 70, 46, 11], [18, 42, 64, 14, 38], [15, 72, 56, 69, 78], [67, 54, 16, 26, 58], [60, 75, 74, 7, 2], [28, 62, 59, 68, 65], [44, 55, 58, 42, 78], [17, 34, 19, 32, 69]],
    [[58, 54, 29, 40, 11], [75, 21, 6, 36, 48], [1, 70, 41, 69, 67], [19, 5, 11, 16, 34], [3, 14, 75, 28, 1], [16, 22, 38, 77, 39], [53, 36, 62, 75, 37], [74, 33, 62, 8, 73]],
    [[79, 67, 1, 32, 11], [11, 24, 16, 74, 39], [27, 41, 58, 71, 39], [47, 67, 42, 35, 3], [15, 14, 2, 23, 53], [66, 36, 68, 14, 63], [13, 72, 12, 4, 76], [73, 52, 24, 26, 42]],
    [[42, 35, 43, 7, 20], [6, 24, 17, 30, 68], [12, 30, 1, 74, 69], [64, 37, 66, 77, 45], [21, 49, 36, 12, 52], [28, 7, 72, 49, 1], [69, 59, 58, 22, 15], [45, 1, 62, 46, 35]],
    [[14, 68, 70, 55, 67], [9, 1, 10, 26, 25], [56, 62, 75, 31, 25], [72, 75, 40, 34, 8], [36, 70, 50, 37, 69], [18, 27, 8, 65, 33], [51, 41, 4, 10, 57], [13, 52, 5, 71, 9]],
    [[25, 79, 77, 21, 73], [61, 76, 63, 53, 27], [77, 53, 4, 0, 14], [36, 50, 8, 30, 21], [16, 4, 44, 36, 66], [45, 42, 56, 59, 57], [27, 77, 31, 32, 61], [71, 25, 45, 11, 16]]
]


clss_support_imagesss_1_shot=[
    [[6], [60], [8], [42], [10], [11], [72], [61]],
    [[3], [55], [11], [34], [23], [40], [13], [22]],
    [[62], [25], [21], [59], [11], [42], [28], [74]],
     [[22], [9], [30], [18], [11], [23], [20], [32]],
    [[38], [55], [6], [70], [77], [29], [40], [72]],
    [[37], [51], [1], [24], [26], [30], [22], [14]],
    [[29], [18], [15], [67], [60], [28], [44], [17]],
    [[58], [75], [1], [19], [3], [16], [53], [74]],
    [[79], [11], [27], [47], [15], [66], [13], [73]],
    [[42], [6], [12], [64], [21], [28], [69], [45]],
    [[14], [9], [56], [72], [36], [18], [51], [13]],
    [[25], [61], [77], [36], [16], [45], [27], [71]]
 






]



In [77]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/olchiki/ressecondary/model_1/model_B_olchiki_1-shot_res.pth',map_location=torch.device('cpu')))
torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [78]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

           
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [79]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.636867088607595
Evaluated Accuracy: 0.6369
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.5644778481012659
Evaluated Accuracy: 0.5645
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.6181434599156118
Evaluated Accuracy: 0.6181
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.6971255274261603
Evaluated Accuracy: 0.6971
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.6824894514767932
Evaluated Accuracy: 0.6825
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 1.0000




----------------------> 0.6760284810126581
Evaluated Accuracy: 0.6760
Testing b_a_r = 0.2213, c_a_r = 1.0000, f_a_r = 0.9428




----------------------> 0.3154008438818565
Evaluated Accuracy: 0.3154
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.5961234177215191
Evaluated Accuracy: 0.5961
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.6156




----------------------> 0.6627109704641351
Evaluated Accuracy: 0.6627
Testing b_a_r = 1.0000, c_a_r = 0.0692, f_a_r = 1.0000




----------------------> 0.7038502109704642
Evaluated Accuracy: 0.7039
Testing b_a_r = 0.9028, c_a_r = 0.6624, f_a_r = 1.0000




----------------------> 0.37223101265822794
Evaluated Accuracy: 0.3722
Testing b_a_r = 1.0000, c_a_r = 0.1644, f_a_r = 0.0000




----------------------> 0.6334388185654007
Evaluated Accuracy: 0.6334
Testing b_a_r = 0.9509, c_a_r = 0.3766, f_a_r = 1.0000




----------------------> 0.548918776371308
Evaluated Accuracy: 0.5489
Testing b_a_r = 0.8807, c_a_r = 0.7938, f_a_r = 0.0007




----------------------> 0.3326740506329114
Evaluated Accuracy: 0.3327
Testing b_a_r = 0.9930, c_a_r = 0.0964, f_a_r = 0.7615




----------------------> 0.7194092827004219
Evaluated Accuracy: 0.7194
Testing b_a_r = 0.0881, c_a_r = 0.0878, f_a_r = 1.0000




----------------------> 0.7022679324894515
Evaluated Accuracy: 0.7023
Testing b_a_r = 0.1821, c_a_r = 0.0913, f_a_r = 0.5732




----------------------> 0.7050369198312237
Evaluated Accuracy: 0.7050
Testing b_a_r = 0.0400, c_a_r = 0.5894, f_a_r = 0.0016




----------------------> 0.40664556962025317
Evaluated Accuracy: 0.4066
Testing b_a_r = 0.9171, c_a_r = 0.1239, f_a_r = 0.6504




----------------------> 0.7137394514767932
Evaluated Accuracy: 0.7137
Testing b_a_r = 0.9619, c_a_r = 0.1195, f_a_r = 0.9545




----------------------> 0.7097837552742616
Evaluated Accuracy: 0.7098
Testing b_a_r = 0.1065, c_a_r = 0.5088, f_a_r = 0.9983




----------------------> 0.4808808016877637
Evaluated Accuracy: 0.4809
Testing b_a_r = 0.0638, c_a_r = 0.9980, f_a_r = 0.0311




----------------------> 0.31526898734177217
Evaluated Accuracy: 0.3153
Testing b_a_r = 0.8513, c_a_r = 0.0929, f_a_r = 0.0103




----------------------> 0.6420094936708861
Evaluated Accuracy: 0.6420
Testing b_a_r = 0.9109, c_a_r = 0.2526, f_a_r = 0.0106




----------------------> 0.5849156118143459
Evaluated Accuracy: 0.5849
Testing b_a_r = 0.5014, c_a_r = 0.8547, f_a_r = 0.9903




----------------------> 0.3321466244725738
Evaluated Accuracy: 0.3321
Testing b_a_r = 0.2121, c_a_r = 0.2839, f_a_r = 0.9929




----------------------> 0.5415348101265823
Evaluated Accuracy: 0.5415
Testing b_a_r = 0.0060, c_a_r = 0.1108, f_a_r = 0.7467




----------------------> 0.7137394514767933
Evaluated Accuracy: 0.7137
Testing b_a_r = 0.9226, c_a_r = 0.3872, f_a_r = 0.0082




----------------------> 0.620253164556962
Evaluated Accuracy: 0.6203
Testing b_a_r = 0.0003, c_a_r = 0.3938, f_a_r = 0.4634




----------------------> 0.6075949367088608
Evaluated Accuracy: 0.6076
Testing b_a_r = 0.1242, c_a_r = 0.1349, f_a_r = 0.3544




----------------------> 0.6813027426160337
Evaluated Accuracy: 0.6813
Testing b_a_r = 0.7848, c_a_r = 0.9032, f_a_r = 0.0024




----------------------> 0.31500527426160335
Evaluated Accuracy: 0.3150
Testing b_a_r = 0.7619, c_a_r = 0.0523, f_a_r = 0.3563




----------------------> 0.6582278481012657
Evaluated Accuracy: 0.6582
Testing b_a_r = 0.9598, c_a_r = 0.0310, f_a_r = 0.9814




----------------------> 0.702795358649789
Evaluated Accuracy: 0.7028
Testing b_a_r = 0.8730, c_a_r = 0.1489, f_a_r = 0.9986




----------------------> 0.6794567510548523
Evaluated Accuracy: 0.6795
Testing b_a_r = 0.9070, c_a_r = 0.6878, f_a_r = 0.0206




----------------------> 0.3640559071729958
Evaluated Accuracy: 0.3641
Testing b_a_r = 0.9918, c_a_r = 0.1033, f_a_r = 0.5175




----------------------> 0.7132120253164557
Evaluated Accuracy: 0.7132
Testing b_a_r = 0.7415, c_a_r = 0.4441, f_a_r = 0.9936




----------------------> 0.5305907172995782
Evaluated Accuracy: 0.5306
Testing b_a_r = 0.8994, c_a_r = 0.7592, f_a_r = 0.9946




----------------------> 0.3544303797468354
Evaluated Accuracy: 0.3544
Testing b_a_r = 0.9660, c_a_r = 0.2879, f_a_r = 0.4943




----------------------> 0.5746308016877638
Evaluated Accuracy: 0.5746
Testing b_a_r = 0.1730, c_a_r = 0.5112, f_a_r = 0.0043




----------------------> 0.4856276371308017
Evaluated Accuracy: 0.4856
Testing b_a_r = 0.0060, c_a_r = 0.2083, f_a_r = 0.3047




----------------------> 0.6257911392405063
Evaluated Accuracy: 0.6258
Testing b_a_r = 0.0761, c_a_r = 0.1408, f_a_r = 0.7627




----------------------> 0.6876318565400844
Evaluated Accuracy: 0.6876
Testing b_a_r = 0.9567, c_a_r = 0.1018, f_a_r = 0.7030




----------------------> 0.7231012658227848
Evaluated Accuracy: 0.7231
Testing b_a_r = 0.9865, c_a_r = 0.1007, f_a_r = 0.7599




----------------------> 0.721123417721519
Evaluated Accuracy: 0.7211
Testing b_a_r = 0.9773, c_a_r = 0.0982, f_a_r = 0.8923




----------------------> 0.7180907172995781
Evaluated Accuracy: 0.7181
Testing b_a_r = 0.7774, c_a_r = 0.2251, f_a_r = 0.9942




----------------------> 0.5814873417721519
Evaluated Accuracy: 0.5815
Testing b_a_r = 0.9926, c_a_r = 0.0838, f_a_r = 0.7740




----------------------> 0.7140031645569621
Evaluated Accuracy: 0.7140
Testing b_a_r = 0.7576, c_a_r = 0.5771, f_a_r = 0.9972




----------------------> 0.4030854430379747
Evaluated Accuracy: 0.4031
Testing b_a_r = 0.0037, c_a_r = 0.0254, f_a_r = 0.8168




----------------------> 0.70042194092827
Evaluated Accuracy: 0.7004
Testing b_a_r = 0.9941, c_a_r = 0.1123, f_a_r = 0.6539




----------------------> 0.7213871308016877
Evaluated Accuracy: 0.7214

✅ Optimal Values:
   - b_a_r: 0.9567
   - c_a_r: 0.1018
   - f_a_r: 0.7030
📈 Highest Accuracy Achieved: 0.7231


In [80]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_olchiki/resnet_resluts/base/8_w_1_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/secondary/8_w_1_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/amul/8_w_1_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.5993 ± 0.0565  0.6236 ± 0.0574  0.5993 ± 0.0565




---Accu-----pre----rec---------> 0.4285 ± 0.0644  0.4663 ± 0.0691  0.4285 ± 0.0644
---Accu-----pre----rec---------> 0.7231 ± 0.0661  0.7440 ± 0.0758  0.7231 ± 0.0661


In [81]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/olchiki/ressecondary/model_5/model_B_olchiki_5-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [82]:
def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

           
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

            
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [83]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.8236111111111111
Evaluated Accuracy: 0.8236
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.6056944444444444
Evaluated Accuracy: 0.6057
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.7254166666666667
Evaluated Accuracy: 0.7254
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.7895833333333333
Evaluated Accuracy: 0.7896
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7531944444444445
Evaluated Accuracy: 0.7532
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.3630555555555555
Evaluated Accuracy: 0.3631
Testing b_a_r = 0.0000, c_a_r = 0.1539, f_a_r = 0.0000




----------------------> 0.7712500000000001
Evaluated Accuracy: 0.7713
Testing b_a_r = 0.0000, c_a_r = 0.1925, f_a_r = 1.0000




----------------------> 0.8240277777777778
Evaluated Accuracy: 0.8240
Testing b_a_r = 1.0000, c_a_r = 0.2729, f_a_r = 1.0000




----------------------> 0.7247222222222223
Evaluated Accuracy: 0.7247
Testing b_a_r = 0.9542, c_a_r = 0.1278, f_a_r = 0.0518




----------------------> 0.7725
Evaluated Accuracy: 0.7725
Testing b_a_r = 1.0000, c_a_r = 0.1361, f_a_r = 1.0000




----------------------> 0.82375
Evaluated Accuracy: 0.8237
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.711388888888889
Evaluated Accuracy: 0.7114
Testing b_a_r = 0.0000, c_a_r = 0.1487, f_a_r = 0.8523




----------------------> 0.8405555555555556
Evaluated Accuracy: 0.8406
Testing b_a_r = 0.1384, c_a_r = 0.7073, f_a_r = 0.9990




----------------------> 0.3973611111111111
Evaluated Accuracy: 0.3974
Testing b_a_r = 0.0008, c_a_r = 0.1792, f_a_r = 0.5025




----------------------> 0.8229166666666669
Evaluated Accuracy: 0.8229
Testing b_a_r = 0.0883, c_a_r = 0.0801, f_a_r = 0.9982




----------------------> 0.8054166666666666
Evaluated Accuracy: 0.8054
Testing b_a_r = 0.3956, c_a_r = 0.7753, f_a_r = 0.0057




----------------------> 0.4630555555555555
Evaluated Accuracy: 0.4631
Testing b_a_r = 0.0181, c_a_r = 0.1360, f_a_r = 0.6417




----------------------> 0.8363888888888891
Evaluated Accuracy: 0.8364
Testing b_a_r = 0.9885, c_a_r = 0.2554, f_a_r = 0.0311




----------------------> 0.7526388888888887
Evaluated Accuracy: 0.7526
Testing b_a_r = 0.3692, c_a_r = 0.4531, f_a_r = 0.9961




----------------------> 0.5629166666666667
Evaluated Accuracy: 0.5629
Testing b_a_r = 0.3619, c_a_r = 0.6192, f_a_r = 0.0008




----------------------> 0.5031944444444445
Evaluated Accuracy: 0.5032
Testing b_a_r = 0.0612, c_a_r = 0.2461, f_a_r = 0.5163




----------------------> 0.7790277777777778
Evaluated Accuracy: 0.7790
Testing b_a_r = 0.0032, c_a_r = 0.0957, f_a_r = 0.3067




----------------------> 0.7719444444444444
Evaluated Accuracy: 0.7719
Testing b_a_r = 0.0644, c_a_r = 0.1457, f_a_r = 0.9996




----------------------> 0.8366666666666668
Evaluated Accuracy: 0.8367
Testing b_a_r = 0.0067, c_a_r = 0.1555, f_a_r = 0.7886




----------------------> 0.8418055555555557
Evaluated Accuracy: 0.8418
Testing b_a_r = 0.9693, c_a_r = 0.1333, f_a_r = 0.6920




----------------------> 0.8304166666666667
Evaluated Accuracy: 0.8304
Testing b_a_r = 0.9189, c_a_r = 0.9925, f_a_r = 0.9835




----------------------> 0.3543055555555556
Evaluated Accuracy: 0.3543
Testing b_a_r = 0.1363, c_a_r = 0.0005, f_a_r = 0.6201




----------------------> 0.7556944444444444
Evaluated Accuracy: 0.7557
Testing b_a_r = 0.0217, c_a_r = 0.1213, f_a_r = 0.8240




----------------------> 0.8352777777777778
Evaluated Accuracy: 0.8353
Testing b_a_r = 0.9645, c_a_r = 0.1857, f_a_r = 0.2300




----------------------> 0.7841666666666667
Evaluated Accuracy: 0.7842
Testing b_a_r = 0.5628, c_a_r = 0.9996, f_a_r = 0.0123




----------------------> 0.39166666666666666
Evaluated Accuracy: 0.3917
Testing b_a_r = 0.0137, c_a_r = 0.3669, f_a_r = 0.6080




----------------------> 0.7672222222222222
Evaluated Accuracy: 0.7672
Testing b_a_r = 0.0446, c_a_r = 0.5381, f_a_r = 0.5598




----------------------> 0.4724999999999999
Evaluated Accuracy: 0.4725
Testing b_a_r = 0.0424, c_a_r = 0.3523, f_a_r = 0.9517




----------------------> 0.7456944444444442
Evaluated Accuracy: 0.7457
Testing b_a_r = 0.4820, c_a_r = 0.8567, f_a_r = 0.9993




----------------------> 0.37916666666666665
Evaluated Accuracy: 0.3792
Testing b_a_r = 0.9337, c_a_r = 0.3203, f_a_r = 0.5827




----------------------> 0.7326388888888888
Evaluated Accuracy: 0.7326
Testing b_a_r = 0.9841, c_a_r = 0.0024, f_a_r = 0.4472




----------------------> 0.7397222222222221
Evaluated Accuracy: 0.7397
Testing b_a_r = 0.0133, c_a_r = 0.2290, f_a_r = 0.0242




----------------------> 0.7425
Evaluated Accuracy: 0.7425
Testing b_a_r = 0.4257, c_a_r = 0.1447, f_a_r = 0.7865




----------------------> 0.8341666666666666
Evaluated Accuracy: 0.8342
Testing b_a_r = 0.2627, c_a_r = 0.1712, f_a_r = 0.9894




----------------------> 0.8329166666666666
Evaluated Accuracy: 0.8329
Testing b_a_r = 0.9392, c_a_r = 0.3852, f_a_r = 0.0529




----------------------> 0.7241666666666666
Evaluated Accuracy: 0.7242
Testing b_a_r = 0.7994, c_a_r = 0.8903, f_a_r = 0.0092




----------------------> 0.43249999999999994
Evaluated Accuracy: 0.4325
Testing b_a_r = 0.1077, c_a_r = 0.0334, f_a_r = 0.9507




----------------------> 0.7705555555555555
Evaluated Accuracy: 0.7706
Testing b_a_r = 0.9865, c_a_r = 0.1007, f_a_r = 0.7599




----------------------> 0.8204166666666667
Evaluated Accuracy: 0.8204
Testing b_a_r = 0.9828, c_a_r = 0.3921, f_a_r = 0.9299




----------------------> 0.7033333333333333
Evaluated Accuracy: 0.7033
Testing b_a_r = 0.8475, c_a_r = 0.1549, f_a_r = 0.5164




----------------------> 0.8245833333333333
Evaluated Accuracy: 0.8246
Testing b_a_r = 0.9128, c_a_r = 0.0559, f_a_r = 0.0033




----------------------> 0.747638888888889
Evaluated Accuracy: 0.7476
Testing b_a_r = 0.9935, c_a_r = 0.2200, f_a_r = 0.7709




----------------------> 0.7975
Evaluated Accuracy: 0.7975
Testing b_a_r = 0.8591, c_a_r = 0.6123, f_a_r = 0.9993




----------------------> 0.4266666666666667
Evaluated Accuracy: 0.4267
Testing b_a_r = 0.0007, c_a_r = 0.1724, f_a_r = 0.8439




----------------------> 0.8361111111111109
Evaluated Accuracy: 0.8361

✅ Optimal Values:
   - b_a_r: 0.0067
   - c_a_r: 0.1555
   - f_a_r: 0.7886
📈 Highest Accuracy Achieved: 0.8418


In [84]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_olchiki/resnet_resluts/base/8_w_5_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/secondary/8_w_5_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/amul/8_w_5_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.7114 ± 0.0670  0.7267 ± 0.0709  0.7114 ± 0.0670




---Accu-----pre----rec---------> 0.5714 ± 0.0683  0.5837 ± 0.0774  0.5714 ± 0.0683
---Accu-----pre----rec---------> 0.8418 ± 0.0481  0.8504 ± 0.0503  0.8418 ± 0.0481


In [85]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/olchiki/ressecondary/model_10/model_B_olchiki_10-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [86]:
def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

           
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

            
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [87]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=80
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.8497023809523809
Evaluated Accuracy: 0.8497
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.5840773809523809
Evaluated Accuracy: 0.5841
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.747470238095238
Evaluated Accuracy: 0.7475
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.8492559523809523
Evaluated Accuracy: 0.8493
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7992559523809524
Evaluated Accuracy: 0.7993
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.2921130952380952
Evaluated Accuracy: 0.2921
Testing b_a_r = 0.0000, c_a_r = 0.1302, f_a_r = 0.0000




----------------------> 0.8092261904761905
Evaluated Accuracy: 0.8092
Testing b_a_r = 1.0000, c_a_r = 0.1472, f_a_r = 1.0000




----------------------> 0.846279761904762
Evaluated Accuracy: 0.8463
Testing b_a_r = 0.0000, c_a_r = 0.1558, f_a_r = 1.0000




----------------------> 0.8449404761904762
Evaluated Accuracy: 0.8449
Testing b_a_r = 0.0000, c_a_r = 0.1216, f_a_r = 0.6548




----------------------> 0.8543154761904762
Evaluated Accuracy: 0.8543
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.7342261904761904
Evaluated Accuracy: 0.7342
Testing b_a_r = 1.0000, c_a_r = 0.2963, f_a_r = 1.0000




----------------------> 0.7639880952380952
Evaluated Accuracy: 0.7640
Testing b_a_r = 1.0000, c_a_r = 0.1289, f_a_r = 0.6469




----------------------> 0.8517857142857143
Evaluated Accuracy: 0.8518
Testing b_a_r = 0.0000, c_a_r = 0.0828, f_a_r = 1.0000




----------------------> 0.8720238095238096
Evaluated Accuracy: 0.8720
Testing b_a_r = 0.0000, c_a_r = 0.2320, f_a_r = 0.0000




----------------------> 0.8235119047619047
Evaluated Accuracy: 0.8235
Testing b_a_r = 0.0777, c_a_r = 0.7080, f_a_r = 1.0000




----------------------> 0.4092261904761904
Evaluated Accuracy: 0.4092
Testing b_a_r = 0.0371, c_a_r = 0.2355, f_a_r = 0.5547




----------------------> 0.8541666666666665
Evaluated Accuracy: 0.8542
Testing b_a_r = 0.9193, c_a_r = 0.7861, f_a_r = 0.0011




----------------------> 0.3840773809523809
Evaluated Accuracy: 0.3841
Testing b_a_r = 0.1773, c_a_r = 0.1878, f_a_r = 0.4113




----------------------> 0.8395833333333335
Evaluated Accuracy: 0.8396
Testing b_a_r = 0.3641, c_a_r = 0.2232, f_a_r = 0.9956




----------------------> 0.8416666666666667
Evaluated Accuracy: 0.8417
Testing b_a_r = 0.6898, c_a_r = 0.6019, f_a_r = 0.0045




----------------------> 0.5541666666666668
Evaluated Accuracy: 0.5542
Testing b_a_r = 0.0759, c_a_r = 0.5222, f_a_r = 0.9983




----------------------> 0.5193452380952381
Evaluated Accuracy: 0.5193
Testing b_a_r = 0.4075, c_a_r = 0.0590, f_a_r = 0.9981




----------------------> 0.8556547619047619
Evaluated Accuracy: 0.8557
Testing b_a_r = 0.8523, c_a_r = 0.0962, f_a_r = 0.9682




----------------------> 0.8641369047619047
Evaluated Accuracy: 0.8641
Testing b_a_r = 0.0899, c_a_r = 0.0918, f_a_r = 0.7251




----------------------> 0.8650297619047618
Evaluated Accuracy: 0.8650
Testing b_a_r = 0.1076, c_a_r = 0.0734, f_a_r = 0.2248




----------------------> 0.7985119047619046
Evaluated Accuracy: 0.7985
Testing b_a_r = 0.0408, c_a_r = 0.1023, f_a_r = 0.9482




----------------------> 0.8654761904761905
Evaluated Accuracy: 0.8655
Testing b_a_r = 0.9603, c_a_r = 0.2583, f_a_r = 0.4000




----------------------> 0.8386904761904762
Evaluated Accuracy: 0.8387
Testing b_a_r = 0.9190, c_a_r = 0.0861, f_a_r = 0.8895




----------------------> 0.8654761904761905
Evaluated Accuracy: 0.8655
Testing b_a_r = 0.2298, c_a_r = 0.8881, f_a_r = 0.9997




----------------------> 0.29523809523809524
Evaluated Accuracy: 0.2952
Testing b_a_r = 0.0162, c_a_r = 0.1080, f_a_r = 0.9957




----------------------> 0.8636904761904763
Evaluated Accuracy: 0.8637
Testing b_a_r = 0.2988, c_a_r = 0.2782, f_a_r = 0.0094




----------------------> 0.8123511904761905
Evaluated Accuracy: 0.8124
Testing b_a_r = 0.0638, c_a_r = 0.3998, f_a_r = 0.9936




----------------------> 0.5770833333333334
Evaluated Accuracy: 0.5771
Testing b_a_r = 0.0062, c_a_r = 0.0774, f_a_r = 0.8099




----------------------> 0.8688988095238095
Evaluated Accuracy: 0.8689
Testing b_a_r = 0.9849, c_a_r = 0.2191, f_a_r = 0.6975




----------------------> 0.8602678571428571
Evaluated Accuracy: 0.8603
Testing b_a_r = 0.0305, c_a_r = 0.0003, f_a_r = 0.6188




----------------------> 0.7830357142857144
Evaluated Accuracy: 0.7830
Testing b_a_r = 0.0033, c_a_r = 0.2407, f_a_r = 0.7512




----------------------> 0.8379464285714285
Evaluated Accuracy: 0.8379
Testing b_a_r = 0.2070, c_a_r = 0.0770, f_a_r = 0.9977




----------------------> 0.86875
Evaluated Accuracy: 0.8688
Testing b_a_r = 0.9370, c_a_r = 0.0909, f_a_r = 0.9898




----------------------> 0.8641369047619047
Evaluated Accuracy: 0.8641
Testing b_a_r = 0.9752, c_a_r = 0.2218, f_a_r = 0.3142




----------------------> 0.8401785714285714
Evaluated Accuracy: 0.8402
Testing b_a_r = 0.9856, c_a_r = 0.0859, f_a_r = 0.8223




----------------------> 0.8657738095238096
Evaluated Accuracy: 0.8658
Testing b_a_r = 0.8580, c_a_r = 0.1283, f_a_r = 0.3596




----------------------> 0.8261904761904763
Evaluated Accuracy: 0.8262
Testing b_a_r = 0.8613, c_a_r = 0.1760, f_a_r = 0.0031




----------------------> 0.8144345238095237
Evaluated Accuracy: 0.8144
Testing b_a_r = 0.5843, c_a_r = 0.9837, f_a_r = 0.0008




----------------------> 0.3306547619047619
Evaluated Accuracy: 0.3307
Testing b_a_r = 0.0497, c_a_r = 0.0950, f_a_r = 0.8013




----------------------> 0.8680059523809525
Evaluated Accuracy: 0.8680
Testing b_a_r = 0.9855, c_a_r = 0.2051, f_a_r = 0.5663




----------------------> 0.8641369047619047
Evaluated Accuracy: 0.8641
Testing b_a_r = 0.8771, c_a_r = 0.5941, f_a_r = 0.5338




----------------------> 0.5345238095238095
Evaluated Accuracy: 0.5345
Testing b_a_r = 0.9947, c_a_r = 0.0850, f_a_r = 0.0061




----------------------> 0.8007440476190476
Evaluated Accuracy: 0.8007
Testing b_a_r = 0.9058, c_a_r = 0.3176, f_a_r = 0.6050




----------------------> 0.7566964285714287
Evaluated Accuracy: 0.7567
Testing b_a_r = 0.9777, c_a_r = 0.0834, f_a_r = 0.5506




----------------------> 0.8465773809523808
Evaluated Accuracy: 0.8466
Testing b_a_r = 0.9860, c_a_r = 0.2453, f_a_r = 0.6289




----------------------> 0.8494047619047619
Evaluated Accuracy: 0.8494
Testing b_a_r = 0.0049, c_a_r = 0.9976, f_a_r = 0.9634




----------------------> 0.2848214285714286
Evaluated Accuracy: 0.2848
Testing b_a_r = 0.0445, c_a_r = 0.6964, f_a_r = 0.3529




----------------------> 0.46532738095238096
Evaluated Accuracy: 0.4653
Testing b_a_r = 0.8398, c_a_r = 0.1934, f_a_r = 0.9943




----------------------> 0.8453869047619048
Evaluated Accuracy: 0.8454
Testing b_a_r = 0.0586, c_a_r = 0.1685, f_a_r = 0.5877




----------------------> 0.8476190476190477
Evaluated Accuracy: 0.8476
Testing b_a_r = 0.9476, c_a_r = 0.4715, f_a_r = 0.5466




----------------------> 0.5547619047619047
Evaluated Accuracy: 0.5548
Testing b_a_r = 0.9744, c_a_r = 0.0222, f_a_r = 0.3449




----------------------> 0.7686011904761904
Evaluated Accuracy: 0.7686
Testing b_a_r = 0.0585, c_a_r = 0.2771, f_a_r = 0.2295




----------------------> 0.8200892857142857
Evaluated Accuracy: 0.8201
Testing b_a_r = 0.0625, c_a_r = 0.1345, f_a_r = 0.8400




----------------------> 0.8507440476190475
Evaluated Accuracy: 0.8507
Testing b_a_r = 0.0448, c_a_r = 0.0332, f_a_r = 0.9016




----------------------> 0.8291666666666666
Evaluated Accuracy: 0.8292
Testing b_a_r = 0.3511, c_a_r = 0.5234, f_a_r = 0.0067




----------------------> 0.5458333333333333
Evaluated Accuracy: 0.5458
Testing b_a_r = 0.1652, c_a_r = 0.8777, f_a_r = 0.3185




----------------------> 0.31428571428571433
Evaluated Accuracy: 0.3143
Testing b_a_r = 0.2377, c_a_r = 0.0456, f_a_r = 0.0019




----------------------> 0.7776785714285714
Evaluated Accuracy: 0.7777
Testing b_a_r = 0.9971, c_a_r = 0.2445, f_a_r = 0.0851




----------------------> 0.8136904761904761
Evaluated Accuracy: 0.8137
Testing b_a_r = 0.0130, c_a_r = 0.2054, f_a_r = 0.6395




----------------------> 0.855357142857143
Evaluated Accuracy: 0.8554
Testing b_a_r = 0.0145, c_a_r = 0.0564, f_a_r = 0.5448




----------------------> 0.8325892857142856
Evaluated Accuracy: 0.8326
Testing b_a_r = 0.9546, c_a_r = 0.2453, f_a_r = 0.9346




----------------------> 0.8239583333333332
Evaluated Accuracy: 0.8240
Testing b_a_r = 0.0051, c_a_r = 0.2817, f_a_r = 0.5338




----------------------> 0.8257440476190477
Evaluated Accuracy: 0.8257
Testing b_a_r = 0.9583, c_a_r = 0.6217, f_a_r = 0.9948




----------------------> 0.5053571428571427
Evaluated Accuracy: 0.5054
Testing b_a_r = 0.0889, c_a_r = 0.3638, f_a_r = 0.0035




----------------------> 0.6921130952380953
Evaluated Accuracy: 0.6921
Testing b_a_r = 0.0306, c_a_r = 0.3842, f_a_r = 0.4256




----------------------> 0.644047619047619
Evaluated Accuracy: 0.6440
Testing b_a_r = 0.0214, c_a_r = 0.1759, f_a_r = 0.1715




----------------------> 0.8133928571428571
Evaluated Accuracy: 0.8134
Testing b_a_r = 0.0094, c_a_r = 0.0808, f_a_r = 0.9170




----------------------> 0.8718749999999998
Evaluated Accuracy: 0.8719
Testing b_a_r = 0.9831, c_a_r = 0.1685, f_a_r = 0.4819




----------------------> 0.8464285714285714
Evaluated Accuracy: 0.8464
Testing b_a_r = 0.9547, c_a_r = 0.7836, f_a_r = 0.6203




----------------------> 0.30565476190476193
Evaluated Accuracy: 0.3057
Testing b_a_r = 0.0462, c_a_r = 0.1005, f_a_r = 0.5241




----------------------> 0.8495535714285714
Evaluated Accuracy: 0.8496
Testing b_a_r = 0.9987, c_a_r = 0.1107, f_a_r = 0.7865




----------------------> 0.856547619047619
Evaluated Accuracy: 0.8565
Testing b_a_r = 0.9985, c_a_r = 0.0158, f_a_r = 0.7787




----------------------> 0.8101190476190475
Evaluated Accuracy: 0.8101
Testing b_a_r = 0.9927, c_a_r = 0.2202, f_a_r = 0.5487




----------------------> 0.8642857142857144
Evaluated Accuracy: 0.8643
Testing b_a_r = 0.0380, c_a_r = 0.2683, f_a_r = 0.9678




----------------------> 0.7895833333333333
Evaluated Accuracy: 0.7896

✅ Optimal Values:
   - b_a_r: 0.0000
   - c_a_r: 0.0828
   - f_a_r: 1.0000
📈 Highest Accuracy Achieved: 0.8720


In [88]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_olchiki/resnet_resluts/base/8_w_10_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/secondary/8_w_10_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/amul/8_w_10_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.7259 ± 0.0547  0.7357 ± 0.0563  0.7259 ± 0.0547




---Accu-----pre----rec---------> 0.5624 ± 0.0649  0.5727 ± 0.0628  0.5624 ± 0.0649
---Accu-----pre----rec---------> 0.8720 ± 0.0406  0.8778 ± 0.0412  0.8720 ± 0.0406


In [None]:
###############################################   12 way 1 shot

In [89]:
clss_support_imagesss_10_shot=[
[[5, 47, 72, 9, 32, 44, 13, 55, 77, 0], [13, 10, 72, 9, 71, 23, 18, 20, 46, 64], [10, 69, 21, 5, 38, 11, 7, 64, 59, 4], [16, 18, 21, 29, 43, 9, 19, 66, 63, 57], [69, 40, 4, 60, 23, 0, 30, 74, 17, 35], [48, 35, 36, 64, 0, 26, 24, 71, 8, 76], [0, 52, 22, 77, 9, 59, 29, 21, 75, 53], [17, 8, 69, 11, 44, 73, 77, 0, 59, 46], [31, 41, 17, 37, 16, 54, 40, 39, 25, 42], [19, 65, 62, 13, 5, 67, 4, 69, 76, 0], [44, 74, 53, 28, 58, 15, 43, 55, 34, 50], [29, 63, 21, 27, 39, 67, 33, 71, 1, 19]],
    [[79, 47, 13, 50, 68, 9, 4, 12, 62, 24], [16, 54, 8, 23, 12, 37, 31, 72, 48, 52], [63, 70, 20, 58, 12, 26, 27, 59, 76, 52], [42, 15, 55, 59, 22, 76, 79, 6, 11, 58], [60, 62, 15, 54, 72, 1, 38, 77, 66, 22], [63, 36, 72, 56, 10, 24, 55, 4, 6, 57], [8, 64, 61, 26, 77, 56, 43, 20, 41, 54], [66, 25, 52, 23, 60, 33, 31, 12, 67, 5], [62, 13, 77, 21, 65, 59, 41, 45, 58, 71], [20, 10, 69, 63, 11, 9, 68, 57, 74, 75], [40, 8, 66, 31, 13, 54, 63, 4, 64, 29], [7, 14, 42, 66, 73, 41, 75, 44, 5, 70]],
    [[51, 46, 56, 19, 16, 36, 25, 29, 53, 12], [6, 11, 68, 48, 35, 46, 78, 23, 39, 15], [34, 2, 32, 21, 18, 52, 74, 71, 12, 24], [64, 13, 29, 43, 52, 11, 37, 50, 49, 17], [24, 74, 20, 25, 61, 45, 54, 4, 78, 59], [53, 50, 12, 39, 66, 9, 73, 15, 11, 44], [38, 76, 65, 28, 24, 58, 36, 56, 18, 68], [43, 72, 4, 8, 73, 63, 19, 51, 60, 32], [19, 44, 65, 7, 2, 71, 66, 35, 50, 41], [44, 35, 14, 34, 29, 45, 23, 26, 41, 43], [35, 59, 1, 46, 37, 41, 12, 38, 44, 56], [28, 37, 47, 19, 59, 4, 52, 63, 0, 7]],
    [[79, 60, 39, 74, 70, 48, 8, 11, 66, 33], [6, 74, 64, 9, 27, 70, 47, 19, 24, 63], [48, 62, 69, 51, 55, 41, 6, 59, 50, 18], [4, 54, 72, 66, 20, 69, 62, 58, 6, 53], [28, 20, 46, 52, 72, 11, 73, 9, 22, 57], [13, 33, 62, 59, 72, 37, 76, 5, 16, 22], [67, 11, 70, 58, 50, 6, 5, 17, 38, 28], [16, 27, 41, 2, 0, 17, 8, 1, 32, 71], [70, 52, 0, 67, 74, 23, 62, 51, 55, 63], [27, 46, 79, 43, 44, 67, 55, 72, 58, 54], [9, 48, 37, 68, 56, 50, 61, 27, 47, 77], [15, 37, 63, 69, 58, 70, 19, 62, 17, 78]],
    [[9, 35, 16, 21, 28, 47, 17, 6, 53, 68], [7, 53, 31, 51, 30, 72, 42, 9, 62, 34], [29, 59, 33, 75, 45, 0, 28, 76, 51, 66], [50, 67, 39, 68, 1, 47, 46, 58, 11, 34], [71, 14, 40, 15, 32, 70, 73, 57, 6, 67], [7, 63, 20, 6, 8, 5, 42, 68, 50, 59], [6, 65, 2, 43, 69, 59, 37, 16, 1, 20], [3, 12, 37, 33, 17, 4, 24, 18, 46, 41], [26, 9, 35, 39, 74, 8, 16, 22, 64, 53], [70, 56, 67, 6, 10, 4, 19, 57, 11, 31], [19, 32, 57, 39, 54, 55, 18, 74, 51, 35], [64, 10, 58, 18, 23, 8, 31, 37, 20, 9]],
    [[4, 50, 41, 48, 72, 7, 79, 43, 65, 69], [38, 51, 64, 35, 43, 0, 5, 58, 67, 34], [65, 16, 1, 6, 52, 34, 3, 44, 38, 31], [32, 12, 74, 14, 29, 13, 31, 1, 27, 76], [78, 54, 53, 33, 67, 71, 22, 31, 25, 45], [65, 64, 43, 49, 74, 8, 45, 48, 39, 69], [12, 73, 32, 10, 63, 79, 11, 69, 54, 35], [33, 48, 70, 71, 14, 26, 28, 17, 3, 67], [17, 22, 29, 77, 14, 75, 52, 31, 50, 26], [19, 41, 67, 4, 53, 77, 17, 43, 38, 16], [75, 38, 24, 18, 23, 39, 5, 4, 72, 50], [27, 43, 64, 47, 50, 38, 14, 2, 33, 37]],
    [[47, 13, 68, 62, 58, 39, 24, 26, 45, 74], [3, 2, 76, 47, 7, 17, 24, 10, 38, 59], [5, 10, 20, 4, 1, 38, 24, 64, 3, 57], [39, 18, 64, 20, 78, 56, 76, 28, 36, 24], [35, 38, 48, 18, 8, 25, 75, 69, 46, 79], [45, 36, 60, 77, 4, 27, 10, 29, 68, 37], [35, 58, 31, 79, 28, 74, 71, 15, 69, 11], [29, 21, 70, 9, 26, 51, 6, 31, 76, 65], [33, 4, 20, 49, 9, 2, 32, 26, 59, 64], [25, 19, 52, 22, 0, 4, 10, 2, 73, 71], [5, 9, 54, 1, 14, 45, 25, 7, 37, 17], [0, 39, 53, 31, 78, 37, 27, 9, 71, 23]],
    [[31, 77, 36, 11, 43, 17, 7, 34, 9, 32], [58, 38, 24, 68, 5, 76, 63, 65, 42, 48], [31, 12, 65, 26, 15, 56, 30, 69, 7, 53], [46, 71, 35, 37, 33, 75, 4, 9, 76, 5], [66, 60, 53, 32, 31, 78, 10, 48, 39, 62], [7, 25, 51, 27, 19, 12, 17, 30, 79, 33], [69, 39, 9, 17, 8, 41, 70, 5, 61, 4], [26, 58, 54, 51, 23, 49, 13, 34, 72, 10], [49, 76, 39, 26, 46, 70, 69, 3, 21, 12], [15, 20, 24, 53, 65, 54, 0, 69, 22, 40], [0, 7, 34, 69, 67, 57, 45, 76, 25, 30], [9, 8, 16, 12, 11, 22, 58, 26, 34, 37]],
    [[5, 34, 32, 26, 35, 36, 31, 21, 56, 8], [61, 56, 25, 29, 64, 45, 74, 24, 18, 0], [22, 46, 10, 17, 66, 50, 21, 42, 40, 73], [58, 2, 67, 78, 26, 48, 57, 41, 45, 59], [48, 15, 3, 23, 26, 37, 35, 49, 12, 39], [20, 68, 56, 30, 33, 77, 16, 17, 10, 35], [0, 16, 1, 77, 48, 68, 67, 14, 8, 22], [74, 63, 64, 21, 10, 16, 66, 54, 28, 23], [8, 40, 35, 60, 71, 34, 54, 26, 64, 37], [70, 23, 58, 25, 65, 30, 12, 31, 4, 61], [33, 51, 74, 22, 9, 53, 72, 41, 69, 60], [76, 1, 28, 68, 49, 14, 4, 12, 48, 36]],
    [[46, 29, 28, 79, 53, 70, 38, 15, 37, 17], [38, 43, 31, 51, 28, 54, 62, 59, 21, 49], [19, 51, 49, 30, 45, 15, 64, 46, 56, 75], [79, 9, 53, 35, 73, 44, 22, 75, 68, 63], [69, 46, 53, 58, 21, 2, 44, 43, 11, 41], [70, 71, 44, 45, 16, 18, 5, 53, 7, 38], [78, 30, 0, 43, 56, 22, 11, 29, 21, 23], [62, 76, 5, 47, 3, 40, 2, 68, 34, 65], [68, 56, 40, 64, 42, 8, 52, 73, 63, 2], [24, 4, 28, 66, 60, 48, 70, 72, 79, 23], [67, 24, 70, 77, 16, 43, 14, 64, 26, 17], [79, 6, 18, 4, 63, 55, 77, 58, 47, 54]],
    [[74, 79, 0, 5, 19, 73, 37, 32, 42, 58], [15, 44, 17, 3, 14, 36, 24, 6, 56, 77], [13, 22, 7, 69, 10, 41, 56, 25, 15, 48], [17, 57, 25, 60, 43, 24, 53, 9, 10, 54], [71, 38, 25, 4, 64, 78, 39, 53, 41, 69], [57, 63, 6, 32, 34, 25, 20, 23, 39, 76], [69, 48, 43, 7, 19, 68, 12, 77, 36, 65], [44, 55, 74, 46, 51, 21, 38, 7, 11, 61], [54, 12, 78, 8, 44, 49, 46, 35, 45, 72], [69, 8, 35, 47, 49, 62, 42, 32, 4, 70], [76, 11, 37, 60, 57, 18, 30, 7, 5, 39], [67, 37, 63, 68, 20, 24, 51, 62, 28, 2]],
    [[66, 55, 74, 37, 40, 1, 21, 46, 34, 28], [31, 59, 66, 69, 16, 6, 37, 11, 60, 58], [58, 48, 35, 65, 67, 2, 60, 45, 17, 6], [49, 63, 9, 74, 26, 79, 11, 46, 56, 59], [22, 35, 66, 42, 10, 75, 28, 1, 39, 67], [23, 50, 41, 17, 71, 32, 53, 27, 56, 21], [29, 67, 37, 69, 50, 8, 40, 17, 59, 32], [0, 79, 20, 68, 63, 66, 50, 26, 57, 44], [38, 46, 70, 61, 4, 43, 39, 3, 50, 55], [65, 58, 51, 59, 71, 74, 39, 34, 60, 23], [30, 26, 72, 25, 27, 23, 42, 24, 53, 1], [31, 22, 34, 7, 0, 5, 42, 36, 49, 45]]
    
  ]


clss_support_imagesss_5_shot=[
    [[5, 47, 72, 9, 32], [13, 10, 72, 9, 71], [10, 69, 21, 5, 38], [16, 18, 21, 29, 43], [69, 40, 4, 60, 23], [48, 35, 36, 64, 0], [0, 52, 22, 77, 9], [17, 8, 69, 11, 44], [31, 41, 17, 37, 16], [19, 65, 62, 13, 5], [44, 74, 53, 28, 58], [29, 63, 21, 27, 39]],
    [[79, 47, 13, 50, 68], [16, 54, 8, 23, 12], [63, 70, 20, 58, 12], [42, 15, 55, 59, 22], [60, 62, 15, 54, 72], [63, 36, 72, 56, 10], [8, 64, 61, 26, 77], [66, 25, 52, 23, 60], [62, 13, 77, 21, 65], [20, 10, 69, 63, 11], [40, 8, 66, 31, 13], [7, 14, 42, 66, 73]],
    [[51, 46, 56, 19, 16], [6, 11, 68, 48, 35], [34, 2, 32, 21, 18], [64, 13, 29, 43, 52], [24, 74, 20, 25, 61], [53, 50, 12, 39, 66], [38, 76, 65, 28, 24], [43, 72, 4, 8, 73], [19, 44, 65, 7, 2], [44, 35, 14, 34, 29], [35, 59, 1, 46, 37], [28, 37, 47, 19, 59]],
    [[79, 60, 39, 74, 70], [6, 74, 64, 9, 27], [48, 62, 69, 51, 55], [4, 54, 72, 66, 20], [28, 20, 46, 52, 72], [13, 33, 62, 59, 72], [67, 11, 70, 58, 50], [16, 27, 41, 2, 0], [70, 52, 0, 67, 74], [27, 46, 79, 43, 44], [9, 48, 37, 68, 56], [15, 37, 63, 69, 58]],
    [[9, 35, 16, 21, 28], [7, 53, 31, 51, 30], [29, 59, 33, 75, 45], [50, 67, 39, 68, 1], [71, 14, 40, 15, 32], [7, 63, 20, 6, 8], [6, 65, 2, 43, 69], [3, 12, 37, 33, 17], [26, 9, 35, 39, 74], [70, 56, 67, 6, 10], [19, 32, 57, 39, 54], [64, 10, 58, 18, 23]],
    [[4, 50, 41, 48, 72], [38, 51, 64, 35, 43], [65, 16, 1, 6, 52], [32, 12, 74, 14, 29], [78, 54, 53, 33, 67], [65, 64, 43, 49, 74], [12, 73, 32, 10, 63], [33, 48, 70, 71, 14], [17, 22, 29, 77, 14], [19, 41, 67, 4, 53], [75, 38, 24, 18, 23], [27, 43, 64, 47, 50]],
    [[47, 13, 68, 62, 58], [3, 2, 76, 47, 7], [5, 10, 20, 4, 1], [39, 18, 64, 20, 78], [35, 38, 48, 18, 8], [45, 36, 60, 77, 4], [35, 58, 31, 79, 28], [29, 21, 70, 9, 26], [33, 4, 20, 49, 9], [25, 19, 52, 22, 0], [5, 9, 54, 1, 14], [0, 39, 53, 31, 78]],
    [[31, 77, 36, 11, 43], [58, 38, 24, 68, 5], [31, 12, 65, 26, 15], [46, 71, 35, 37, 33], [66, 60, 53, 32, 31], [7, 25, 51, 27, 19], [69, 39, 9, 17, 8], [26, 58, 54, 51, 23], [49, 76, 39, 26, 46], [15, 20, 24, 53, 65], [0, 7, 34, 69, 67], [9, 8, 16, 12, 11]],
    [[5, 34, 32, 26, 35], [61, 56, 25, 29, 64], [22, 46, 10, 17, 66], [58, 2, 67, 78, 26], [48, 15, 3, 23, 26], [20, 68, 56, 30, 33], [0, 16, 1, 77, 48], [74, 63, 64, 21, 10], [8, 40, 35, 60, 71], [70, 23, 58, 25, 65], [33, 51, 74, 22, 9], [76, 1, 28, 68, 49]],
    [[46, 29, 28, 79, 53], [38, 43, 31, 51, 28], [19, 51, 49, 30, 45], [79, 9, 53, 35, 73], [69, 46, 53, 58, 21], [70, 71, 44, 45, 16], [78, 30, 0, 43, 56], [62, 76, 5, 47, 3], [68, 56, 40, 64, 42], [24, 4, 28, 66, 60], [67, 24, 70, 77, 16], [79, 6, 18, 4, 63]],
    [[74, 79, 0, 5, 19], [15, 44, 17, 3, 14], [13, 22, 7, 69, 10], [17, 57, 25, 60, 43], [71, 38, 25, 4, 64], [57, 63, 6, 32, 34], [69, 48, 43, 7, 19], [44, 55, 74, 46, 51], [54, 12, 78, 8, 44], [69, 8, 35, 47, 49], [76, 11, 37, 60, 57], [67, 37, 63, 68, 20]],
    [[66, 55, 74, 37, 40], [31, 59, 66, 69, 16], [58, 48, 35, 65, 67], [49, 63, 9, 74, 26], [22, 35, 66, 42, 10], [23, 50, 41, 17, 71], [29, 67, 37, 69, 50], [0, 79, 20, 68, 63], [38, 46, 70, 61, 4], [65, 58, 51, 59, 71], [30, 26, 72, 25, 27], [31, 22, 34, 7, 0]],
    
   ]

clss_12=[
    ['Olchiki_2_80_c23', 'Olchiki_2_80_c17', 'Olchiki_2_80_c15', 'Olchiki_2_80_c18', 'Olchiki_2_80_c12', 'Olchiki_2_80_c19', 'Olchiki_2_80_c16', 'Olchiki_2_80_c6', 'Olchiki_2_80_c27', 'Olchiki_2_80_c25', 'Olchiki_2_80_c7', 'Olchiki_2_80_c22'],
    ['Olchiki_2_80_c14', 'Olchiki_2_80_c24', 'Olchiki_2_80_c20', 'Olchiki_2_80_c11', 'Olchiki_2_80_c4', 'Olchiki_2_80_c17', 'Olchiki_2_80_c28', 'Olchiki_2_80_c8', 'Olchiki_2_80_c6', 'Olchiki_2_80_c13', 'Olchiki_2_80_c5', 'Olchiki_2_80_c12'],
    ['Olchiki_2_80_c1', 'Olchiki_2_80_c5', 'Olchiki_2_80_c8', 'Olchiki_2_80_c12', 'Olchiki_2_80_c17', 'Olchiki_2_80_c18', 'Olchiki_2_80_c13', 'Olchiki_2_80_c22', 'Olchiki_2_80_c24', 'Olchiki_2_80_c28', 'Olchiki_2_80_c27', 'Olchiki_2_80_c15'],
    ['Olchiki_2_80_c7', 'Olchiki_2_80_c12', 'Olchiki_2_80_c18', 'Olchiki_2_80_c8', 'Olchiki_2_80_c23', 'Olchiki_2_80_c4', 'Olchiki_2_80_c30', 'Olchiki_2_80_c3', 'Olchiki_2_80_c24', 'Olchiki_2_80_c28', 'Olchiki_2_80_c22', 'Olchiki_2_80_c15'],
    ['Olchiki_2_80_c8', 'Olchiki_2_80_c11', 'Olchiki_2_80_c9', 'Olchiki_2_80_c6', 'Olchiki_2_80_c21', 'Olchiki_2_80_c16', 'Olchiki_2_80_c30', 'Olchiki_2_80_c4', 'Olchiki_2_80_c5', 'Olchiki_2_80_c12', 'Olchiki_2_80_c15', 'Olchiki_2_80_c22'],
    ['Olchiki_2_80_c5', 'Olchiki_2_80_c10', 'Olchiki_2_80_c21', 'Olchiki_2_80_c18', 'Olchiki_2_80_c23', 'Olchiki_2_80_c24', 'Olchiki_2_80_c9', 'Olchiki_2_80_c14', 'Olchiki_2_80_c26', 'Olchiki_2_80_c7', 'Olchiki_2_80_c13', 'Olchiki_2_80_c2'],
    ['Olchiki_2_80_c22', 'Olchiki_2_80_c18', 'Olchiki_2_80_c5', 'Olchiki_2_80_c25', 'Olchiki_2_80_c26', 'Olchiki_2_80_c13', 'Olchiki_2_80_c7', 'Olchiki_2_80_c6', 'Olchiki_2_80_c11', 'Olchiki_2_80_c10', 'Olchiki_2_80_c4', 'Olchiki_2_80_c8'],
    ['Olchiki_2_80_c26', 'Olchiki_2_80_c15', 'Olchiki_2_80_c5', 'Olchiki_2_80_c20', 'Olchiki_2_80_c7', 'Olchiki_2_80_c9', 'Olchiki_2_80_c18', 'Olchiki_2_80_c23', 'Olchiki_2_80_c12', 'Olchiki_2_80_c25', 'Olchiki_2_80_c6', 'Olchiki_2_80_c24'],
    ['Olchiki_2_80_c3', 'Olchiki_2_80_c14', 'Olchiki_2_80_c11', 'Olchiki_2_80_c22', 'Olchiki_2_80_c26', 'Olchiki_2_80_c24', 'Olchiki_2_80_c12', 'Olchiki_2_80_c9', 'Olchiki_2_80_c28', 'Olchiki_2_80_c4', 'Olchiki_2_80_c25', 'Olchiki_2_80_c17'],
     ['Olchiki_2_80_c2', 'Olchiki_2_80_c10', 'Olchiki_2_80_c7', 'Olchiki_2_80_c29', 'Olchiki_2_80_c24', 'Olchiki_2_80_c18', 'Olchiki_2_80_c19', 'Olchiki_2_80_c8', 'Olchiki_2_80_c28', 'Olchiki_2_80_c6', 'Olchiki_2_80_c5', 'Olchiki_2_80_c12'],
    ['Olchiki_2_80_c21', 'Olchiki_2_80_c4', 'Olchiki_2_80_c15', 'Olchiki_2_80_c12', 'Olchiki_2_80_c26', 'Olchiki_2_80_c17', 'Olchiki_2_80_c8', 'Olchiki_2_80_c7', 'Olchiki_2_80_c10', 'Olchiki_2_80_c14', 'Olchiki_2_80_c19', 'Olchiki_2_80_c30'],
     ['Olchiki_2_80_c14', 'Olchiki_2_80_c8', 'Olchiki_2_80_c15', 'Olchiki_2_80_c16', 'Olchiki_2_80_c1', 'Olchiki_2_80_c6', 'Olchiki_2_80_c28', 'Olchiki_2_80_c12', 'Olchiki_2_80_c17', 'Olchiki_2_80_c23', 'Olchiki_2_80_c5', 'Olchiki_2_80_c27'],
    
     ]
clss_support_imagesss_1_shot=[
    [[5], [13], [10], [16], [69], [48], [0], [17], [31], [19], [44], [29]],
    [[79], [16], [63], [42], [60], [63], [8], [66], [62], [20], [40], [7]],
    [[51], [6], [34], [64], [24], [53], [38], [43], [19], [44], [35], [28]],
    [[79], [6], [48], [4], [28], [13], [67], [16], [70], [27], [9], [15]],
         [[9], [7], [29], [50], [71], [7], [6], [3], [26], [70], [19], [64]],
    [[4], [38], [65], [32], [78], [65], [12], [33], [17], [19], [75], [27]],
 [[47], [3], [5], [39], [35], [45], [35], [29], [33], [25], [5], [0]],
    [[31], [58], [31], [46], [66], [7], [69], [26], [49], [15], [0], [9]],
    [[5], [61], [22], [58], [48], [20], [0], [74], [8], [70], [33], [76]],
    [[46], [38], [19], [79], [69], [70], [78], [62], [68], [24], [67], [79]],
    [[74], [15], [13], [17], [71], [57], [69], [44], [54], [69], [76], [67]],
    [[66], [31], [58], [49], [22], [23], [29], [0], [38], [65], [30], [31]]
    
]















In [90]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/olchiki/ressecondary/model_1/model_B_olchiki_1-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')




In [91]:

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

           
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [92]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.5992440225035162
Evaluated Accuracy: 0.5992
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.5094936708860759
Evaluated Accuracy: 0.5095
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.5636427566807314
Evaluated Accuracy: 0.5636
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.67457805907173
Evaluated Accuracy: 0.6746
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.6393284106891703
Evaluated Accuracy: 0.6393
Testing b_a_r = 0.1324, c_a_r = 0.0000, f_a_r = 0.9996




----------------------> 0.6295710267229254
Evaluated Accuracy: 0.6296
Testing b_a_r = 1.0000, c_a_r = 0.0355, f_a_r = 0.3122




----------------------> 0.6195499296765119
Evaluated Accuracy: 0.6195
Testing b_a_r = 1.0000, c_a_r = 0.0010, f_a_r = 0.6808




----------------------> 0.6452180028129396
Evaluated Accuracy: 0.6452
Testing b_a_r = 1.0000, c_a_r = 0.0793, f_a_r = 0.9945




----------------------> 0.6819620253164556
Evaluated Accuracy: 0.6820
Testing b_a_r = 0.0467, c_a_r = 1.0000, f_a_r = 0.9647




----------------------> 0.27786568213783397
Evaluated Accuracy: 0.2779
Testing b_a_r = 0.4083, c_a_r = 0.0007, f_a_r = 0.0883




----------------------> 0.583245428973277
Evaluated Accuracy: 0.5832
Testing b_a_r = 1.0000, c_a_r = 0.6415, f_a_r = 1.0000




----------------------> 0.312851617440225
Evaluated Accuracy: 0.3129
Testing b_a_r = 0.9504, c_a_r = 0.3625, f_a_r = 1.0000




----------------------> 0.518635724331927
Evaluated Accuracy: 0.5186
Testing b_a_r = 0.9768, c_a_r = 0.1764, f_a_r = 0.0000




----------------------> 0.6177039381153305
Evaluated Accuracy: 0.6177
Testing b_a_r = 0.1505, c_a_r = 0.8105, f_a_r = 0.0000




----------------------> 0.2782172995780591
Evaluated Accuracy: 0.2782
Testing b_a_r = 0.0000, c_a_r = 0.0924, f_a_r = 1.0000




----------------------> 0.6815225035161744
Evaluated Accuracy: 0.6815
Testing b_a_r = 0.7393, c_a_r = 0.0789, f_a_r = 1.0000




----------------------> 0.6794127988748242
Evaluated Accuracy: 0.6794
Testing b_a_r = 0.9509, c_a_r = 0.0910, f_a_r = 0.8033




----------------------> 0.6915436005625879
Evaluated Accuracy: 0.6915
Testing b_a_r = 0.0000, c_a_r = 0.1123, f_a_r = 0.5552




----------------------> 0.6996308016877637
Evaluated Accuracy: 0.6996
Testing b_a_r = 0.0000, c_a_r = 0.1139, f_a_r = 0.0000




----------------------> 0.6392405063291139
Evaluated Accuracy: 0.6392
Testing b_a_r = 1.0000, c_a_r = 0.1195, f_a_r = 0.4950




----------------------> 0.6889064697609002
Evaluated Accuracy: 0.6889
Testing b_a_r = 0.7370, c_a_r = 0.5947, f_a_r = 0.0099




----------------------> 0.337289029535865
Evaluated Accuracy: 0.3373
Testing b_a_r = 0.7174, c_a_r = 0.4792, f_a_r = 1.0000




----------------------> 0.45648734177215183
Evaluated Accuracy: 0.4565
Testing b_a_r = 0.0308, c_a_r = 0.1181, f_a_r = 0.7094




----------------------> 0.6880274261603376
Evaluated Accuracy: 0.6880
Testing b_a_r = 0.0120, c_a_r = 0.0928, f_a_r = 0.5693




----------------------> 0.6911040787623065
Evaluated Accuracy: 0.6911
Testing b_a_r = 0.2818, c_a_r = 0.9983, f_a_r = 0.0041




----------------------> 0.26300984528832627
Evaluated Accuracy: 0.2630
Testing b_a_r = 0.9411, c_a_r = 0.1175, f_a_r = 0.9881




----------------------> 0.6779184247538678
Evaluated Accuracy: 0.6779
Testing b_a_r = 0.9475, c_a_r = 0.0984, f_a_r = 0.6941




----------------------> 0.6996308016877636
Evaluated Accuracy: 0.6996
Testing b_a_r = 0.5350, c_a_r = 0.8314, f_a_r = 0.9865




----------------------> 0.28182137834036575
Evaluated Accuracy: 0.2818
Testing b_a_r = 0.3006, c_a_r = 0.2560, f_a_r = 0.0052




----------------------> 0.570675105485232
Evaluated Accuracy: 0.5707
Testing b_a_r = 0.9724, c_a_r = 0.0929, f_a_r = 0.6507




----------------------> 0.6979606188466949
Evaluated Accuracy: 0.6980
Testing b_a_r = 0.9116, c_a_r = 0.1001, f_a_r = 0.5046




----------------------> 0.6930379746835443
Evaluated Accuracy: 0.6930
Testing b_a_r = 0.3103, c_a_r = 0.2681, f_a_r = 0.9936




----------------------> 0.49384669479606186
Evaluated Accuracy: 0.4938
Testing b_a_r = 0.9432, c_a_r = 0.1029, f_a_r = 0.7137




----------------------> 0.699367088607595
Evaluated Accuracy: 0.6994
Testing b_a_r = 0.0378, c_a_r = 0.1030, f_a_r = 0.7432




----------------------> 0.693301687763713
Evaluated Accuracy: 0.6933
Testing b_a_r = 0.9982, c_a_r = 0.1102, f_a_r = 0.6945




----------------------> 0.6964662447257384
Evaluated Accuracy: 0.6965
Testing b_a_r = 0.1510, c_a_r = 0.1378, f_a_r = 0.4190




----------------------> 0.6759845288326299
Evaluated Accuracy: 0.6760
Testing b_a_r = 0.9640, c_a_r = 0.1032, f_a_r = 0.4723




----------------------> 0.68820323488045
Evaluated Accuracy: 0.6882
Testing b_a_r = 0.9758, c_a_r = 0.1218, f_a_r = 0.6786




----------------------> 0.6912798874824192
Evaluated Accuracy: 0.6913
Testing b_a_r = 0.0336, c_a_r = 0.3795, f_a_r = 0.0007




----------------------> 0.5809599156118143
Evaluated Accuracy: 0.5810
Testing b_a_r = 0.0617, c_a_r = 0.0954, f_a_r = 0.7132




----------------------> 0.6926863572433194
Evaluated Accuracy: 0.6927
Testing b_a_r = 0.0041, c_a_r = 0.3990, f_a_r = 0.4708




----------------------> 0.56179676511955
Evaluated Accuracy: 0.5618
Testing b_a_r = 0.8466, c_a_r = 0.6991, f_a_r = 0.0316




----------------------> 0.2957102672292545
Evaluated Accuracy: 0.2957
Testing b_a_r = 0.9363, c_a_r = 0.0492, f_a_r = 0.9993




----------------------> 0.6690400843881856
Evaluated Accuracy: 0.6690
Testing b_a_r = 0.0499, c_a_r = 0.1044, f_a_r = 0.4868




----------------------> 0.6903129395218003
Evaluated Accuracy: 0.6903
Testing b_a_r = 0.0468, c_a_r = 0.1489, f_a_r = 0.9825




----------------------> 0.6362517580872011
Evaluated Accuracy: 0.6363
Testing b_a_r = 0.9766, c_a_r = 0.0837, f_a_r = 0.6581




----------------------> 0.6933895921237693
Evaluated Accuracy: 0.6934
Testing b_a_r = 0.9980, c_a_r = 0.0972, f_a_r = 0.7833




----------------------> 0.6953234880450071
Evaluated Accuracy: 0.6953
Testing b_a_r = 0.0270, c_a_r = 0.1925, f_a_r = 0.3023




----------------------> 0.6308016877637131
Evaluated Accuracy: 0.6308
Testing b_a_r = 0.9757, c_a_r = 0.1277, f_a_r = 0.2777




----------------------> 0.6567334739803093
Evaluated Accuracy: 0.6567

✅ Optimal Values:
   - b_a_r: 0.0000
   - c_a_r: 0.1123
   - f_a_r: 0.5552
📈 Highest Accuracy Achieved: 0.6996


In [93]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_olchiki/resnet_resluts/base/12_w_1_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/secondary/12_w_1_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/amul/12_w_1_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.5839 ± 0.0553  0.6184 ± 0.0668  0.5839 ± 0.0553




---Accu-----pre----rec---------> 0.3717 ± 0.0468  0.3943 ± 0.0665  0.3717 ± 0.0468
---Accu-----pre----rec---------> 0.6996 ± 0.0520  0.7226 ± 0.0524  0.6996 ± 0.0520


In [94]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/olchiki/ressecondary/model_5/model_B_olchiki_5-shot_res.pth',map_location=torch.device('cpu')))


torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [95]:

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

           
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

            
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [96]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.7956481481481482
Evaluated Accuracy: 0.7956
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.5434259259259259
Evaluated Accuracy: 0.5434
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.6575925925925926
Evaluated Accuracy: 0.6576
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.7605555555555555
Evaluated Accuracy: 0.7606
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7197222222222223
Evaluated Accuracy: 0.7197
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.28259259259259256
Evaluated Accuracy: 0.2826
Testing b_a_r = 0.0000, c_a_r = 0.1462, f_a_r = 0.0000




----------------------> 0.7471296296296295
Evaluated Accuracy: 0.7471
Testing b_a_r = 0.0000, c_a_r = 0.1595, f_a_r = 1.0000




----------------------> 0.8089814814814814
Evaluated Accuracy: 0.8090
Testing b_a_r = 0.1311, c_a_r = 0.2340, f_a_r = 1.0000




----------------------> 0.7446296296296296
Evaluated Accuracy: 0.7446
Testing b_a_r = 1.0000, c_a_r = 0.1238, f_a_r = 1.0000




----------------------> 0.8052777777777779
Evaluated Accuracy: 0.8053
Testing b_a_r = 0.9664, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.6887037037037037
Evaluated Accuracy: 0.6887
Testing b_a_r = 0.4127, c_a_r = 0.7104, f_a_r = 0.9920




----------------------> 0.3232407407407407
Evaluated Accuracy: 0.3232
Testing b_a_r = 1.0000, c_a_r = 0.1327, f_a_r = 1.0000




----------------------> 0.8030555555555555
Evaluated Accuracy: 0.8031
Testing b_a_r = 0.0798, c_a_r = 0.1426, f_a_r = 0.6975




----------------------> 0.8203703703703703
Evaluated Accuracy: 0.8204
Testing b_a_r = 0.0383, c_a_r = 0.1101, f_a_r = 0.9715




----------------------> 0.8037962962962965
Evaluated Accuracy: 0.8038
Testing b_a_r = 0.0073, c_a_r = 0.1441, f_a_r = 0.6410




----------------------> 0.8219444444444445
Evaluated Accuracy: 0.8219
Testing b_a_r = 0.2951, c_a_r = 0.4298, f_a_r = 0.9970




----------------------> 0.5656481481481482
Evaluated Accuracy: 0.5656
Testing b_a_r = 0.9469, c_a_r = 0.1303, f_a_r = 0.5488




----------------------> 0.8052777777777779
Evaluated Accuracy: 0.8053
Testing b_a_r = 0.0208, c_a_r = 0.1718, f_a_r = 0.4603




----------------------> 0.7846296296296296
Evaluated Accuracy: 0.7846
Testing b_a_r = 0.0407, c_a_r = 0.1389, f_a_r = 0.8191




----------------------> 0.8167592592592593
Evaluated Accuracy: 0.8168
Testing b_a_r = 0.5100, c_a_r = 0.7752, f_a_r = 0.0029




----------------------> 0.3948148148148148
Evaluated Accuracy: 0.3948
Testing b_a_r = 0.0430, c_a_r = 0.1286, f_a_r = 0.7701




----------------------> 0.8165740740740741
Evaluated Accuracy: 0.8166
Testing b_a_r = 0.0032, c_a_r = 0.0957, f_a_r = 0.3067




----------------------> 0.752685185185185
Evaluated Accuracy: 0.7527
Testing b_a_r = 0.6634, c_a_r = 0.2308, f_a_r = 0.0090




----------------------> 0.7152777777777777
Evaluated Accuracy: 0.7153
Testing b_a_r = 0.9913, c_a_r = 0.1390, f_a_r = 0.7865




----------------------> 0.8089814814814814
Evaluated Accuracy: 0.8090
Testing b_a_r = 0.2797, c_a_r = 0.9974, f_a_r = 0.9984




----------------------> 0.2742592592592593
Evaluated Accuracy: 0.2743
Testing b_a_r = 0.1134, c_a_r = 0.1521, f_a_r = 0.8497




----------------------> 0.8169444444444444
Evaluated Accuracy: 0.8169
Testing b_a_r = 0.0264, c_a_r = 0.1332, f_a_r = 0.8203




----------------------> 0.8148148148148148
Evaluated Accuracy: 0.8148
Testing b_a_r = 0.0174, c_a_r = 0.2952, f_a_r = 0.5726




----------------------> 0.6736111111111112
Evaluated Accuracy: 0.6736
Testing b_a_r = 0.5116, c_a_r = 0.0008, f_a_r = 0.4345




----------------------> 0.7131481481481482
Evaluated Accuracy: 0.7131
Testing b_a_r = 0.5087, c_a_r = 0.6094, f_a_r = 0.0010




----------------------> 0.43064814814814817
Evaluated Accuracy: 0.4306
Testing b_a_r = 0.9538, c_a_r = 0.1444, f_a_r = 0.7631




----------------------> 0.8118518518518517
Evaluated Accuracy: 0.8119
Testing b_a_r = 0.9189, c_a_r = 0.9881, f_a_r = 0.0074




----------------------> 0.3211111111111111
Evaluated Accuracy: 0.3211
Testing b_a_r = 0.0596, c_a_r = 0.1661, f_a_r = 0.9950




----------------------> 0.8058333333333333
Evaluated Accuracy: 0.8058
Testing b_a_r = 0.0251, c_a_r = 0.1107, f_a_r = 0.5879




----------------------> 0.8032407407407408
Evaluated Accuracy: 0.8032
Testing b_a_r = 0.3219, c_a_r = 0.5505, f_a_r = 0.6966




----------------------> 0.367037037037037
Evaluated Accuracy: 0.3670
Testing b_a_r = 0.5992, c_a_r = 0.3259, f_a_r = 0.9937




----------------------> 0.6490740740740741
Evaluated Accuracy: 0.6491
Testing b_a_r = 0.0281, c_a_r = 0.1556, f_a_r = 0.7083




----------------------> 0.8204629629629628
Evaluated Accuracy: 0.8205
Testing b_a_r = 0.9671, c_a_r = 0.0835, f_a_r = 0.0106




----------------------> 0.7337037037037039
Evaluated Accuracy: 0.7337
Testing b_a_r = 0.9963, c_a_r = 0.1602, f_a_r = 0.2086




----------------------> 0.7554629629629629
Evaluated Accuracy: 0.7555
Testing b_a_r = 0.9182, c_a_r = 0.0747, f_a_r = 0.9977




----------------------> 0.763888888888889
Evaluated Accuracy: 0.7639
Testing b_a_r = 0.9353, c_a_r = 0.8482, f_a_r = 0.6845




----------------------> 0.30518518518518517
Evaluated Accuracy: 0.3052
Testing b_a_r = 0.0299, c_a_r = 0.1675, f_a_r = 0.6392




----------------------> 0.8119444444444445
Evaluated Accuracy: 0.8119
Testing b_a_r = 0.9416, c_a_r = 0.3979, f_a_r = 0.5402




----------------------> 0.6849999999999999
Evaluated Accuracy: 0.6850
Testing b_a_r = 0.9542, c_a_r = 0.2351, f_a_r = 0.3896




----------------------> 0.7492592592592593
Evaluated Accuracy: 0.7493
Testing b_a_r = 0.0081, c_a_r = 0.2186, f_a_r = 0.7032




----------------------> 0.7675925925925925
Evaluated Accuracy: 0.7676
Testing b_a_r = 0.9917, c_a_r = 0.0572, f_a_r = 0.4008




----------------------> 0.7444444444444445
Evaluated Accuracy: 0.7444
Testing b_a_r = 0.9800, c_a_r = 0.0996, f_a_r = 0.7820




----------------------> 0.797222222222222
Evaluated Accuracy: 0.7972
Testing b_a_r = 0.0540, c_a_r = 0.8819, f_a_r = 0.0116




----------------------> 0.3552777777777778
Evaluated Accuracy: 0.3553
Testing b_a_r = 0.0199, c_a_r = 0.0477, f_a_r = 0.0729




----------------------> 0.7148148148148148
Evaluated Accuracy: 0.7148

✅ Optimal Values:
   - b_a_r: 0.0073
   - c_a_r: 0.1441
   - f_a_r: 0.6410
📈 Highest Accuracy Achieved: 0.8219


In [97]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_olchiki/resnet_resluts/base/12_w_5_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/secondary/12_w_5_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/amul/12_w_5_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.6856 ± 0.0393  0.6960 ± 0.0462  0.6856 ± 0.0393




---Accu-----pre----rec---------> 0.5194 ± 0.0734  0.5270 ± 0.0659  0.5194 ± 0.0734
---Accu-----pre----rec---------> 0.8219 ± 0.0415  0.8282 ± 0.0432  0.8219 ± 0.0415


In [98]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/olchiki/ressecondary/model_10/model_B_olchiki_10-shot_res.pth',map_location=torch.device('cpu')))


torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [99]:

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

           
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

            
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [100]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.8166666666666668
Evaluated Accuracy: 0.8167
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.5342261904761904
Evaluated Accuracy: 0.5342
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.7042658730158732
Evaluated Accuracy: 0.7043
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.8261904761904763
Evaluated Accuracy: 0.8262
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7756944444444444
Evaluated Accuracy: 0.7757
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.2621031746031746
Evaluated Accuracy: 0.2621
Testing b_a_r = 0.0000, c_a_r = 0.1188, f_a_r = 0.0000




----------------------> 0.780654761904762
Evaluated Accuracy: 0.7807
Testing b_a_r = 1.0000, c_a_r = 0.1373, f_a_r = 1.0000




----------------------> 0.817361111111111
Evaluated Accuracy: 0.8174
Testing b_a_r = 0.0000, c_a_r = 0.1446, f_a_r = 1.0000




----------------------> 0.8189484126984127
Evaluated Accuracy: 0.8189
Testing b_a_r = 0.0000, c_a_r = 0.1128, f_a_r = 0.6408




----------------------> 0.8281746031746035
Evaluated Accuracy: 0.8282
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.1375




----------------------> 0.7182539682539683
Evaluated Accuracy: 0.7183
Testing b_a_r = 1.0000, c_a_r = 0.1124, f_a_r = 0.7035




----------------------> 0.8284722222222222
Evaluated Accuracy: 0.8285
Testing b_a_r = 0.0000, c_a_r = 0.2914, f_a_r = 1.0000




----------------------> 0.7295634920634919
Evaluated Accuracy: 0.7296
Testing b_a_r = 0.0000, c_a_r = 0.0859, f_a_r = 0.8561




----------------------> 0.8365079365079365
Evaluated Accuracy: 0.8365
Testing b_a_r = 0.0000, c_a_r = 0.2214, f_a_r = 0.0000




----------------------> 0.7871031746031746
Evaluated Accuracy: 0.7871
Testing b_a_r = 0.8669, c_a_r = 0.7217, f_a_r = 0.9974




----------------------> 0.31884920634920627
Evaluated Accuracy: 0.3188
Testing b_a_r = 0.4694, c_a_r = 0.0633, f_a_r = 0.9922




----------------------> 0.8297619047619048
Evaluated Accuracy: 0.8298
Testing b_a_r = 0.9193, c_a_r = 0.7861, f_a_r = 0.0011




----------------------> 0.3343253968253968
Evaluated Accuracy: 0.3343
Testing b_a_r = 0.0034, c_a_r = 0.2257, f_a_r = 0.4828




----------------------> 0.828968253968254
Evaluated Accuracy: 0.8290
Testing b_a_r = 0.1543, c_a_r = 0.1812, f_a_r = 0.3899




----------------------> 0.8103174603174602
Evaluated Accuracy: 0.8103
Testing b_a_r = 0.2665, c_a_r = 0.2192, f_a_r = 0.9863




----------------------> 0.8088293650793652
Evaluated Accuracy: 0.8088
Testing b_a_r = 0.4320, c_a_r = 0.6142, f_a_r = 0.0023




----------------------> 0.5030753968253968
Evaluated Accuracy: 0.5031
Testing b_a_r = 0.9943, c_a_r = 0.2577, f_a_r = 0.4408




----------------------> 0.8026785714285714
Evaluated Accuracy: 0.8027
Testing b_a_r = 0.9995, c_a_r = 0.0810, f_a_r = 0.4852




----------------------> 0.8156746031746032
Evaluated Accuracy: 0.8157
Testing b_a_r = 0.3251, c_a_r = 0.5161, f_a_r = 0.9850




----------------------> 0.45049603174603176
Evaluated Accuracy: 0.4505
Testing b_a_r = 0.0326, c_a_r = 0.0981, f_a_r = 0.9665




----------------------> 0.8307539682539682
Evaluated Accuracy: 0.8308
Testing b_a_r = 0.0452, c_a_r = 0.2372, f_a_r = 0.7590




----------------------> 0.8065476190476192
Evaluated Accuracy: 0.8065
Testing b_a_r = 0.9631, c_a_r = 0.0722, f_a_r = 0.9902




----------------------> 0.8296626984126984
Evaluated Accuracy: 0.8297
Testing b_a_r = 0.0042, c_a_r = 0.0874, f_a_r = 0.7966




----------------------> 0.8378968253968254
Evaluated Accuracy: 0.8379
Testing b_a_r = 0.5887, c_a_r = 0.0671, f_a_r = 0.0070




----------------------> 0.7702380952380953
Evaluated Accuracy: 0.7702
Testing b_a_r = 0.0760, c_a_r = 0.2758, f_a_r = 0.0097




----------------------> 0.7671626984126986
Evaluated Accuracy: 0.7672
Testing b_a_r = 0.5288, c_a_r = 0.8979, f_a_r = 0.9995




----------------------> 0.255952380952381
Evaluated Accuracy: 0.2560
Testing b_a_r = 0.3398, c_a_r = 0.3911, f_a_r = 0.9903




----------------------> 0.5190476190476191
Evaluated Accuracy: 0.5190
Testing b_a_r = 0.0388, c_a_r = 0.2229, f_a_r = 0.2807




----------------------> 0.8035714285714285
Evaluated Accuracy: 0.8036
Testing b_a_r = 0.0904, c_a_r = 0.0796, f_a_r = 0.9921




----------------------> 0.8357142857142859
Evaluated Accuracy: 0.8357
Testing b_a_r = 0.0041, c_a_r = 0.0570, f_a_r = 0.9972




----------------------> 0.8294642857142857
Evaluated Accuracy: 0.8295
Testing b_a_r = 0.9498, c_a_r = 0.0865, f_a_r = 0.9638




----------------------> 0.830456349206349
Evaluated Accuracy: 0.8305
Testing b_a_r = 0.0246, c_a_r = 0.0810, f_a_r = 0.7562




----------------------> 0.8370039682539683
Evaluated Accuracy: 0.8370
Testing b_a_r = 0.0564, c_a_r = 0.0661, f_a_r = 0.7631




----------------------> 0.8328373015873015
Evaluated Accuracy: 0.8328
Testing b_a_r = 0.9591, c_a_r = 0.2010, f_a_r = 0.5793




----------------------> 0.8334325396825397
Evaluated Accuracy: 0.8334
Testing b_a_r = 0.9789, c_a_r = 0.0884, f_a_r = 1.0000




----------------------> 0.8300595238095237
Evaluated Accuracy: 0.8301
Testing b_a_r = 0.0102, c_a_r = 0.1449, f_a_r = 0.6177




----------------------> 0.8264880952380952
Evaluated Accuracy: 0.8265
Testing b_a_r = 0.0118, c_a_r = 0.1040, f_a_r = 0.7799




----------------------> 0.8318452380952382
Evaluated Accuracy: 0.8318
Testing b_a_r = 0.9760, c_a_r = 0.0950, f_a_r = 0.7554




----------------------> 0.8327380952380953
Evaluated Accuracy: 0.8327
Testing b_a_r = 0.0021, c_a_r = 0.0480, f_a_r = 0.7863




----------------------> 0.8231150793650793
Evaluated Accuracy: 0.8231
Testing b_a_r = 0.0538, c_a_r = 0.0908, f_a_r = 0.9885




----------------------> 0.8341269841269843
Evaluated Accuracy: 0.8341
Testing b_a_r = 0.0450, c_a_r = 0.0061, f_a_r = 0.5420




----------------------> 0.7648809523809524
Evaluated Accuracy: 0.7649
Testing b_a_r = 0.4008, c_a_r = 0.9991, f_a_r = 0.0006




----------------------> 0.2876984126984127
Evaluated Accuracy: 0.2877
Testing b_a_r = 0.8966, c_a_r = 0.1168, f_a_r = 0.2691




----------------------> 0.7933531746031747
Evaluated Accuracy: 0.7934
Testing b_a_r = 0.0414, c_a_r = 0.1929, f_a_r = 0.6115




----------------------> 0.825595238095238
Evaluated Accuracy: 0.8256

✅ Optimal Values:
   - b_a_r: 0.0042
   - c_a_r: 0.0874
   - f_a_r: 0.7966
📈 Highest Accuracy Achieved: 0.8379


In [101]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_olchiki/resnet_resluts/base/12_w_10_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/secondary/12_w_10_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_olchiki/resnet_resluts/amul/12_w_10_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.7173 ± 0.0372  0.7267 ± 0.0438  0.7173 ± 0.0372




---Accu-----pre----rec---------> 0.5036 ± 0.0423  0.5161 ± 0.0440  0.5036 ± 0.0423
---Accu-----pre----rec---------> 0.8379 ± 0.0539  0.8417 ± 0.0547  0.8379 ± 0.0539
