In [1]:
import torch
import torch.nn as nn
# !pip install easyfsl
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

class PrototypicalNetworks_dynamic_query(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks_dynamic_query, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images,
        support_labels,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        num_classes=len(list(support_labels))
        prototypes = []
        for class_label in range(num_classes):
            # print(type(support_images[class_label]),torch.cat(support_images[class_label].shape)
            support_images_eachclass=(support_images[class_label])
            # print("kk", support_images_eachclass.shape)
            class_features = self.backbone(support_images_eachclass.to(device))  # Select features for the current class
            class_prototype = class_features.mean(dim=0)  # Compute the mean along the batch dimension
            prototypes.append(class_prototype)
        prototypes = torch.stack(prototypes)


        distances=[]
        for each_query_class in query_images:

          query_features = self.backbone(each_query_class)

          # Compute the distance between query features and prototypes
          distance = torch.cdist(query_features, prototypes)
          distances.append(-distance)

        # print(distances.shape)

        return distances


In [2]:
root_path = '/home/asufian/Desktop/output_olchiki/code/asamese'

In [3]:
import os
import random
from collections import defaultdict
from PIL import Image, ImageOps
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class PrototypicalOmniglotDataset(Dataset):
    def __init__(self, root, num_classes=1623, n_shot=5, n_query=10, transform=None):
        self.root = root
        self.num_classes = num_classes
        self.n_shot = n_shot
        self.n_query = n_query
        self.transform = transform
        self.samples_by_label = defaultdict(list)
        self.all_imgs = {}
        self.classes = []

        # Common image file extensions
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp']

        # Organize data into classes and image paths
        for alphabet in os.listdir(self.root):
            alphabet_path = os.path.join(self.root, alphabet)
            if not os.path.isdir(alphabet_path):
                continue

            for char_class in os.listdir(alphabet_path):
                char_class_path = os.path.join(alphabet_path, char_class)

                # Skip if the entry is not a directory
                if not os.path.isdir(char_class_path):
                    continue

                all_images = []
                for img_name in sorted(os.listdir(char_class_path)):
                    img_path = os.path.join(char_class_path, img_name)

                    # Check if it's a file with a recognized image extension
                    if os.path.isfile(img_path) and any(img_name.lower().endswith(ext) for ext in image_extensions):
                        all_images.append(img_path)

                if all_images:
                    char_class_name = f"{alphabet}_{char_class}"
                    self.samples_by_label[char_class_name] = list(range(len(all_images)))
                    self.all_imgs[char_class_name] = all_images
                    self.classes.append(char_class_name)
                    # print(char_class_name)

    def transform_image(self, raw_img):
        img = raw_img#ImageOps.invert(raw_img)
        if self.transform is not None:
            # print(self.transform)
            img = self.transform(img)
        # print(img.shape)
        return img

    def __getitem__(self, selected_classes,selected_supports):#
        # selected_classes = random.sample(self.classes, self.num_classes)
        class_indices = [self.samples_by_label[each_cls] for each_cls in selected_classes]
        # print("selected_classes: ",selected_classes)
        support_set = []
        query_set = []
        label_id = 0
        qs=[]
        ko=0
        ghj=[]
        for idx_set in class_indices:
            # Creating support set
            selected_support =selected_supports[ko]# random.sample(idx_set, self.n_shot)
            ko=ko+1
            ghj.append(selected_support)
            # print("selected_support images: ", ghj)
            support_images = [self.transform_image(Image.open(self.all_imgs[selected_classes[label_id]][each]).convert('L')) for each in selected_support]

            # Creating query set
            selected_query = [item for item in idx_set if item not in selected_support]
            query_images = [self.transform_image(Image.open(self.all_imgs[selected_classes[label_id]][each]).convert('L')) for each in selected_query]
            support_set.append((support_images, [label_id for _ in range(self.n_shot)]))
            # print("aaaaaaaaaa",label_id)
            query_set.append((query_images,  [label_id for _ in range(len(selected_query))]))

            label_id += 1
            # qs= [item for item in a for _ in range(self.n_query)]
        return support_set, query_set

    def __len__(self):
        return 10

# Set up the transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.ToTensor(),
])

# Set up the SiameseOmniglotDataset
# Training == evaluation
# root_path = './assamese'
num_classes = 5  # Set the desired number of classes per episode
n_shot = 5
n_query = 10

PrototypicalOmniglotDatasetLoader = PrototypicalOmniglotDataset(root=root_path, num_classes=num_classes, n_shot=n_shot, n_query=n_query, transform=transform)


In [4]:
import torch
import torch.nn as nn

class PrototypicalNetworks_dynamic_query(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks_dynamic_query, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images,
        support_labels,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        num_classes=len(list(support_labels))
        prototypes = []
        for class_label in range(num_classes):
            # print(type(support_images[class_label]),torch.cat(support_images[class_label].shape)
            support_images_eachclass=(support_images[class_label])
            # print("kk", support_images_eachclass.shape)
            class_features = self.backbone(support_images_eachclass.to(device))  # Select features for the current class
            class_prototype = class_features.mean(dim=0)  # Compute the mean along the batch dimension
            prototypes.append(class_prototype)
        prototypes = torch.stack(prototypes)


        distances=[]
        for each_query_class in query_images:

          query_features = self.backbone(each_query_class)

          # Compute the distance between query features and prototypes
          distance = torch.cdist(query_features, prototypes)
          distances.append(-distance)

        # print(distances.shape)

        return distances


In [5]:
class PrototypicalNetworks33(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks33, self).__init__()
        self.backbone = backbone

    def forward(
        self,
        support_images,
        support_labels,
        query_images: torch.Tensor,
    ) -> torch.Tensor:
        num_classes = len(support_labels)
        prototypes = []
        for class_label in range(num_classes):
            support_images_eachclass = support_images[class_label]
            class_features = torch.stack([self.backbone(each.unsqueeze(0)) for each in support_images_eachclass])
            prototype_bymean = class_features.mean(dim=0)
            # print(prototype_bymean.shape)
            prototypes.append(prototype_bymean)

        query_features = torch.stack([self.backbone(each.unsqueeze(0)) for each in query_images])
        # print(query_features.shape,query_features[0].shape,torch.stack(prototypes).shape,torch.stack(prototypes)[0].shape)
        # distances = -torch.cdist(query_features, torch.stack(prototypes))
        distances = []  # Initialize list to store distances
        prototypes=torch.stack(prototypes)
        for each_query_feature in query_features:
            d = torch.cdist(prototypes, each_query_feature)  # Compute distances
            distances.append(d.squeeze())  # Remove singleton dimension and store
            # print(d.shape)  # Output: [5, 1], shape of each distance matrix

        distances = torch.stack(distances)

        return -distances


In [7]:
import pandas as pd
import os
import csv
def save_in_csv(cls,flnm,conf_matrix, acciuracy,precision_overall, recall_overall, f_betas):
  with open(flnm, 'w', newline='') as csvfile:
      writer = csv.writer(csvfile)
      headers=['Accuracy', 'Precision', 'Recall']
      beta=-5
      values=[]
      values.append(acciuracy)
      values.append(precision_overall)
      values.append(recall_overall)
      for each in f_betas:
              headers.append('F-betas (' +str(beta)+' )')
              beta=beta+1
              values.append(each)

    # Write confusion matrix
      writer.writerow(['used Class:'])
      writer.writerows([cls])
      writer.writerow([])
      writer.writerow([])
      writer.writerow(['Confusion Matrix'])
      writer.writerows(conf_matrix)
      writer.writerow([])
      writer.writerow([])  # Add an empty row for separation

      # Write metrics

      writer.writerow(headers)
      writer.writerow(values)


  # print("Metrics have been written to 'metrics.csv' file.")
def save_in_text_file(msg,cls, flnm, conf_matrix, accuracy, precision_overall, recall_overall, f_betas):
    with open(flnm, 'a') as textfile:
        textfile.write(str("\n\n\n                                 "+msg))
        textfile.write("\n\n\n")
        textfile.write("used Class:\n")
        textfile.write(str(cls) + "\n\n")

        textfile.write("Confusion Matrix:\n")
        for row in conf_matrix:
            textfile.write(' '.join([str(elem) for elem in row]) + "\n")
        textfile.write("\n")

        textfile.write("Metrics:\n")
        textfile.write("Accuracy: {}\n".format(accuracy))
        textfile.write("Precision: {}\n".format(precision_overall))
        textfile.write("Recall: {}\n".format(recall_overall))

        beta = -5
        for f_beta in f_betas:
            textfile.write("F-betas ({}): {}\n".format(beta, f_beta))
            beta += 1

    # print("Metrics have been written to '{}' file.".format(flnm))

def save_in_excel_final(msg, cls, flnm, conf_matrix, accuracy, precision_overall, recall_overall):
    data = {
        "Used Class": [cls],
        "Used support images idx": [msg],
        "Confusion Matrix": [conf_matrix],
        "Accuracy": [accuracy],
        "Precision": [precision_overall],
        "Recall": [recall_overall]
    }
    # for i, f_beta in enumerate(f_betas):
    #     data[f"F-beta ({i - 5})"] = [f_beta]

    df_new = pd.DataFrame(data)
    df_new = df_new.transpose()  # Transpose the DataFrame

    if os.path.exists(flnm):
        df_existing = pd.read_excel(flnm, index_col=0)  # Read existing file
        df_combined = pd.concat([df_existing, df_new], axis=1)  # Concatenate along columns
        df_combined.to_excel(flnm)  # Write to Excel
    else:
        df_new.to_excel(flnm)  # Write new DataFrame to Excel

    # print("Metrics have been written to '{}' file.".format(flnm))


In [8]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch
import torch.nn as nn
import torch.hub

# Define a modified version of DenseNet169 with dropout after each convolutional layer
class DenseNet169WithDropout(nn.Module):
    def __init__(self,modell, pretrained=False, dr=0.4):
        super(DenseNet169WithDropout, self).__init__()
        self.densenet169 = torch.hub.load('pytorch/vision:v0.10.0', modell, pretrained=pretrained)
        self.dropout = nn.Dropout(p=dr)  # Dropout with probability 0.5

        # Modify each dense block to include dropout after each convolutional layer
        for name, module in self.densenet169.features.named_children():
            if isinstance(module, nn.Sequential):
                for sub_name, sub_module in module.named_children():
                    if isinstance(sub_module, nn.Conv2d):
                        setattr(module, sub_name, nn.Sequential(sub_module, self.dropout))

    def forward(self, x):
        return self.densenet169(x)

# Instantiate the modified DenseNet169 model with dropout
# convolutional_network_with_dropout = DenseNet169WithDropout(modell='resnet18',pretrained=False,dr=0.4)

# Example usage:
# output = convolutional_network_with_dropout(input_tensor)

class ResNet18WithDropout(nn.Module):
    def __init__(self, pretrained=False, dr=0.4):
        super(ResNet18WithDropout, self).__init__()
        self.resnet18 = models.resnet18(pretrained=pretrained)
        self.dropout = nn.Dropout(p=dr)

        # Modify each residual block to include dropout after each convolutional layer
        for name, module in self.resnet18.named_children():
            if isinstance(module, nn.Sequential):
                for sub_name, sub_module in module.named_children():
                    if isinstance(sub_module, nn.Conv2d):
                        setattr(module, sub_name, nn.Sequential(sub_module, self.dropout))

    def forward(self, x):
        return self.resnet18(x)

# Instantiate the modified ResNet18 model with dropout
# convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.4)


In [48]:
clss_5=[
    ['assamese1_BXG', 'assamese1_NAA', 'assamese1_KHYA', 'assamese1_CBN', 'assamese1_DHA'],
      ['assamese1_DRA', 'assamese1_OI', 'assamese1_CA', 'assamese1_JHA', 'assamese1_MA'],
     ['assamese1_KHA', 'assamese1_CARI', 'assamese1_CA', 'assamese1_CAY', 'assamese1_TA'],
    ['assamese1_CA', 'assamese1_MNA', 'assamese1_TR', 'assamese1_CAY', 'assamese1_KHYA'],
    ['assamese1_NA', 'assamese1_KTA', 'assamese1_MDA', 'assamese1_MDHA', 'assamese1_MTA'],
     ['assamese1_OI', 'assamese1_MXA', 'assamese1_ATH', 'assamese1_DA', 'assamese1_BHA'],
     ['assamese1_PAC', 'assamese1_ANSR', 'assamese1_E', 'assamese1_TR', 'assamese1_BHA'],
     ['assamese1_CA', 'assamese1_CARI', 'assamese1_A', 'assamese1_AYA', 'assamese1_NIYA'],
    ['assamese1_MDHA', 'assamese1_BHA', 'assamese1_MTA', 'assamese1_EK', 'assamese1_GHA'],
    ['assamese1_ATH', 'assamese1_CAY', 'assamese1_BHA', 'assamese1_TXA', 'assamese1_MDHA'],
    ['assamese1_OU', 'assamese1_AE', 'assamese1_BXG', 'assamese1_JA', 'assamese1_MDHA'],
    ['assamese1_MXA', 'assamese1_ANSR', 'assamese1_CA', 'assamese1_MTA', 'assamese1_THA'],
     ['assamese1_HA', 'assamese1_KHA', 'assamese1_OI', 'assamese1_NG', 'assamese1_MNA'],
    ['assamese1_NG', 'assamese1_MDHA', 'assamese1_KA', 'assamese1_THA', 'assamese1_BA'],
    ['assamese1_UU', 'assamese1_EK', 'assamese1_SUNYA', 'assamese1_CA', 'assamese1_TR'],

     ]

clss_support_imagesss_1_shot=[
    [[35], [11], [7], [4], [12]],
    [[1], [15], [2], [23], [43]],
    [[20], [44], [5], [15], [37]],
    [[37], [33], [44], [6], [17]],
    [[27], [33], [9], [40], [40]],
    [[22], [6], [18], [37], [7]],
    [[9], [40], [34], [8], [1]],
    [[9], [5], [10], [2], [9]],
    [[17], [10], [27], [39], [6]],
    [[28], [38], [38], [26], [43]],
    [[35], [2], [12], [27], [38]],
    [[12], [11], [2], [3], [19]],
     [[39], [1], [2], [2], [25]],
    [[19], [0], [10], [19], [3]],
    [[37], [1], [1], [7], [9]]







]


clss_support_imagesss_5_shot=[[[35, 8, 42, 10, 39], [11, 28, 8, 10, 26], [7, 10, 6, 21, 36], [4, 9, 24, 11, 1], [12, 34, 2, 40, 18]], [[1, 8, 13, 24, 7], [15, 42, 34, 19, 13], [2, 8, 1, 19, 40], [23, 14, 17, 11, 28], [43, 25, 2, 13, 21]], [[20, 30, 9, 22, 13], [44, 21, 13, 20, 36], [5, 20, 15, 41, 19], [15, 40, 33, 25, 14], [37, 15, 27, 7, 41]], [[37, 31, 39, 42, 15], [33, 15, 34, 19, 3], [44, 0, 24, 25, 26], [6, 17, 8, 15, 9], [17, 13, 15, 5, 44]], [[27, 17, 4, 12, 42], [33, 0, 30, 31, 15], [9, 11, 39, 41, 34], [40, 43, 9, 32, 41], [40, 3, 18, 38, 4]], [[22, 8, 10, 25, 15], [6, 28, 5, 3, 31], [18, 32, 23, 38, 4], [37, 1, 24, 13, 30], [7, 11, 38, 32, 5]], [[9, 28, 6, 5, 18], [40, 29, 12, 5, 39], [34, 22, 25, 23, 8], [8, 32, 41, 2, 30], [1, 18, 41, 39, 40]], [[9, 24, 28, 5, 14], [5, 28, 36, 4, 27], [10, 24, 17, 12, 14], [2, 17, 20, 9, 7], [9, 29, 3, 28, 37]], [[17, 6, 31, 39, 36], [10, 34, 8, 37, 14], [27, 38, 14, 15, 28], [39, 5, 38, 4, 2], [6, 26, 27, 31, 43]], [[28, 12, 25, 38, 1], [38, 15, 37, 43, 29], [38, 29, 2, 4, 11], [26, 4, 27, 13, 14], [43, 20, 11, 38, 16]], [[35, 18, 1, 11, 17], [2, 18, 25, 1, 12], [12, 25, 18, 24, 40], [27, 16, 30, 25, 35], [38, 39, 9, 22, 24]], [[12, 2, 27, 4, 17], [11, 16, 36, 20, 8], [2, 28, 17, 32, 26], [3, 20, 28, 22, 4], [19, 1, 9, 0, 11]], [[39, 16, 19, 32, 34], [1, 9, 20, 0, 8], [2, 22, 20, 43, 21], [2, 13, 44, 21, 19], [25, 1, 41, 3, 33]], [[19, 40, 42, 43, 27], [0, 1, 30, 12, 40], [10, 43, 15, 2, 20], [19, 43, 38, 16, 33], [3, 30, 17, 1, 41]], [[37, 11, 25, 28, 35], [1, 3, 35, 4, 33], [1, 31, 25, 32, 13], [7, 3, 10, 11, 35], [9, 19, 1, 22, 8]]]

clss_support_imagesss_10_shot=[[[35, 8, 42, 10, 39, 30, 15, 11, 44, 7], [11, 28, 8, 10, 26, 16, 12, 7, 44, 37], [7, 10, 6, 21, 36, 24, 22, 38, 9, 39], [4, 9, 24, 11, 1, 32, 13, 28, 30, 17], [12, 34, 2, 40, 18, 38, 1, 14, 22, 36]], [[1, 8, 13, 24, 7, 12, 44, 29, 22, 23], [15, 42, 34, 19, 13, 30, 0, 23, 1, 9], [2, 8, 1, 19, 40, 44, 32, 26, 27, 9], [23, 14, 17, 11, 28, 16, 24, 35, 6, 34], [43, 25, 2, 13, 21, 17, 20, 38, 24, 35]], [[20, 30, 9, 22, 13, 36, 18, 21, 3, 37], [44, 21, 13, 20, 36, 33, 22, 0, 34, 12], [5, 20, 15, 41, 19, 10, 44, 12, 28, 26], [15, 40, 33, 25, 14, 44, 19, 32, 24, 6], [37, 15, 27, 7, 41, 9, 43, 33, 35, 23]], [[37, 31, 39, 42, 15, 11, 25, 1, 36, 21], [33, 15, 34, 19, 3, 28, 25, 29, 40, 37], [44, 0, 24, 25, 26, 6, 4, 19, 12, 14], [6, 17, 8, 15, 9, 25, 27, 29, 7, 0], [17, 13, 15, 5, 44, 33, 8, 6, 34, 31]], [[27, 17, 4, 12, 42, 11, 6, 13, 33, 3], [33, 0, 30, 31, 15, 34, 17, 14, 36, 20], [9, 11, 39, 41, 34, 10, 8, 15, 31, 4], [40, 43, 9, 32, 41, 14, 35, 10, 20, 23], [40, 3, 18, 38, 4, 41, 22, 29, 15, 27]], [[22, 8, 10, 25, 15, 14, 29, 2, 41, 24], [6, 28, 5, 3, 31, 2, 32, 40, 16, 13], [18, 32, 23, 38, 4, 2, 34, 22, 16, 36], [37, 1, 24, 13, 30, 5, 36, 27, 39, 35], [7, 11, 38, 32, 5, 15, 8, 41, 10, 2]], [[9, 28, 6, 5, 18, 44, 3, 8, 2, 34], [40, 29, 12, 5, 39, 18, 9, 43, 19, 28], [34, 22, 25, 23, 8, 18, 17, 24, 19, 6], [8, 32, 41, 2, 30, 5, 9, 17, 44, 1], [1, 18, 41, 39, 40, 7, 0, 25, 24, 15]], [[9, 24, 28, 5, 14, 16, 25, 22, 32, 37], [5, 28, 36, 4, 27, 18, 32, 2, 10, 31], [10, 24, 17, 12, 14, 41, 4, 32, 42, 37], [2, 17, 20, 9, 7, 31, 44, 38, 1, 22], [9, 29, 3, 28, 37, 17, 41, 13, 21, 44]], [[17, 6, 31, 39, 36, 8, 11, 24, 10, 7], [10, 34, 8, 37, 14, 16, 29, 44, 31, 33], [27, 38, 14, 15, 28, 23, 7, 42, 35, 41], [39, 5, 38, 4, 2, 42, 6, 22, 21, 9], [6, 26, 27, 31, 43, 41, 10, 12, 36, 32]], [[28, 12, 25, 38, 1, 41, 15, 42, 33, 20], [38, 15, 37, 43, 29, 44, 41, 39, 12, 21], [38, 29, 2, 4, 11, 43, 33, 24, 7, 21], [26, 4, 27, 13, 14, 16, 28, 0, 3, 36], [43, 20, 11, 38, 16, 14, 27, 13, 19, 0]], [[35, 18, 1, 11, 17, 19, 31, 16, 7, 15], [2, 18, 25, 1, 12, 16, 9, 23, 11, 44], [12, 25, 18, 24, 40, 35, 17, 44, 28, 22], [27, 16, 30, 25, 35, 18, 26, 42, 33, 40], [38, 39, 9, 22, 24, 18, 29, 21, 1, 7]], [[12, 2, 27, 4, 17, 42, 37, 3, 18, 10], [11, 16, 36, 20, 8, 4, 42, 28, 44, 7], [2, 28, 17, 32, 26, 3, 37, 16, 10, 14], [3, 20, 28, 22, 4, 43, 17, 9, 0, 11], [19, 1, 9, 0, 11, 7, 27, 15, 18, 6]], [[39, 16, 19, 32, 34, 31, 13, 4, 37, 6], [1, 9, 20, 0, 8, 29, 4, 27, 26, 13], [2, 22, 20, 43, 21, 4, 15, 13, 14, 10], [2, 13, 44, 21, 19, 5, 14, 23, 43, 26], [25, 1, 41, 3, 33, 16, 9, 37, 7, 30]], [[19, 40, 42, 43, 27, 30, 3, 10, 21, 28], [0, 1, 30, 12, 40, 33, 11, 22, 44, 27], [10, 43, 15, 2, 20, 32, 14, 30, 11, 42], [19, 43, 38, 16, 33, 17, 35, 14, 10, 5], [3, 30, 17, 1, 41, 29, 40, 43, 8, 31]], [[37, 11, 25, 28, 35, 13, 5, 9, 31, 6], [1, 3, 35, 4, 33, 8, 41, 26, 12, 37], [1, 31, 25, 32, 13, 14, 22, 9, 26, 19], [7, 3, 10, 11, 35, 13, 41, 26, 43, 20], [9, 19, 1, 22, 8, 17, 40, 3, 32, 14]]]


In [49]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")




convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/asamese_sce/ressecondary/model_1/model_B_olchiki_1-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [25]:
# !ls '/home/asufian/Desktop/output_olchiki/'

In [50]:
criterion = nn.CrossEntropyLoss()

In [51]:
root_path

'/home/asufian/Desktop/output_olchiki/code/asamese'

In [55]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import numpy as np
import numpy as np
import pandas as pd
# root_path = '/content/assamese'
num_classes = 5
n_shot = 1
n_query = 14
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformt_ugh = transforms.Compose([
    transforms.Resize((64, 64)),
    # transforms.RandomResizedCrop(64),  # Randomly crop the image and resize to 224x224
    # transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),  # Randomly adjust brightness, contrast, saturation, and hue
    # transforms.RandomAffine(degrees=0, translate=(0, 0), scale=(0.9, 1.1), shear=0),  # Random zoom (scaling)
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),  # Convert PIL Image to PyTorch Tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize the image with ImageNet statistics
])

