In [None]:
import matplotlib.pyplot as plt
import numpy as np

class TaskAccuracyPlotter:
    def __init__(self):
        self.fig1 = None
        self.fig2 = None
        self.fig3 = None
        self.label = None

    def plot_individual_task_accuracy(self, task_acc):
        """
        Plot the accuracy of each individual task.

        Args:
            task_acc (dict): A dictionary containing the accuracy of each task.
            label (str, optional): Label for the current plot.

        Returns:
            matplotlib.figure.Figure: The updated figure containing the plots.
        """
        if self.fig1 is None:
            num_tasks = len(task_acc)
            num_rows = (num_tasks + 1) // 2  # Number of rows in the subplots grid
            self.fig1, axes = plt.subplots(num_rows, 2, figsize=(15, 5 * num_rows))
            axes = axes.flatten()
            plot_num = 0
        else:
            axes = self.fig1.get_axes()
            plot_num = 0

        for idx, (key, ax) in enumerate(zip(task_acc.keys(), axes[plot_num:])):
            ax.plot(task_acc[key], label=self.label, marker='.')
            ax.grid(True)
            ax.set_xlabel('Task')
            ax.set_ylabel('Accuracy')
            ax.set_yticks(np.arange(0, 101, 5))
            ax.set_ylim(-1, 101)  # Set y-axis limits
            ax.set_title(f"Task {key}", loc='center')
            ax.set_xticks(list(task_acc.keys()))
            ax.legend()
            plot_num += 1

        # Remove any unused subplots if num_tasks is odd
        for i in range(plot_num, len(axes)):
            self.fig1.delaxes(axes[i])

        plt.tight_layout()

        return self.fig1

    def plot_average_accuracy(self, task_acc):
        """
        Plot the average accuracy of all tasks.

        Args:
            task_acc (dict): A dictionary containing the accuracy of each task.
            label (str, optional): Label for the current plot.

        Returns:
            matplotlib.figure.Figure: The figure containing the plot.
        """
        if self.fig2 is None:
            self.fig2, ax = plt.subplots(figsize=(10, 6))
        else:
            ax = self.fig2.get_axes()[0]

        num_tasks = len(task_acc)
        averages = [sum(values) / num_tasks for values in zip(*task_acc.values())]
        ax.plot(averages, marker='.', label=self.label)
        ax.grid(True)
        ax.set_xlabel('Task')
        ax.set_ylabel('Accuracy')
        ax.set_yticks(np.arange(0, 101, 5))
        ax.set_ylim(-1, 101)
        ax.set_xticks(list(task_acc.keys()))
        ax.set_title('Average Accuracy')
        ax.legend()

        plt.tight_layout()

        return self.fig2

    def plot_encountered_tasks_accuracy(self, task_acc):
        """
        Plot the average accuracy of only encountered tasks.

        Args:
            task_acc (dict): A dictionary containing the accuracy of each task.
            label (str, optional): Label for the current plot.

        Returns:
            matplotlib.figure.Figure: The figure containing the plot.
        """
        if self.fig3 is None:
            self.fig3, ax = plt.subplots(figsize=(10, 6))
        else:
            ax = self.fig3.get_axes()[0]

        encountered_averages = []
        for last_key in range(len(next(iter(task_acc.values())))):
            total_last_element = 0
            num_keys = 0

            for key, value in task_acc.items():
                if int(key) <= int(last_key):
                    last_element = value[int(last_key)]
                    total_last_element += last_element
                    num_keys += 1

            avg = total_last_element / num_keys
            encountered_averages.append(avg)

        ax.plot(encountered_averages, marker='.', label=self.label)
        ax.grid(True)
        ax.set_xlabel('Task')
        ax.set_ylabel('Accuracy')
        ax.set_yticks(np.arange(0, 101, 5))
        ax.set_ylim(-1, 101)
        ax.set_xticks(list(task_acc.keys()))
        ax.set_title('Average Accuracy of Encountered Tasks')
        ax.legend()

        plt.tight_layout()

        return self.fig3


    def show_figures(self):
        """
        Show all figures.
        """
        if self.fig1 is not None:
            self.fig1.show()
        if self.fig2 is not None:
            self.fig2.show()
        if self.fig3 is not None:
            self.fig3.show()
        
    def plot_task_accuracy(self, task_acc,  label=None, plot_task_acc=True, plot_avg_acc=True, plot_encountered_avg=True):
        """
        Plot the accuracy of each task, the average accuracy of all tasks, and the average accuracy of only encountered tasks.

        Args:
            task_acc (dict): A dictionary containing the accuracy of each task.
            plot_task_acc (bool): Whether to plot individual task accuracies. Default is True.
            plot_avg_acc (bool): Whether to plot average accuracy of all tasks. Default is True.
            plot_encountered_avg (bool): Whether to plot average accuracy of only encountered tasks. Default is True.
        """
        self.label = label
        
        if plot_task_acc:
            fig1 = self.plot_individual_task_accuracy(task_acc)
        if plot_avg_acc:
            fig2 = self.plot_average_accuracy(task_acc)
        if plot_encountered_avg:
            fig3 = self.plot_encountered_tasks_accuracy(task_acc)

        return self.fig1, self.fig2, self.fig3


