In [None]:
import torch
import torch.nn as nn
import math
from torch.nn.parameter import Parameter
import torch.nn.functional as F

# from torch.nn.modules.conv import _ConvNd
from typing import Union, Optional, Tuple, Union, List
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.utils import _single, _pair, _triple, _reverse_repeat_tuple
import torch.nn.init as init

from architectures import SubspaceConv2d, SubspaceLinear

"""
Intrinsic dims of MNIST - Fully-connected and CNN
"""

from timeit import default_timer as timer

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from tqdm import trange

# from data import load_mnist
# from net import SubspaceConv2d, SubspaceLinear


import torch
import torchvision
from torch.utils.data import DataLoader


class SubspaceConstrainedLeNet(nn.Module):
    def __init__(self, intrinsic_dim: int, device="cpu"):
        """
        Subspace constrained version of PyImageSearch's LeNet implementation
        """
        super().__init__()
        self.theta = Parameter(torch.empty((intrinsic_dim, 1), device=device))
        self.theta.data.fill_(0)

        self.conv1 = SubspaceConv2d(
            self.theta,
            in_channels=1,
            out_channels=20,
            kernel_size=(5, 5),
            stride=1,
            device=device,
        )
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv2 = SubspaceConv2d(
            self.theta,
            in_channels=20,
            out_channels=50,
            kernel_size=(5, 5),
            stride=1,
            device=device,
        )
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.flatten1 = nn.Flatten()

        self.fc1 = SubspaceLinear(
            self.theta, in_features=800, out_features=500, device=device
        )
        self.relu3 = nn.ReLU()

        self.fc2 = SubspaceLinear(
            self.theta, in_features=500, out_features=10, device=device
        )
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x = self.flatten1(x)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.logsoftmax(x)

        return x



In [8]:
import torch
import torch.nn as nn

import torch
import torch.nn as nn

def replace_layers(model, intrinsic_dim, device, theta=None):
    # Initialize theta if not provided
    if theta is None:
        theta = nn.Parameter(torch.empty((intrinsic_dim, 1), device=device))
        theta.data.fill_(0)
        print('theta', theta.shape)
        # Register theta as a parameter of the model
        # model.register_parameter('theta', theta)

    for name, module in model._modules.items():
        print(name)
        if isinstance(module, nn.Conv2d):
            # Get parameters from the existing Conv2d layer
            in_channels = module.in_channels
            out_channels = module.out_channels
            kernel_size = module.kernel_size
            stride = module.stride
            padding = module.padding
            dilation = module.dilation
            groups = module.groups
            bias = module.bias is not None
            padding_mode = module.padding_mode

            # Create a new SubspaceConv2d with the same parameters
            new_module = SubspaceConv2d(
                theta=theta,
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
                bias=bias,
                padding_mode=padding_mode,
                device=device
            )
            # Replace the module in the parent module's _modules
            model._modules[name] = new_module

        elif isinstance(module, nn.Linear):
            # Get parameters from the existing Linear layer
            in_features = module.in_features
            out_features = module.out_features
            bias = module.bias is not None

            # Create a new SubspaceLinear with the same parameters
            new_module = SubspaceLinear(
                theta=theta,
                in_features=in_features,
                out_features=out_features,
                bias=bias,
                device=device
            )
            # Replace the module
            model._modules[name] = new_module

        elif isinstance(module, nn.Module):
            # Recursively apply to child modules
            replace_layers(module, intrinsic_dim, device, theta=theta)


from utils import get_args
from architectures import load_architecture

args = get_args()

args.dataset = 'CIFAR10'
args.selection_method = 'random'
args.aug = 'aug'
args.loss_function = 'CLASSIC_AT'

args.iterations = 10
args.pruning_ratio = 0
args.delta = 1
args.batch_size = 24
args.init_lr = 0.001
args.freeze_epochs = 5
args.backbone = 'convnext_tiny' #deit_small_patch16_224.fb_in1k
args.ft_type = 'full_fine_tuning'

model = load_architecture(args, N=10, rank=0 )


replace_layers(model, 10, 'cuda')

./data
theta torch.Size([10, 1])
stem
0
1
stages
0
downsample
blocks
0
conv_dw
norm
mlp
fc1
act
drop1
norm
fc2
drop2
shortcut
drop_path
1
conv_dw


OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 MiB. GPU 0 has a total capacity of 2.94 GiB of which 10.12 MiB is free. Including non-PyTorch memory, this process has 2.92 GiB memory in use. Of the allocated memory 2.78 GiB is allocated by PyTorch, and 30.44 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

./data


In [4]:
model

