In [1]:
import logging
from argparse import ArgumentParser

import torch
import torch.nn.functional as F
# import wandb
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader
from tqdm import trange

from experiments.data import INRDataset
from experiments.utils import (
    common_parser,
    count_parameters,
    get_device,
    set_logger,
    set_seed,
    str2bool,
)
from nn.models import DWSModelForClassification, MLPModelForClassification , DWSModel

from experiments.mnist.generate_data_splits import generate_splits
from experiments.mnist.compute_statistics import compute_stats

set_logger()

In [2]:
torch.cuda.empty_cache()

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.enabled = False

In [4]:
# generate_splits(data_path="notebooks/dataset/mnist-inrs", save_path="dataset")

In [5]:
# compute_stats(data_path="notebooks/dataset/mnist_splits.json", save_path="dataset", batch_size=1024)

## INR Dataset

We create INR Datasets and Dataloaders, and visualize some INRs (by reconstruction the images).



In [6]:
#Loading inr data we created while mnist training
import os
current_working_directory = os.getcwd()
print(current_working_directory)
path = current_working_directory + "/notebooks/dataset/mnist_splits.json"
statistics_path = current_working_directory + "/notebooks/dataset/statistics.pth"
normalize = True
augmentation = True

batch_size = 64
num_workers = 1

/work/talisman/sgupta/DWSNets/equivariant-diffusion


In [7]:
# from torchvision.utils import save_image, make_grid
# import torch

# from experiments.data import INRImageDataset
# from experiments.utils import set_seed
# import matplotlib.pyplot as plt
# dataset = INRImageDataset(
#     path=path,  # path to splits json file
#     augmentation=False,
#     split="train",
# )

# loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False)

# dataset_aug = INRImageDataset(
#     path=path,  # path to splits json file
#     augmentation=True,
#     split="train",
# )
# loader_aug = torch.utils.data.DataLoader(dataset_aug, batch_size=64, shuffle=False)

# batch = next(iter(loader))
# batch_aug = next(iter(loader_aug))

# fig, axs = plt.subplots(1, 2, figsize=(10,20)) 

# axs[0].imshow(make_grid(batch.image.squeeze(-1)).permute(1, 2, 0).clip(0, 1))
# axs[0].set_title('Recunstracted Images from INRs')

# axs[1].imshow(make_grid(batch_aug.image.squeeze(-1)).permute(1, 2, 0).clip(0, 1))
# axs[1].set_title('Recunstracted Images from Augmented INRs')
# plt.show()

In [8]:
train_set = INRDataset(
        path=path,
        split="train",
        normalize=normalize,
        augmentation=augmentation,
        statistics_path=statistics_path,
    )

val_set = INRDataset(
    path=path,
    split="val",
    normalize=normalize,
    statistics_path=statistics_path,
)

test_set = INRDataset(
    path=path,
    split="test",
    normalize=normalize,
    statistics_path=statistics_path,
)

train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
val_loader = torch.utils.data.DataLoader(
    dataset=val_set,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
)

logging.info(
    f"train size {len(train_set)}, "
    f"val size {len(val_set)}, "
    f"test size {len(test_set)}"
)

  self.stats = torch.load(statistics_path, map_location="cpu")
2024-10-15 09:42:26,841 - root - INFO - train size 55000, val size 5000, test size 10000


In [9]:
point = train_set.__getitem__(4)
weight_shapes = tuple(w.shape[:2] for w in point.weights)
bias_shapes = tuple(b.shape[:1] for b in point.biases)
print(weight_shapes,bias_shapes)
# print(point.weights,point.biases,point.label)

(torch.Size([2, 32]), torch.Size([32, 32]), torch.Size([32, 1])) (torch.Size([32]), torch.Size([32]), torch.Size([1]))


  state_dict = torch.load(path, map_location=lambda storage, loc: storage)


In [10]:
new_weight_shapes = (torch.Size([2, 8]), torch.Size([8, 8]), torch.Size([8, 1]))
new_bias_shapes = (torch.Size([8]), torch.Size([8]), torch.Size([1]))

In [11]:
from typing import Tuple
import torch
import torch.nn as nn
from nn.layers import BN, DWSLayer,InvariantLayer, Dropout, ReLU
from nn.layers.base import BaseLayer,GeneralSetLayer