print(f"Generated random values - num_classes: {num_classes}, n_shot: {n_shot}, n_query: {n_query}")

test_loader = PrototypicalOmniglotDataset(root=root_path, num_classes=num_classes, n_shot=n_shot, n_query=20, transform=transformt_ugh)

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    precision=0
    recall=0
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            # msg="\n\n       Way: "+str(len(clss_5[tlv_cls]))+"  Shot: "+str(len(clss_support_imagesss_1_shot[tlv_cls][0]))    +" \n\n"
            # msg=str(clss_5[tlv_cls])
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision=precision+precision_overall
            recall=recall+recall_overall
    print("---------------------->",ttlac/ttla,precision/ttla ,recall/ttla)
# %%%%%%%%%%%%%%%%%%%


pp='/home/asufian/Desktop/output_olchiki/faltu'






Generated random values - num_classes: 5, n_shot: 1, n_query: 14


In [56]:
import numpy as np

conf_matrix = np.array([[100, 0, 0, 0, 0],
                        [0, 100, 0, 0, 0],
                        [0, 0, 100, 0, 0],
                        [0, 0, 0, 100, 0],
                        [0, 0, 0, 0, 100]])

def calculate_metrics_get_per(msg, conf_matrix,acciuracy, cls, flnm):
    if conf_matrix.size == 0:
        raise ValueError("Confusion matrix is empty!")

    metrics = []
    for i in range(conf_matrix.shape[0]):
        tp = conf_matrix[i, i]
        fp = np.sum(conf_matrix[:, i]) - tp
        fn = np.sum(conf_matrix[i, :]) - tp

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        metrics.append((precision, recall, f1))

    # Compute overall scores
    precision_overall = np.mean([m[0] for m in metrics])
    recall_overall = np.mean([m[1] for m in metrics])
    f1_overall = np.mean([m[2] for m in metrics])

    # Compute overall accuracy
    total_samples = np.sum(conf_matrix)
    correct_predictions = np.sum(np.diag(conf_matrix))
    accuracy_overall = correct_predictions / total_samples if total_samples > 0 else 0
    # print(accuracy_overall)

    # Save results (Assuming these functions exist)
    save_in_excel_final(msg, cls, flnm + ".xlsx", conf_matrix, accuracy_overall, precision_overall, recall_overall)
    # save_in_text_file(msg, cls, flnm + ".txt", conf_matrix, accuracy_overall, precision_overall, recall_overall, f_betas)

    return precision_overall, recall_overall, f1_overall


# precision, recall, f_beta = calculate_metrics(conf_matrix)
# print(f'Overall Metrics for beta={beta}:')
# print(f'Overall Precision: {precision:.4f}')
# print(f'Overall Recall: {recall:.4f}')
# print(f'Overall F-beta ({f_beta}): ')
# print()


In [57]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

           
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [58]:
def update_model_weights2(model_a, model_b, conv_ratio=0.8, fc_ratio=0.8, bias_ratio=0.8):
    """Updates model_a's weights using a weighted combination of model_a and model_b."""
    for name, param in model_a.named_parameters():
        if name in model_b.state_dict():
            weight_b = model_b.state_dict()[name]
            
            if 'weight' in name:
                # Separate conditions for convolutional and fully connected layers
                if 'conv' in name:  # Convolutional layer weights
                    # print('conv')
                    param.data = conv_ratio * weight_b + (1 - conv_ratio) * param.data
                elif 'fc' in name or 'linear' in name:  # Fully connected layer weights
                    # print('fc')
                    param.data = fc_ratio * weight_b + (1 - fc_ratio) * param.data
            
            elif 'bias' in name:
                param.data = bias_ratio * weight_b + (1 - bias_ratio) * param.data