ConvNeXt(
  (stem): Sequential(
    (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): ConvNeXtStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): ConvNeXtBlock(
          (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (norm): Identity()
            (fc2): Linear(in_features=384, out_features=96, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
          (shortcut): Identity()
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)


In [None]:


def load_mnist(flatten=True):
    if flatten is True:
        dataset_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Lambda(lambda x: torch.flatten(x)),
            ]
        )
    else:
        dataset_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
            ]
        )

    train = torchvision.datasets.MNIST(
        root="~/.torchdata/",
        download=True,
        # natively stored as PIL images
        transform=dataset_transform,
    )

    test = torchvision.datasets.MNIST(
        root="~/.torchdata/", download=True, train=False, transform=dataset_transform
    )

    train_loader = DataLoader(train, batch_size=100, shuffle=True)
    # If flatten
    # Returns (torch.Size([100, 784]), torch.Size([100]))
    # Else
    # Returns (torch.Size([100, 1, 28, 28]), torch.Size([100]))

    test_loader = DataLoader(test, batch_size=500, shuffle=False)

    return train_loader, test_loader



## Util functions
from tqdm import tqdm

def train(net, num_epochs, train_loader, device="cuda"):
    opt = torch.optim.Adam(net.parameters(), lr=1e-3)
    net.train()
    loss_history = []
    acc_history = []

    # Single progress bar over all epochs
    pbar = trange(
        len(train_loader) * num_epochs,
        bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}",
        ascii=True,
    )

    for _ in range(num_epochs):
        for batch_id, (features, target) in tqdm(enumerate(train_loader)):
            # forward pass, calculate loss and backprop!
            opt.zero_grad()
            preds = net(features.to(device))
            loss = F.nll_loss(preds, target.to(device))
            loss.backward()
            loss_history.append(loss.item())
            opt.step()

            pbar.update()

    # Verified don't need to return the net
    return loss_history, acc_history


def eval(net, test_loader, device="cuda"):
    net.eval()
    test_loss = 0
    correct = 0

    for features, target in test_loader:
        output = net(features.to(device))
        test_loss += F.nll_loss(output, target.to(device)).item()
        pred = torch.argmax(output, dim=-1)  # get the index of the max log-probability
        correct += pred.eq(target.to(device)).cpu().sum()

    test_loss = test_loss
    test_loss /= len(test_loader)  # loss function already averages over batch size
    accuracy = 100.0 * correct / len(test_loader.dataset)
    acc_history.append(accuracy)
    print(
        "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), accuracy
        )
    )

    return test_loss, correct.item()

In [2]:
import torch
from datasets import IndexedDataset, WeightedDataset, load_data
from torch.utils.data import DataLoader, DistributedSampler

from utils import get_args
from architectures import load_architecture

from tqdm.notebook import tqdm
from architectures import CustomModel, load_architecture, add_lora, set_lora_gradients #load_statedict

import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

args = get_args()

args.dataset = 'CIFAR10'
args.selection_method = 'random'
args.aug = 'aug'
args.loss_function = 'CLASSIC_AT'

args.iterations = 10
args.pruning_ratio = 0
args.delta = 1
args.batch_size = 24
args.init_lr = 0.001
args.freeze_epochs = 5
args.backbone = 'convnext_tiny' #deit_small_patch16_224.fb_in1k
args.ft_type = 'full_fine_tuning'



im_train_loader, im_test_loader = load_mnist(flatten=False)
flat_train_loader, flat_test_loader = load_mnist(flatten=True)


# ## Data

# im_train_loader, im_test_loader = load_mnist(flatten=False)
# flat_train_loader, flat_test_loader = load_mnist(flatten=True)

## Config


dims = [10, ] #30, 50, 100, 300, 500, 1000, 3000, 5000

num_reps = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
training_epochs = 5


## Intrinsic dim for fully-connected network on MNIST

fc_corrects = {}

for count, d in enumerate(dims):
    start_ts = timer()
    print(f"Training {num_reps} repetitions for intrinsic dimension: {d}")

    corrects_per_dim = {}

    for i in range(num_reps):
        print(i)
        ssnet = SubspaceConstrainedLeNet(intrinsic_dim=d, device=device)

        loss_history, acc_history = train(ssnet, 20, im_train_loader, device =device)
        test_loss, correct = eval(ssnet, im_test_loader, device=device)

        corrects_per_dim[i] = correct / 10000 * 100

    fc_corrects[d] = corrects_per_dim

    end_ts = timer()
    print(
        f"Time taken for dim {d}: {end_ts - start_ts:.2f}s. Remaining dims: {dims[count+1:]}"
    )