class DWSEncoder(BaseLayer):
    def __init__(
       self,
        weight_shapes,
        bias_shapes,
        input_features,
        hidden_dims,
        downsample_dim,
        n_hidden=2,
        reduction="max",
        bias=True,
        n_fc_layers=1,
        num_heads=4,
        set_layer="sab",
        add_layer_skip=False,
        input_dim_downsample=None,
        init_scale=1.,
        init_off_diag_scale_penalty=1.,
        bn=False,
        dropout_rate = 0.001,
        diagonal=False,
    ):
        super().__init__(
            in_features=input_features,
            out_features=hidden_dims,
            bias=bias,
            reduction=reduction,
            n_fc_layers=n_fc_layers,
            num_heads=num_heads,
            set_layer=set_layer,
        )

        assert len(weight_shapes) > 2, "The implementation only supports networks with more than 2 layers."

        self.downsample_dim = downsample_dim
        self.bias = bias
        self.n_fc_layers = n_fc_layers
        self.num_heads = num_heads
        
        self.skip = self._get_mlp(
            in_features=input_features,
            out_features=input_features,
            bias=bias,
        )
        
        self.InitialLayer = DWSModel(
            weight_shapes= weight_shapes,
            bias_shapes= bias_shapes,
            input_features=input_features,
            hidden_dim=hidden_dims,
            n_hidden=n_hidden,
            reduction=reduction,
            bias=bias,
            output_features=input_features,
            n_fc_layers=n_fc_layers,
            num_heads=num_heads,
            set_layer=set_layer,
            dropout_rate=dropout_rate,
            input_dim_downsample=input_dim_downsample,
            init_scale=init_scale,
            init_off_diag_scale_penalty=init_off_diag_scale_penalty,
            bn=bn,
            add_skip=False,
            add_layer_skip=add_layer_skip,
            diagonal=diagonal,
        )   

    def downsample_input_weights(self, inputs, downsample_dim):
        """Downsample the input weights to the specified dimensions."""
        inputs = list(inputs)

        # Downsample first weight dimension [32,2,32,1] -> [32,2,8,1]
        inputs[0] = self._downsample_weight(inputs[0], dim=2, downsample_dim = downsample_dim)
        # Downsample second weight dimension [32,32,32,1] -> [32,8,8,1]
        inputs[1] = self._downsample_weight(inputs[1], dim=1, downsample_dim = downsample_dim)
        inputs[1] = self._downsample_weight(inputs[1], dim=2, downsample_dim = downsample_dim)

        # Downsample third weight dimension [32,32,1,1] -> [32,8,1,1]
        inputs[2] = self._downsample_weight(inputs[2], dim=1, downsample_dim = downsample_dim)

        return tuple(inputs)

    def downsample_input_biases(self, inputs, downsample_dim):
        """Downsample the input biases to the specified dimensions."""
        inputs = list(inputs)

        # Downsample first bias dimension [32,32,1] -> [32,8,1]
        inputs[0] = self._downsample_bias(inputs[0], downsample_dim= downsample_dim)

        # Downsample second bias dimension [32,32,1] -> [32,8,1]
        inputs[1] = self._downsample_bias(inputs[1], downsample_dim = downsample_dim)

        return tuple(inputs)
    
    def batchNormLayer(self,weights, biases):
        relu = nn.ReLU()
        weights = tuple(relu(nn.BatchNorm2d(w.shape[1]).to(device)(w)) for w in weights)
        biases = tuple(relu(nn.BatchNorm1d(b.shape[1]).to(device)(b.squeeze(-1)).unsqueeze(-1)) for b in biases)
        return weights, biases

    def forward(self, x: Tuple[Tuple[torch.tensor], Tuple[torch.tensor]]):
        """Forward pass through the encoder."""
        x = self.InitialLayer(x)
        weights = self.downsample_input_weights(x[0], 24)
        biases = self.downsample_input_biases(x[1], 24)
        weights= (weights[0] + self.skip(weights[0]), weights[1] + self.skip(weights[1]) ,weights[2] + self.skip(weights[2]))
        biases = (biases[0] + self.skip(biases[0]) , biases[1] + self.skip(biases[1]) ,biases[2] + self.skip(biases[2]))
        weights, biases = self.batchNormLayer( weights, biases)
        weights = self.downsample_input_weights(weights, 16)
        biases = self.downsample_input_biases(biases, 16)
        weights= (weights[0] + self.skip(weights[0]), weights[1] + self.skip(weights[1]) ,weights[2] + self.skip(weights[2]))
        biases = (biases[0] + self.skip(biases[0]) , biases[1] + self.skip(biases[1]) ,biases[2] + self.skip(biases[2]))
        weights, biases = self.batchNormLayer( weights, biases)
        weights = self.downsample_input_weights(weights, self.downsample_dim)
        biases = self.downsample_input_biases(biases, self.downsample_dim)
        weights= (weights[0] + self.skip(weights[0]), weights[1] + self.skip(weights[1]) ,weights[2] + self.skip(weights[2]))
        biases = (biases[0] + self.skip(biases[0]) , biases[1] + self.skip(biases[1]) ,biases[2] + self.skip(biases[2]))
        out = (weights, biases)
        return out

    def _downsample_weight(self, weight, dim,downsample_dim):
        d0 = weight.shape[dim]
        down_sample = GeneralSetLayer(
            in_features=d0,
            out_features= downsample_dim,
            reduction="attn",
            bias=self.bias,
            n_fc_layers=self.n_fc_layers,
            num_heads=self.num_heads,
            set_layer="ds",
        ).to(device)
        
        wi = weight.permute(0, 3, 1, 2) if dim == 2 else weight.permute(0, 3, 2, 1)
        wi = down_sample(wi)
        return wi.permute(0, 2, 3, 1) if dim == 2 else wi.permute(0, 3, 2, 1)

    def _downsample_bias(self, bias, downsample_dim):
        d0 = bias.shape[1]
        down_sample = GeneralSetLayer(
            in_features=d0,
            out_features=downsample_dim,
            reduction="attn",
            bias=self.bias,
            n_fc_layers=self.n_fc_layers,
            num_heads=self.num_heads,
            set_layer="ds",
        ).to(device)
        
        wi = bias.permute(0, 2, 1)
        wi = down_sample(wi)
        return wi.permute(0, 2, 1)

