In [1]:
import matplotlib.pyplot as plt
import torch
import os
import torch.nn as nn
import ipywidgets as widgets
import torch.nn.functional as F
from IPython.display import display
import math
import random
import numpy as np

def set_all_seeds(seed=42):
    random.seed(seed) 
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_all_seeds(42)

In [2]:
feature_names = {
  0: 'hair',
  1: 'feathers', 
  2: 'eggs',  
  3: 'milk', 
  4: 'airborne', 
  5: 'aquatic', 
  6: 'predator', 
  7: 'toothed', 
  8: 'backbone',
  9: 'breathes',  
  10: 'venomous',
  11: 'fins', 
  12: 'legs', 
  13: 'tail', 
  14: 'domestic', 
  15: 'catsize'}

In [3]:
class HGNAM(nn.Module):
    def __init__(
          self,
          in_channels,
          out_channels,
          num_layers,
          hidden_channels=None,
          bias=True,
          dropout=0.0,
          device='cuda',
          limited_m=True,
          normalize_m=True,
          m_per_feature=False,
          weight = False,
          aggregation = "overall"
    ):
        
        super().__init__()
        self.device = device
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels
        self.bias = bias
        self.dropout = dropout
        self.limited_m = limited_m
        self.normalize_m = normalize_m
        self.m_per_feature = m_per_feature
        self.weight = weight
        self.aggregation = aggregation
        if self.weight == True:
            self.feature_weights = nn.Parameter(torch.rand(self.in_channels))

        # shape functions f_k
        self.fs = nn.ModuleList()
        for _ in range(in_channels):
            if num_layers == 1:
                layers = [nn.Linear(1, out_channels, bias=bias)]
            else:
                layers = [nn.Linear(1, hidden_channels, bias=bias), nn.ReLU(), nn.Dropout(p=dropout)]
                for _ in range(1, num_layers - 1):
                    layers += [nn.Linear(hidden_channels, hidden_channels, bias=bias), nn.ReLU(), nn.Dropout(p=dropout)]
                layers.append(nn.Linear(hidden_channels, out_channels, bias=bias))
            self.fs.append(nn.Sequential(*layers))

        # distance functions \rho
        if m_per_feature:
            self.ms = nn.ModuleList()
            for _ in range(out_channels if limited_m else in_channels):
                if num_layers == 1:
                    m_layers = [nn.Linear(1, out_channels, bias=bias)]
                else:
                    m_layers = [nn.Linear(1, hidden_channels, bias=bias), nn.ReLU()]
                    for _ in range(1, num_layers - 1):
                        m_layers += [nn.Linear(hidden_channels, hidden_channels, bias=bias), nn.ReLU()]
                    if limited_m:
                        m_layers.append(nn.Linear(hidden_channels, 1, bias=bias))
                    else:
                        m_layers.append(nn.Linear(hidden_channels, out_channels, bias=bias))
                self.ms.append(nn.Sequential(*m_layers))
        else:
            if num_layers == 1:
                m_layers = [nn.Linear(1, out_channels, bias=bias)]
            else:
                m_layers = [nn.Linear(1, hidden_channels, bias=bias), nn.ReLU()]
                for _ in range(1, num_layers - 1):
                    m_layers += [nn.Linear(hidden_channels, hidden_channels, bias=bias), nn.ReLU()]
                if limited_m:
                    m_layers.append(nn.Linear(hidden_channels, 1, bias=bias))
                else:
                    m_layers.append(nn.Linear(hidden_channels, out_channels, bias=bias))
            self.m = nn.Sequential(*m_layers)

    def forward(self, inputs):
        x, distances, normalization_matrix = inputs.x.to(self.device), inputs.dist_mat.to(self.device), inputs.norm_mat.to(self.device)
        fx = torch.empty(x.size(0), x.size(1), self.out_channels).to(self.device)
        for feature_index in range(x.size(1)):
            feature_col = x[:, feature_index].view(-1, 1)
            fx[:, feature_index] = self.fs[feature_index](feature_col)
        if self.weight == True:
            attention_weights = F.softmax(torch.exp(self.feature_weights), dim=0)
            fx_weighted = fx * attention_weights.unsqueeze(0).unsqueeze(-1)  # (N, num_features, out_channels)
            f_sums = fx_weighted.sum(dim=1)
        else:
            f_sums = fx.sum(dim=1)

        if self.aggregation == "overall":
            m_dist = self.m(distances.flatten().view(-1, 1))
            m_dist = m_dist.view(distances.size(0), distances.size(1), self.out_channels)

            if self.normalize_m:
                m_dist = m_dist / normalization_matrix.unsqueeze(-1)

            output = torch.sum(m_dist * f_sums.unsqueeze(0), dim=1)

        elif self.aggregation == "neighbor":
            N = distances.size(0)
            out_channels = f_sums.size(1)
            self_embedding = f_sums

            # distinguish neighbor(distances==0.5 because distances = 1/(real distances + 1))
            neighbor_mask = (distances == 0.5)

            neighbor_indices = neighbor_mask.nonzero(as_tuple=False)

            neighbor_agg = torch.zeros((N, out_channels), device=f_sums.device)
            neighbor_agg.index_add_(0, neighbor_indices[:, 0], f_sums[neighbor_indices[:, 1]])

            neighbor_counts = neighbor_mask.float().sum(dim=1, keepdim=True)
            avg_neighbors = torch.where(neighbor_counts > 0, neighbor_agg / neighbor_counts, torch.zeros_like(neighbor_agg))
            output = self_embedding + avg_neighbors

        else:
            raise ValueError("Unknown aggregation type: {}".format(self.aggregation))
        return output

    def print_m_params(self):
        if hasattr(self, 'm'):
            print("Single m network parameters:")
            for name, param in self.m.named_parameters():
                print(name, param)
        elif hasattr(self, 'ms'):
            print("Separate m networks per dimension:")
            for idx, module in enumerate(self.ms):
                for name, param in module.named_parameters():
                    print(f"ms[{idx}].{name}", param)
        else:
            print("No m parameters found.")