./data
Training 1 repetitions for intrinsic dimension: 10
0


  0%|          | 0/12000 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  5%|4         | 595/12000 [00:09<02:55, 64.83it/s]

0it [00:00, ?it/s]

 10%|9         | 1197/12000 [00:18<02:41, 66.90it/s]

0it [00:00, ?it/s]

 15%|#4        | 1798/12000 [00:27<02:38, 64.33it/s]

0it [00:00, ?it/s]

 20%|##        | 2400/12000 [00:36<02:27, 65.15it/s]

0it [00:00, ?it/s]

 25%|##4       | 2994/12000 [00:46<02:15, 66.33it/s]

0it [00:00, ?it/s]

 30%|##9       | 3596/12000 [00:55<02:06, 66.34it/s]

0it [00:00, ?it/s]

 35%|###4      | 4198/12000 [01:04<02:01, 63.96it/s]

0it [00:00, ?it/s]

 40%|###9      | 4799/12000 [01:14<01:49, 65.78it/s]

0it [00:00, ?it/s]

 45%|####5     | 5400/12000 [01:23<01:40, 65.67it/s]

0it [00:00, ?it/s]

 50%|####9     | 5995/12000 [01:32<01:33, 64.45it/s]

0it [00:00, ?it/s]

 55%|#####4    | 6595/12000 [01:42<01:28, 61.40it/s]

0it [00:00, ?it/s]

 60%|#####9    | 7196/12000 [01:52<01:25, 56.11it/s]

0it [00:00, ?it/s]

 65%|######4   | 7796/12000 [02:01<01:04, 65.55it/s]

0it [00:00, ?it/s]

 70%|######9   | 8398/12000 [02:10<00:54, 65.51it/s]

0it [00:00, ?it/s]

 75%|#######5  | 9000/12000 [02:19<00:43, 68.63it/s]

0it [00:00, ?it/s]

 80%|#######9  | 9595/12000 [02:28<00:37, 63.51it/s]

0it [00:00, ?it/s]

 85%|########4 | 10199/12000 [02:38<00:28, 64.04it/s]

0it [00:00, ?it/s]

 90%|########9 | 10794/12000 [02:47<00:18, 66.07it/s]

0it [00:00, ?it/s]

 95%|#########4| 11396/12000 [02:56<00:09, 61.13it/s]

0it [00:00, ?it/s]

100%|##########| 12000/12000 [03:06<00:00, 64.36it/s]


Test set: Average loss: 2.2977, Accuracy: 1137/10000 (11%)

Time taken for dim 10: 187.64s. Remaining dims: []


In [None]:

# Save results
df = pd.DataFrame(fc_corrects)
df.to_csv("fc-mnist-accuracy.csv")

tidydf = df.melt()
tidydf = tidydf.rename(columns={"variable": "num_id", "value": "acc"})

# Generate and save plot
fig = plt.figure(figsize=(14, 6))
sns.boxplot(data=tidydf, x="num_id", y="acc")
plt.xlabel("Intrinsic dimension")
plt.ylabel("Accuracy")
plt.title(
    "Accuracy of fully-connected network on MNIST, by constrained intrinsic dimension"
)
plt.savefig("fc-results.PNG", bbox_inches="tight")
plt.close(fig)

## Intrinsic dim for convolutional network on MNIST

conv_corrects = {}

for count, d in enumerate(dims):
    start_ts = timer()
    print(f"Training {num_reps} repetitions for intrinsic dimension: {d}")

    corrects_per_dim = {}

    for i in range(num_reps):
        ssnet = SubspaceConstrainedLeNet(intrinsic_dim=d, device=device)

        loss_history, acc_history = train(ssnet, 20, im_train_loader, device=device)
        test_loss, correct = eval(ssnet, im_test_loader, device=device)

        corrects_per_dim[i] = correct / 10000 * 100

    conv_corrects[d] = corrects_per_dim

    end_ts = timer()
    print(
        f"Time taken for dim {d}: {end_ts - start_ts:.2f}s. Remaining dims: {dims[count+1:]}!"
    )

# Save results
df = pd.DataFrame(conv_corrects)
df.to_csv("conv-mnist-accuracy.csv")

tidydf = df.melt()
tidydf = tidydf.rename(columns={"variable": "num_id", "value": "acc"})

# Generate and save plot
fig = plt.figure(figsize=(14, 6))
sns.boxplot(data=tidydf, x="num_id", y="acc")
plt.xlabel("Intrinsic dimension")
plt.ylabel("Accuracy")
plt.title("Accuracy of conv network on MNIST, by constrained intrinsic dimension")
plt.savefig("conv-results.PNG", bbox_inches="tight")
plt.close(fig)