In [12]:
class ResidualBlock(nn.Module):
    def __init__(self, input_dim, output_dim=None):
        super(ResidualBlock, self).__init__()
        if output_dim is None:
            output_dim = input_dim 

        self.fc1 = nn.Linear(input_dim, output_dim)
        self.fc2 = nn.Linear(output_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        x = self.relu(self.fc1(x))  # First linear transformation + ReLU
        x = self.fc2(x)  # Second linear transformation
        return self.relu(x + residual)

class BiasResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return self.relu(x + residual)

In [13]:
# import torch
# import torch.nn as nn

# class Decoder(nn.Module):
#     def __init__(self):
#         super(Decoder, self).__init__()

#         # For Weight Reconstruction
#         self.fc_weight_1 = nn.Sequential(
#             nn.Linear(8 * 2 * 1, 16 * 2 * 1),  # Assuming input size [batch, 8, 1, 1], adjust for your dimensions
#             ResidualBlock(32),
#             nn.Linear(16 * 2 * 1, 24 * 2 * 1),
#             ResidualBlock(48),
#             nn.Linear(24 * 2 * 1, 32 * 2 * 1),
#         )
#         self.fc_weight_21 = nn.Sequential(
#             nn.Linear(8 * 8 * 1, 16 * 8 * 1),
#             ResidualBlock(128),
#             nn.Linear(16 * 8 * 1, 24 * 8 * 1),
#             ResidualBlock(192),
#             nn.Linear(24 * 8 * 1, 32 * 8 * 1),
#         )
#         self.fc_weight_22 = nn.Sequential(
#             nn.Linear(8 * 32 * 1, 16 * 32 * 1),
#             ResidualBlock(512),
#             nn.Linear(16 * 32 * 1, 24 * 32 * 1),
#             ResidualBlock(768),
#             nn.Linear(24 * 32 * 1, 32 * 32 * 1),
#         )
#         self.fc_weight_3 = nn.Sequential(
#             nn.Linear(8 * 1 * 1, 16 * 1 * 1),
#             ResidualBlock(16),
#             nn.Linear(16 * 1 * 1, 24 * 1 * 1),
#             ResidualBlock(24),
#             nn.Linear(24 * 1 * 1, 32 * 1 * 1),
#         )

#         # For Bias Reconstruction
#         self.fc_bias_1 = nn.Sequential(
#             nn.Linear(8, 16),
#             BiasResidualBlock(16),
#             nn.Linear(16, 24),
#             BiasResidualBlock(24),
#             nn.Linear(24, 32),
#         )
#         self.fc_bias_2 = nn.Sequential(
#             nn.Linear(8, 16),
#             BiasResidualBlock(16),
#             nn.Linear(16, 24),
#             BiasResidualBlock(24),
#             nn.Linear(24, 32),
#         )
#         self.fc_bias_3 = nn.Sequential(
#             nn.Linear(1, 1),
#             BiasResidualBlock(1),
#             nn.Linear(1, 1),
#         )

#     def forward(self, x):
#         # Decode weight tensors
#         weight_space = x[0]
#         bias_space = x[1]
#         new_weight_space = []

#         # Flatten and apply linear layers for weight reconstruction
#         weight1_flat = weight_space[0].view(weight_space[0].size(0), -1)  # Flatten
#         new_weight_space.append(self.fc_weight_1(weight1_flat).view(weight_space[0].size(0), 2, 32, 1))

#         weight2_flat = weight_space[1].view(weight_space[1].size(0), -1)
#         weight2_intermediate = self.fc_weight_21(weight2_flat)
#         new_weight_space.append(self.fc_weight_22(weight2_intermediate).view(weight_space[1].size(0), 32, 32, 1))

#         weight3_flat = weight_space[2].view(weight_space[2].size(0), -1)
#         new_weight_space.append(self.fc_weight_3(weight3_flat).view(weight_space[2].size(0), 32, 1, 1))

#         new_bias_space = []
#         # Decode bias tensors
#         new_bias_space.append(self.fc_bias_1(bias_space[0].squeeze(-1)).view(-1, 32, 1))
#         new_bias_space.append(self.fc_bias_2(bias_space[1].squeeze(-1)).view(-1, 32, 1))
#         new_bias_space.append(self.fc_bias_3(bias_space[2].squeeze(-1)).view(-1, 1, 1))
        
#         return (tuple(new_weight_space), tuple(new_bias_space))

In [14]:
# import torch
# import torch.nn as nn

# class Decoder(nn.Module):
#     def __init__(self, upsample_dim,
#                  bias=True,
#                  n_fc_layers=1,
#                  num_heads=4):
        
#         super(Decoder, self).__init__()
        
#         self.fc_weight_1 = nn.Sequential(
#             nn.Linear(32 * 2 * 1, 16 * 2 * 1),  # Assuming input size [batch, 8, 1, 1], adjust for your dimensions
#             nn.ReLU(),
#             nn.Linear(16 * 2 * 1, 32 * 2 * 1),
#         )
        
#         self.fc_weight_2 = nn.Sequential(
#             nn.Linear(32 * 32 * 1, 16 * 32 * 1),
#             nn.ReLU(),
#             nn.Linear(16 * 32 * 1, 32 * 32 * 1),
#         )
#         self.fc_weight_3 = nn.Sequential(
#             nn.Linear(32 * 1 * 1, 16 * 1 * 1),
#             nn.ReLU(),
#             nn.Linear(16 * 1 * 1, 32 * 1 * 1),
#         )

#         # For Bias Reconstruction
#         self.fc_bias_1 = nn.Sequential(
#             nn.Linear(32, 16),
#             nn.ReLU(),
#             nn.Linear(16, 32),
#         )
#         self.fc_bias_2 = nn.Sequential(
#             nn.Linear(32, 16),
#             nn.ReLU(),
#             nn.Linear(16, 32),
#         )
#         self.fc_bias_3 = nn.Sequential(
#             nn.Linear(1, 1),
#             nn.ReLU(),
#             nn.Linear(1, 1),
#         )
        
#         self.upsample_dim = upsample_dim
#         self.bias = bias
#         self.n_fc_layers = n_fc_layers
#         self.num_heads = num_heads
        
#     def upsample_input_weights(self, inputs):
#         """Upsample the input weights to the specified dimensions."""
#         inputs = list(inputs)

#         # Downsample first weight dimension [32,2,8,1] -> [32,2,32,1]
#         inputs[0] = self._upsample_weight(inputs[0], dim=2 , index = 0)

#         # Downsample second weight dimension [32,8,8,1] -> [32,32,32,1]
#         inputs[1] = self._upsample_weight(inputs[1], dim=1, index = 1)
#         inputs[1] = self._upsample_weight(inputs[1], dim=2, index = 1)


#         # Downsample third weight dimension [32,8,1,1] -> [32,32,1,1]
#         inputs[2] = self._upsample_weight(inputs[2], dim=1, index = 2)

#         return tuple(inputs)

#     def upsample_input_biases(self, inputs):
#         """Upsample the input biases to the specified dimensions."""
#         inputs = list(inputs)

#         # Downsample first bias dimension [32,8,1] -> [32,32,1]
#         inputs[0] = self._upsample_bias(inputs[0], index = 0)

#         # Downsample second bias dimension [32,8,1] -> [32,32,1]
#         inputs[1] = self._upsample_bias(inputs[1], index = 1)

#         return tuple(inputs)
    
#     def _upsample_weight(self, weight, dim, index):
#         d0 = weight.shape[dim]
#         up_sample = GeneralSetLayer(
#             in_features=d0,
#             out_features=self.upsample_dim,
#             reduction="attn",
#             bias=self.bias,
#             n_fc_layers=self.n_fc_layers,
#             num_heads=self.num_heads,
#             set_layer="ds",
#         ).to(device)
        
#         wi = weight.permute(0, 3, 1, 2) if dim == 2 else weight.permute(0, 3, 2, 1)
#         wi = up_sample(wi)
#         wi = wi.permute(0, 2, 3, 1) if dim == 2 else wi.permute(0, 3, 2, 1)
#         return wi 

#     def _upsample_bias(self, bias, index):
#         d0 = bias.shape[1]
#         up_sample = GeneralSetLayer(
#             in_features=d0,
#             out_features=self.upsample_dim,
#             reduction="attn",
#             bias=self.bias,
#             n_fc_layers=self.n_fc_layers,
#             num_heads=self.num_heads,
#             set_layer="ds",
#         ).to(device)
        
#         bi = bias.permute(0, 2, 1)
#         bi = up_sample(bi)
#         bi = bi.permute(0, 2, 1)
#         return bi


#     def forward(self, x):
#         # Decode weight tensors
#         weight_space = self.upsample_input_weights(x[0])
#         bias_space = self.upsample_input_biases(x[1])
        
#         # Process weight input 1 (32, 2, 32, 1)
#         x1 = weight_space[0].view(weight_space[0].size(0), -1) 
#         x1 = self.fc_weight_1(x1)
#         x1 = x1.view(weight_space[0].size())
        
#         # Process weight input 2 (32, 32, 32, 1)
#         x2 = weight_space[1].view(weight_space[1].size(0), -1) 
#         x2 = self.fc_weight_2(x2)
#         x2 = x2.view(weight_space[1].size())  # Output similar to (32, 32, 32, 1)
        
#         # Process weight input 3 (32, 32, 1, 1)
#         x3 = weight_space[2].view(weight_space[2].size(0), -1) 
#         x3 = self.fc_weight_3(x3)
#         x3 = x3.view(weight_space[2].size())  # Output similar to (32, 32, 1, 1)
        
#         # Process bias input 1 (32, 32, 1)
#         b1 = bias_space[0].view(bias_space[0].size(0), -1)  # Flatten (32, 32, 1) to (32, 32)
#         b1 = self.fc_bias_1(b1)
#         b1 = b1.view(bias_space[0].size())  # Reshape back to (32, 32, 1)
        
#         # Process bias input 2 (32, 1, 1)
#         b2 = bias_space[1].view(bias_space[1].size(0), -1)  # Flatten (32, 32, 1) to (32, 32)
#         b2 = self.fc_bias_2(b2)
#         b2 = b2.view(bias_space[1].size())  # Reshape back to (32, 1, 1)
        
#         # Process bias input 3 (32, 1, 1)
#         b3 = bias_space[2].view(bias_space[2].size(0), -1)  # Flatten (32, 1, 1) to (32, 1)
#         b3 = self.fc_bias_3(b3)  # Using the same fc2 layer as bias_input2, or define another if needed
#         b3 = b3.view(bias_space[2].size())  # Reshape back to (32, 1, 1)
        
#         new_weight_space = (x1, x2, x3)
#         new_bias_space = (b1, b2, b3)

#         return (new_weight_space, new_bias_space)

In [15]:
import torch
import torch.nn as nn

class Decoder(nn.Module,
             ):
    def __init__(self, upsample_dim,
                weight_shapes,
                bias_shapes,
                input_features,
                hidden_dims,
                downsample_dim,
                n_hidden=2,
                reduction="max",
                set_layer="sab",
                add_layer_skip=False,
                input_dim_downsample=None,
                init_scale=1.,
                init_off_diag_scale_penalty=1.,
                bn=False,
                dropout_rate = 0.001,
                diagonal=False,
                 bias=True,
                 n_fc_layers=1,
                 num_heads=4):
        
        super(Decoder, self).__init__()
        
        self.fc_weight_1 = nn.Sequential(
            nn.Linear(32 * 2 * 1, 16 * 2 * 1),  # Assuming input size [batch, 8, 1, 1], adjust for your dimensions
            nn.ReLU(),
            nn.Linear(16 * 2 * 1, 32 * 2 * 1),
        )
        
        self.fc_weight_2 = nn.Sequential(
            nn.Linear(32 * 32 * 1, 16 * 32 * 1),
            nn.ReLU(),
            nn.Linear(16 * 32 * 1, 32 * 32 * 1),
        )
        self.fc_weight_3 = nn.Sequential(
            nn.Linear(32 * 1 * 1, 16 * 1 * 1),
            nn.ReLU(),
            nn.Linear(16 * 1 * 1, 32 * 1 * 1),
        )

        # For Bias Reconstruction
        self.fc_bias_1 = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
        )
        self.fc_bias_2 = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
        )
        self.fc_bias_3 = nn.Sequential(
            nn.Linear(1, 1),
            nn.ReLU(),
            nn.Linear(1, 1),
        )
        
        self.upsample_dim = upsample_dim
        self.bias = bias
        self.n_fc_layers = n_fc_layers
        self.num_heads = num_heads
        
        self.InitialLayer = DWSModel(
            weight_shapes= weight_shapes,
            bias_shapes= bias_shapes,
            input_features=input_features,
            hidden_dim=hidden_dims,
            n_hidden=n_hidden,
            reduction=reduction,
            bias=bias,
            output_features=input_features,
            n_fc_layers=n_fc_layers,
            num_heads=num_heads,
            set_layer=set_layer,
            dropout_rate=dropout_rate,
            input_dim_downsample=input_dim_downsample,
            init_scale=init_scale,
            init_off_diag_scale_penalty=init_off_diag_scale_penalty,
            bn=bn,
            add_skip=False,
            add_layer_skip=add_layer_skip,
            diagonal=diagonal,
        )   
        
    def upsample_input_weights(self, inputs):
        """Upsample the input weights to the specified dimensions."""
        inputs = list(inputs)

        # Downsample first weight dimension [32,2,8,1] -> [32,2,32,1]
        inputs[0] = self._upsample_weight(inputs[0], dim=2 , index = 0)

        # Downsample second weight dimension [32,8,8,1] -> [32,32,32,1]
        inputs[1] = self._upsample_weight(inputs[1], dim=1, index = 1)
        inputs[1] = self._upsample_weight(inputs[1], dim=2, index = 1)


        # Downsample third weight dimension [32,8,1,1] -> [32,32,1,1]
        inputs[2] = self._upsample_weight(inputs[2], dim=1, index = 2)

        return tuple(inputs)

    def upsample_input_biases(self, inputs):
        """Upsample the input biases to the specified dimensions."""
        inputs = list(inputs)

        # Downsample first bias dimension [32,8,1] -> [32,32,1]
        inputs[0] = self._upsample_bias(inputs[0], index = 0)

        # Downsample second bias dimension [32,8,1] -> [32,32,1]
        inputs[1] = self._upsample_bias(inputs[1], index = 1)

        return tuple(inputs)
    
    def _upsample_weight(self, weight, dim, index):
        d0 = weight.shape[dim]
        up_sample = GeneralSetLayer(
            in_features=d0,
            out_features=self.upsample_dim,
            reduction="attn",
            bias=self.bias,
            n_fc_layers=self.n_fc_layers,
            num_heads=self.num_heads,
            set_layer="ds",
        ).to(device)
        
        wi = weight.permute(0, 3, 1, 2) if dim == 2 else weight.permute(0, 3, 2, 1)
        wi = up_sample(wi)
        wi = wi.permute(0, 2, 3, 1) if dim == 2 else wi.permute(0, 3, 2, 1)
        return wi 

    def _upsample_bias(self, bias, index):
        d0 = bias.shape[1]
        up_sample = GeneralSetLayer(
            in_features=d0,
            out_features=self.upsample_dim,
            reduction="attn",
            bias=self.bias,
            n_fc_layers=self.n_fc_layers,
            num_heads=self.num_heads,
            set_layer="ds",
        ).to(device)
        
        bi = bias.permute(0, 2, 1)
        bi = up_sample(bi)
        bi = bi.permute(0, 2, 1)
        return bi


    def forward(self, x):
        # Decode weight tensors
        weight_space = self.upsample_input_weights(x[0])
        bias_space = self.upsample_input_biases(x[1])

        output = self.InitialLayer((weight_space , bias_space))
        # Process weight input 1 (32, 2, 32, 1)