In [28]:
class FScorePlotter:
    def __init__(self, model, data, data_name, feature_names, class_names):
        self.model = model
        self.data = data
        self.data_name = data_name
        self.feature_names = feature_names
        self.class_names = class_names
        self.num_features = data.size(1)
        self.binary_indices = self.compute_binary_indices()
        self.last_fig = None
        
        self.dropdown = widgets.Dropdown(
            options=[(f"{class_id}: {name}", class_id) for class_id, name in self.class_names.items()],
            value=0,
            description='Class:',
            disabled=False,
        )
        self.filename_text = widgets.Text(
            value='',
            placeholder='Enter file name',
            description='File Name:',
            disabled=False,
        )
        self.save_button = widgets.Button(
            description="Save Plot",
            disabled=False,
            button_style='',
            tooltip='Click to save the current plot',
            icon='save'
        )
        self.save_button.on_click(self.save_plot)
    
    def compute_binary_indices(self):
        binary_indices = []
        num_features = self.data.size(1)
        for i in range(num_features):
            unique_vals = torch.unique(self.data[:, i])
            if unique_vals.numel() <= 2:
                binary_indices.append(i)
        return binary_indices
    
    def plot_f_scores(self, selected_class):
        plt.rcParams['font.family'] = 'STIXGeneral'
        plt.rcParams['mathtext.fontset'] = 'stix'
        plt.rcParams.update({
            'font.size': 20,          
            'axes.titlesize': 20,    
            'axes.labelsize': 20,      
            'xtick.labelsize': 20,
            'ytick.labelsize': 20,
        })
        
        binary_indices = self.binary_indices
        n_features = len(binary_indices)
        nrows = 3
        ncols = math.ceil(n_features / nrows)
        
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 8))
        axes = axes.flatten()
        attention_weights = F.softmax(torch.exp(self.model.feature_weights), dim=0)
        for idx, feature_idx in enumerate(binary_indices):
            f1 = self.model.fs[feature_idx].forward(torch.tensor([1.0]).view(-1, 1))\
                     .detach().flatten()[selected_class].item()
            f0 = self.model.fs[feature_idx].forward(torch.tensor([0.0]).view(-1, 1))\
                     .detach().flatten()[selected_class].item()
            weights = float(attention_weights[idx])
            f1 = f1 * weights
            f0 = f0 * weights
            
            ax = axes[idx]
            positions = [0, 0.5]
            bar_width = 0.3
            ax.bar(positions, [f1, f0], color=['#4c72b0', '#dd8452'], width=bar_width, align='center')
            ax.set_xticks(positions)
            ax.set_xticklabels(['$f(1)$', '$f(0)$'])
            ax.axhline(0, color='black', linewidth=1)

            ax.grid(False)
            ax.set_title(self.feature_names[feature_idx])

            max_val = max(abs(f1), abs(f0))
            if max_val == 0:
                max_val = 1
            margin = max_val * 0.1
            ax.set_ylim([-max_val - margin, max_val + margin])

        for j in range(idx + 1, len(axes)):
            fig.delaxes(axes[j])
        
        fig.suptitle(f'Class: {self.class_names[selected_class]}', fontsize=22, fontweight='bold', y=0.85)
        plt.tight_layout(rect=[0, 0, 1, 0.90], h_pad=0.5, w_pad=0.5)
        plt.subplots_adjust(wspace=0.75, hspace=0.5)
        
        self.last_fig = fig
        plt.show()
    
    def save_plot(self, b):
        fname = self.filename_text.value.strip()
        if not fname:
            fname = str(self.data_name) + '_' + str(self.class_names[self.dropdown.value])
            print(fname)
        elif self.last_fig is None:
            print("No plot is available to save.")
        directory = "./plot"
        if not os.path.exists(directory):
            os.makedirs(directory)
        self.last_fig.savefig(f"{directory}/{fname}.pdf", format="pdf", bbox_inches="tight")

        print(f"Plot saved as {fname}.pdf")
    
    def display(self):
        # Use interact to update the plot based on dropdown selection
        widgets.interact(self.plot_f_scores, selected_class=self.dropdown)
        # Display the filename textbox and save button below the plot widget
        display(self.filename_text, self.save_button)

In [29]:
data_name = 'zoo'
data_path = f'processed_data/{data_name}.pt'
data = torch.load(data_path, weights_only=False)

model_dict = 'models/zoo_HGNAM_2732_best_val_acc.pt'
model = HGNAM(in_channels=16, hidden_channels=256, num_layers=5, out_channels=7, dropout=0.0, limited_m=0, bias=True, normalize_m=1, weight=True, aggregation='neighbor')
model.load_state_dict(torch.load(f"{model_dict}", map_location=torch.device('cuda'), weights_only=False))

<All keys matched successfully>

In [30]:
class_names = {
    0: 'Mammals', 1: 'Birds', 2: 'Reptiles',
    3: 'Fish', 4: 'Amphibians', 5: 'Insects', 6: 'Invertebrates'
}

plotter = FScorePlotter(model, data.x, data_name, feature_names, class_names)
plotter.display()

interactive(children=(Dropdown(description='Class:', options=(('0: Mammals', 0), ('1: Birds', 1), ('2: Reptile…

Text(value='', description='File Name:', placeholder='Enter file name')

Button(description='Save Plot', icon='save', style=ButtonStyle(), tooltip='Click to save the current plot')