In [59]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Function to get model accuracy
def get_model_accuracy(b_a_r, c_a_r, f_a_r):
    try:
        print(f"Testing b_a_r = {b_a_r:.4f}, c_a_r = {c_a_r:.4f}, f_a_r = {f_a_r:.4f}")  # Debugging Output

        convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
        convolutional_network_with_dropout.fc = nn.Flatten()
        xyz = convolutional_network_with_dropout
        M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
        M3.load_state_dict(torch.load('model_own_path.pt'))

        # Load Model M2 (Reference)
        convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
        convolutional_network_with_dropout.fc = nn.Flatten()
        xyz2 = convolutional_network_with_dropout
        M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
        M2.load_state_dict(torch.load('model_mu_path.pt'))

        # Apply bias adaptation to M3 using M2
        update_model_weights2(M3, M2, conv_ratio=c_a_r, fc_ratio=f_a_r, bias_ratio=b_a_r)
        
        # Evaluate accuracy
        accuracy = evaluate3(pp+'/_olchiki_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)
        
        if accuracy is None:
            raise ValueError("Accuracy computation failed. Returning worst case.")

        return accuracy  # Ensure accuracy is valid
    except Exception as e:
        print(f"Error encountered: {e}. Assigning worst-case accuracy of 0.")
        return 0.0  # Assign lowest accuracy to prevent optimization failures


In [60]:
criterion = nn.CrossEntropyLoss()


In [62]:

# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.5726996801211145
Evaluated Accuracy: 0.5727
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.5136184193637859
Evaluated Accuracy: 0.5136
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.4681952732557084
Evaluated Accuracy: 0.4682
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.6002879661300291
Evaluated Accuracy: 0.6003
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.5721419363159896
Evaluated Accuracy: 0.5721
Testing b_a_r = 0.0311, c_a_r = 0.0000, f_a_r = 0.5141




----------------------> 0.5709090846480048
Evaluated Accuracy: 0.5709
Testing b_a_r = 1.0000, c_a_r = 0.1083, f_a_r = 1.0000




----------------------> 0.5735935753743974
Evaluated Accuracy: 0.5736
Testing b_a_r = 0.3222, c_a_r = 0.1162, f_a_r = 0.6574




----------------------> 0.5969325311629904
Evaluated Accuracy: 0.5969
Testing b_a_r = 0.9612, c_a_r = 0.7433, f_a_r = 0.0000




----------------------> 0.41121751200719775
Evaluated Accuracy: 0.4112
Testing b_a_r = 0.3582, c_a_r = 1.0000, f_a_r = 0.8974




----------------------> 0.3675424799034791
Evaluated Accuracy: 0.3675
Testing b_a_r = 0.2385, c_a_r = 0.4951, f_a_r = 1.0000




----------------------> 0.4809573137937361
Evaluated Accuracy: 0.4810
Testing b_a_r = 0.2695, c_a_r = 0.0817, f_a_r = 0.0389




----------------------> 0.5460674556887288
Evaluated Accuracy: 0.5461
Testing b_a_r = 0.0157, c_a_r = 0.0000, f_a_r = 1.0000




----------------------> 0.579406386432978
Evaluated Accuracy: 0.5794
Testing b_a_r = 0.0855, c_a_r = 0.0846, f_a_r = 0.5308




----------------------> 0.5933237163051828
Evaluated Accuracy: 0.5933
Testing b_a_r = 0.0000, c_a_r = 0.0761, f_a_r = 0.7687




----------------------> 0.5921074815595364
Evaluated Accuracy: 0.5921
Testing b_a_r = 0.0315, c_a_r = 0.3039, f_a_r = 1.0000




----------------------> 0.39456701146467454
Evaluated Accuracy: 0.3946
Testing b_a_r = 0.9591, c_a_r = 0.1754, f_a_r = 0.1943




----------------------> 0.5873015407906623
Evaluated Accuracy: 0.5873
Testing b_a_r = 0.9842, c_a_r = 0.5646, f_a_r = 0.0000




----------------------> 0.45635671160248115
Evaluated Accuracy: 0.4564
Testing b_a_r = 0.9602, c_a_r = 0.1472, f_a_r = 0.4698




----------------------> 0.5978733094285069
Evaluated Accuracy: 0.5979
Testing b_a_r = 0.2317, c_a_r = 0.6717, f_a_r = 1.0000




----------------------> 0.4245370333847369
Evaluated Accuracy: 0.4245
Testing b_a_r = 0.1413, c_a_r = 0.2208, f_a_r = 0.0000




----------------------> 0.563654069987672
Evaluated Accuracy: 0.5637
Testing b_a_r = 0.9834, c_a_r = 0.9093, f_a_r = 0.0000




----------------------> 0.36453271949645843
Evaluated Accuracy: 0.3645
Testing b_a_r = 0.1299, c_a_r = 0.0460, f_a_r = 0.9771




----------------------> 0.5893760952987382
Evaluated Accuracy: 0.5894
Testing b_a_r = 0.9819, c_a_r = 0.8257, f_a_r = 0.9961




----------------------> 0.4090933822198931
Evaluated Accuracy: 0.4091
Testing b_a_r = 0.2821, c_a_r = 0.1492, f_a_r = 0.0063




----------------------> 0.5775950855483492
Evaluated Accuracy: 0.5776
Testing b_a_r = 0.3190, c_a_r = 0.0397, f_a_r = 0.6622




----------------------> 0.5987700346926776
Evaluated Accuracy: 0.5988
Testing b_a_r = 0.0412, c_a_r = 0.1506, f_a_r = 0.5810




----------------------> 0.6054338963928004
Evaluated Accuracy: 0.6054
Testing b_a_r = 0.0140, c_a_r = 0.1735, f_a_r = 0.4436




----------------------> 0.6081735973758537
Evaluated Accuracy: 0.6082
Testing b_a_r = 0.0903, c_a_r = 0.1470, f_a_r = 0.3623




----------------------> 0.5996956423467302
Evaluated Accuracy: 0.5997
Testing b_a_r = 0.9539, c_a_r = 0.0107, f_a_r = 0.0116




----------------------> 0.5315219510545861
Evaluated Accuracy: 0.5315
Testing b_a_r = 0.0800, c_a_r = 0.2042, f_a_r = 0.4701




----------------------> 0.6051597860461679
Evaluated Accuracy: 0.6052
Testing b_a_r = 0.9655, c_a_r = 0.0469, f_a_r = 0.4356




----------------------> 0.5793967130550531
Evaluated Accuracy: 0.5794
Testing b_a_r = 0.0503, c_a_r = 0.1889, f_a_r = 0.3103




----------------------> 0.6112148447845466
Evaluated Accuracy: 0.6112
Testing b_a_r = 0.9885, c_a_r = 0.2167, f_a_r = 0.3541




----------------------> 0.5660700728226916
Evaluated Accuracy: 0.5661
Testing b_a_r = 0.0006, c_a_r = 0.1421, f_a_r = 0.9657




----------------------> 0.5799504435040291
Evaluated Accuracy: 0.5800
Testing b_a_r = 0.0063, c_a_r = 0.1693, f_a_r = 0.5424




----------------------> 0.6018044387343341
Evaluated Accuracy: 0.6018
Testing b_a_r = 0.0572, c_a_r = 0.1869, f_a_r = 0.2818




----------------------> 0.6103043576532698
Evaluated Accuracy: 0.6103
Testing b_a_r = 0.9051, c_a_r = 0.5750, f_a_r = 0.9848




----------------------> 0.44275333136735234
Evaluated Accuracy: 0.4428
Testing b_a_r = 0.2397, c_a_r = 0.9987, f_a_r = 0.0089




----------------------> 0.33389913515618586
Evaluated Accuracy: 0.3339
Testing b_a_r = 0.9968, c_a_r = 0.1108, f_a_r = 0.4267




----------------------> 0.5766624967833671
Evaluated Accuracy: 0.5767
Testing b_a_r = 0.0072, c_a_r = 0.1677, f_a_r = 0.3281




----------------------> 0.6078815490177296
Evaluated Accuracy: 0.6079
Testing b_a_r = 0.0262, c_a_r = 0.1823, f_a_r = 0.3737




----------------------> 0.6124104617738784
Evaluated Accuracy: 0.6124
Testing b_a_r = 0.9410, c_a_r = 0.4151, f_a_r = 0.9493




----------------------> 0.4515297774121303
Evaluated Accuracy: 0.4515
Testing b_a_r = 0.0012, c_a_r = 0.6618, f_a_r = 0.0028




----------------------> 0.42877260173795223
Evaluated Accuracy: 0.4288
Testing b_a_r = 0.0239, c_a_r = 0.1805, f_a_r = 0.1755




----------------------> 0.6003098674089811
Evaluated Accuracy: 0.6003
Testing b_a_r = 0.0091, c_a_r = 0.1885, f_a_r = 0.5335




----------------------> 0.6060468629767581
Evaluated Accuracy: 0.6060
Testing b_a_r = 0.9867, c_a_r = 0.0233, f_a_r = 0.7621




----------------------> 0.5921226772153445
Evaluated Accuracy: 0.5921
Testing b_a_r = 0.0081, c_a_r = 0.4886, f_a_r = 0.4008




----------------------> 0.5166818310170446
Evaluated Accuracy: 0.5167
Testing b_a_r = 0.0033, c_a_r = 0.0363, f_a_r = 0.7031




----------------------> 0.5996763957682573
Evaluated Accuracy: 0.5997
Testing b_a_r = 0.0001, c_a_r = 0.2041, f_a_r = 0.2975




----------------------> 0.6112382675074054
Evaluated Accuracy: 0.6112

✅ Optimal Values:
   - b_a_r: 0.0262
   - c_a_r: 0.1823
   - f_a_r: 0.3737
📈 Highest Accuracy Achieved: 0.6124


In [63]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_assamese/resnet_resluts/base/5_w_1_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/secondary/5_w_1_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/amul/5_w_1_s_f", test_loader, M3, criterion)


# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
# evaluate3(pp+'/_olchiki_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)



---Accu-----pre----rec---------> 0.5330 ± 0.0913  0.5527 ± 0.0991  0.5330 ± 0.0913




---Accu-----pre----rec---------> 0.3655 ± 0.0640  0.4139 ± 0.0741  0.3655 ± 0.0639
---Accu-----pre----rec---------> 0.6124 ± 0.0742  0.6468 ± 0.0701  0.6125 ± 0.0741


In [21]:
!ls /home/asufian/Desktop/output_olchiki/code/olchiki/model_5/model_B_olchiki_1-shot_30.pth

model_B_olchiki_1-shot_20.pth  model_B_olchiki_1-shot_80.pth
model_B_olchiki_1-shot_30.pth  model_B_olchiki_1-shot_90.pth
model_B_olchiki_1-shot_50.pth


In [None]:
########################  5   way    5  shot    #########################

In [64]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/asamese_sce/ressecondary/model_5/model_B_olchiki_5-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [65]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

           
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

            
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [66]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.7869641574372693
Evaluated Accuracy: 0.7870
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.6160137253431335
Evaluated Accuracy: 0.6160
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.6223353167162512
Evaluated Accuracy: 0.6223
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.777967390851438
Evaluated Accuracy: 0.7780
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7779624490612266
Evaluated Accuracy: 0.7780
Testing b_a_r = 0.1143, c_a_r = 0.9973, f_a_r = 0.8760




----------------------> 0.5450052334641701
Evaluated Accuracy: 0.5450
Testing b_a_r = 0.4116, c_a_r = 0.2243, f_a_r = 1.0000




----------------------> 0.7362672066801669
Evaluated Accuracy: 0.7363
Testing b_a_r = 1.0000, c_a_r = 0.0541, f_a_r = 0.7607




----------------------> 0.7856390326424827
Evaluated Accuracy: 0.7856
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.8706




----------------------> 0.7756274240189337
Evaluated Accuracy: 0.7756
Testing b_a_r = 1.0000, c_a_r = 1.0000, f_a_r = 0.0000




----------------------> 0.5260136586748003
Evaluated Accuracy: 0.5260
Testing b_a_r = 1.0000, c_a_r = 0.1346, f_a_r = 1.0000




----------------------> 0.781627390684767
Evaluated Accuracy: 0.7816
Testing b_a_r = 0.6754, c_a_r = 0.1567, f_a_r = 0.0000




----------------------> 0.6613321166362492
Evaluated Accuracy: 0.6613
Testing b_a_r = 1.0000, c_a_r = 0.1058, f_a_r = 0.8544




----------------------> 0.7926140570180923
Evaluated Accuracy: 0.7926
Testing b_a_r = 0.7971, c_a_r = 0.6057, f_a_r = 1.0000




----------------------> 0.5560020000500013
Evaluated Accuracy: 0.5560
Testing b_a_r = 0.7575, c_a_r = 0.0000, f_a_r = 0.2936




----------------------> 0.7023138411793628
Evaluated Accuracy: 0.7023
Testing b_a_r = 0.0059, c_a_r = 0.1424, f_a_r = 0.6645




----------------------> 0.7766540913522839
Evaluated Accuracy: 0.7767
Testing b_a_r = 0.0000, c_a_r = 0.1107, f_a_r = 0.8586




----------------------> 0.7822872905155961
Evaluated Accuracy: 0.7823
Testing b_a_r = 0.9254, c_a_r = 0.2369, f_a_r = 0.6193




----------------------> 0.7213021492203971
Evaluated Accuracy: 0.7213
Testing b_a_r = 0.0505, c_a_r = 0.7020, f_a_r = 0.0000




----------------------> 0.5626734918372959
Evaluated Accuracy: 0.5627
Testing b_a_r = 0.9931, c_a_r = 0.0734, f_a_r = 1.0000




----------------------> 0.7832957323933097
Evaluated Accuracy: 0.7833
Testing b_a_r = 1.0000, c_a_r = 0.1451, f_a_r = 0.7876




----------------------> 0.78830410760269
Evaluated Accuracy: 0.7883
Testing b_a_r = 0.9725, c_a_r = 0.1034, f_a_r = 0.8487




----------------------> 0.7906157153928848
Evaluated Accuracy: 0.7906
Testing b_a_r = 0.9822, c_a_r = 0.1275, f_a_r = 0.7690




----------------------> 0.786967440852688
Evaluated Accuracy: 0.7870
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.5709




----------------------> 0.7553023492253973
Evaluated Accuracy: 0.7553
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.6926538246789504
Evaluated Accuracy: 0.6927
Testing b_a_r = 0.0000, c_a_r = 0.2161, f_a_r = 0.7697




----------------------> 0.7552923156412242
Evaluated Accuracy: 0.7553
Testing b_a_r = 0.0115, c_a_r = 0.1016, f_a_r = 0.8324




----------------------> 0.7802856488078871
Evaluated Accuracy: 0.7803
Testing b_a_r = 0.9529, c_a_r = 0.0050, f_a_r = 0.7912




----------------------> 0.7809640741018525
Evaluated Accuracy: 0.7810
Testing b_a_r = 0.9245, c_a_r = 0.1535, f_a_r = 0.8490




----------------------> 0.7873041076026902
Evaluated Accuracy: 0.7873
Testing b_a_r = 0.9800, c_a_r = 0.7621, f_a_r = 0.5515




----------------------> 0.5583684508779386
Evaluated Accuracy: 0.5584
Testing b_a_r = 0.9627, c_a_r = 0.0904, f_a_r = 0.5772




----------------------> 0.7569689492237306
Evaluated Accuracy: 0.7570
Testing b_a_r = 0.9903, c_a_r = 0.0673, f_a_r = 0.8988




----------------------> 0.7873040742685234
Evaluated Accuracy: 0.7873
Testing b_a_r = 0.0618, c_a_r = 0.0797, f_a_r = 0.9947




----------------------> 0.7772906906005983
Evaluated Accuracy: 0.7773
Testing b_a_r = 0.8575, c_a_r = 0.4146, f_a_r = 0.9912




----------------------> 0.6839906080985357
Evaluated Accuracy: 0.6840
Testing b_a_r = 0.4820, c_a_r = 0.8567, f_a_r = 0.9993




----------------------> 0.5537234930873272
Evaluated Accuracy: 0.5537
Testing b_a_r = 0.1566, c_a_r = 0.8605, f_a_r = 0.0002




----------------------> 0.5603852679650324
Evaluated Accuracy: 0.5604
Testing b_a_r = 0.5992, c_a_r = 0.3259, f_a_r = 0.9937




----------------------> 0.5743135911731126
Evaluated Accuracy: 0.5743
Testing b_a_r = 0.0782, c_a_r = 0.4886, f_a_r = 1.0000




----------------------> 0.6576606081818711
Evaluated Accuracy: 0.6577
Testing b_a_r = 0.9258, c_a_r = 0.1800, f_a_r = 0.5464




----------------------> 0.7593174079351984
Evaluated Accuracy: 0.7593
Testing b_a_r = 0.1006, c_a_r = 0.5691, f_a_r = 0.0126




----------------------> 0.583337075093544
Evaluated Accuracy: 0.5833
Testing b_a_r = 0.9905, c_a_r = 0.0320, f_a_r = 0.9524




----------------------> 0.7809574239355984
Evaluated Accuracy: 0.7810
Testing b_a_r = 0.4942, c_a_r = 0.2403, f_a_r = 0.0017




----------------------> 0.6526302574231023
Evaluated Accuracy: 0.6526
Testing b_a_r = 0.9090, c_a_r = 0.1791, f_a_r = 0.9977




----------------------> 0.7792874571864296
Evaluated Accuracy: 0.7793
Testing b_a_r = 0.9407, c_a_r = 0.4453, f_a_r = 0.6277




----------------------> 0.6863506754335524
Evaluated Accuracy: 0.6864
Testing b_a_r = 0.7125, c_a_r = 0.0769, f_a_r = 0.0009




----------------------> 0.6869988333041661
Evaluated Accuracy: 0.6870
Testing b_a_r = 0.1618, c_a_r = 0.7267, f_a_r = 0.9954




----------------------> 0.5463617173762678
Evaluated Accuracy: 0.5464
Testing b_a_r = 0.0395, c_a_r = 0.1744, f_a_r = 0.8791




----------------------> 0.7889674741868546
Evaluated Accuracy: 0.7890
Testing b_a_r = 0.9773, c_a_r = 0.5295, f_a_r = 0.5246




----------------------> 0.609652132969991
Evaluated Accuracy: 0.6097
Testing b_a_r = 0.8827, c_a_r = 0.0977, f_a_r = 0.9663




----------------------> 0.7879523988099703
Evaluated Accuracy: 0.7880
Testing b_a_r = 0.0225, c_a_r = 0.1705, f_a_r = 0.7117




----------------------> 0.7866574747702025
Evaluated Accuracy: 0.7867

✅ Optimal Values:
   - b_a_r: 1.0000
   - c_a_r: 0.1058
   - f_a_r: 0.8544
📈 Highest Accuracy Achieved: 0.7926


In [None]:
# b_a_r: 
#    - c_a_r: 
#    - f_a_r: 

#  pt=20

In [67]:

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_assamese/resnet_resluts/base/5_w_5_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/secondary/5_w_5_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/amul/5_w_5_s_f", test_loader, M3, criterion)




# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
# evaluate3(pp+'/_olchiki_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)



---Accu-----pre----rec---------> 0.6927 ± 0.0745  0.7005 ± 0.0729  0.6927 ± 0.0745




---Accu-----pre----rec---------> 0.5787 ± 0.0772  0.5853 ± 0.0762  0.5787 ± 0.0771
---Accu-----pre----rec---------> 0.7926 ± 0.0715  0.8073 ± 0.0675  0.7927 ± 0.0713


In [33]:
!ls /home/asufian/Desktop/output_olchiki/code/olchiki/model_10

model_B_olchiki_1-shot_20.pth  model_B_olchiki_1-shot_70.pth
model_B_olchiki_1-shot_40.pth  model_B_olchiki_1-shot_90.pth
model_B_olchiki_1-shot_50.pth


In [68]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/asamese_sce/ressecondary/model_10/model_B_olchiki_10-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [70]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

           
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_5)):
            support_set, query_set = data_loader.__getitem__(clss_5[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

            
            clssa=str(clss_5[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_5[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [71]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.7611284768870975
Evaluated Accuracy: 0.7611
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.5592098571926158
Evaluated Accuracy: 0.5592
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.6091256904015524
Evaluated Accuracy: 0.6091
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.7695269443200478
Evaluated Accuracy: 0.7695
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7432344379758172
Evaluated Accuracy: 0.7432
Testing b_a_r = 0.3930, c_a_r = 1.0000, f_a_r = 1.0000




----------------------> 0.4003890879235707
Evaluated Accuracy: 0.4004
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.3125




----------------------> 0.7219381250932976
Evaluated Accuracy: 0.7219
Testing b_a_r = 0.3232, c_a_r = 0.1123, f_a_r = 1.0000




----------------------> 0.7573277106035726
Evaluated Accuracy: 0.7573
Testing b_a_r = 0.1969, c_a_r = 0.3437, f_a_r = 1.0000




----------------------> 0.5908444046375081
Evaluated Accuracy: 0.5908
Testing b_a_r = 1.0000, c_a_r = 0.1238, f_a_r = 1.0000




----------------------> 0.7584661641040952
Evaluated Accuracy: 0.7585
Testing b_a_r = 0.9209, c_a_r = 0.6889, f_a_r = 0.0000




----------------------> 0.4678112404836543
Evaluated Accuracy: 0.4678
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 1.0000




----------------------> 0.7432192615813304
Evaluated Accuracy: 0.7432
Testing b_a_r = 1.0000, c_a_r = 0.1197, f_a_r = 0.6101




----------------------> 0.7862997462307807
Evaluated Accuracy: 0.7863
Testing b_a_r = 0.4679, c_a_r = 0.1303, f_a_r = 0.0000




----------------------> 0.7436242722794446
Evaluated Accuracy: 0.7436
Testing b_a_r = 0.9952, c_a_r = 0.1452, f_a_r = 0.4260




----------------------> 0.7683949594466836
Evaluated Accuracy: 0.7684
Testing b_a_r = 0.6409, c_a_r = 1.0000, f_a_r = 0.0000




----------------------> 0.4415081853012887
Evaluated Accuracy: 0.4415
Testing b_a_r = 0.0631, c_a_r = 0.0897, f_a_r = 0.4180




----------------------> 0.7646051898293277
Evaluated Accuracy: 0.7646
Testing b_a_r = 0.6714, c_a_r = 0.6368, f_a_r = 0.9995




----------------------> 0.4323738866497487
Evaluated Accuracy: 0.4324
Testing b_a_r = 0.6492, c_a_r = 0.0090, f_a_r = 0.0095




----------------------> 0.7028904562870079
Evaluated Accuracy: 0.7029
Testing b_a_r = 0.4451, c_a_r = 0.1224, f_a_r = 0.7440




----------------------> 0.7649466835846146
Evaluated Accuracy: 0.7649
Testing b_a_r = 0.0123, c_a_r = 0.1373, f_a_r = 0.5519




----------------------> 0.7790594864905209
Evaluated Accuracy: 0.7791
Testing b_a_r = 0.0002, c_a_r = 0.1830, f_a_r = 0.4031




----------------------> 0.7554446683584614
Evaluated Accuracy: 0.7554
Testing b_a_r = 0.9836, c_a_r = 0.0859, f_a_r = 0.6356




----------------------> 0.7836223316912974
Evaluated Accuracy: 0.7836
Testing b_a_r = 0.9891, c_a_r = 0.0988, f_a_r = 0.5564




----------------------> 0.7786807981290739
Evaluated Accuracy: 0.7787
Testing b_a_r = 0.0296, c_a_r = 0.1104, f_a_r = 0.6652




----------------------> 0.7767760113449769
Evaluated Accuracy: 0.7768
Testing b_a_r = 0.9474, c_a_r = 0.1300, f_a_r = 0.4571




----------------------> 0.7763928695825248
Evaluated Accuracy: 0.7764
Testing b_a_r = 0.9587, c_a_r = 0.0805, f_a_r = 0.7921




----------------------> 0.7691351445489377
Evaluated Accuracy: 0.7691
Testing b_a_r = 0.9975, c_a_r = 0.1007, f_a_r = 0.4351




----------------------> 0.7714514852963128
Evaluated Accuracy: 0.7715
Testing b_a_r = 0.9791, c_a_r = 0.1914, f_a_r = 0.0077




----------------------> 0.7226847788227099
Evaluated Accuracy: 0.7227
Testing b_a_r = 0.8824, c_a_r = 0.1939, f_a_r = 0.9969




----------------------> 0.7447496890083097
Evaluated Accuracy: 0.7447
Testing b_a_r = 0.9892, c_a_r = 0.1411, f_a_r = 0.6940




----------------------> 0.7748644573816986
Evaluated Accuracy: 0.7749
Testing b_a_r = 0.0506, c_a_r = 0.0507, f_a_r = 0.6160




----------------------> 0.7664619843757774
Evaluated Accuracy: 0.7665
Testing b_a_r = 0.9718, c_a_r = 0.0664, f_a_r = 0.9993




----------------------> 0.7542887744439467
Evaluated Accuracy: 0.7543
Testing b_a_r = 0.0057, c_a_r = 0.1394, f_a_r = 0.6873




----------------------> 0.7752542419266556
Evaluated Accuracy: 0.7753
Testing b_a_r = 0.3170, c_a_r = 0.8282, f_a_r = 0.9919




----------------------> 0.4106724884311091
Evaluated Accuracy: 0.4107
Testing b_a_r = 0.9552, c_a_r = 0.0826, f_a_r = 0.2872




----------------------> 0.7512760362243122
Evaluated Accuracy: 0.7513
Testing b_a_r = 0.7598, c_a_r = 0.0005, f_a_r = 0.6893




----------------------> 0.7618947604120018
Evaluated Accuracy: 0.7619
Testing b_a_r = 0.9998, c_a_r = 0.2212, f_a_r = 0.5878




----------------------> 0.740180375180375
Evaluated Accuracy: 0.7402
Testing b_a_r = 0.9722, c_a_r = 0.1395, f_a_r = 0.5499




----------------------> 0.7817195850126883
Evaluated Accuracy: 0.7817
Testing b_a_r = 0.7855, c_a_r = 0.5333, f_a_r = 0.5058




----------------------> 0.49903649798477384
Evaluated Accuracy: 0.4990
Testing b_a_r = 0.8084, c_a_r = 0.1517, f_a_r = 0.2143




----------------------> 0.7504901975419216
Evaluated Accuracy: 0.7505
Testing b_a_r = 0.1874, c_a_r = 0.8532, f_a_r = 0.4010




----------------------> 0.4190469473055681
Evaluated Accuracy: 0.4190
Testing b_a_r = 0.9623, c_a_r = 0.1494, f_a_r = 0.5629




----------------------> 0.7813342289894013
Evaluated Accuracy: 0.7813
Testing b_a_r = 0.9883, c_a_r = 0.0851, f_a_r = 0.5930




----------------------> 0.7813366174055828
Evaluated Accuracy: 0.7813
Testing b_a_r = 0.0552, c_a_r = 0.1281, f_a_r = 0.5638




----------------------> 0.7763884659401901
Evaluated Accuracy: 0.7764
Testing b_a_r = 0.9975, c_a_r = 0.0544, f_a_r = 0.5322




----------------------> 0.76763096482062
Evaluated Accuracy: 0.7676
Testing b_a_r = 0.9963, c_a_r = 0.1236, f_a_r = 0.6114




----------------------> 0.7843883913021844
Evaluated Accuracy: 0.7844
Testing b_a_r = 0.9212, c_a_r = 0.0002, f_a_r = 0.8169




----------------------> 0.7508449271035478
Evaluated Accuracy: 0.7508
Testing b_a_r = 0.9785, c_a_r = 0.1568, f_a_r = 0.5808




----------------------> 0.7824705677464298
Evaluated Accuracy: 0.7825
Testing b_a_r = 0.9941, c_a_r = 0.1123, f_a_r = 0.6539




----------------------> 0.7866742051052396
Evaluated Accuracy: 0.7867

✅ Optimal Values:
   - b_a_r: 0.9941
   - c_a_r: 0.1123
   - f_a_r: 0.6539
📈 Highest Accuracy Achieved: 0.7867


In [72]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_assamese/resnet_resluts/base/5_w_10_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/secondary/5_w_10_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/amul/5_w_5_10_f", test_loader, M3, criterion)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
# evaluate3(pp+'/_olchiki_5-wau-1-shot_matchine_unlearning_own', test_loader, M3, criterion)



---Accu-----pre----rec---------> 0.7033 ± 0.0753  0.7140 ± 0.0701  0.7033 ± 0.0752




---Accu-----pre----rec---------> 0.4964 ± 0.0800  0.5000 ± 0.0754  0.4964 ± 0.0799
---Accu-----pre----rec---------> 0.7867 ± 0.0592  0.7975 ± 0.0538  0.7868 ± 0.0591


In [50]:
ls /home/asufian/Desktop/output_olchiki/lol

5_w_10_s.txt  5_w_10_s.xlsx


In [75]:
clss_8=[
       ['assamese1_DXA', 'assamese1_SUNYA', 'assamese1_KTA', 'assamese1_CARI', 'assamese1_AA', 'assamese1_ATH', 'assamese1_DUI', 'assamese1_MNA'],
     ['assamese1_BA', 'assamese1_OU', 'assamese1_NA', 'assamese1_TINI', 'assamese1_MDHA', 'assamese1_E', 'assamese1_A', 'assamese1_TXA'],
    ['assamese1_ATH', 'assamese1_CA', 'assamese1_NG', 'assamese1_EK', 'assamese1_ANSR', 'assamese1_AA', 'assamese1_KTA', 'assamese1_DHA'],
    ['assamese1_KA', 'assamese1_XAT', 'assamese1_E', 'assamese1_MA', 'assamese1_CCA', 'assamese1_MTA', 'assamese1_A', 'assamese1_AE'],
    ['assamese1_TINI', 'assamese1_KHYA', 'assamese1_NAA', 'assamese1_HA', 'assamese1_DUI', 'assamese1_CBN', 'assamese1_TR', 'assamese1_AE'],
     ['assamese1_DA', 'assamese1_SUNYA', 'assamese1_AJA', 'assamese1_ANSR', 'assamese1_PHA', 'assamese1_NAA', 'assamese1_MDHA', 'assamese1_NG'],
    ['assamese1_DUI', 'assamese1_E', 'assamese1_AYA', 'assamese1_MA', 'assamese1_THA', 'assamese1_MTA', 'assamese1_ATH', 'assamese1_NIYA'],
    ['assamese1_MA', 'assamese1_AYA', 'assamese1_NIYA', 'assamese1_KTA', 'assamese1_ATH', 'assamese1_PAC', 'assamese1_KHYA', 'assamese1_DUI'],
     ['assamese1_DXA', 'assamese1_CAY', 'assamese1_MNA', 'assamese1_ANSR', 'assamese1_AA', 'assamese1_TXA', 'assamese1_GA', 'assamese1_ATH'],
    ['assamese1_RA', 'assamese1_ATH', 'assamese1_CARI', 'assamese1_NG', 'assamese1_DA', 'assamese1_REE', 'assamese1_AJA', 'assamese1_A'],
    ['assamese1_KHA', 'assamese1_THA', 'assamese1_AA', 'assamese1_MNA', 'assamese1_SUNYA', 'assamese1_NIYA', 'assamese1_LA', 'assamese1_MDHA'],
     ['assamese1_KA', 'assamese1_AYA', 'assamese1_LA', 'assamese1_TINI', 'assamese1_MA', 'assamese1_KTA', 'assamese1_MDHA', 'assamese1_PA'],
    ['assamese1_DHRA', 'assamese1_CCA', 'assamese1_DXA', 'assamese1_TXA', 'assamese1_SUNYA', 'assamese1_ANSR', 'assamese1_HA', 'assamese1_OU'],
    ['assamese1_NG', 'assamese1_JHA', 'assamese1_GHA', 'assamese1_MTHA', 'assamese1_ANSR', 'assamese1_ATH', 'assamese1_PA', 'assamese1_MDHA'],
    ['assamese1_ATH', 'assamese1_CAY', 'assamese1_DUI', 'assamese1_AJA', 'assamese1_AA', 'assamese1_THA', 'assamese1_BHA', 'assamese1_HA'],
    ]

clss_support_imagesss_10_shot=[[[24, 30, 13, 26, 3, 2, 42, 12, 29, 33], [16, 31, 30, 15, 22, 33, 6, 41, 0, 12], [7, 31, 42, 1, 30, 35, 20, 39, 0, 3], [39, 41, 6, 34, 16, 36, 17, 1, 43, 44], [42, 5, 38, 41, 7, 37, 6, 1, 28, 2], [2, 10, 15, 24, 38, 5, 29, 18, 9, 32], [34, 38, 18, 5, 25, 14, 19, 29, 0, 2], [7, 13, 26, 43, 4, 41, 8, 2, 18, 30]], [[14, 31, 35, 11, 44, 40, 9, 20, 12, 24], [43, 35, 29, 0, 10, 24, 34, 27, 11, 38], [40, 42, 34, 41, 12, 31, 15, 10, 44, 5], [21, 29, 4, 10, 19, 34, 5, 2, 41, 0], [18, 25, 11, 5, 43, 8, 19, 13, 42, 39], [18, 41, 2, 3, 33, 6, 21, 36, 32, 23], [44, 1, 16, 29, 35, 26, 19, 18, 36, 37], [36, 12, 22, 44, 29, 1, 2, 39, 34, 35]], [[10, 26, 40, 36, 38, 16, 2, 39, 6, 11], [41, 23, 19, 24, 6, 34, 13, 35, 22, 44], [29, 0, 40, 32, 41, 3, 11, 34, 39, 13], [40, 6, 0, 24, 14, 41, 37, 25, 27, 36], [36, 35, 32, 42, 7, 19, 33, 13, 17, 22], [13, 7, 12, 44, 19, 39, 40, 37, 29, 36], [17, 19, 23, 8, 31, 18, 32, 26, 22, 34], [19, 17, 0, 13, 28, 32, 20, 22, 1, 24]], [[23, 2, 25, 26, 12, 4, 34, 6, 13, 19], [15, 14, 24, 1, 26, 35, 28, 5, 13, 34], [28, 30, 25, 41, 33, 37, 20, 4, 3, 15], [39, 13, 41, 6, 35, 38, 29, 20, 14, 19], [38, 30, 10, 12, 3, 40, 25, 34, 0, 27], [0, 16, 27, 13, 40, 43, 41, 1, 39, 9], [27, 42, 35, 38, 40, 20, 15, 2, 26, 6], [14, 22, 29, 10, 37, 33, 28, 27, 43, 13]], [[7, 34, 0, 37, 21, 33, 44, 2, 24, 26], [28, 21, 36, 32, 33, 9, 1, 13, 5, 29], [28, 30, 18, 8, 27, 13, 22, 36, 33, 4], [32, 22, 36, 42, 0, 11, 39, 7, 37, 38], [16, 11, 40, 14, 38, 7, 33, 28, 26, 20], [36, 44, 43, 12, 38, 9, 13, 34, 30, 5], [22, 30, 25, 4, 44, 40, 16, 21, 36, 18], [11, 12, 26, 9, 29, 20, 37, 19, 30, 0]], [[37, 8, 41, 16, 27, 2, 13, 39, 38, 24], [10, 43, 42, 5, 8, 16, 34, 23, 9, 1], [13, 11, 41, 44, 32, 17, 3, 9, 34, 28], [36, 1, 4, 18, 24, 42, 22, 17, 10, 15], [37, 35, 40, 12, 41, 28, 22, 18, 4, 17], [8, 33, 15, 1, 3, 0, 17, 44, 20, 12], [4, 19, 20, 24, 12, 15, 6, 14, 10, 17], [21, 13, 38, 24, 9, 26, 22, 23, 41, 6]], [[42, 30, 16, 39, 41, 8, 23, 37, 9, 12], [13, 3, 29, 10, 25, 38, 40, 14, 11, 5], [36, 26, 14, 27, 6, 0, 38, 18, 42, 44], [20, 38, 33, 6, 36, 8, 40, 4, 18, 44], [30, 26, 38, 24, 8, 14, 0, 6, 19, 37], [8, 27, 34, 17, 43, 23, 11, 30, 21, 44], [12, 36, 35, 25, 39, 43, 29, 44, 10, 21], [12, 7, 22, 0, 36, 8, 28, 29, 15, 26]], [[40, 8, 43, 38, 23, 3, 24, 12, 25, 22], [27, 20, 14, 31, 9, 11, 42, 23, 5, 38], [22, 12, 27, 1, 17, 38, 41, 21, 13, 2], [6, 16, 8, 42, 3, 7, 35, 22, 44, 21], [2, 12, 30, 44, 4, 10, 38, 6, 36, 42], [32, 17, 21, 38, 36, 43, 10, 29, 20, 4], [15, 25, 43, 10, 5, 35, 19, 20, 12, 0], [2, 13, 15, 16, 43, 8, 9, 5, 41, 24]], [[1, 26, 39, 8, 5, 22, 16, 38, 29, 19], [12, 11, 43, 28, 40, 6, 38, 24, 8, 16], [6, 14, 20, 2, 31, 28, 35, 19, 34, 33], [29, 23, 1, 25, 3, 31, 18, 27, 9, 12], [22, 42, 35, 3, 37, 6, 31, 38, 7, 41], [14, 23, 37, 44, 30, 18, 24, 32, 4, 25], [26, 10, 13, 32, 25, 36, 7, 29, 41, 15], [31, 16, 35, 9, 6, 11, 3, 41, 2, 25]], [[25, 37, 5, 10, 33, 36, 3, 19, 43, 29], [11, 2, 23, 29, 3, 43, 37, 21, 44, 7], [40, 17, 11, 14, 31, 38, 7, 3, 4, 42], [39, 29, 16, 19, 6, 37, 4, 23, 5, 15], [22, 5, 34, 36, 13, 44, 29, 16, 33, 25], [20, 15, 17, 38, 33, 29, 30, 13, 1, 8], [42, 44, 22, 0, 29, 19, 6, 13, 25, 1], [4, 38, 3, 35, 19, 34, 20, 29, 42, 6]], [[7, 29, 19, 4, 40, 6, 30, 12, 41, 31], [38, 3, 29, 17, 14, 28, 37, 33, 36, 40], [22, 30, 11, 17, 10, 32, 39, 12, 2, 6], [29, 17, 9, 44, 20, 21, 12, 19, 25, 4], [5, 21, 37, 23, 41, 40, 38, 16, 12, 25], [23, 8, 21, 15, 43, 44, 7, 12, 26, 16], [6, 4, 26, 38, 10, 20, 28, 13, 0, 39], [27, 36, 2, 1, 28, 4, 5, 24, 13, 33]], [[42, 1, 20, 12, 34, 24, 8, 26, 15, 18], [40, 41, 43, 11, 24, 27, 10, 21, 29, 38], [16, 34, 19, 39, 7, 22, 29, 5, 35, 17], [24, 34, 37, 16, 35, 31, 20, 1, 42, 38], [33, 2, 16, 7, 3, 4, 39, 20, 11, 35], [43, 17, 24, 14, 40, 5, 29, 39, 13, 38], [6, 3, 21, 27, 19, 10, 44, 9, 30, 26], [16, 3, 25, 22, 38, 42, 27, 23, 20, 18]], [[40, 30, 1, 23, 41, 0, 12, 25, 43, 35], [18, 7, 20, 31, 21, 30, 43, 11, 42, 12], [25, 14, 2, 28, 24, 3, 39, 9, 36, 33], [23, 14, 32, 35, 17, 20, 29, 5, 30, 16], [19, 17, 34, 18, 20, 39, 7, 24, 29, 35], [4, 39, 15, 25, 44, 7, 6, 23, 29, 37], [39, 34, 23, 13, 2, 6, 42, 25, 7, 1], [7, 9, 19, 34, 39, 28, 4, 41, 29, 33]], [[1, 29, 19, 21, 43, 41, 36, 15, 14, 18], [36, 33, 23, 31, 19, 25, 11, 32, 21, 28], [43, 26, 42, 29, 2, 13, 39, 41, 33, 19], [36, 27, 34, 20, 43, 10, 33, 7, 37, 28], [33, 25, 9, 16, 2, 26, 39, 41, 11, 8], [16, 39, 40, 12, 13, 0, 1, 27, 5, 9], [31, 13, 11, 15, 27, 32, 4, 8, 43, 41], [11, 33, 40, 28, 20, 30, 22, 15, 12, 21]], [[18, 24, 21, 32, 19, 40, 35, 20, 9, 38], [5, 28, 0, 19, 40, 33, 29, 10, 24, 27], [11, 22, 27, 14, 40, 9, 36, 16, 41, 5], [9, 28, 34, 15, 39, 20, 4, 33, 43, 16], [8, 25, 26, 31, 27, 20, 6, 40, 2, 30], [3, 15, 33, 28, 44, 0, 36, 12, 21, 18], [39, 17, 22, 9, 1, 33, 21, 29, 15, 36], [38, 30, 36, 40, 21, 4, 5, 22, 7, 3]]]

clss_support_imagesss_5_shot=[[[24, 30, 13, 26, 3], [16, 31, 30, 15, 22], [7, 31, 42, 1, 30], [39, 41, 6, 34, 16], [42, 5, 38, 41, 7], [2, 10, 15, 24, 38], [34, 38, 18, 5, 25], [7, 13, 26, 43, 4]], [[14, 31, 35, 11, 44], [43, 35, 29, 0, 10], [40, 42, 34, 41, 12], [21, 29, 4, 10, 19], [18, 25, 11, 5, 43], [18, 41, 2, 3, 33], [44, 1, 16, 29, 35], [36, 12, 22, 44, 29]], [[10, 26, 40, 36, 38], [41, 23, 19, 24, 6], [29, 0, 40, 32, 41], [40, 6, 0, 24, 14], [36, 35, 32, 42, 7], [13, 7, 12, 44, 19], [17, 19, 23, 8, 31], [19, 17, 0, 13, 28]], [[23, 2, 25, 26, 12], [15, 14, 24, 1, 26], [28, 30, 25, 41, 33], [39, 13, 41, 6, 35], [38, 30, 10, 12, 3], [0, 16, 27, 13, 40], [27, 42, 35, 38, 40], [14, 22, 29, 10, 37]], [[7, 34, 0, 37, 21], [28, 21, 36, 32, 33], [28, 30, 18, 8, 27], [32, 22, 36, 42, 0], [16, 11, 40, 14, 38], [36, 44, 43, 12, 38], [22, 30, 25, 4, 44], [11, 12, 26, 9, 29]], [[37, 8, 41, 16, 27], [10, 43, 42, 5, 8], [13, 11, 41, 44, 32], [36, 1, 4, 18, 24], [37, 35, 40, 12, 41], [8, 33, 15, 1, 3], [4, 19, 20, 24, 12], [21, 13, 38, 24, 9]], [[42, 30, 16, 39, 41], [13, 3, 29, 10, 25], [36, 26, 14, 27, 6], [20, 38, 33, 6, 36], [30, 26, 38, 24, 8], [8, 27, 34, 17, 43], [12, 36, 35, 25, 39], [12, 7, 22, 0, 36]], [[40, 8, 43, 38, 23], [27, 20, 14, 31, 9], [22, 12, 27, 1, 17], [6, 16, 8, 42, 3], [2, 12, 30, 44, 4], [32, 17, 21, 38, 36], [15, 25, 43, 10, 5], [2, 13, 15, 16, 43]], [[1, 26, 39, 8, 5], [12, 11, 43, 28, 40], [6, 14, 20, 2, 31], [29, 23, 1, 25, 3], [22, 42, 35, 3, 37], [14, 23, 37, 44, 30], [26, 10, 13, 32, 25], [31, 16, 35, 9, 6]], [[25, 37, 5, 10, 33], [11, 2, 23, 29, 3], [40, 17, 11, 14, 31], [39, 29, 16, 19, 6], [22, 5, 34, 36, 13], [20, 15, 17, 38, 33], [42, 44, 22, 0, 29], [4, 38, 3, 35, 19]], [[7, 29, 19, 4, 40], [38, 3, 29, 17, 14], [22, 30, 11, 17, 10], [29, 17, 9, 44, 20], [5, 21, 37, 23, 41], [23, 8, 21, 15, 43], [6, 4, 26, 38, 10], [27, 36, 2, 1, 28]], [[42, 1, 20, 12, 34], [40, 41, 43, 11, 24], [16, 34, 19, 39, 7], [24, 34, 37, 16, 35], [33, 2, 16, 7, 3], [43, 17, 24, 14, 40], [6, 3, 21, 27, 19], [16, 3, 25, 22, 38]], [[40, 30, 1, 23, 41], [18, 7, 20, 31, 21], [25, 14, 2, 28, 24], [23, 14, 32, 35, 17], [19, 17, 34, 18, 20], [4, 39, 15, 25, 44], [39, 34, 23, 13, 2], [7, 9, 19, 34, 39]], [[1, 29, 19, 21, 43], [36, 33, 23, 31, 19], [43, 26, 42, 29, 2], [36, 27, 34, 20, 43], [33, 25, 9, 16, 2], [16, 39, 40, 12, 13], [31, 13, 11, 15, 27], [11, 33, 40, 28, 20]], [[18, 24, 21, 32, 19], [5, 28, 0, 19, 40], [11, 22, 27, 14, 40], [9, 28, 34, 15, 39], [8, 25, 26, 31, 27], [3, 15, 33, 28, 44], [39, 17, 22, 9, 1], [38, 30, 36, 40, 21]]]


clss_support_imagesss_1_shot=[
   [[24], [16], [7], [39], [42], [2], [34], [7]],
   [[14], [43], [40], [21], [18], [18], [44], [36]],
   [[10], [41], [29], [40], [36], [13], [17], [19]],
   [[23], [15], [28], [39], [38], [0], [27], [14]],
   [[7], [28], [28], [32], [16], [36], [22], [11]],
   [[37], [10], [13], [36], [37], [8], [4], [21]],
   [[42], [13], [36], [20], [30], [8], [12], [12]],
   [[40], [27], [22], [6], [2], [32], [15], [2]],
   [[1], [12], [6], [29], [22], [14], [26], [31]],
   [[25], [11], [40], [39], [22], [20], [42], [4]],
   [[7], [38], [22], [29], [5], [23], [6], [27]],
   [[42], [40], [16], [24], [33], [43], [6], [16]],
   [[40], [18], [25], [23], [19], [4], [39], [7]],
   [[1], [36], [43], [36], [33], [16], [31], [11]],
   [[18], [5], [11], [9], [8], [3], [39], [38]]

]



In [77]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/asamese_sce/ressecondary/model_1/model_B_olchiki_1-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [78]:


def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

           
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [79]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.4725345144136457
Evaluated Accuracy: 0.4725
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.457953281330288
Evaluated Accuracy: 0.4580
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.4206506345139489
Evaluated Accuracy: 0.4207
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.5252071436502505
Evaluated Accuracy: 0.5252
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.5100297743298216
Evaluated Accuracy: 0.5100
Testing b_a_r = 0.9842, c_a_r = 0.7728, f_a_r = 0.5635




----------------------> 0.344621843724771
Evaluated Accuracy: 0.3446
Testing b_a_r = 0.2146, c_a_r = 1.0000, f_a_r = 0.1928




----------------------> 0.26091172799152024
Evaluated Accuracy: 0.2609
Testing b_a_r = 0.2635, c_a_r = 0.5794, f_a_r = 0.9851




----------------------> 0.39217049584731345
Evaluated Accuracy: 0.3922
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.47349068455064675
Evaluated Accuracy: 0.4735
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.47670931152399515
Evaluated Accuracy: 0.4767
Testing b_a_r = 1.0000, c_a_r = 0.0795, f_a_r = 1.0000




----------------------> 0.48787220128697933
Evaluated Accuracy: 0.4879
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.5923




----------------------> 0.5301260238642203
Evaluated Accuracy: 0.5301
Testing b_a_r = 0.9316, c_a_r = 0.0636, f_a_r = 0.4486




----------------------> 0.5303347877946557
Evaluated Accuracy: 0.5303
Testing b_a_r = 0.9103, c_a_r = 0.6030, f_a_r = 0.0000




----------------------> 0.3864838428105661
Evaluated Accuracy: 0.3865
Testing b_a_r = 0.6517, c_a_r = 0.3789, f_a_r = 1.0000




----------------------> 0.38957764859558996
Evaluated Accuracy: 0.3896
Testing b_a_r = 1.0000, c_a_r = 0.1363, f_a_r = 0.0000




----------------------> 0.4842995522646137
Evaluated Accuracy: 0.4843
Testing b_a_r = 0.9225, c_a_r = 0.1128, f_a_r = 0.4771




----------------------> 0.5248407356292153
Evaluated Accuracy: 0.5248
Testing b_a_r = 0.1348, c_a_r = 0.9906, f_a_r = 0.9846




----------------------> 0.2825139622294957
Evaluated Accuracy: 0.2825
Testing b_a_r = 0.0030, c_a_r = 0.0223, f_a_r = 0.5364




----------------------> 0.5369619936074137
Evaluated Accuracy: 0.5370
Testing b_a_r = 0.3413, c_a_r = 0.0034, f_a_r = 0.4036




----------------------> 0.510248802582372
Evaluated Accuracy: 0.5102
Testing b_a_r = 0.7589, c_a_r = 0.0012, f_a_r = 0.7557




----------------------> 0.5250064612493791
Evaluated Accuracy: 0.5250
Testing b_a_r = 0.1042, c_a_r = 0.0526, f_a_r = 0.5498




----------------------> 0.5382920403399629
Evaluated Accuracy: 0.5383
Testing b_a_r = 0.9367, c_a_r = 0.4990, f_a_r = 0.4815




----------------------> 0.4605682117780796
Evaluated Accuracy: 0.4606
Testing b_a_r = 0.0958, c_a_r = 0.1892, f_a_r = 0.3720




----------------------> 0.5104580694110913
Evaluated Accuracy: 0.5105
Testing b_a_r = 0.1348, c_a_r = 0.1134, f_a_r = 0.2737




----------------------> 0.5064614144642473
Evaluated Accuracy: 0.5065
Testing b_a_r = 0.1829, c_a_r = 0.3014, f_a_r = 0.5625




----------------------> 0.38368699922642324
Evaluated Accuracy: 0.3837
Testing b_a_r = 0.8339, c_a_r = 0.2234, f_a_r = 0.0000




----------------------> 0.4320278002171299
Evaluated Accuracy: 0.4320
Testing b_a_r = 0.8891, c_a_r = 0.7848, f_a_r = 0.0034




----------------------> 0.3266094946217704
Evaluated Accuracy: 0.3266
Testing b_a_r = 0.1315, c_a_r = 0.7545, f_a_r = 0.9972




----------------------> 0.34652063020216944
Evaluated Accuracy: 0.3465
Testing b_a_r = 0.8834, c_a_r = 0.0297, f_a_r = 0.6127




----------------------> 0.533734742463496
Evaluated Accuracy: 0.5337
Testing b_a_r = 0.0609, c_a_r = 0.1106, f_a_r = 0.6773




----------------------> 0.5049403030694344
Evaluated Accuracy: 0.5049
Testing b_a_r = 0.2783, c_a_r = 0.4783, f_a_r = 0.9973




----------------------> 0.43500049005832764
Evaluated Accuracy: 0.4350
Testing b_a_r = 0.2893, c_a_r = 0.0656, f_a_r = 0.0176




----------------------> 0.4892216852674833
Evaluated Accuracy: 0.4892
Testing b_a_r = 0.0007, c_a_r = 0.0317, f_a_r = 0.7293




----------------------> 0.5276660457019287
Evaluated Accuracy: 0.5277
Testing b_a_r = 0.1070, c_a_r = 0.6228, f_a_r = 0.4985




----------------------> 0.3919746513585041
Evaluated Accuracy: 0.3920
Testing b_a_r = 0.6847, c_a_r = 0.1582, f_a_r = 0.5019




----------------------> 0.5140420282948612
Evaluated Accuracy: 0.5140
Testing b_a_r = 0.7760, c_a_r = 0.4239, f_a_r = 0.4708




----------------------> 0.4606064060620246
Evaluated Accuracy: 0.4606
Testing b_a_r = 0.9989, c_a_r = 0.0475, f_a_r = 0.4801




----------------------> 0.529580446483185
Evaluated Accuracy: 0.5296
Testing b_a_r = 0.0136, c_a_r = 0.0634, f_a_r = 0.5927




----------------------> 0.5394176306416628
Evaluated Accuracy: 0.5394
Testing b_a_r = 0.8837, c_a_r = 0.2280, f_a_r = 0.9949




----------------------> 0.42271947856815656
Evaluated Accuracy: 0.4227
Testing b_a_r = 0.0171, c_a_r = 0.8644, f_a_r = 0.9975




----------------------> 0.3095908388991486
Evaluated Accuracy: 0.3096
Testing b_a_r = 0.3537, c_a_r = 0.5180, f_a_r = 0.0011




----------------------> 0.41474183189244573
Evaluated Accuracy: 0.4147
Testing b_a_r = 0.0895, c_a_r = 0.0775, f_a_r = 0.5371




----------------------> 0.5346849252066439
Evaluated Accuracy: 0.5347
Testing b_a_r = 0.6403, c_a_r = 0.1505, f_a_r = 0.9976




----------------------> 0.47424710165503364
Evaluated Accuracy: 0.4742
Testing b_a_r = 0.1744, c_a_r = 0.1610, f_a_r = 0.2652




----------------------> 0.49207337686445335
Evaluated Accuracy: 0.4921
Testing b_a_r = 0.2262, c_a_r = 0.0450, f_a_r = 0.5905




----------------------> 0.5415117434239248
Evaluated Accuracy: 0.5415
Testing b_a_r = 0.8105, c_a_r = 0.4630, f_a_r = 0.7223




----------------------> 0.4429679946972204
Evaluated Accuracy: 0.4430
Testing b_a_r = 0.2782, c_a_r = 0.0502, f_a_r = 0.2587




----------------------> 0.4967980191127027
Evaluated Accuracy: 0.4968
Testing b_a_r = 0.3534, c_a_r = 0.8890, f_a_r = 0.0007




----------------------> 0.27927699857586535
Evaluated Accuracy: 0.2793
Testing b_a_r = 0.6945, c_a_r = 0.3872, f_a_r = 0.0091




----------------------> 0.43408684921904944
Evaluated Accuracy: 0.4341

✅ Optimal Values:
   - b_a_r: 0.2262
   - c_a_r: 0.0450
   - f_a_r: 0.5905
📈 Highest Accuracy Achieved: 0.5415


In [80]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_assamese/resnet_resluts/base/8_w_1_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/secondary/8_w_1_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/amul/8_w_1_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.4735 ± 0.0766  0.5124 ± 0.0714  0.4737 ± 0.0766




---Accu-----pre----rec---------> 0.3708 ± 0.0661  0.3829 ± 0.0569  0.3707 ± 0.0661
---Accu-----pre----rec---------> 0.5415 ± 0.0628  0.5851 ± 0.0577  0.5417 ± 0.0627


In [81]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/asamese_sce/ressecondary/model_5/model_B_olchiki_5-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [82]:
def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

           
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

            
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [83]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.6948963653616407
Evaluated Accuracy: 0.6949
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.5291824037181352
Evaluated Accuracy: 0.5292
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.55478062326455
Evaluated Accuracy: 0.5548
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.7230227162862918
Evaluated Accuracy: 0.7230
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.7171867950370607
Evaluated Accuracy: 0.7172
Testing b_a_r = 0.3930, c_a_r = 1.0000, f_a_r = 1.0000




----------------------> 0.449783113523244
Evaluated Accuracy: 0.4498
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.7903




----------------------> 0.7221880889950097
Evaluated Accuracy: 0.7222
Testing b_a_r = 0.8260, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.6234111661246692
Evaluated Accuracy: 0.6234
Testing b_a_r = 0.3094, c_a_r = 0.1011, f_a_r = 1.0000




----------------------> 0.7105194223901925
Evaluated Accuracy: 0.7105
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.5711




----------------------> 0.6934354156290587
Evaluated Accuracy: 0.6934
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.8546




----------------------> 0.7153104725957611
Evaluated Accuracy: 0.7153
Testing b_a_r = 0.3823, c_a_r = 1.0000, f_a_r = 0.0000




----------------------> 0.43540156642154704
Evaluated Accuracy: 0.4354
Testing b_a_r = 0.0150, c_a_r = 0.0000, f_a_r = 1.0000




----------------------> 0.7203137420775594
Evaluated Accuracy: 0.7203
Testing b_a_r = 0.0232, c_a_r = 0.0055, f_a_r = 0.8817




----------------------> 0.7228143870220086
Evaluated Accuracy: 0.7228
Testing b_a_r = 0.4790, c_a_r = 0.0000, f_a_r = 0.8409




----------------------> 0.7180220672727925
Evaluated Accuracy: 0.7180
Testing b_a_r = 0.0000, c_a_r = 0.0860, f_a_r = 0.7366




----------------------> 0.7184393585386576
Evaluated Accuracy: 0.7184
Testing b_a_r = 1.0000, c_a_r = 0.5044, f_a_r = 1.0000




----------------------> 0.5510287535360046
Evaluated Accuracy: 0.5510
Testing b_a_r = 1.0000, c_a_r = 0.0868, f_a_r = 0.8658




----------------------> 0.7286477203553421
Evaluated Accuracy: 0.7286
Testing b_a_r = 0.9875, c_a_r = 0.1173, f_a_r = 0.5117




----------------------> 0.6723950144369899
Evaluated Accuracy: 0.6724
Testing b_a_r = 0.0211, c_a_r = 0.0767, f_a_r = 0.8601




----------------------> 0.7215611012640099
Evaluated Accuracy: 0.7216
Testing b_a_r = 0.7274, c_a_r = 0.2225, f_a_r = 0.9981




----------------------> 0.6411358143145929
Evaluated Accuracy: 0.6411
Testing b_a_r = 0.0101, c_a_r = 0.6924, f_a_r = 0.5608




----------------------> 0.47770306187885947
Evaluated Accuracy: 0.4777
Testing b_a_r = 0.9583, c_a_r = 0.3834, f_a_r = 0.5936




----------------------> 0.5716693949647946
Evaluated Accuracy: 0.5717
Testing b_a_r = 0.0236, c_a_r = 0.0100, f_a_r = 0.3792




----------------------> 0.6488343921001832
Evaluated Accuracy: 0.6488
Testing b_a_r = 0.8241, c_a_r = 0.0510, f_a_r = 0.9027




----------------------> 0.727818293310156
Evaluated Accuracy: 0.7278
Testing b_a_r = 0.0662, c_a_r = 0.7178, f_a_r = 0.0042




----------------------> 0.4875006144265733
Evaluated Accuracy: 0.4875
Testing b_a_r = 0.0443, c_a_r = 0.0666, f_a_r = 0.8365




----------------------> 0.7253111094021102
Evaluated Accuracy: 0.7253
Testing b_a_r = 0.1205, c_a_r = 0.1422, f_a_r = 0.0008




----------------------> 0.5934078396761688
Evaluated Accuracy: 0.5934
Testing b_a_r = 0.9519, c_a_r = 0.7649, f_a_r = 0.9972




----------------------> 0.47145563758760667
Evaluated Accuracy: 0.4715
Testing b_a_r = 0.9406, c_a_r = 0.0963, f_a_r = 0.7682




----------------------> 0.7301092987561728
Evaluated Accuracy: 0.7301
Testing b_a_r = 0.9577, c_a_r = 0.0645, f_a_r = 0.9867




----------------------> 0.7317655486544467
Evaluated Accuracy: 0.7318
Testing b_a_r = 0.8110, c_a_r = 0.9983, f_a_r = 0.5229




----------------------> 0.4516594237573935
Evaluated Accuracy: 0.4517
Testing b_a_r = 0.9445, c_a_r = 0.0650, f_a_r = 0.8549




----------------------> 0.7315630889950097
Evaluated Accuracy: 0.7316
Testing b_a_r = 0.9923, c_a_r = 0.1132, f_a_r = 0.9809




----------------------> 0.7128065093409114
Evaluated Accuracy: 0.7128
Testing b_a_r = 0.0097, c_a_r = 0.2254, f_a_r = 0.6047




----------------------> 0.663234197029919
Evaluated Accuracy: 0.6632
Testing b_a_r = 0.0041, c_a_r = 0.0570, f_a_r = 0.9972




----------------------> 0.7236450958179931
Evaluated Accuracy: 0.7236
Testing b_a_r = 0.0573, c_a_r = 0.3589, f_a_r = 0.9608




----------------------> 0.5460586366403318
Evaluated Accuracy: 0.5461
Testing b_a_r = 0.9815, c_a_r = 0.2377, f_a_r = 0.3776




----------------------> 0.6175804003619826
Evaluated Accuracy: 0.6176
Testing b_a_r = 0.9750, c_a_r = 0.0402, f_a_r = 0.9903




----------------------> 0.728649022451391
Evaluated Accuracy: 0.7286
Testing b_a_r = 0.9748, c_a_r = 0.1365, f_a_r = 0.7107




----------------------> 0.7103156565656564
Evaluated Accuracy: 0.7103
Testing b_a_r = 0.0198, c_a_r = 0.0330, f_a_r = 0.6575




----------------------> 0.7103052194520129
Evaluated Accuracy: 0.7103
Testing b_a_r = 0.9357, c_a_r = 0.0523, f_a_r = 0.7816




----------------------> 0.7301093150323735
Evaluated Accuracy: 0.7301
Testing b_a_r = 0.0274, c_a_r = 0.5225, f_a_r = 0.5275




----------------------> 0.5483379964648092
Evaluated Accuracy: 0.5483
Testing b_a_r = 0.0960, c_a_r = 0.6136, f_a_r = 0.9792




----------------------> 0.47998235659853455
Evaluated Accuracy: 0.4800
Testing b_a_r = 0.9041, c_a_r = 0.5907, f_a_r = 0.0027




----------------------> 0.5027070414099096
Evaluated Accuracy: 0.5027
Testing b_a_r = 0.9946, c_a_r = 0.8537, f_a_r = 0.3061




----------------------> 0.47729162801069014
Evaluated Accuracy: 0.4773
Testing b_a_r = 0.0055, c_a_r = 0.1370, f_a_r = 0.8369




----------------------> 0.709487511271269
Evaluated Accuracy: 0.7095
Testing b_a_r = 0.0005, c_a_r = 0.0327, f_a_r = 0.9225




----------------------> 0.7267727284934423
Evaluated Accuracy: 0.7268
Testing b_a_r = 0.9710, c_a_r = 0.0682, f_a_r = 0.7589




----------------------> 0.7292785818220882
Evaluated Accuracy: 0.7293
Testing b_a_r = 0.9844, c_a_r = 0.0778, f_a_r = 0.2008




----------------------> 0.604247687558798
Evaluated Accuracy: 0.6042

✅ Optimal Values:
   - b_a_r: 0.9577
   - c_a_r: 0.0645
   - f_a_r: 0.9867
📈 Highest Accuracy Achieved: 0.7318


In [84]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_assamese/resnet_resluts/base/8_w_5_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/secondary/8_w_5_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/amul/8_w_5_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.6251 ± 0.0708  0.6429 ± 0.0746  0.6252 ± 0.0709




---Accu-----pre----rec---------> 0.4814 ± 0.0716  0.4969 ± 0.0809  0.4813 ± 0.0718
---Accu-----pre----rec---------> 0.7318 ± 0.0532  0.7436 ± 0.0530  0.7318 ± 0.0532


In [85]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/asamese_sce/ressecondary/model_10/model_B_olchiki_10-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [86]:
def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

           
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_8)):
            support_set, query_set = data_loader.__getitem__(clss_8[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

            
            clssa=str(clss_8[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_8[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [87]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=80
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.6939105242474546
Evaluated Accuracy: 0.6939
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.4918843686660239
Evaluated Accuracy: 0.4919
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.5509900120203185
Evaluated Accuracy: 0.5510
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.7108349687404903
Evaluated Accuracy: 0.7108
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.6808127077665593
Evaluated Accuracy: 0.6808
Testing b_a_r = 0.0250, c_a_r = 1.0000, f_a_r = 0.4630




----------------------> 0.2786863049152109
Evaluated Accuracy: 0.2787
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.4369




----------------------> 0.6812846313029989
Evaluated Accuracy: 0.6813
Testing b_a_r = 1.0000, c_a_r = 0.1096, f_a_r = 0.0000




----------------------> 0.6765431531864776
Evaluated Accuracy: 0.6765
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 1.0000




----------------------> 0.6886766959491102
Evaluated Accuracy: 0.6887
Testing b_a_r = 0.0000, c_a_r = 0.1095, f_a_r = 1.0000




----------------------> 0.7044080797920771
Evaluated Accuracy: 0.7044
Testing b_a_r = 1.0000, c_a_r = 0.1104, f_a_r = 1.0000




----------------------> 0.7089369914217808
Evaluated Accuracy: 0.7089
Testing b_a_r = 0.0000, c_a_r = 0.1010, f_a_r = 0.6484




----------------------> 0.7210807383962017
Evaluated Accuracy: 0.7211
Testing b_a_r = 0.0104, c_a_r = 0.5480, f_a_r = 0.9989




----------------------> 0.38018318989734445
Evaluated Accuracy: 0.3802
Testing b_a_r = 0.8299, c_a_r = 0.1161, f_a_r = 0.5497




----------------------> 0.7241759947132465
Evaluated Accuracy: 0.7242
Testing b_a_r = 0.0161, c_a_r = 0.1132, f_a_r = 0.4866




----------------------> 0.7187006333292639
Evaluated Accuracy: 0.7187
Testing b_a_r = 0.5665, c_a_r = 0.7353, f_a_r = 0.0004




----------------------> 0.35635154481440784
Evaluated Accuracy: 0.3564
Testing b_a_r = 0.9818, c_a_r = 0.1241, f_a_r = 0.6597




----------------------> 0.7248766565900076
Evaluated Accuracy: 0.7249
Testing b_a_r = 0.0228, c_a_r = 0.0071, f_a_r = 0.0050




----------------------> 0.6476969276211614
Evaluated Accuracy: 0.6477
Testing b_a_r = 0.9185, c_a_r = 0.1657, f_a_r = 0.3684




----------------------> 0.6946333499151777
Evaluated Accuracy: 0.6946
Testing b_a_r = 0.9772, c_a_r = 0.3131, f_a_r = 0.9891




----------------------> 0.5769458794117273
Evaluated Accuracy: 0.5769
Testing b_a_r = 0.8938, c_a_r = 0.8322, f_a_r = 0.9952




----------------------> 0.2949083382380363
Evaluated Accuracy: 0.2949
Testing b_a_r = 0.9441, c_a_r = 0.0733, f_a_r = 0.4123




----------------------> 0.7094105852905073
Evaluated Accuracy: 0.7094
Testing b_a_r = 0.1340, c_a_r = 0.1938, f_a_r = 0.0020




----------------------> 0.6403151036304521
Evaluated Accuracy: 0.6403
Testing b_a_r = 0.0529, c_a_r = 0.1427, f_a_r = 0.6743




----------------------> 0.7203468733505469
Evaluated Accuracy: 0.7203
Testing b_a_r = 0.9819, c_a_r = 0.1124, f_a_r = 0.6765




----------------------> 0.7215450178847033
Evaluated Accuracy: 0.7215
Testing b_a_r = 0.0465, c_a_r = 0.1020, f_a_r = 0.7162




----------------------> 0.7160756241424362
Evaluated Accuracy: 0.7161
Testing b_a_r = 0.9731, c_a_r = 0.1124, f_a_r = 0.6828




----------------------> 0.7215467246606038
Evaluated Accuracy: 0.7215
Testing b_a_r = 0.0505, c_a_r = 0.1088, f_a_r = 0.5126




----------------------> 0.7189378812533445
Evaluated Accuracy: 0.7189
Testing b_a_r = 0.9920, c_a_r = 0.0980, f_a_r = 0.7289




----------------------> 0.7177380420911589
Evaluated Accuracy: 0.7177
Testing b_a_r = 0.4304, c_a_r = 0.1721, f_a_r = 0.9943




----------------------> 0.6955738077319985
Evaluated Accuracy: 0.6956
Testing b_a_r = 0.0101, c_a_r = 0.1241, f_a_r = 0.7050




----------------------> 0.7220228969149873
Evaluated Accuracy: 0.7220
Testing b_a_r = 0.0046, c_a_r = 0.1116, f_a_r = 0.4233




----------------------> 0.7113196384307623
Evaluated Accuracy: 0.7113
Testing b_a_r = 0.0011, c_a_r = 0.0645, f_a_r = 0.6022




----------------------> 0.7170280172426884
Evaluated Accuracy: 0.7170
Testing b_a_r = 0.9831, c_a_r = 0.1148, f_a_r = 0.5986




----------------------> 0.7241708804594811
Evaluated Accuracy: 0.7242
Testing b_a_r = 0.9614, c_a_r = 0.0735, f_a_r = 0.9968




----------------------> 0.7039250682862209
Evaluated Accuracy: 0.7039
Testing b_a_r = 0.0088, c_a_r = 0.1246, f_a_r = 0.7813




----------------------> 0.7191674608337448
Evaluated Accuracy: 0.7192
Testing b_a_r = 0.9821, c_a_r = 0.1124, f_a_r = 0.5921




----------------------> 0.722742302956974
Evaluated Accuracy: 0.7227
Testing b_a_r = 0.9523, c_a_r = 0.1354, f_a_r = 0.7393




----------------------> 0.7227278257314992
Evaluated Accuracy: 0.7227
Testing b_a_r = 0.9905, c_a_r = 0.0952, f_a_r = 0.5205




----------------------> 0.7217984437362236
Evaluated Accuracy: 0.7218
Testing b_a_r = 0.0005, c_a_r = 0.1347, f_a_r = 0.5869




----------------------> 0.7208341275004115
Evaluated Accuracy: 0.7208
Testing b_a_r = 0.0027, c_a_r = 0.0783, f_a_r = 0.7467




----------------------> 0.708682748625924
Evaluated Accuracy: 0.7087
Testing b_a_r = 0.9890, c_a_r = 0.1309, f_a_r = 0.7268




----------------------> 0.7213009367830858
Evaluated Accuracy: 0.7213
Testing b_a_r = 0.9976, c_a_r = 0.1436, f_a_r = 0.6180




----------------------> 0.7262950116589194
Evaluated Accuracy: 0.7263
Testing b_a_r = 0.0513, c_a_r = 0.0655, f_a_r = 0.3341




----------------------> 0.6936928738765498
Evaluated Accuracy: 0.6937
Testing b_a_r = 0.1412, c_a_r = 0.9985, f_a_r = 0.0023




----------------------> 0.32751900989990756
Evaluated Accuracy: 0.3275
Testing b_a_r = 0.3586, c_a_r = 0.0001, f_a_r = 0.7516




----------------------> 0.698434364140942
Evaluated Accuracy: 0.6984
Testing b_a_r = 0.0038, c_a_r = 0.3948, f_a_r = 0.6321




----------------------> 0.46211131519534693
Evaluated Accuracy: 0.4621
Testing b_a_r = 0.5193, c_a_r = 0.9949, f_a_r = 0.9955




----------------------> 0.27249748690915027
Evaluated Accuracy: 0.2725
Testing b_a_r = 0.0446, c_a_r = 0.5843, f_a_r = 0.0001




----------------------> 0.40782879579975206
Evaluated Accuracy: 0.4078
Testing b_a_r = 0.0058, c_a_r = 0.2141, f_a_r = 0.5011




----------------------> 0.6800864867688422
Evaluated Accuracy: 0.6801
Testing b_a_r = 0.9976, c_a_r = 0.1427, f_a_r = 0.5561




----------------------> 0.7255858341244632
Evaluated Accuracy: 0.7256
Testing b_a_r = 0.9287, c_a_r = 0.1560, f_a_r = 0.5996




----------------------> 0.7189234222496764
Evaluated Accuracy: 0.7189
Testing b_a_r = 0.8796, c_a_r = 0.1174, f_a_r = 0.8736




----------------------> 0.7117899037827864
Evaluated Accuracy: 0.7118
Testing b_a_r = 0.8947, c_a_r = 0.6590, f_a_r = 0.4993




----------------------> 0.34493056276835404
Evaluated Accuracy: 0.3449
Testing b_a_r = 0.8956, c_a_r = 0.0458, f_a_r = 0.5850




----------------------> 0.7132210110794658
Evaluated Accuracy: 0.7132
Testing b_a_r = 0.9085, c_a_r = 0.1141, f_a_r = 0.1870




----------------------> 0.6829709076707126
Evaluated Accuracy: 0.6830
Testing b_a_r = 0.9738, c_a_r = 0.1366, f_a_r = 0.5119




----------------------> 0.7274965636709411
Evaluated Accuracy: 0.7275
Testing b_a_r = 0.9336, c_a_r = 0.0502, f_a_r = 0.8695




----------------------> 0.707500782019207
Evaluated Accuracy: 0.7075
Testing b_a_r = 0.9937, c_a_r = 0.1499, f_a_r = 0.5465




----------------------> 0.7215399097048735
Evaluated Accuracy: 0.7215
Testing b_a_r = 0.9517, c_a_r = 0.0854, f_a_r = 0.5793




----------------------> 0.7208332923342682
Evaluated Accuracy: 0.7208
Testing b_a_r = 0.1048, c_a_r = 0.1506, f_a_r = 0.5937




----------------------> 0.7179761312553186
Evaluated Accuracy: 0.7180
Testing b_a_r = 0.9665, c_a_r = 0.1240, f_a_r = 0.5024




----------------------> 0.7263188579300391
Evaluated Accuracy: 0.7263
Testing b_a_r = 0.2529, c_a_r = 0.0018, f_a_r = 0.6432




----------------------> 0.7048697383773723
Evaluated Accuracy: 0.7049
Testing b_a_r = 0.9976, c_a_r = 0.0034, f_a_r = 0.2266




----------------------> 0.6591383241647276
Evaluated Accuracy: 0.6591
Testing b_a_r = 0.9598, c_a_r = 0.1410, f_a_r = 0.4892




----------------------> 0.724878345144101
Evaluated Accuracy: 0.7249
Testing b_a_r = 0.0331, c_a_r = 0.0876, f_a_r = 0.8686




----------------------> 0.7029684264680248
Evaluated Accuracy: 0.7030
Testing b_a_r = 0.9916, c_a_r = 0.1184, f_a_r = 0.5605




----------------------> 0.7251258276496484
Evaluated Accuracy: 0.7251
Testing b_a_r = 0.0101, c_a_r = 0.1296, f_a_r = 0.5497




----------------------> 0.721316279544382
Evaluated Accuracy: 0.7213
Testing b_a_r = 0.9926, c_a_r = 0.0854, f_a_r = 0.5326




----------------------> 0.7206062546959113
Evaluated Accuracy: 0.7206
Testing b_a_r = 0.0951, c_a_r = 0.2315, f_a_r = 0.9632




----------------------> 0.6310114621238488
Evaluated Accuracy: 0.6310
Testing b_a_r = 0.9946, c_a_r = 0.2509, f_a_r = 0.0067




----------------------> 0.6157520807784843
Evaluated Accuracy: 0.6158
Testing b_a_r = 0.9800, c_a_r = 0.2776, f_a_r = 0.5598




----------------------> 0.5964680186032499
Evaluated Accuracy: 0.5965
Testing b_a_r = 0.9853, c_a_r = 0.0421, f_a_r = 0.0044




----------------------> 0.6648543227288493
Evaluated Accuracy: 0.6649
Testing b_a_r = 0.0254, c_a_r = 0.0177, f_a_r = 0.4538




----------------------> 0.6927243210706648
Evaluated Accuracy: 0.6927
Testing b_a_r = 0.9382, c_a_r = 0.8814, f_a_r = 0.0020




----------------------> 0.33871381984342613
Evaluated Accuracy: 0.3387
Testing b_a_r = 0.1080, c_a_r = 0.6929, f_a_r = 0.9883




----------------------> 0.3141991424817736
Evaluated Accuracy: 0.3142
Testing b_a_r = 0.0553, c_a_r = 0.1000, f_a_r = 0.0005




----------------------> 0.6793994426556703
Evaluated Accuracy: 0.6794
Testing b_a_r = 0.7750, c_a_r = 0.4316, f_a_r = 0.9986




----------------------> 0.4266162226316055
Evaluated Accuracy: 0.4266
Testing b_a_r = 0.0520, c_a_r = 0.0454, f_a_r = 0.9824




----------------------> 0.7024973502455996
Evaluated Accuracy: 0.7025
Testing b_a_r = 0.0273, c_a_r = 0.2547, f_a_r = 0.2289




----------------------> 0.6264817214019374
Evaluated Accuracy: 0.6265

✅ Optimal Values:
   - b_a_r: 0.9738
   - c_a_r: 0.1366
   - f_a_r: 0.5119
📈 Highest Accuracy Achieved: 0.7275


In [88]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_assamese/resnet_resluts/base/8_w_10_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/secondary/8_w_10_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/amul/8_w_10_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.6472 ± 0.0532  0.6571 ± 0.0548  0.6473 ± 0.0533




---Accu-----pre----rec---------> 0.4293 ± 0.0635  0.4302 ± 0.0706  0.4294 ± 0.0635
---Accu-----pre----rec---------> 0.7275 ± 0.0492  0.7375 ± 0.0493  0.7276 ± 0.0492


In [None]:
###############################################   12 way 1 shot

In [90]:
# clss_20=
# clss_1_shot_20_w=
# clss_5_shot_20_w=
# clss_10_shot_20_w=

In [91]:
clss_support_imagesss_10_shot=[[[4, 37, 38, 22, 13, 16, 5, 43, 27, 21], [23, 20, 4, 37, 0, 3, 34, 29, 6, 43], [38, 41, 12, 27, 43, 24, 22, 23, 8, 26], [41, 37, 8, 0, 20, 14, 21, 39, 7, 35], [43, 33, 21, 37, 25, 32, 34, 9, 11, 28], [25, 13, 27, 39, 1, 16, 28, 15, 38, 41], [1, 44, 32, 24, 12, 3, 40, 26, 33, 39], [37, 32, 30, 26, 39, 24, 7, 17, 36, 35], [40, 37, 16, 5, 25, 1, 4, 30, 38, 33], [8, 44, 10, 31, 3, 14, 0, 16, 15, 21], [27, 15, 31, 14, 1, 6, 22, 39, 21, 35], [35, 31, 12, 8, 22, 2, 30, 37, 21, 36]], [[20, 4, 39, 43, 12, 31, 14, 8, 30, 38], [5, 9, 17, 41, 37, 1, 34, 32, 21, 25], [15, 5, 35, 41, 42, 34, 4, 1, 37, 22], [36, 21, 9, 26, 5, 32, 24, 35, 11, 7], [40, 2, 30, 43, 14, 32, 0, 13, 12, 4], [29, 42, 32, 34, 8, 26, 3, 7, 33, 25], [22, 36, 25, 17, 28, 20, 39, 38, 27, 9], [1, 13, 20, 9, 40, 44, 35, 37, 10, 12], [10, 39, 27, 28, 26, 1, 23, 12, 40, 18], [22, 4, 37, 17, 18, 0, 44, 39, 24, 33], [35, 36, 13, 28, 26, 6, 0, 39, 43, 7], [42, 20, 25, 11, 8, 16, 22, 35, 31, 13]], [[21, 6, 0, 43, 34, 29, 33, 19, 2, 27], [7, 6, 38, 18, 11, 20, 30, 1, 33, 21], [25, 24, 32, 35, 28, 27, 2, 8, 31, 9], [24, 0, 28, 5, 6, 22, 32, 3, 2, 12], [17, 19, 32, 13, 0, 39, 25, 37, 15, 9], [14, 40, 13, 32, 33, 28, 5, 39, 16, 12], [31, 7, 42, 25, 39, 19, 27, 18, 15, 32], [27, 33, 28, 31, 38, 22, 42, 20, 19, 34], [40, 27, 43, 42, 15, 4, 35, 34, 20, 21], [12, 41, 27, 36, 9, 16, 10, 2, 43, 32], [42, 30, 41, 34, 7, 20, 9, 28, 36, 40], [13, 42, 36, 41, 44, 5, 24, 7, 9, 4]], [[33, 28, 15, 41, 3, 24, 4, 32, 39, 8], [29, 36, 5, 32, 34, 7, 15, 1, 10, 44], [6, 25, 21, 7, 0, 8, 40, 14, 12, 41], [37, 27, 10, 40, 12, 42, 7, 24, 21, 44], [43, 18, 12, 42, 25, 26, 32, 21, 13, 29], [10, 28, 27, 8, 32, 13, 16, 3, 30, 22], [6, 31, 27, 7, 1, 20, 30, 25, 38, 18], [44, 30, 32, 25, 13, 20, 16, 4, 9, 31], [17, 21, 6, 41, 20, 36, 43, 1, 9, 44], [12, 25, 18, 35, 44, 33, 10, 41, 21, 37], [21, 12, 29, 44, 30, 5, 34, 40, 42, 15], [10, 42, 41, 9, 26, 25, 19, 27, 37, 34]], [[1, 39, 41, 30, 6, 3, 10, 14, 25, 16], [6, 9, 28, 25, 5, 20, 4, 43, 0, 24], [21, 42, 7, 18, 2, 28, 9, 32, 1, 20], [43, 5, 37, 13, 25, 31, 12, 19, 26, 33], [13, 35, 29, 44, 31, 12, 34, 10, 43, 33], [17, 21, 20, 34, 8, 6, 35, 37, 1, 33], [38, 32, 13, 23, 19, 24, 2, 31, 43, 36], [35, 37, 12, 7, 24, 4, 38, 21, 16, 8], [10, 15, 32, 38, 33, 24, 18, 39, 1, 34], [4, 17, 8, 5, 43, 38, 22, 3, 13, 11], [17, 2, 1, 32, 34, 39, 7, 13, 16, 41], [28, 33, 16, 43, 29, 5, 3, 14, 13, 26]], [[25, 9, 3, 12, 33, 30, 2, 21, 43, 41], [37, 4, 27, 39, 2, 16, 1, 11, 10, 35], [24, 22, 10, 7, 18, 40, 31, 25, 14, 32], [3, 35, 26, 4, 21, 5, 30, 8, 38, 29], [12, 14, 24, 0, 16, 3, 43, 5, 41, 9], [20, 4, 27, 0, 2, 36, 12, 32, 29, 40], [21, 7, 24, 8, 36, 29, 17, 35, 14, 11], [40, 17, 38, 29, 2, 14, 36, 22, 37, 35], [30, 16, 1, 36, 18, 13, 28, 43, 39, 5], [11, 0, 19, 14, 33, 40, 30, 20, 13, 34], [16, 20, 5, 9, 37, 25, 23, 10, 17, 14], [20, 7, 44, 30, 29, 4, 39, 28, 26, 22]], [[34, 22, 31, 32, 11, 23, 42, 20, 15, 5], [1, 3, 36, 23, 43, 34, 16, 14, 21, 40], [20, 6, 42, 23, 25, 8, 31, 22, 7, 26], [37, 27, 35, 8, 42, 32, 16, 43, 36, 24], [44, 40, 20, 1, 14, 13, 3, 22, 33, 23], [18, 27, 34, 7, 11, 33, 6, 30, 43, 31], [44, 2, 0, 31, 13, 16, 9, 11, 20, 29], [25, 22, 28, 3, 4, 18, 31, 30, 23, 34], [38, 2, 18, 33, 17, 12, 32, 43, 34, 5], [13, 23, 19, 5, 36, 35, 1, 34, 32, 6], [44, 15, 25, 2, 43, 39, 24, 30, 42, 31], [20, 31, 30, 38, 29, 35, 16, 9, 26, 22]], [[42, 5, 41, 16, 24, 28, 26, 36, 17, 20], [39, 27, 33, 22, 37, 9, 11, 31, 38, 14], [44, 3, 17, 4, 35, 39, 10, 21, 18, 32], [26, 3, 41, 12, 38, 24, 25, 10, 0, 37], [29, 24, 32, 0, 36, 25, 19, 6, 3, 10], [31, 5, 21, 26, 14, 39, 22, 18, 15, 27], [3, 42, 26, 28, 41, 5, 10, 1, 15, 32], [41, 44, 12, 27, 21, 33, 38, 9, 13, 0], [9, 32, 22, 6, 20, 16, 7, 41, 24, 8], [30, 37, 11, 35, 23, 1, 10, 41, 28, 6], [1, 8, 12, 0, 4, 11, 16, 40, 2, 22], [38, 31, 2, 28, 33, 23, 44, 16, 43, 19]], [[41, 13, 0, 2, 42, 30, 43, 10, 14, 37], [38, 33, 9, 12, 21, 1, 8, 23, 13, 37], [33, 42, 24, 5, 36, 13, 11, 16, 19, 2], [29, 12, 31, 10, 32, 7, 41, 27, 33, 39], [3, 13, 41, 16, 10, 28, 42, 12, 22, 14], [5, 35, 28, 0, 33, 22, 40, 30, 36, 24], [16, 6, 38, 1, 26, 24, 39, 18, 7, 37], [37, 18, 12, 27, 24, 11, 29, 4, 21, 30], [18, 30, 0, 19, 34, 43, 22, 33, 3, 36], [30, 44, 38, 4, 40, 0, 32, 19, 23, 27], [26, 7, 22, 41, 4, 9, 5, 29, 32, 43], [44, 40, 5, 10, 11, 6, 25, 27, 23, 34]], [[28, 37, 1, 16, 19, 34, 3, 30, 44, 2], [10, 1, 8, 7, 12, 34, 20, 11, 32, 25], [2, 5, 16, 37, 32, 17, 6, 7, 28, 4], [15, 0, 12, 9, 19, 41, 1, 38, 28, 31], [15, 33, 17, 41, 9, 39, 22, 29, 10, 14], [44, 8, 41, 39, 4, 19, 29, 3, 12, 30], [33, 36, 26, 3, 9, 7, 21, 31, 8, 34], [36, 9, 42, 40, 33, 30, 6, 17, 1, 3], [21, 40, 39, 28, 6, 27, 18, 7, 20, 36], [42, 7, 44, 35, 28, 6, 8, 30, 1, 33], [11, 35, 26, 6, 44, 32, 8, 39, 13, 27], [10, 40, 28, 20, 22, 8, 29, 2, 33, 21]], [[20, 15, 39, 16, 5, 42, 24, 3, 23, 12], [29, 40, 21, 44, 36, 16, 33, 38, 15, 31], [19, 4, 9, 43, 11, 3, 22, 27, 5, 24], [36, 15, 20, 37, 25, 41, 16, 17, 21, 35], [10, 16, 41, 29, 14, 22, 37, 38, 31, 9], [9, 13, 21, 11, 4, 39, 5, 28, 37, 16], [41, 16, 21, 3, 24, 32, 19, 1, 34, 44], [3, 20, 0, 30, 41, 5, 4, 6, 28, 9], [32, 20, 27, 34, 8, 31, 0, 17, 18, 13], [21, 2, 15, 20, 32, 38, 12, 33, 17, 23], [33, 36, 31, 34, 22, 6, 12, 11, 20, 29], [13, 27, 32, 34, 26, 44, 17, 8, 21, 0]], [[34, 31, 35, 29, 42, 13, 41, 32, 37, 43], [17, 9, 0, 1, 26, 44, 16, 25, 12, 31], [23, 27, 6, 25, 8, 37, 15, 43, 30, 32], [1, 5, 16, 21, 15, 44, 35, 4, 28, 20], [23, 44, 35, 4, 12, 22, 34, 29, 26, 37], [17, 9, 20, 44, 21, 41, 13, 3, 35, 5], [17, 37, 41, 24, 28, 14, 16, 23, 44, 30], [26, 25, 9, 5, 44, 35, 16, 2, 6, 15], [42, 39, 0, 36, 11, 12, 37, 4, 32, 19], [35, 21, 31, 44, 10, 13, 0, 34, 29, 7], [14, 8, 6, 13, 5, 28, 35, 39, 1, 38], [2, 18, 26, 34, 17, 29, 24, 0, 11, 7]], [[8, 32, 11, 36, 17, 1, 24, 21, 6, 20], [40, 23, 13, 35, 37, 3, 22, 24, 21, 28], [3, 35, 8, 15, 25, 28, 37, 20, 5, 43], [41, 44, 43, 39, 29, 1, 7, 3, 6, 13], [40, 21, 19, 23, 7, 10, 14, 32, 24, 15], [33, 23, 21, 35, 19, 39, 28, 13, 25, 0], [5, 1, 40, 35, 11, 13, 34, 20, 10, 33], [24, 29, 27, 39, 44, 30, 18, 26, 1, 33], [0, 16, 34, 5, 8, 2, 44, 20, 38, 42], [33, 13, 20, 34, 31, 41, 6, 17, 0, 26], [7, 29, 28, 33, 34, 21, 26, 2, 40, 36], [16, 28, 44, 24, 29, 42, 25, 6, 10, 11]], [[7, 1, 13, 3, 41, 10, 11, 12, 25, 27], [15, 35, 26, 29, 21, 10, 27, 20, 18, 4], [28, 20, 23, 34, 39, 36, 0, 31, 15, 18], [24, 27, 33, 11, 2, 18, 42, 14, 26, 9], [9, 30, 35, 37, 24, 19, 44, 8, 10, 33], [7, 19, 25, 36, 26, 4, 16, 8, 15, 12], [24, 9, 15, 44, 22, 8, 19, 5, 27, 20], [34, 15, 35, 24, 27, 10, 25, 19, 12, 30], [27, 28, 12, 34, 0, 9, 1, 33, 42, 2], [7, 8, 36, 42, 25, 23, 17, 31, 33, 41], [15, 31, 0, 6, 25, 20, 32, 12, 35, 43], [6, 14, 28, 0, 26, 8, 9, 35, 5, 1]], [[31, 37, 25, 10, 28, 14, 34, 21, 39, 7], [19, 25, 7, 29, 18, 4, 32, 40, 42, 30], [44, 26, 11, 32, 19, 8, 16, 37, 29, 40], [13, 44, 38, 33, 0, 40, 22, 15, 1, 32], [34, 2, 28, 14, 23, 36, 39, 38, 5, 13], [7, 0, 26, 40, 25, 43, 37, 35, 15, 11], [35, 8, 30, 20, 43, 32, 29, 0, 12, 18], [33, 20, 25, 36, 7, 14, 21, 42, 37, 30], [41, 37, 13, 25, 38, 20, 29, 15, 30, 17], [42, 4, 14, 43, 12, 0, 13, 11, 40, 29], [41, 23, 22, 40, 21, 12, 0, 37, 16, 44], [41, 43, 2, 32, 16, 11, 8, 4, 1, 17]]]



clss_support_imagesss_5_shot=[[[4, 37, 38, 22, 13], [23, 20, 4, 37, 0], [38, 41, 12, 27, 43], [41, 37, 8, 0, 20], [43, 33, 21, 37, 25], [25, 13, 27, 39, 1], [1, 44, 32, 24, 12], [37, 32, 30, 26, 39], [40, 37, 16, 5, 25], [8, 44, 10, 31, 3], [27, 15, 31, 14, 1], [35, 31, 12, 8, 22]], [[20, 4, 39, 43, 12], [5, 9, 17, 41, 37], [15, 5, 35, 41, 42], [36, 21, 9, 26, 5], [40, 2, 30, 43, 14], [29, 42, 32, 34, 8], [22, 36, 25, 17, 28], [1, 13, 20, 9, 40], [10, 39, 27, 28, 26], [22, 4, 37, 17, 18], [35, 36, 13, 28, 26], [42, 20, 25, 11, 8]], [[21, 6, 0, 43, 34], [7, 6, 38, 18, 11], [25, 24, 32, 35, 28], [24, 0, 28, 5, 6], [17, 19, 32, 13, 0], [14, 40, 13, 32, 33], [31, 7, 42, 25, 39], [27, 33, 28, 31, 38], [40, 27, 43, 42, 15], [12, 41, 27, 36, 9], [42, 30, 41, 34, 7], [13, 42, 36, 41, 44]], [[33, 28, 15, 41, 3], [29, 36, 5, 32, 34], [6, 25, 21, 7, 0], [37, 27, 10, 40, 12], [43, 18, 12, 42, 25], [10, 28, 27, 8, 32], [6, 31, 27, 7, 1], [44, 30, 32, 25, 13], [17, 21, 6, 41, 20], [12, 25, 18, 35, 44], [21, 12, 29, 44, 30], [10, 42, 41, 9, 26]], [[1, 39, 41, 30, 6], [6, 9, 28, 25, 5], [21, 42, 7, 18, 2], [43, 5, 37, 13, 25], [13, 35, 29, 44, 31], [17, 21, 20, 34, 8], [38, 32, 13, 23, 19], [35, 37, 12, 7, 24], [10, 15, 32, 38, 33], [4, 17, 8, 5, 43], [17, 2, 1, 32, 34], [28, 33, 16, 43, 29]], [[25, 9, 3, 12, 33], [37, 4, 27, 39, 2], [24, 22, 10, 7, 18], [3, 35, 26, 4, 21], [12, 14, 24, 0, 16], [20, 4, 27, 0, 2], [21, 7, 24, 8, 36], [40, 17, 38, 29, 2], [30, 16, 1, 36, 18], [11, 0, 19, 14, 33], [16, 20, 5, 9, 37], [20, 7, 44, 30, 29]], [[34, 22, 31, 32, 11], [1, 3, 36, 23, 43], [20, 6, 42, 23, 25], [37, 27, 35, 8, 42], [44, 40, 20, 1, 14], [18, 27, 34, 7, 11], [44, 2, 0, 31, 13], [25, 22, 28, 3, 4], [38, 2, 18, 33, 17], [13, 23, 19, 5, 36], [44, 15, 25, 2, 43], [20, 31, 30, 38, 29]], [[42, 5, 41, 16, 24], [39, 27, 33, 22, 37], [44, 3, 17, 4, 35], [26, 3, 41, 12, 38], [29, 24, 32, 0, 36], [31, 5, 21, 26, 14], [3, 42, 26, 28, 41], [41, 44, 12, 27, 21], [9, 32, 22, 6, 20], [30, 37, 11, 35, 23], [1, 8, 12, 0, 4], [38, 31, 2, 28, 33]], [[41, 13, 0, 2, 42], [38, 33, 9, 12, 21], [33, 42, 24, 5, 36], [29, 12, 31, 10, 32], [3, 13, 41, 16, 10], [5, 35, 28, 0, 33], [16, 6, 38, 1, 26], [37, 18, 12, 27, 24], [18, 30, 0, 19, 34], [30, 44, 38, 4, 40], [26, 7, 22, 41, 4], [44, 40, 5, 10, 11]], [[28, 37, 1, 16, 19], [10, 1, 8, 7, 12], [2, 5, 16, 37, 32], [15, 0, 12, 9, 19], [15, 33, 17, 41, 9], [44, 8, 41, 39, 4], [33, 36, 26, 3, 9], [36, 9, 42, 40, 33], [21, 40, 39, 28, 6], [42, 7, 44, 35, 28], [11, 35, 26, 6, 44], [10, 40, 28, 20, 22]], [[20, 15, 39, 16, 5], [29, 40, 21, 44, 36], [19, 4, 9, 43, 11], [36, 15, 20, 37, 25], [10, 16, 41, 29, 14], [9, 13, 21, 11, 4], [41, 16, 21, 3, 24], [3, 20, 0, 30, 41], [32, 20, 27, 34, 8], [21, 2, 15, 20, 32], [33, 36, 31, 34, 22], [13, 27, 32, 34, 26]], [[34, 31, 35, 29, 42], [17, 9, 0, 1, 26], [23, 27, 6, 25, 8], [1, 5, 16, 21, 15], [23, 44, 35, 4, 12], [17, 9, 20, 44, 21], [17, 37, 41, 24, 28], [26, 25, 9, 5, 44], [42, 39, 0, 36, 11], [35, 21, 31, 44, 10], [14, 8, 6, 13, 5], [2, 18, 26, 34, 17]], [[8, 32, 11, 36, 17], [40, 23, 13, 35, 37], [3, 35, 8, 15, 25], [41, 44, 43, 39, 29], [40, 21, 19, 23, 7], [33, 23, 21, 35, 19], [5, 1, 40, 35, 11], [24, 29, 27, 39, 44], [0, 16, 34, 5, 8], [33, 13, 20, 34, 31], [7, 29, 28, 33, 34], [16, 28, 44, 24, 29]], [[7, 1, 13, 3, 41], [15, 35, 26, 29, 21], [28, 20, 23, 34, 39], [24, 27, 33, 11, 2], [9, 30, 35, 37, 24], [7, 19, 25, 36, 26], [24, 9, 15, 44, 22], [34, 15, 35, 24, 27], [27, 28, 12, 34, 0], [7, 8, 36, 42, 25], [15, 31, 0, 6, 25], [6, 14, 28, 0, 26]], [[31, 37, 25, 10, 28], [19, 25, 7, 29, 18], [44, 26, 11, 32, 19], [13, 44, 38, 33, 0], [34, 2, 28, 14, 23], [7, 0, 26, 40, 25], [35, 8, 30, 20, 43], [33, 20, 25, 36, 7], [41, 37, 13, 25, 38], [42, 4, 14, 43, 12], [41, 23, 22, 40, 21], [41, 43, 2, 32, 16]]]


clss_12=[['assamese1_PAC', 'assamese1_KTA', 'assamese1_A', 'assamese1_JA', 'assamese1_DXA', 'assamese1_KHA', 'assamese1_OI', 'assamese1_OU', 'assamese1_E', 'assamese1_CAY', 'assamese1_DHRA', 'assamese1_XAT'],
         ['assamese1_CCA', 'assamese1_PA', 'assamese1_DHA', 'assamese1_A', 'assamese1_PAC', 'assamese1_GA', 'assamese1_MTA', 'assamese1_CA', 'assamese1_EE', 'assamese1_CAY', 'assamese1_NIYA', 'assamese1_AA'],
        ['assamese1_OI', 'assamese1_EK', 'assamese1_MNA', 'assamese1_GA', 'assamese1_TINI', 'assamese1_CARI', 'assamese1_KHA', 'assamese1_DXA', 'assamese1_ANSR', 'assamese1_CA', 'assamese1_KTA', 'assamese1_AJA'],
         ['assamese1_ANSR', 'assamese1_HA', 'assamese1_AE', 'assamese1_AA', 'assamese1_MDHA', 'assamese1_DHA', 'assamese1_CAY', 'assamese1_O', 'assamese1_ATH', 'assamese1_KHA', 'assamese1_BXG', 'assamese1_NAA'],
          ['assamese1_E', 'assamese1_OU', 'assamese1_TINI', 'assamese1_GHA', 'assamese1_CCA', 'assamese1_NA', 'assamese1_NG', 'assamese1_MNA', 'assamese1_PAC', 'assamese1_AE', 'assamese1_TXA', 'assamese1_BXG'],
          ['assamese1_TA', 'assamese1_MDHA', 'assamese1_DA', 'assamese1_TINI', 'assamese1_KA', 'assamese1_EE', 'assamese1_GHA', 'assamese1_OI', 'assamese1_AA', 'assamese1_DXA', 'assamese1_DHRA', 'assamese1_CCA'],
          ['assamese1_SUNYA', 'assamese1_KHA', 'assamese1_ATH', 'assamese1_AA', 'assamese1_CAY', 'assamese1_DUI', 'assamese1_NG', 'assamese1_MA', 'assamese1_CA', 'assamese1_MNA', 'assamese1_REE', 'assamese1_KHYA'],
         ['assamese1_KHA', 'assamese1_NAA', 'assamese1_OU', 'assamese1_DRA', 'assamese1_DXA', 'assamese1_CBN', 'assamese1_SUNYA', 'assamese1_BXG', 'assamese1_AYA', 'assamese1_BHA', 'assamese1_JHA', 'assamese1_AE'],
        ['assamese1_CAY', 'assamese1_E', 'assamese1_XAT', 'assamese1_KTA', 'assamese1_JHA', 'assamese1_PHA', 'assamese1_KHA', 'assamese1_TXA', 'assamese1_UU', 'assamese1_CA', 'assamese1_DUI', 'assamese1_AJA'],
         ['assamese1_EK', 'assamese1_CA', 'assamese1_TINI', 'assamese1_DRA', 'assamese1_DXA', 'assamese1_CAY', 'assamese1_JA', 'assamese1_HA', 'assamese1_CCA', 'assamese1_O', 'assamese1_NIYA', 'assamese1_ATH'],
        ['assamese1_THA', 'assamese1_DRA', 'assamese1_DHRA', 'assamese1_BA', 'assamese1_BHA', 'assamese1_KTA', 'assamese1_NAA', 'assamese1_PAC', 'assamese1_E', 'assamese1_CAY', 'assamese1_JHA', 'assamese1_ANSR'],
         ['assamese1_BXG', 'assamese1_REE', 'assamese1_CA', 'assamese1_TR', 'assamese1_NAA', 'assamese1_THA', 'assamese1_LA', 'assamese1_U', 'assamese1_CCA', 'assamese1_DUI', 'assamese1_MNA', 'assamese1_MDA'],
        ['assamese1_AE', 'assamese1_MXA', 'assamese1_O', 'assamese1_TR', 'assamese1_E', 'assamese1_PA', 'assamese1_CCA', 'assamese1_GA', 'assamese1_MTHA', 'assamese1_ANSR', 'assamese1_MDHA', 'assamese1_HA'],
         ['assamese1_EK', 'assamese1_CAY', 'assamese1_U', 'assamese1_CCA', 'assamese1_BXG', 'assamese1_ANSR', 'assamese1_NIYA', 'assamese1_PHA', 'assamese1_GA', 'assamese1_DHA', 'assamese1_A', 'assamese1_MDHA'],
         ['assamese1_NG', 'assamese1_NA', 'assamese1_NIYA', 'assamese1_ATH', 'assamese1_MNA', 'assamese1_DXA', 'assamese1_O', 'assamese1_MTA', 'assamese1_CAY', 'assamese1_SUNYA', 'assamese1_AE', 'assamese1_XAT'],
    ]
clss_support_imagesss_1_shot=[
    [[4], [23], [38], [41], [43], [25], [1], [37], [40], [8], [27], [35]],
    [[20], [5], [15], [36], [40], [29], [22], [1], [10], [22], [35], [42]],
    [[21], [7], [25], [24], [17], [14], [31], [27], [40], [12], [42], [13]],
    [[33], [29], [6], [37], [43], [10], [6], [44], [17], [12], [21], [10]],
    [[1], [6], [21], [43], [13], [17], [38], [35], [10], [4], [17], [28]],
    [[25], [37], [24], [3], [12], [20], [21], [40], [30], [11], [16], [20]],
    [[34], [1], [20], [37], [44], [18], [44], [25], [38], [13], [44], [20]],
    [[42], [39], [44], [26], [29], [31], [3], [41], [9], [30], [1], [38]],
     [[41], [38], [33], [29], [3], [5], [16], [37], [18], [30], [26], [44]],
    [[28], [10], [2], [15], [15], [44], [33], [36], [21], [42], [11], [10]],
    [[20], [29], [19], [36], [10], [9], [41], [3], [32], [21], [33], [13]],
     [[34], [17], [23], [1], [23], [17], [17], [26], [42], [35], [14], [2]],
    [[8], [40], [3], [41], [40], [33], [5], [24], [0], [33], [7], [16]],
    [[7], [15], [28], [24], [9], [7], [24], [34], [27], [7], [15], [6]],
    [[31], [19], [44], [13], [34], [7], [35], [33], [41], [42], [41], [41]]
 ]














In [92]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/asamese_sce/ressecondary/model_1/model_B_olchiki_1-shot_res.pth',map_location=torch.device('cpu')))

torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')




In [93]:

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

           
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_1_shot[tlv_cls])

            
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_1_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [94]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.41042746791273654
Evaluated Accuracy: 0.4104
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.3542030418383702
Evaluated Accuracy: 0.3542
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.32782362969277856
Evaluated Accuracy: 0.3278
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.45600101697944606
Evaluated Accuracy: 0.4560
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.41774503422212944
Evaluated Accuracy: 0.4177
Testing b_a_r = 0.0647, c_a_r = 0.0000, f_a_r = 0.5869




----------------------> 0.4332703665013071
Evaluated Accuracy: 0.4333
Testing b_a_r = 0.0928, c_a_r = 0.0751, f_a_r = 0.0000




----------------------> 0.41103539911483583
Evaluated Accuracy: 0.4110
Testing b_a_r = 0.5645, c_a_r = 0.9982, f_a_r = 0.9612




----------------------> 0.19589275389226612
Evaluated Accuracy: 0.1959
Testing b_a_r = 0.9338, c_a_r = 0.0000, f_a_r = 0.8520




----------------------> 0.422417480110113
Evaluated Accuracy: 0.4224
Testing b_a_r = 0.0811, c_a_r = 0.6253, f_a_r = 1.0000




----------------------> 0.26016036291026123
Evaluated Accuracy: 0.2602
Testing b_a_r = 0.4083, c_a_r = 0.0000, f_a_r = 0.0883




----------------------> 0.3826210299823944
Evaluated Accuracy: 0.3826
Testing b_a_r = 0.6974, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.38186178344257204
Evaluated Accuracy: 0.3819
Testing b_a_r = 0.0500, c_a_r = 0.0939, f_a_r = 0.5471




----------------------> 0.44741251482975614
Evaluated Accuracy: 0.4474
Testing b_a_r = 0.0000, c_a_r = 0.0991, f_a_r = 1.0000




----------------------> 0.42154389158605465
Evaluated Accuracy: 0.4215
Testing b_a_r = 0.8720, c_a_r = 0.6998, f_a_r = 0.0000




----------------------> 0.2475280816764854
Evaluated Accuracy: 0.2475
Testing b_a_r = 0.9978, c_a_r = 0.3833, f_a_r = 1.0000




----------------------> 0.30801302133431296
Evaluated Accuracy: 0.3080
Testing b_a_r = 0.8556, c_a_r = 0.0647, f_a_r = 0.5249




----------------------> 0.4570073062978766
Evaluated Accuracy: 0.4570
Testing b_a_r = 0.9698, c_a_r = 0.1739, f_a_r = 0.0046




----------------------> 0.40750197313406494
Evaluated Accuracy: 0.4075
Testing b_a_r = 0.5579, c_a_r = 0.9976, f_a_r = 0.0232




----------------------> 0.17327450073587938
Evaluated Accuracy: 0.1733
Testing b_a_r = 0.9244, c_a_r = 0.1069, f_a_r = 0.4990




----------------------> 0.44413421428373767
Evaluated Accuracy: 0.4441
Testing b_a_r = 0.0951, c_a_r = 0.0534, f_a_r = 0.6396




----------------------> 0.45587332768240046
Evaluated Accuracy: 0.4559
Testing b_a_r = 0.5859, c_a_r = 0.5481, f_a_r = 0.0188




----------------------> 0.30243976891621854
Evaluated Accuracy: 0.3024
Testing b_a_r = 0.9121, c_a_r = 0.8120, f_a_r = 0.9988




----------------------> 0.21382402139125212
Evaluated Accuracy: 0.2138
Testing b_a_r = 0.8857, c_a_r = 0.0530, f_a_r = 0.5143




----------------------> 0.4576386176175638
Evaluated Accuracy: 0.4576
Testing b_a_r = 0.6821, c_a_r = 0.2501, f_a_r = 0.9961




----------------------> 0.32064599330686533
Evaluated Accuracy: 0.3206
Testing b_a_r = 0.4494, c_a_r = 0.4994, f_a_r = 0.9993




----------------------> 0.3272042181762928
Evaluated Accuracy: 0.3272
Testing b_a_r = 0.9827, c_a_r = 0.1054, f_a_r = 0.7541




----------------------> 0.4351805445527482
Evaluated Accuracy: 0.4352
Testing b_a_r = 0.2438, c_a_r = 0.1558, f_a_r = 0.4013




----------------------> 0.44021337617905165
Evaluated Accuracy: 0.4402
Testing b_a_r = 0.0540, c_a_r = 0.0358, f_a_r = 0.6191




----------------------> 0.450194599701981
Evaluated Accuracy: 0.4502
Testing b_a_r = 0.6291, c_a_r = 0.2164, f_a_r = 0.3175




----------------------> 0.4081411411095648
Evaluated Accuracy: 0.4081
Testing b_a_r = 0.2771, c_a_r = 0.0592, f_a_r = 0.3954




----------------------> 0.443744937868584
Evaluated Accuracy: 0.4437
Testing b_a_r = 0.9250, c_a_r = 0.0516, f_a_r = 0.8186




----------------------> 0.44867561751923224
Evaluated Accuracy: 0.4487
Testing b_a_r = 0.6855, c_a_r = 0.8500, f_a_r = 0.0019




----------------------> 0.19966894442015873
Evaluated Accuracy: 0.1997
Testing b_a_r = 0.6781, c_a_r = 0.1183, f_a_r = 0.2514




----------------------> 0.4212653299450925
Evaluated Accuracy: 0.4213
Testing b_a_r = 0.0047, c_a_r = 0.0669, f_a_r = 0.6199




----------------------> 0.4538552738181258
Evaluated Accuracy: 0.4539
Testing b_a_r = 0.0681, c_a_r = 0.4583, f_a_r = 0.5196




----------------------> 0.3446316611419151
Evaluated Accuracy: 0.3446
Testing b_a_r = 0.1761, c_a_r = 0.0514, f_a_r = 0.9961




----------------------> 0.43528982355188217
Evaluated Accuracy: 0.4353
Testing b_a_r = 0.9994, c_a_r = 0.0340, f_a_r = 0.5511




----------------------> 0.4496881125775867
Evaluated Accuracy: 0.4497
Testing b_a_r = 0.4245, c_a_r = 0.2510, f_a_r = 0.0026




----------------------> 0.33980807447491
Evaluated Accuracy: 0.3398
Testing b_a_r = 0.9917, c_a_r = 0.1459, f_a_r = 0.5966




----------------------> 0.45513483215857836
Evaluated Accuracy: 0.4551
Testing b_a_r = 0.3787, c_a_r = 0.1762, f_a_r = 0.5375




----------------------> 0.4512194729148121
Evaluated Accuracy: 0.4512
Testing b_a_r = 0.9958, c_a_r = 0.2790, f_a_r = 0.5777




----------------------> 0.3160940672012758
Evaluated Accuracy: 0.3161
Testing b_a_r = 0.9250, c_a_r = 0.5694, f_a_r = 0.5213




----------------------> 0.294869737869442
Evaluated Accuracy: 0.2949
Testing b_a_r = 0.0199, c_a_r = 0.7193, f_a_r = 0.5739




----------------------> 0.23617541929317645
Evaluated Accuracy: 0.2362
Testing b_a_r = 0.0489, c_a_r = 0.1527, f_a_r = 0.6216




----------------------> 0.44944295763622677
Evaluated Accuracy: 0.4494
Testing b_a_r = 0.9641, c_a_r = 0.1509, f_a_r = 0.9936




----------------------> 0.41940552769661166
Evaluated Accuracy: 0.4194
Testing b_a_r = 0.9901, c_a_r = 0.0382, f_a_r = 0.6764




----------------------> 0.45713882536696104
Evaluated Accuracy: 0.4571
Testing b_a_r = 0.9848, c_a_r = 0.0696, f_a_r = 0.6577




----------------------> 0.45423476779861977
Evaluated Accuracy: 0.4542
Testing b_a_r = 0.0353, c_a_r = 0.0393, f_a_r = 0.7267




----------------------> 0.45195630217007055
Evaluated Accuracy: 0.4520
Testing b_a_r = 0.9663, c_a_r = 0.3893, f_a_r = 0.4579




----------------------> 0.34930265865549565
Evaluated Accuracy: 0.3493

✅ Optimal Values:
   - b_a_r: 0.8857
   - c_a_r: 0.0530
   - f_a_r: 0.5143
📈 Highest Accuracy Achieved: 0.4576


In [95]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_assamese/resnet_resluts/base/12_w_1_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/secondary/12_w_1_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/amul/12_w_1_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.3774 ± 0.0385  0.4130 ± 0.0548  0.3775 ± 0.0386




---Accu-----pre----rec---------> 0.2437 ± 0.0392  0.2608 ± 0.0470  0.2437 ± 0.0392
---Accu-----pre----rec---------> 0.4576 ± 0.0416  0.4981 ± 0.0492  0.4577 ± 0.0415


In [96]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/asamese_sce/ressecondary/model_5/model_B_olchiki_5-shot_res.pth',map_location=torch.device('cpu')))


torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [97]:

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

           
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_5_shot[tlv_cls])

            
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_5_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [98]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.6210518983068664
Evaluated Accuracy: 0.6211
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.46508211217117
Evaluated Accuracy: 0.4651
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.47105207488530076
Evaluated Accuracy: 0.4711
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.6218716205905761
Evaluated Accuracy: 0.6219
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.6349465586596562
Evaluated Accuracy: 0.6349
Testing b_a_r = 0.3930, c_a_r = 1.0000, f_a_r = 1.0000




----------------------> 0.3528240616283219
Evaluated Accuracy: 0.3528
Testing b_a_r = 0.3418, c_a_r = 0.1556, f_a_r = 1.0000




----------------------> 0.6220307774815351
Evaluated Accuracy: 0.6220
Testing b_a_r = 0.0000, c_a_r = 0.0000, f_a_r = 0.8772




----------------------> 0.633421940301147
Evaluated Accuracy: 0.6334
Testing b_a_r = 0.9262, c_a_r = 1.0000, f_a_r = 0.0000




----------------------> 0.34657950582905533
Evaluated Accuracy: 0.3466
Testing b_a_r = 0.2485, c_a_r = 0.6817, f_a_r = 1.0000




----------------------> 0.3882497975713874
Evaluated Accuracy: 0.3882
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.0000




----------------------> 0.5025755358989145
Evaluated Accuracy: 0.5026
Testing b_a_r = 0.9974, c_a_r = 0.0745, f_a_r = 0.9059




----------------------> 0.629105450649946
Evaluated Accuracy: 0.6291
Testing b_a_r = 0.8848, c_a_r = 0.3120, f_a_r = 0.9967




----------------------> 0.4649403670134208
Evaluated Accuracy: 0.4649
Testing b_a_r = 0.0000, c_a_r = 0.1354, f_a_r = 0.6577




----------------------> 0.6145165696111109
Evaluated Accuracy: 0.6145
Testing b_a_r = 0.9851, c_a_r = 0.6716, f_a_r = 0.0000




----------------------> 0.41200821592335723
Evaluated Accuracy: 0.4120
Testing b_a_r = 0.0000, c_a_r = 0.0337, f_a_r = 1.0000




----------------------> 0.6385513292382505
Evaluated Accuracy: 0.6386
Testing b_a_r = 0.8798, c_a_r = 0.1855, f_a_r = 0.0000




----------------------> 0.4986554436916641
Evaluated Accuracy: 0.4987
Testing b_a_r = 0.5176, c_a_r = 0.5126, f_a_r = 0.9998




----------------------> 0.4646651458734572
Evaluated Accuracy: 0.4647
Testing b_a_r = 0.8777, c_a_r = 0.8262, f_a_r = 0.0015




----------------------> 0.3857604573318717
Evaluated Accuracy: 0.3858
Testing b_a_r = 0.4975, c_a_r = 0.0006, f_a_r = 0.6157




----------------------> 0.6014655546906781
Evaluated Accuracy: 0.6015
Testing b_a_r = 0.9939, c_a_r = 0.0256, f_a_r = 0.9108




----------------------> 0.6399376229689425
Evaluated Accuracy: 0.6399
Testing b_a_r = 0.1973, c_a_r = 0.2126, f_a_r = 0.5144




----------------------> 0.5678442697812236
Evaluated Accuracy: 0.5678
Testing b_a_r = 0.9643, c_a_r = 0.2009, f_a_r = 0.9977




----------------------> 0.5996724792663338
Evaluated Accuracy: 0.5997
Testing b_a_r = 0.3080, c_a_r = 0.8329, f_a_r = 0.9992




----------------------> 0.36741753380400205
Evaluated Accuracy: 0.3674
Testing b_a_r = 0.1018, c_a_r = 0.5601, f_a_r = 0.4521




----------------------> 0.45007774235143233
Evaluated Accuracy: 0.4501
Testing b_a_r = 0.9642, c_a_r = 0.0935, f_a_r = 0.2596




----------------------> 0.5090854218594518
Evaluated Accuracy: 0.5091
Testing b_a_r = 0.1111, c_a_r = 0.1322, f_a_r = 0.8596




----------------------> 0.6213279447435256
Evaluated Accuracy: 0.6213
Testing b_a_r = 0.7872, c_a_r = 0.4249, f_a_r = 0.6261




----------------------> 0.5405163230336374
Evaluated Accuracy: 0.5405
Testing b_a_r = 0.9539, c_a_r = 0.1078, f_a_r = 0.9942




----------------------> 0.624941075931036
Evaluated Accuracy: 0.6249
Testing b_a_r = 0.2837, c_a_r = 0.3113, f_a_r = 0.5885




----------------------> 0.495077022918253
Evaluated Accuracy: 0.4951
Testing b_a_r = 0.7271, c_a_r = 0.7442, f_a_r = 0.4806




----------------------> 0.3971447562895536
Evaluated Accuracy: 0.3971
Testing b_a_r = 0.8317, c_a_r = 0.4241, f_a_r = 0.9860




----------------------> 0.5437119393356562
Evaluated Accuracy: 0.5437
Testing b_a_r = 0.3666, c_a_r = 0.5433, f_a_r = 0.0003




----------------------> 0.4343743056338345
Evaluated Accuracy: 0.4344
Testing b_a_r = 0.0405, c_a_r = 0.0467, f_a_r = 0.9062




----------------------> 0.6331329192341177
Evaluated Accuracy: 0.6331
Testing b_a_r = 0.7847, c_a_r = 0.0599, f_a_r = 0.9957




----------------------> 0.63244225742621
Evaluated Accuracy: 0.6324
Testing b_a_r = 0.0663, c_a_r = 0.0219, f_a_r = 0.9990




----------------------> 0.6386925137218259
Evaluated Accuracy: 0.6387
Testing b_a_r = 0.5840, c_a_r = 0.2565, f_a_r = 0.0014




----------------------> 0.48172082160210083
Evaluated Accuracy: 0.4817
Testing b_a_r = 0.3940, c_a_r = 0.9088, f_a_r = 0.4908




----------------------> 0.361300618759309
Evaluated Accuracy: 0.3613
Testing b_a_r = 0.1106, c_a_r = 0.0196, f_a_r = 0.7818




----------------------> 0.633548981913692
Evaluated Accuracy: 0.6335
Testing b_a_r = 0.3958, c_a_r = 0.1660, f_a_r = 0.5479




----------------------> 0.5907641849793117
Evaluated Accuracy: 0.5908
Testing b_a_r = 0.9868, c_a_r = 0.0986, f_a_r = 0.7270




----------------------> 0.6259086769279923
Evaluated Accuracy: 0.6259
Testing b_a_r = 0.8013, c_a_r = 0.4838, f_a_r = 0.5825




----------------------> 0.5105118798050563
Evaluated Accuracy: 0.5105
Testing b_a_r = 0.0024, c_a_r = 0.2319, f_a_r = 0.7954




----------------------> 0.5704981918372324
Evaluated Accuracy: 0.5705
Testing b_a_r = 0.0970, c_a_r = 0.0375, f_a_r = 0.4761




----------------------> 0.55660779182375
Evaluated Accuracy: 0.5566
Testing b_a_r = 0.2061, c_a_r = 0.0809, f_a_r = 0.0010




----------------------> 0.4979797777736516
Evaluated Accuracy: 0.4980
Testing b_a_r = 0.2533, c_a_r = 0.9970, f_a_r = 0.4830




----------------------> 0.35671382039249483
Evaluated Accuracy: 0.3567
Testing b_a_r = 0.9796, c_a_r = 0.0005, f_a_r = 0.7514




----------------------> 0.6209204676519122
Evaluated Accuracy: 0.6209
Testing b_a_r = 0.8600, c_a_r = 0.0257, f_a_r = 0.9148




----------------------> 0.639522108899659
Evaluated Accuracy: 0.6395
Testing b_a_r = 0.9615, c_a_r = 0.1541, f_a_r = 0.8196




----------------------> 0.6254960696367955
Evaluated Accuracy: 0.6255
Testing b_a_r = 0.9838, c_a_r = 0.0180, f_a_r = 0.9208




----------------------> 0.639246052810392
Evaluated Accuracy: 0.6392

✅ Optimal Values:
   - b_a_r: 0.9939
   - c_a_r: 0.0256
   - f_a_r: 0.9108
📈 Highest Accuracy Achieved: 0.6399


In [99]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_assamese/resnet_resluts/base/12_w_5_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/secondary/12_w_5_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/amul/12_w_5_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.5090 ± 0.0542  0.5168 ± 0.0503  0.5090 ± 0.0545




---Accu-----pre----rec---------> 0.3862 ± 0.0648  0.3871 ± 0.0668  0.3861 ± 0.0648
---Accu-----pre----rec---------> 0.6399 ± 0.0515  0.6484 ± 0.0494  0.6399 ± 0.0516


In [100]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



!rm model_mu_path.pt
!rm model_own_path.pt

convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/modelr_res18.pth',map_location=torch.device('cpu')))


convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
model_own_1_shot = PrototypicalNetworks_dynamic_query(convolutional_network_with_dropout).to(device)
model_own_1_shot.load_state_dict(torch.load('/home/asufian/Desktop/output_olchiki/code/asamese_sce/ressecondary/model_10/model_B_olchiki_10-shot_res.pth',map_location=torch.device('cpu')))


torch.save(model_own_1_shot.state_dict(), 'model_mu_path.pt')
torch.save(model_own.state_dict(), 'model_own_path.pt')



In [101]:

def evaluate2(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlal=[]
    ttlac=0
    precision=[]
    recall=[]
    
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

           
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            ttlal.append(accuracy)
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            precision_overall, recall_overall, f1ss= calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)
            precision.append(precision_overall)
            recall.append(recall_overall)
    print(f"---Accu-----pre----rec---------> {np.mean(ttlal):.4f} ± {np.std(ttlal, ddof=1):.4f}  "
      f"{np.mean(precision):.4f} ± {np.std(precision, ddof=1):.4f}  "
      f"{np.mean(recall):.4f} ± {np.std(recall, ddof=1):.4f}")

    # print("---Accu-----pre----rec---------->",np.mean(ttlal),'±',np.std(ttlal, ddof=1), np.mean(precision),'±',np.std(precision, ddof=1),np.mean(precision),'±',np.std(precision, ddof=1),)
# %%%%%%

def evaluate3(fname,data_loader, model, criterion=nn.CrossEntropyLoss()):
    total_predictions = 0
    correct_predictions = 0
    total_loss = 0.0
    accuracy = 0
    datta = ""
    model.eval()
    tlv_cls=0
    ttla=0
    ttlac=0
    with torch.no_grad():
        for episode in range(len(clss_12)):
            support_set, query_set = data_loader.__getitem__(clss_12[tlv_cls] , clss_support_imagesss_10_shot[tlv_cls])

            
            clssa=str(clss_12[tlv_cls])
            clss=str(clssa).replace(',',' ').replace('\'',' ').replace('\'',' ')
            msg=str(clss_support_imagesss_10_shot[tlv_cls])
            str(clss_12[tlv_cls])

            tlv_cls+=1
            # print(len(support_set),len(query_set))
            support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
            query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

            support_images, support_labels = zip(*support_set)

            query_images, query_labels = zip(*query_set)

            classification_scores = model(support_images, support_labels, query_images)
            cortt = 0
            totl = 0
            confusion_mat=None
            all_predicted_labels = []
            all_actual_labels = []
            # print(len(classification_scores),len(query_labels))
            for ei in range(len(classification_scores)):
                classification_scores_each_class = classification_scores[ei]
                # print(classification_scores_each_class)
                predicted_labels_eachclas = torch.argmax(classification_scores_each_class, dim=1)
                pp = predicted_labels_eachclas.tolist()
                act = query_labels[ei]
                all_predicted_labels.extend(pp)
                all_actual_labels.extend(act)
                for iiii in range(len(pp)):
                    if pp[iiii] == act[iiii]:
                        cortt = cortt + 1
                totl = totl + len(pp)
                # print(len(act))
            accuracy = cortt / totl
            ttla+=1
            ttlac+=accuracy
            # print(cortt , totl,"Validation Accuracy:              ", cortt / totl)
            # confusion_mat = confusion_matrix(all_actual_labels, all_predicted_labels)
            # print("Confusion Matrix:")
            # print(confusion_mat)

            # fname='./own35_mapi_5_way_1_shot'

            # calculate_metrics_get_per(msg,confusion_mat,acciuracy=cortt / totl,cls=clss,flnm=fname)

    print("---------------------->",ttlac/ttla)
    return ttlac/ttla

In [102]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from skopt import gp_minimize
from skopt.space import Real
from skopt.utils import use_named_args
from skopt.plots import plot_convergence

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define the search space for b_a_r, c_a_r, and f_a_r
search_space = [
    Real(0.0, 1.0, name="b_a_r"),
    Real(0.0, 1.0, name="c_a_r"),
    Real(0.0, 1.0, name="f_a_r")
]

# Define the objective function for Bayesian Optimization
@use_named_args(search_space)
def objective(b_a_r, c_a_r, f_a_r):
    """
    Objective function to optimize the adaptation ratios.
    It applies adaptation, evaluates the model, and returns the negative accuracy.
    """
    accuracy = get_model_accuracy(b_a_r, c_a_r, f_a_r)
    print(f"Evaluated Accuracy: {accuracy:.4f}")  # Debugging Output
    
    return -accuracy  # We minimize, so return negative accuracy

# Perform Bayesian Optimization with additional settings
search_iteration=50
res = gp_minimize(
    func=objective,
    dimensions=search_space,
    n_calls=search_iteration,               # Number of evaluations
    n_initial_points=5,       # Initial random explorations before GP starts
    acq_func="EI",            # Acquisition function: Expected Improvement
    random_state=42,          # Ensure reproducibility
    n_jobs=-1,                # Parallel execution for faster optimization
)

# Extract optimal results
optimal_b_a_r, optimal_c_a_r, optimal_f_a_r = res.x
best_accuracy = -res.fun

print(f"\n✅ Optimal Values:")
print(f"   - b_a_r: {optimal_b_a_r:.4f}")
print(f"   - c_a_r: {optimal_c_a_r:.4f}")
print(f"   - f_a_r: {optimal_f_a_r:.4f}")
print(f"📈 Highest Accuracy Achieved: {best_accuracy:.4f}")




Testing b_a_r = 0.7965, c_a_r = 0.1834, f_a_r = 0.7797




----------------------> 0.6166421352499064
Evaluated Accuracy: 0.6166
Testing b_a_r = 0.5969, c_a_r = 0.4458, f_a_r = 0.1000
----------------------> 0.4032386603964717
Evaluated Accuracy: 0.4032
Testing b_a_r = 0.4592, c_a_r = 0.3337, f_a_r = 0.1429
----------------------> 0.47771542046966065
Evaluated Accuracy: 0.4777
Testing b_a_r = 0.6509, c_a_r = 0.0564, f_a_r = 0.7220
----------------------> 0.6315349467957099
Evaluated Accuracy: 0.6315
Testing b_a_r = 0.9386, c_a_r = 0.0008, f_a_r = 0.9922
----------------------> 0.6127922058388156
Evaluated Accuracy: 0.6128
Testing b_a_r = 0.5365, c_a_r = 1.0000, f_a_r = 0.6493




----------------------> 0.21873734972423808
Evaluated Accuracy: 0.2187
Testing b_a_r = 0.0000, c_a_r = 0.1123, f_a_r = 0.0000




----------------------> 0.5771068601445783
Evaluated Accuracy: 0.5771
Testing b_a_r = 1.0000, c_a_r = 0.0000, f_a_r = 0.3905




----------------------> 0.5775732046799604
Evaluated Accuracy: 0.5776
Testing b_a_r = 0.0000, c_a_r = 0.1056, f_a_r = 1.0000




----------------------> 0.614079453560531
Evaluated Accuracy: 0.6141
Testing b_a_r = 1.0000, c_a_r = 0.1099, f_a_r = 1.0000




----------------------> 0.6142397044163999
Evaluated Accuracy: 0.6142
Testing b_a_r = 0.1488, c_a_r = 0.4018, f_a_r = 1.0000




----------------------> 0.3625809990438851
Evaluated Accuracy: 0.3626
Testing b_a_r = 1.0000, c_a_r = 0.1241, f_a_r = 0.5941




----------------------> 0.6450507994166522
Evaluated Accuracy: 0.6451
Testing b_a_r = 0.9795, c_a_r = 0.1614, f_a_r = 0.4082




----------------------> 0.6207547759025629
Evaluated Accuracy: 0.6208
Testing b_a_r = 0.8807, c_a_r = 0.7938, f_a_r = 0.0007




----------------------> 0.2773403021455534
Evaluated Accuracy: 0.2773
Testing b_a_r = 0.7938, c_a_r = 0.1182, f_a_r = 0.7426




----------------------> 0.6317064925201479
Evaluated Accuracy: 0.6317
Testing b_a_r = 0.0283, c_a_r = 0.1090, f_a_r = 0.5634




----------------------> 0.6356754862430957
Evaluated Accuracy: 0.6357
Testing b_a_r = 0.7752, c_a_r = 0.7311, f_a_r = 0.9974




----------------------> 0.24319539985430852
Evaluated Accuracy: 0.2432
Testing b_a_r = 0.9886, c_a_r = 0.1061, f_a_r = 0.4557




----------------------> 0.6234611350265112
Evaluated Accuracy: 0.6235
Testing b_a_r = 0.0022, c_a_r = 0.1267, f_a_r = 0.6767




----------------------> 0.6388516316725183
Evaluated Accuracy: 0.6389
Testing b_a_r = 0.0206, c_a_r = 0.0007, f_a_r = 0.0121




----------------------> 0.5388300632704881
Evaluated Accuracy: 0.5388
Testing b_a_r = 0.0176, c_a_r = 0.0760, f_a_r = 0.7846




----------------------> 0.620911671155093
Evaluated Accuracy: 0.6209
Testing b_a_r = 0.0025, c_a_r = 0.1651, f_a_r = 0.5373




----------------------> 0.6312468611062602
Evaluated Accuracy: 0.6312
Testing b_a_r = 0.9909, c_a_r = 0.1396, f_a_r = 0.6801




----------------------> 0.6463161126992854
Evaluated Accuracy: 0.6463
Testing b_a_r = 0.9656, c_a_r = 0.1260, f_a_r = 0.7115




----------------------> 0.6388497590499725
Evaluated Accuracy: 0.6388
Testing b_a_r = 0.7342, c_a_r = 0.9968, f_a_r = 0.0033




----------------------> 0.25558288570088616
Evaluated Accuracy: 0.2556
Testing b_a_r = 0.8809, c_a_r = 0.1994, f_a_r = 0.0017




----------------------> 0.5539260878995212
Evaluated Accuracy: 0.5539
Testing b_a_r = 0.8972, c_a_r = 0.0010, f_a_r = 0.7582




----------------------> 0.6256535607511972
Evaluated Accuracy: 0.6257
Testing b_a_r = 0.9787, c_a_r = 0.2827, f_a_r = 0.5720




----------------------> 0.5297930442517772
Evaluated Accuracy: 0.5298
Testing b_a_r = 0.7079, c_a_r = 0.1967, f_a_r = 0.9866




----------------------> 0.5966439252666604
Evaluated Accuracy: 0.5966
Testing b_a_r = 0.1466, c_a_r = 0.1513, f_a_r = 0.6627




----------------------> 0.6366346574341897
Evaluated Accuracy: 0.6366
Testing b_a_r = 0.9627, c_a_r = 0.0904, f_a_r = 0.5772




----------------------> 0.6255155854234046
Evaluated Accuracy: 0.6255
Testing b_a_r = 0.9685, c_a_r = 0.1352, f_a_r = 0.7728




----------------------> 0.6348871408565174
Evaluated Accuracy: 0.6349
Testing b_a_r = 0.1132, c_a_r = 0.1464, f_a_r = 0.5961




----------------------> 0.6372661991084189
Evaluated Accuracy: 0.6373
Testing b_a_r = 0.8699, c_a_r = 0.0497, f_a_r = 0.9958




----------------------> 0.6207291626271337
Evaluated Accuracy: 0.6207
Testing b_a_r = 0.1147, c_a_r = 0.0009, f_a_r = 0.8337




----------------------> 0.6177087113355911
Evaluated Accuracy: 0.6177
Testing b_a_r = 0.9563, c_a_r = 0.1648, f_a_r = 0.5774




----------------------> 0.6385409382364186
Evaluated Accuracy: 0.6385
Testing b_a_r = 0.0002, c_a_r = 0.1252, f_a_r = 0.4153




----------------------> 0.622350787001859
Evaluated Accuracy: 0.6224
Testing b_a_r = 0.1087, c_a_r = 0.6115, f_a_r = 0.0033




----------------------> 0.31178525655388806
Evaluated Accuracy: 0.3118
Testing b_a_r = 0.9488, c_a_r = 0.1305, f_a_r = 0.6177




----------------------> 0.6479030444412247
Evaluated Accuracy: 0.6479
Testing b_a_r = 0.9917, c_a_r = 0.1459, f_a_r = 0.5966




----------------------> 0.6490186745923602
Evaluated Accuracy: 0.6490
Testing b_a_r = 0.8917, c_a_r = 0.5688, f_a_r = 0.5614




----------------------> 0.3033698255185758
Evaluated Accuracy: 0.3034
Testing b_a_r = 0.9532, c_a_r = 0.1765, f_a_r = 0.5905




----------------------> 0.6340961023761371
Evaluated Accuracy: 0.6341
Testing b_a_r = 0.9367, c_a_r = 0.0942, f_a_r = 0.2404




----------------------> 0.5914001492747637
Evaluated Accuracy: 0.5914
Testing b_a_r = 0.0032, c_a_r = 0.2025, f_a_r = 0.3576




----------------------> 0.5899760843013077
Evaluated Accuracy: 0.5900
Testing b_a_r = 0.1724, c_a_r = 0.9893, f_a_r = 0.9999




----------------------> 0.21841763259267025
Evaluated Accuracy: 0.2184
Testing b_a_r = 0.0293, c_a_r = 0.0118, f_a_r = 0.6265




----------------------> 0.6247068209598368
Evaluated Accuracy: 0.6247
Testing b_a_r = 0.9854, c_a_r = 0.1398, f_a_r = 0.5926




----------------------> 0.6498134564925724
Evaluated Accuracy: 0.6498
Testing b_a_r = 0.9851, c_a_r = 0.0445, f_a_r = 0.6514




----------------------> 0.632644521063811
Evaluated Accuracy: 0.6326
Testing b_a_r = 0.9780, c_a_r = 0.2654, f_a_r = 0.9943




----------------------> 0.5298114669254725
Evaluated Accuracy: 0.5298
Testing b_a_r = 0.9429, c_a_r = 0.0369, f_a_r = 0.8533




----------------------> 0.6288285589153931
Evaluated Accuracy: 0.6288

✅ Optimal Values:
   - b_a_r: 0.9854
   - c_a_r: 0.1398
   - f_a_r: 0.5926
📈 Highest Accuracy Achieved: 0.6498


In [103]:
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz = convolutional_network_with_dropout
M3 = PrototypicalNetworks_dynamic_query(xyz).to(device)
M3.load_state_dict(torch.load('model_own_path.pt'))
evaluate2('/home/asufian/Desktop/output_assamese/resnet_resluts/base/12_w_10_s_f', test_loader, M3, criterion)
# Load Model M2
convolutional_network_with_dropout = ResNet18WithDropout(pretrained=False, dr=0.0)
convolutional_network_with_dropout.fc = nn.Flatten()
xyz2 = convolutional_network_with_dropout
M2 = PrototypicalNetworks_dynamic_query(xyz2).to(device)
M2.load_state_dict(torch.load('model_mu_path.pt'))
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/secondary/12_w_10_s_f", test_loader, M2, criterion)
# Update Model M3's weights using a weighted combination of Model M3 and Model M2
update_model_weights2(M3, M2, conv_ratio= optimal_c_a_r , fc_ratio=optimal_f_a_r , bias_ratio=optimal_b_a_r)

# Assuming you have the evaluate2 function defined
# Evaluate the updated model (M3) using evaluate2 function with some arguments
evaluate2("/home/asufian/Desktop/output_assamese/resnet_resluts/amul/12_w_10_s_f", test_loader, M3, criterion)




---Accu-----pre----rec---------> 0.5376 ± 0.0625  0.5420 ± 0.0622  0.5377 ± 0.0628




---Accu-----pre----rec---------> 0.3141 ± 0.0518  0.3140 ± 0.0557  0.3141 ± 0.0517
---Accu-----pre----rec---------> 0.6498 ± 0.0486  0.6591 ± 0.0478  0.6500 ± 0.0487