#         x1 = weight_space[0].reshape(weight_space[0].size(0), -1) 
#         x1 = self.fc_weight_1(x1)
#         x1 = x1.reshape(weight_space[0].size())
        
#         # Process weight input 2 (32, 32, 32, 1)
#         x2 = weight_space[1].reshape(weight_space[1].size(0), -1) 
#         x2 = self.fc_weight_2(x2)
#         x2 = x2.reshape(weight_space[1].size())  # Output similar to (32, 32, 32, 1)
        
#         # Process weight input 3 (32, 32, 1, 1)
#         x3 = weight_space[2].reshape(weight_space[2].size(0), -1) 
#         x3 = self.fc_weight_3(x3)
#         x3 = x3.reshape(weight_space[2].size())  # Output similar to (32, 32, 1, 1)
        
#         # Process bias input 1 (32, 32, 1)
#         b1 = bias_space[0].reshape(bias_space[0].size(0), -1)  # Flatten (32, 32, 1) to (32, 32)
#         b1 = self.fc_bias_1(b1)
#         b1 = b1.reshape(bias_space[0].size())  # Reshape back to (32, 32, 1)
        
#         # Process bias input 2 (32, 1, 1)
#         b2 = bias_space[1].reshape(bias_space[1].size(0), -1)  # Flatten (32, 32, 1) to (32, 32)
#         b2 = self.fc_bias_2(b2)
#         b2 = b2.reshape(bias_space[1].size())  # Reshape back to (32, 1, 1)
        