In [None]:
# Example usage:
strategy_dict = {
    "Strategy 1": {
        0: [85, 90, 92],
        1: [0, 80, 85],
        2: [0, 0, 90],
        # Add more tasks as needed...
    },
    "Strategy 2": {
        0: [90, 92, 88],
        1: [0, 95, 89],
        2: [0, 0, 78],
        # Add more tasks as needed...
    },
    "Strategy 3": {
        0: [80, 78, 82],
        1: [0, 90, 85],
        2: [0, 0, 97],
        # Add more tasks as needed...
    },
}
plotter = TaskAccuracyPlotter()

fig = None
for key, value in strategy_dict.items():
    print(f"Strategy: {key}, Tasks: {list(value.keys())}")
    print(f"Task Accuracy: {value}")
    _ = plotter.plot_task_accuracy(value, label=key, plot_task_acc=False, plot_avg_acc=True, plot_encountered_avg=True)

plotter.show_figures()
a = {
        0: [85, 90, 92],
        1: [0, 80, 85],
        2: [0, 0, 90],
        # Add more tasks as needed...
    }
a = {
        0: [85],
        1: [0],
        2: [0],
        # Add more tasks as needed...
    }
_ = plotter.plot_task_accuracy(a, plot_task_acc=True, plot_avg_acc=True, plot_encountered_avg=True)
plotter.show_figures()

In [4]:
from micromind import PhiNet
import torch
import torch.nn as nn
from torch.optim import SGD, Adam

from model.phinet_v2 import PhiNet_v2
from model.phinet_v3 import PhiNetV3

#model = PhiNet_v2(pretrained="TestModel/7_Layers/Adam.pth", num_layers= 7, latent_layer_num=9)
model1 = PhiNet(input_shape=(1,28,28), alpha=3, beta=0.75, t_zero=6, num_layers=7, include_top=False)
model1 = PhiNetV3(model1, latent_layer_num = 9)

print(model1)




PhiNetV3(
  (lat_features): Sequential(
    (0): ZeroPad2d((0, 1, 0, 1))
    (1): SeparableConv2d(
      (_layers): ModuleList(
        (0): Conv2d(1, 1, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (1): Conv2d(1, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(144, eps=0.001, momentum=0.999, affine=True, track_running_stats=True)
        (3): HSwish()
      )
    )
    (2): PhiNetConvBlock(
      (_layers): ModuleList(
        (0): Dropout2d(p=0.05, inplace=False)
        (1): DepthwiseConv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
        (2): BatchNorm2d(144, eps=0.001, momentum=0.999, affine=True, track_running_stats=True)
        (3): HSwish()
        (4): Conv2d(144, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (5): BatchNorm2d(72, eps=0.001, momentum=0.999, affine=True, track_running_stats=True)
      )
    )
    (3): PhiNetConvBlock(
      (_layers): ModuleList(
        (0): Conv2d