#         # Process bias input 3 (32, 1, 1)
#         b3 = bias_space[2].reshape(bias_space[2].size(0), -1)  # Flatten (32, 1, 1) to (32, 1)
#         b3 = self.fc_bias_3(b3)  # Using the same fc2 layer as bias_input2, or define another if needed
#         b3 = b3.reshape(bias_space[2].size())  # Reshape back to (32, 1, 1)
        
#         new_weight_space = (x1, x2, x3)
#         new_bias_space = (b1, b2, b3)

        return output

In [16]:
# Our AutoEncoder using DWSModel
import numpy as np
class AutoEncoder(nn.Module):
    def __init__(self,
            input_features,
            weight_shapes,
            bias_shapes,
            hidden_dims,
            downsample_dim,
            n_hidden=2,
            reduction = "attn",
            input_dim_downsample=None,
            bn = False,
    ):
        super().__init__()
        self.encoder = DWSEncoder(weight_shapes=weight_shapes,
                                bias_shapes=bias_shapes,
                                input_features=input_features,
                                hidden_dims=hidden_dims,
                                downsample_dim = downsample_dim,
                                n_hidden=n_hidden,
                                reduction= reduction,
                                bn=bn).to(device)
        self.decoder = Decoder(upsample_dim= 32,weight_shapes=weight_shapes,
                                bias_shapes=bias_shapes,
                                input_features=input_features,
                                hidden_dims=hidden_dims,
                                downsample_dim = downsample_dim,
                                n_hidden=n_hidden,
                                reduction= reduction,
                                bn=bn).to(device)
        
    def forward(self,inputs):
        encoded_data = self.encoder(inputs)
        output = self.decoder(encoded_data)
        return encoded_data,output

Training MNIST

In [17]:
# import warnings
# warnings.filterwarnings("ignore")

# @torch.no_grad()
# def evaluate(model, loader):
#     model.eval()
#     loss = 0.0
#     correct = 0.0
#     total = 0.0
#     predicted, gt = [], []
#     for batch in loader:
#         batch = batch.to(device)
#         inputs = (batch.weights, batch.biases)
#         out = model(inputs)
#         loss += F.cross_entropy(out, batch.label, reduction="sum")
#         total += len(batch.label)
#         pred = out.argmax(1)
#         correct += pred.eq(batch.label).sum()
#         predicted.extend(pred.cpu().numpy().tolist())
#         gt.extend(batch.label.cpu().numpy().tolist())

#     model.train()
#     avg_loss = loss / total
#     avg_acc = correct / total

#     return dict(avg_loss=avg_loss, avg_acc=avg_acc, predicted=predicted, gt=gt)

# model = AutoEncoder(
#     input_features=1,
#     weight_shapes = weight_shapes, 
#     bias_shapes = bias_shapes,
#     downsample_dim = 8,
#     hidden_dims=32,
#     reduction = "max",
#     n_hidden=8,
#     bn=False,
# ).to(device)

# optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3, amsgrad=True, weight_decay=5e-4)
# epochs = 5

# epoch_iter = trange(epochs)
# criterion = nn.CrossEntropyLoss()

# for epoch in epoch_iter:
#     for i, batch in enumerate(train_loader):
#         model.train()
#         optimizer.zero_grad()

#         batch = batch.to(device)
#         inputs = (batch.weights, batch.biases)
#         out = model(inputs)

#         loss = criterion(out, batch.label)
#         loss.backward()
#         optimizer.step()

#         epoch_iter.set_description(  
#             f"[{epoch} {i+1}], train loss: {loss.item():.3f}"
#         )
#     test_loss_dict = evaluate(model, test_loader)
#     test_acc = test_loss_dict['avg_acc'].item()
#     print(f"test accuracy:{test_acc:.4f}")
#     torch.save(model.state_dict(), f"Outputs/model_epoch_{epoch}.pth")

Training Autoencoder

In [18]:
def sinkhorn_loss(preds, targets, epsilon=0.1, num_iters=100):
   
    # Determine the dimensions of preds and targets
    preds_dims = preds.dim()
    targets_dims = targets.dim()
    
    if preds_dims == 4:  # preds is of shape (batch_size, n_sets, n_points_pred, dim)
        batch_size, n_sets, n_points_pred, _ = preds.shape
        n_points_target = targets.shape[2]
        
        # Compute the pairwise distance matrix (cost_matrix) between preds and targets
        cost_matrix = torch.cdist(preds.reshape(batch_size * n_sets, n_points_pred, -1), 
                                   targets.reshape(batch_size * n_sets, n_points_target, -1), 
                                   p=2).reshape(batch_size, n_sets, n_points_pred, n_points_target)

        # Initialize uniform marginals for both distributions
        mu = torch.full((batch_size, n_sets, n_points_pred), 1.0 / n_points_pred, device=preds.device)
        nu = torch.full((batch_size, n_sets, n_points_target), 1.0 / n_points_target, device=targets.device)
        
    elif preds_dims == 3:  # preds is of shape (batch_size, n_points_pred, dim)
        batch_size, n_points_pred, _ = preds.shape
        n_points_target = targets.shape[1]
        
        # Compute the pairwise distance matrix (cost_matrix) between preds and targets
        cost_matrix = torch.cdist(preds.reshape(batch_size, n_points_pred, -1), 
                                   targets.reshape(batch_size, n_points_target, -1), 
                                   p=2).reshape(batch_size, 1, n_points_pred, n_points_target)

        # Initialize uniform marginals for both distributions
        mu = torch.full((batch_size, 1, n_points_pred), 1.0 / n_points_pred, device=preds.device)
        nu = torch.full((batch_size, 1, n_points_target), 1.0 / n_points_target, device=targets.device)
    
    else:
        raise ValueError("Input tensors must be either 3D or 4D.")

    # Initialize dual variables (log scale to ensure positivity)
    u = torch.zeros_like(mu)
    v = torch.zeros_like(nu)

    # Perform Sinkhorn iterations
    for _ in range(num_iters):
        u = epsilon * (torch.log(mu + 1e-8) - torch.logsumexp(-cost_matrix / epsilon + v.unsqueeze(-2), dim=-1))
        v = epsilon * (torch.log(nu + 1e-8) - torch.logsumexp(-cost_matrix / epsilon + u.unsqueeze(-1), dim=-2))

    # Compute the optimal transport cost
    transport_plan = torch.exp((-cost_matrix + u.unsqueeze(-1) + v.unsqueeze(-2)) / epsilon)
    loss = torch.sum(transport_plan * cost_matrix, dim=[2, 3]).mean()
    
    return loss

# Chamfer Loss implementation
def chamfer_loss(x, y):
    x = x.unsqueeze(1)  # (n, 1, d)
    y = y.unsqueeze(0)  # (1, m, d)

    dist = torch.norm(x - y, dim=2)  # Shape (n, m)

    min_dist_x = dist.min(dim=1)[0]  # Min distance for each point in x
    min_dist_y = dist.min(dim=0)[0]  # Min distance for each point in y

    loss = min_dist_x.mean() + min_dist_y.mean()
    return loss

# Set Mean Squared Error implementation
def set_mse_loss(x, y):
    x_expanded = x.unsqueeze(1)  # (n, 1, d)
    y_expanded = y.unsqueeze(0)  # (1, m, d)

    squared_diff = (x_expanded - y_expanded) ** 2  # Shape (n, m, d)

    mse = squared_diff.mean(dim=-1)  # Shape (n, m)

    min_mse_x = mse.min(dim=1)[0]  # Min MSE for each point in x
    min_mse_y = mse.min(dim=0)[0]  # Min MSE for each point in y

    loss = (min_mse_x.mean() + min_mse_y.mean()) / 2
    return loss

def hungarian_loss(preds, targets):
    # Calculate the cost matrix (L2 distance between preds and targets)
    cost_matrix = torch.cdist(preds.unsqueeze(1), targets.unsqueeze(1), p=2).squeeze(1)

    # Solve the assignment problem
    row_ind, col_ind = torch.min(cost_matrix, dim=1)

    # Ensure indices are long tensors (required for indexing)
    row_ind = row_ind.long()
    col_ind = col_ind.long()

    # Calculate the total loss
    loss = cost_matrix[row_ind, col_ind].mean()

    return loss


def rotation_matrix(theta):
    """
    Create a 2D rotation matrix for a given angle theta (in radians).
    theta: float or tensor
    Returns a 2x2 tensor rotation matrix.
    """
    theta_tensor = torch.tensor(theta)  # Convert theta to tensor
    return torch.tensor([
        [torch.cos(theta_tensor), -torch.sin(theta_tensor)],
        [torch.sin(theta_tensor), torch.cos(theta_tensor)]
    ])

# Define the MSE loss with rotation
def mse_loss_with_rotation(x, y, num_rotations=36):
    """
    Compute the MSE reconstruction loss with discrete rotations applied.
    
    Args:
    - x (torch.Tensor): Original input sample (e.g., MNIST image).
    - y (torch.Tensor): Reconstructed output from the autoencoder.
    - num_rotations (int): Number of discrete rotations to consider (e.g., 36 means 10-degree steps).
    
    Returns:
    - loss (torch.Tensor): The minimum MSE loss across all rotations.
    """
    batch_size = x.size(0)
    
    # Initialize the minimum MSE as a large value
    min_mse = torch.full((batch_size,), float('inf'), device=x.device)
    
    for i in range(num_rotations):
        # Compute the rotation angle (in radians)
        theta = i * (2 * torch.pi / num_rotations)
        
        # Apply rotation to the reconstructed output y
        rot_matrix = rotation_matrix(theta).to(x.device)
        
        # Assume inputs have at least two dimensions (for rotation)
        if len(y.shape) >= 3 and y.shape[-1] == 2:  # Check if it's a 2D vector
            y_rot = torch.einsum('ij,bjk->bik', rot_matrix, y)  # Apply rotation to each sample
        else:
            # Skip rotation if it's not a 2D vector
            y_rot = y
        
        # Compute the MSE loss between the original input x and the rotated reconstruction y_rot
        mse = F.mse_loss(x, y_rot, reduction='none').mean(dim=[1])  # Mean over spatial dimensions
        
        # Track the minimum MSE over all rotations
        min_mse = torch.min(min_mse, mse)
    
    return min_mse.mean()

# Custom loss function for tuples of tensors
class TupleLoss(nn.Module):
    def __init__(self, loss_type='hungarian_loss'):
        super(TupleLoss, self).__init__()
        self.loss_type = loss_type  # Type of loss to compute

    def forward(self, output, target):
        weights1, biases1 = output
        weights2, biases2 = target

        # Calculate weight loss
        weight_loss = torch.mean(torch.stack(
            [self._calculate_loss(w1, w2) for w1, w2 in zip(weights1, weights2)]
        ))

        # Calculate bias loss
        bias_loss = torch.mean(torch.stack(
            [self._calculate_loss(b1, b2) for b1, b2 in zip(biases1, biases2)]
        ))

        # Total loss calculation
        total_loss = 0.5 * weight_loss + 0.5 * bias_loss
        return total_loss

    def _calculate_loss(self, w1, w2):
        # Calculate weight loss based on the specified loss type
        if self.loss_type == 'chamfer':
            return chamfer_loss(w1, w2)
        elif self.loss_type == 'sinkhorn':
            return sinkhorn_loss(w1, w2)
        elif self.loss_type == 'set_mse':
            return set_mse_loss(w1, w2)
        elif self.loss_type == 'hungarian':
            return hungarian_loss(w1, w2)
        elif self.loss_type == 'mse_loss_with_rotation':
            return mse_loss_with_rotation(w1,w2)
        else:
            raise ValueError("Invalid loss type specified")


In [19]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class TupleCosineLoss(nn.Module):
#     def __init__(self):
#         super(TupleCosineLoss, self).__init__()

#     def forward(self, output, target):
#         weights1, biases1 = output
#         weights2, biases2 = target

#         # Compute cosine similarity for weights
#         weight_loss = [1 - F.cosine_similarity(w1.reshape(1, -1), w2.reshape(1, -1), dim=1) for w1, w2 in zip(weights1, weights2)]
#         weight_loss = torch.mean(torch.stack(weight_loss))
        
#         # Compute cosine similarity for biases
#         bias_loss = [1 - F.cosine_similarity(b1.reshape(1, -1), b2.reshape(1, -1), dim=1) for b1, b2 in zip(biases1, biases2)]
#         bias_loss = torch.mean(torch.stack(bias_loss))
                
#         # Combine the losses
#         total_loss = 0.5 * weight_loss + 0.5 * bias_loss
        
#         return total_loss

In [20]:
# import matplotlib.pyplot as plt
# from sklearn.manifold import TSNE

# def plot_tsne(inputs, out):
#     weights_out, biases_out = out[0], out[1]
#     weights_inputs, biases_inputs = inputs[0], inputs[1]

#     # Function to flatten and pad data to ensure consistent dimensions
#     def flatten_and_pad(data):
#         max_length = max(len(d.cpu().detach().flatten().numpy()) for d in data)
#         return [np.pad(d.cpu().detach().flatten().numpy(), (0, max_length - len(d.cpu().detach().flatten().numpy())), 'constant') for d in data]

#     # Flattening and padding weights and biases for outputs and inputs
#     flattened_weights_out = flatten_and_pad(weights_out)
#     flattened_biases_out = flatten_and_pad(biases_out)
#     flattened_weights_inputs = flatten_and_pad(weights_inputs)
#     flattened_biases_inputs = flatten_and_pad(biases_inputs)

#     # Combining weights and biases
#     combined_data_out = [np.concatenate([w, b]) for w, b in zip(flattened_weights_out, flattened_biases_out)]
#     combined_data_inputs = [np.concatenate([w, b]) for w, b in zip(flattened_weights_inputs, flattened_biases_inputs)]

#     # Ensure all combined data has the same length
#     max_length = max(len(d) for d in combined_data_out + combined_data_inputs)
#     combined_data_out = [np.pad(d, (0, max_length - len(d)), 'constant') for d in combined_data_out]
#     combined_data_inputs = [np.pad(d, (0, max_length - len(d)), 'constant') for d in combined_data_inputs]

#     # Combine both inputs and outputs for t-SNE
#     combined_data = np.vstack([combined_data_inputs, combined_data_out])

#     # Applying t-SNE
#     tsne = TSNE(n_components=2, random_state=0)
#     tsne_results = tsne.fit_transform(combined_data)

#     # Split results back into inputs and outputs
#     tsne_results_inputs = tsne_results[:len(combined_data_inputs)]
#     tsne_results_out = tsne_results[len(combined_data_inputs):]

#     # Plotting
#     plt.figure(figsize=(10, 6))
    
#     # Plot t-SNE results for outputs
#     plt.scatter(tsne_results_out[:, 0], tsne_results_out[:, 1], label='Outputs', marker='o')
    
#     # Plot t-SNE results for inputs
#     plt.scatter(tsne_results_inputs[:, 0], tsne_results_inputs[:, 1], label='Inputs', marker='x')

#     # Adding titles and labels
#     plt.title('t-SNE of Combined Weights and Biases')
#     plt.xlabel('t-SNE Dimension 1')
#     plt.ylabel('t-SNE Dimension 2')
#     plt.legend()
#     plt.show()

In [21]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    loss = 0.0
    total = 0.0
    criterion = TupleLoss(loss_type='mse_loss_with_rotation')
    for batch in loader:
        batch = batch.to(device)
        inputs = (batch.weights, batch.biases)
        _,out = model(inputs)
        loss += criterion(out, inputs)
        total += 1

    model.train()
    avg_loss = loss / total

    return avg_loss

In [22]:
import logging
import torch
from tqdm import trange
    
def train_model(model):
    learning_rate = 1e-3
    num_epochs = 500
    criterion = TupleLoss(loss_type='mse_loss_with_rotation') 
    epoch_iter = trange(num_epochs)
    epoch_loss = -1
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.01, patience=5)
    previous_epoch_loss = 1000

    for epoch in epoch_iter:
        total_loss = 0
        counter = 0
        
        for i, batch in enumerate(train_loader):
            model.train() 
            optimizer.zero_grad()

            batch = batch.to(device)
            inputs = (batch.weights, batch.biases)
            _,out = model(inputs)

            loss = criterion(out, inputs)
            
            loss.backward()
            optimizer.step() 

            total_loss += loss.item()
            counter += 1

            epoch_iter.set_description(
                f"[{epoch} {i+1}], train loss: {loss.item():.5f}, epoch loss: {epoch_loss:.5f}"
            )
            
            
        epoch_loss = total_loss / counter

        scheduler.step(epoch_loss)
            
        if epoch_loss<previous_epoch_loss:
            model_path = f"Outputs/model_epoch_withSkip_mse_loss_with_rotation.pth"
            torch.save(model.state_dict(), model_path)
            previous_epoch_loss = epoch_loss
            
        if (epoch+1)%25 == 0:
             print(evaluate(model, test_loader))

    print("Training complete!")

In [26]:
# #Model with DWS DownSample Layers
weight_shapes = tuple(w.shape[:2] for w in point.weights)
bias_shapes = tuple(b.shape[:1] for b in point.biases)

model = AutoEncoder(
    input_features=1,
    weight_shapes = weight_shapes, 
    bias_shapes = bias_shapes,
    downsample_dim = 8,
    hidden_dims=32,
    reduction = "max",
    n_hidden=4,
    bn=True,
).to(device)


# for name, param in model.named_parameters():
#     if param.dim() >= 2:  # Initialize weights (only for tensors with 2 or more dimensions)
#         # Initialize convolutional layer weights using Kaiming normal initialization
#         if 'conv' in name and 'weight' in name:
#             nn.init.kaiming_normal_(param)
#         # Initialize linear layer weights using Xavier uniform initialization
#         elif 'fc' in name and 'weight' in name:
#             nn.init.xavier_uniform_(param)

model.load_state_dict(torch.load('Outputs/model_epoch_withSkip_mse_loss_with_rotation.pth'))
train_model(model)
print(evaluate(model, train_loader))

  model.load_state_dict(torch.load('Outputs/model_epoch_withSkip_mse_loss_with_rotation.pth'))
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  sta

KeyboardInterrupt: 

In [28]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

def knn_classifier(embeddings_train, labels_train, embeddings_test, labels_test, k=5):
    # Convert to CPU and numpy for sklearn
    embeddings_train_np = embeddings_train.cpu().numpy()
    labels_train_np = labels_train.cpu().numpy()
    embeddings_test_np = embeddings_test.cpu().numpy()
    labels_test_np = labels_test.cpu().numpy()

    # Initialize the KNN classifier
    knn = KNeighborsClassifier(n_neighbors=k)
    
    # Train the KNN classifier
    knn.fit(embeddings_train_np, labels_train_np)
    
    # Predict on the test embeddings
    predictions = knn.predict(embeddings_test_np)
    
    # Calculate accuracy
    accuracy = accuracy_score(labels_test_np, predictions)
    print(f"KNN Accuracy: {accuracy * 100:.2f}%")
    
    return accuracy

model.eval()
with torch.no_grad():
    embeddings_train, labels_train = [], []
    embeddings_test, labels_test = [], []

    # Collect embeddings and labels for train and test sets
    for batch in train_loader:
        batch = batch.to(device)
        inputs = (batch.weights, batch.biases)
        labels = batch.label
        
        # Get embeddings as a tuple of tuples: ((w1, w2, w3), biases)
        embeddings, _ = model(inputs)
        weight_embeddings, biases_embeddings = embeddings
        
        # Unpack the weight embeddings tuple (w1, w2, w3)
        w1, w2, w3 = weight_embeddings
        
        # Flatten each embedding (w1, w2, w3) and the biases
        w1_flat = w1.view(w1.size(0), -1)
        w2_flat = w2.view(w2.size(0), -1)
        w3_flat = w3.view(w3.size(0), -1)
        
        # Unpack the weight embeddings tuple (w1, w2, w3)
        b1, b2, b3 = biases_embeddings
        
        # Flatten each embedding (w1, w2, w3) and the biases
        b1_flat = b1.view(b1.size(0), -1)
        b2_flat = b2.view(b2.size(0), -1)
        b3_flat = b3.view(b3.size(0), -1)
        
        # Concatenate all flattened embeddings into a single vector per sample
        flattened_embeddings = torch.cat([w1_flat, w2_flat, w3_flat, b1_flat, b2_flat, b3_flat], dim=1)
        
        embeddings_train.append(flattened_embeddings)
        labels_train.append(labels)

    for batch in test_loader:
        batch = batch.to(device)
        inputs = (batch.weights, batch.biases)
        labels = batch.label
        
        # Get embeddings as a tuple of tuples: ((w1, w2, w3), biases)
        embeddings, _ = model(inputs)
        weight_embeddings, biases_embeddings = embeddings
        
        # Unpack the weight embeddings tuple (w1, w2, w3)
        w1, w2, w3 = weight_embeddings
        
        # Flatten each embedding (w1, w2, w3) and the biases
        w1_flat = w1.view(w1.size(0), -1)
        w2_flat = w2.view(w2.size(0), -1)
        w3_flat = w3.view(w3.size(0), -1)
        
        # Unpack the weight embeddings tuple (w1, w2, w3)
        b1, b2, b3 = biases_embeddings
        
        # Flatten each embedding (w1, w2, w3) and the biases
        b1_flat = b1.view(b1.size(0), -1)
        b2_flat = b2.view(b2.size(0), -1)
        b3_flat = b3.view(b3.size(0), -1)
        
        # Concatenate all flattened embeddings into a single vector per sample
        flattened_embeddings = torch.cat([w1_flat, w2_flat, w3_flat, b1_flat, b2_flat, b3_flat], dim=1)
        
        embeddings_test.append(flattened_embeddings)
        labels_test.append(labels)

    # Concatenate all the batches together
    embeddings_train = torch.cat(embeddings_train, dim=0)
    labels_train = torch.cat(labels_train, dim=0)
    embeddings_test = torch.cat(embeddings_test, dim=0)
    labels_test = torch.cat(labels_test, dim=0)

    # Call the KNN classifier
    knn_classifier(embeddings_train, labels_train, embeddings_test, labels_test)

  state_dict = torch.load(path, map_location=lambda storage, loc: storage)
  state_dict = torch.load(path, map_location=lambda storage, loc: storage)


KNN Accuracy: 10.09%


In [None]:
from nn.models import DWSModelForClassification


classfication_model = DWSModelForClassification(
    weight_shapes=weight_shapes,
    bias_shapes=bias_shapes,
    input_features=1,
    hidden_dim=32,
    n_hidden=4,
    bn=True,
).to(device)

classfication_model.load_state_dict(torch.load('Outputs/model_dws_classification.pth'))
epochs = 25 
@torch.no_grad()
def evaluate(classfication_model, loader):
    classfication_model.eval()
    loss = 0.0
    correct = 0.0
    total = 0.0
    predicted, gt = [], []
    for batch in loader:
        batch = batch.to(device)
        inputs = (batch.weights, batch.biases)
        out = classfication_model(model(inputs))
        loss += F.cross_entropy(out, batch.label, reduction="sum")
        total += len(batch.label)
        pred = out.argmax(1)
        correct += pred.eq(batch.label).sum()
        predicted.extend(pred.cpu().numpy().tolist())
        gt.extend(batch.label.cpu().numpy().tolist())

    classfication_model.train()
    avg_loss = loss / total
    avg_acc = correct / total

    return dict(avg_loss=avg_loss, avg_acc=avg_acc, predicted=predicted, gt=gt)

# optimizer = torch.optim.AdamW(params=classfication_model.parameters(), lr=1e-3, amsgrad=True, weight_decay=5e-4)

# epoch_iter = trange(epochs)
# criterion = nn.CrossEntropyLoss()

# for epoch in epoch_iter:
#     for i, batch in enumerate(train_loader):
#         classfication_model.train()
#         optimizer.zero_grad()

#         batch = batch.to(device)
#         inputs = (batch.weights, batch.biases)
#         out = classfication_model(model(inputs))

#         loss = criterion(out, batch.label)
#         loss.backward()
#         optimizer.step()

#         epoch_iter.set_description(
#             f"[{epoch} {i+1}], train loss: {loss.item():.3f}"
#         )
        
#     model_path = f"Outputs/model_dws_classification_with_autoencoder.pth"
#     torch.save(classfication_model.state_dict(), model_path)
#     test_loss_dict = evaluate(classfication_model, test_loader)
#     test_acc = test_loss_dict['avg_acc'].item()
#     print("accuracy for DWS Model for classification", test_acc)
    
test_loss_dict = evaluate(classfication_model, test_loader)
test_acc = test_loss_dict['avg_acc'].item()
print("accuracy for new Autoencoder", test_acc)