Used for Colab 'Run all' command and Wandb experiment name:
- CIFAR or MNIST: CIFAR10 vs MNIST dataset (loaded using code)
- FC or CNN: MLP vs CNN model (created through code)

In [None]:
PROJECT_NAME = 'CIFAR-FC-SparseVsDense'

## Setup

Installs Pytorch Lightning and Wandb on Colab

In [None]:
!pip install pytorch-lightning wandb

In [None]:
# @title Import dependencies

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
import torch.optim as optim
import math

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [None]:
# @title Reproducibility stuff

import random
torch.manual_seed(42)
np.random.seed(42)
random.seed(0)

torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True  # Deterministic mode can have a performance impact
torch.backends.cudnn.benchmark = False

In [None]:
 # @title Setup Wandb

import wandb
from pytorch_lightning.loggers import WandbLogger

wandb.login()

from google.colab import output
output.enable_custom_widget_manager()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madrianrob[0m ([33msapienza-ml[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using device: {DEVICE}') 

Using device: cuda:0


In [None]:
wandb_logger = WandbLogger(project=PROJECT_NAME)

### Base Model

In [None]:
class BaseModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        res = self.model(x)
        return res

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        preds, loss, acc, ts = self._get_preds_metrics(batch)
        # logging
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        self.log("train_time_step", ts)
        return loss

    def validation_step(self, batch, batch_idx):
        # validation_step defines the validation loop
        preds, loss, acc, ts = self._get_preds_metrics(batch)
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        self.log("val_time_step", ts)
        return preds

    def test_step(self, batch, batch_idx):
        # test defines the test loop
        preds, loss, acc, ts = self._get_preds_metrics(batch)
        self.log("test_loss", loss)
        self.log("test_acc", acc)
        self.log("test_time_step", ts)
        return loss
    
    def _get_preds_metrics(self, batch):
        # time step (ms) calculation: https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964/6
        from torchmetrics.functional import accuracy

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()

        x, y = batch
        xd, yd = x.to(DEVICE), y.to(DEVICE)
        
        preds = self(xd)

        end.record()

        # Waits for everything to finish running
        torch.cuda.synchronize()

        time_step = start.elapsed_time(end)
        loss = F.cross_entropy(preds, yd)
        acc = accuracy(preds, yd, self.task, num_classes=self.output_size)
        preds = torch.argmax(preds, dim=1)

        return preds, loss, acc, time_step

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), 1e-3)
        return optimizer

## MNIST Data Setup

In [None]:
if("mnist" in PROJECT_NAME.lower()):
  print("loading mnist dataset")
  kwargs = {'num_workers': 1, 'pin_memory': True}
  train_set = MNIST('data', train=True, download=True,
                    transform=transforms.Compose([transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))]))
  test_set = MNIST('data', train=False,
                    transform=transforms.Compose([transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))]))
  # use 20% of training data for validation
  train_set_size = int(len(train_set) * 0.8)
  valid_set_size = len(train_set) - train_set_size

  batch_size = 64

  train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size])

  train_set = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  val_set = DataLoader(valid_set, batch_size=batch_size, shuffle=False)
  test_set = DataLoader(test_set, batch_size=batch_size, shuffle=False)

  classes = None

## CIFAR10 Data Setup

In [None]:
if("cifar" in PROJECT_NAME.lower()):
  print("loading cifar dataset")
  # Source: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
  transform = transforms.Compose(
      [transforms.Resize((28,28)),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  batch_size = 64

  train_set = CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)
  test_set = CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
  
  # use 20% of training data for validation
  train_set_size = int(len(train_set) * 0.8)
  valid_set_size = len(train_set) - train_set_size

  train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size])

  train_set = DataLoader(train_set, batch_size=batch_size,
                                            shuffle=True, num_workers=1)
  val_set = DataLoader(valid_set, batch_size=batch_size,
                                            shuffle=False, num_workers=1)
  test_set = DataLoader(test_set, batch_size=batch_size,
                                          shuffle=False, num_workers=1)

  classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

loading cifar dataset
Files already downloaded and verified
Files already downloaded and verified


## Callbacks

In [None]:
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(monitor='val_acc', mode='max')
 
class LogPredictionsCallback(Callback):
    '''
    Wandb logging
    '''
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        t = torch.cuda.get_device_properties(0).total_memory
        r = torch.cuda.memory_reserved(0)
        a = torch.cuda.memory_allocated(0)

        wandb_logger.log_metrics({'Train_Total_VRAM': t,
                'Train_Reserved_VRAM': r,
                'Train_Allocated_VRAM': a,
                'Train_Free_VRAM': r-a})

    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]

            if classes is not None:
              captions = [f'Ground Truth: {classes[y_i]} - Prediction: {classes[y_pred]}' for y_i, y_pred in zip(y[:n], outputs[:n])]
            else:
              captions = [f'Ground Truth: {y_i} - Prediction: {y_pred}' for y_i, y_pred in zip(y[:n], outputs[:n])]
            
            # Option 1: log images with `WandbLogger.log_image`
            wandb_logger.log_image(key='sample_images', images=images, caption=captions)

## Random Projections

In [None]:
'''
  Generates N x M matrix (gaussian random projections)
  
  @ortho: if true the returned matrix is orthogonal
'''
def generate_gaussian_rp(N, M, ortho=False):
  '''
  We want to generate a NxM matrix by usign a gaussian distribution
  Can also orthogonalize the matrix using QR decomposition
  '''
  if(ortho): # using the gaussian distribution we get an approximate orthogonal matrix, so orthogonalization may not be necessary
    rp = torch.randn(N, M, device=DEVICE)
    return torch.linalg.qr(rp)[0] # QR decomposition where Q is an orthogonal matrix
  else:
    return torch.randn(N, M, device=DEVICE) / math.sqrt(M)

'''
  Generates n_components x n_features matrix (sparse and very sparse random projections)

  @original: if true uses code from bit.ly/3ZOt9S0, otherwise we adjust by swapping d (subspace) and D (full space)
  @fullRange: if false values are either 0 or 1 before multiplication by np.sqrt(density), otherwise they are -1, 0 or 1
  @variation: sparse2 type which uses sqrt of D (full space) as elements p_ij in the projection matrix
  @sparse_type:
    - None: dense representation w/ dense operations
    - coo: sparse coo representation w/ sparse ops
    - csr: sparse csr representation w/ sprase ops
    - csc: sparse csc representation w/ sparse ops
'''
def generate_sparse_rp(n_features, n_components, original=False, sparse_type=None, fullRange=True, variation=False):
  '''
  Sparse Random Projection algorithm (Dimitris Achlioptas)
  We want to generate a NxM matrix by first having all 0s, then pick some (row, col) elements and assign 1
  After this, we can multiply by a constant value which will compensate for the fact that we have mostly 0s
  The resulting matrix should approximate an orthogonal matrix for a smoother energy landscape (subspace)

  Very Sparse Random Projections (Li et al.): we can reduce the density of the matrix even more
  '''

  from sklearn.utils.random import sample_without_replacement

  eps = 0.1
  denominator = (eps**2 / 2) - (eps**3 / 3)
  johnson_lindenstrauss_min_dim = (4 * np.log(n_components) / denominator).astype(np.int64)

  print("min dim should be: ", johnson_lindenstrauss_min_dim)

  density = 1 / np.sqrt(n_features) if original else 1 / np.sqrt(n_components)

  if density == 1:
    # skip index generation if totally dense
    binomial = torch.distributions.Binomial(total_count=1, probs=0.5)
    if(fullRange):
      components = binomial.sample((n_components, n_features)) * 2 - 1
    else:
      components = binomial.sample((n_components, n_features))
    components = 1 / np.sqrt(n_components) * components

  else:
    if(sparse_type != None):
      col_idx = torch.tensor([], dtype=torch.long, device=DEVICE)
      row_idx = torch.tensor([], dtype=torch.long, device=DEVICE)
      values = torch.tensor([], device=DEVICE)

    components = torch.zeros((n_components, n_features), dtype=torch.float, device=DEVICE)
    for i in range(n_components):
        # find the indices of the non-zero components for row i
        nnz_idx = torch.distributions.Binomial(total_count=n_features, probs=density).sample()
        # get nnz_idx column indices
        c_idx = torch.tensor(
            sample_without_replacement(
                n_population=n_features, n_samples=nnz_idx, random_state=42
            ),
            dtype=torch.long,
            device=DEVICE
        )

        if(fullRange):
          data = torch.distributions.Binomial(total_count=1, probs=0.5).sample(sample_shape=c_idx.size()).to(DEVICE) * 2 - 1 # row with values -1 or 1
        else:
          data = torch.distributions.Binomial(total_count=1, probs=0.5).sample(sample_shape=c_idx.size()).to(DEVICE) # row with values 0 or 1
        
        # assign data only to those columns
        if(sparse_type == None):
          components[i, c_idx] = data.float()
        else:
          # for sparse representations we first get coo represent. then convert it to the other types
          row_idx = torch.cat([row_idx, torch.ones(c_idx.shape[0], dtype=torch.long, device=DEVICE) * i], dim=0)
          col_idx = torch.cat([col_idx, c_idx], dim=0)
          
          if(variation): # pytorch doesn't support sparse multiplication for scalars so we have to do it here
            data *= np.sqrt(density) # sparse2 sparse rep.
          else:
            data *= np.sqrt(1 / density) / np.sqrt(n_components) if original else np.sqrt(1 / density) / np.sqrt(n_features)
          values = torch.cat([values, data], dim=0)

    if(sparse_type != None):
      idx = torch.cat([row_idx.unsqueeze(0), col_idx.unsqueeze(0)], dim=0)
      components = torch.sparse_coo_tensor(idx, values, size=(n_components, n_features), device=DEVICE) # sparse coo matrix
      if(sparse_type == 'csc'): # representation conversion
        components = components.to_sparse_csc()
      elif(sparse_type == 'csr'):
        components = components.to_sparse_csr()
    else:
      if(variation):
        components *= np.sqrt(density) #sparse2 dense rep.
      else:
        components *= np.sqrt(1 / density) / np.sqrt(n_components) if original else np.sqrt(1 / density) / np.sqrt(n_features)

  return components

## FC Layer with Random Projection

In [None]:
class Linear_Projected(nn.Module):
  def __init__(self, input_size, output_size, d_units, rp_gen_algorithm='gaussian'):
    super().__init__()
    d = d_units
    self.input_size = input_size
    self.output_size = output_size
    self.rp_gen_algorithm = rp_gen_algorithm
    
    self.theta_0 = torch.randn((input_size + 1) * output_size, device=DEVICE) / math.sqrt(input_size)

    self.theta_d = nn.Parameter(torch.randn(d, device=DEVICE) / math.sqrt(d))

    N = (input_size + 1) * output_size
    M = d
    
    print("No. original params: " + str(N))
    print("No. params used through projection: " + str(M))

    if (rp_gen_algorithm == 'gaussian'):
      self.P = generate_gaussian_rp(N,M).T
    elif(rp_gen_algorithm == 'sparse'):
      self.P = generate_sparse_rp(N,M)
    elif(rp_gen_algorithm == 'sparse_coo'):
      self.P = generate_sparse_rp(N,M, sparse_type='coo').T
    elif(rp_gen_algorithm == 'sparse_csr'):
      self.P = generate_sparse_rp(N,M, sparse_type='csr')
    elif(rp_gen_algorithm == 'sparse_csc'):
      self.P = generate_sparse_rp(N,M, sparse_type='csc')
    elif(rp_gen_algorithm == 'sparse_original'):
      self.P = generate_sparse_rp(N,M, original=True)
    elif(rp_gen_algorithm == 'sparse2'):
      self.P = generate_sparse_rp(N,M, variation=True, fullRange=False)
    elif(rp_gen_algorithm == 'sparse2full'):
      self.P = generate_sparse_rp(N,M, variation=True, fullRange=True)
    elif(rp_gen_algorithm == 'sparse2_csc'):
      self.P = generate_sparse_rp(N, M, variation=True, fullRange=False, sparse_type='csc')
    elif(rp_gen_algorithm == 'sparse2full_csc'):
      self.P = generate_sparse_rp(N, M, variation=True, fullRange=True, sparse_type='csc')
    elif(rp_gen_algorithm == 'ortho'):
      self.P = generate_gaussian_rp(N,M, ortho=True).T
    else:
      raise Exception("Supported random projections are: gaussian, sparse, sparse_coo, " +
                      "sparse_csr, sparse_csc, sparse_original, sparse2, sparse2full, ortho")

  def forward(self, xb):
    if(self.rp_gen_algorithm == 'sparse_coo'): # trick for mv multiplication pytorch support (vm mult doesn't work)
      temp = self.theta_0 + self.P @ self.theta_d
    else:
      temp = self.theta_0 + self.theta_d @ self.P
    t = xb @ (temp[:-self.output_size]).reshape(self.input_size, self.output_size)
    res =  t + temp[-self.output_size:]
    return res

class Linear(nn.Module):
  def __init__(self, input_size, output_size):
    super().__init__()
    self.input_size = input_size
    self.linear = nn.Linear(input_size, output_size)

  def forward(self, xb):
    return self.linear(xb)

## Fully Connected Neural Networks

### Model

In [None]:
class ParamRegularizedMLP(BaseModel):
    def __init__(self, input_size, output_size, num_mid_units, projected=True, d_units=[750, ], depth=3, rp_gen_algorithm='sparse_csc'):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.activation = nn.ReLU()
        
        self.task="multiclass"

        if(projected):
          if (depth < 2):
            self.model = nn.Sequential(nn.Flatten(), Linear_Projected(input_size, output_size, d_units[0]))
            return
          
          print("#layer_0")
          self.model = nn.Sequential(nn.Flatten(), Linear_Projected(input_size, num_mid_units, d_units[0], rp_gen_algorithm), self.activation)
          for i in range(depth-2):
            print("#layer_" + str(i + 1))
            
            self.model.append(Linear_Projected(num_mid_units, num_mid_units, d_units[i+1 if len(d_units)>1 else 0], rp_gen_algorithm))
            
            self.model.append(self.activation)
          
          print("#layer_final")
          self.model.append(Linear_Projected(num_mid_units, output_size, d_units[-1 if len(d_units)>1 else 0], rp_gen_algorithm))
        else:
          if (depth < 2):
            self.model = nn.Sequential(nn.Flatten(), Linear(input_size, output_size))
            return
          
          self.model = nn.Sequential(nn.Flatten(), Linear(input_size, num_mid_units), self.activation)
          for _ in range(depth-2):
            self.model.append(Linear(num_mid_units, num_mid_units))
            self.model.append(self.activation)
          
          self.model.append(Linear(num_mid_units, output_size))
      
        self.save_hyperparameters()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), 1e-3)
        return optimizer

### Run

In [None]:
if("fc" in PROJECT_NAME.lower()):
  model = ParamRegularizedMLP(784 * (1 if "mnist" in PROJECT_NAME.lower() else 3), 10, 200, True, [2000, 1050], 2, rp_gen_algorithm='ortho')

  wandb_logger.watch(model, log="all")

  t = torch.cuda.get_device_properties(0).total_memory
  r = torch.cuda.memory_reserved(0)
  a = torch.cuda.memory_allocated(0)

  print('Total VRAM:', t)
  print('Reserved VRAM:', r)
  print('Allocated VRAM:', a)
  print('Free VRAM:', r-a)

  trainer = pl.Trainer(max_epochs=30, accelerator="gpu", logger=wandb_logger,
                      callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5),
                                  LogPredictionsCallback(), checkpoint_callback])
  trainer.fit(model, train_dataloaders=train_set, val_dataloaders=val_set)

  wandb.finish()

#layer_0
No. original params: 470600
No. params used through projection: 2000
#layer_final
No. original params: 2010
No. params used through projection: 1050


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Total VRAM: 15843721216
Reserved VRAM: 7820279808
Allocated VRAM: 3775467520
Free VRAM: 4044812288


INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type       | Params
------------------------------------------
0 | activation | ReLU       | 0     
1 | model      | Sequential | 3.1 K 
------------------------------------------
3.1 K     Trainable params
0         Non-trainable params
3.1 K     Total params
0.012     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


0,1
Train_Allocated_VRAM,▁███████████████████████████████████████
Train_Free_VRAM,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train_Reserved_VRAM,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train_Total_VRAM,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁
train_acc,▁
train_loss,▁
train_time_step,▁
trainer/global_step,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Train_Allocated_VRAM,3776123392.0
Train_Free_VRAM,13430272.0
Train_Reserved_VRAM,3789553664.0
Train_Total_VRAM,15843721216.0
epoch,0.0
train_acc,0.17188
train_loss,2.2121
train_time_step,15.2384
trainer/global_step,49.0


## Convolutional Neural Networks

### Model

In [None]:
class Conv2dProj(nn.Module):
  def __init__(self, in_channels, kernel_size, out_channels, d_units, padding, stride, bias=True, rp_gen_algorithm='gaussian'):
    super(Conv2dProj, self).__init__()

    self.bias = bias

    n_kernel_params = in_channels * kernel_size**2 * out_channels

    if(bias):
      N = n_kernel_params + out_channels
    else:
      N = n_kernel_params
    
    d = d_units
    M = d
    
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.n_kernel_params = n_kernel_params
    
    self.theta_0 = torch.randn(N, device=DEVICE) / math.sqrt(N)

    self.theta_d = nn.Parameter(torch.randn(d, device=DEVICE) / math.sqrt(d))
    
    print("No. original params: " + str(N))
    print("No. params used by projection: " + str(d))

    if (rp_gen_algorithm == 'gaussian'):
      self.P = generate_gaussian_rp(N, M).T
    elif(rp_gen_algorithm == 'sparse'):
      self.P = generate_sparse_rp(N, M)
    elif(rp_gen_algorithm == 'sparse_coo'):
      self.P = generate_sparse_rp(N,M, sparse_type='coo').T
    elif(rp_gen_algorithm == 'sparse_csr'):
      self.P = generate_sparse_rp(N,M, sparse_type='csr')
    elif(rp_gen_algorithm == 'sparse_csc'):
      self.P = generate_sparse_rp(N,M, sparse_type='csc')
    elif(rp_gen_algorithm == 'sparse_original'):
      self.P = generate_sparse_rp(N, M, original=True)
    elif(rp_gen_algorithm == 'sparse2'):
      self.P = generate_sparse_rp(N, M, variation=True, fullRange=False)
    elif(rp_gen_algorithm == 'sparse2full'):
      self.P = generate_sparse_rp(N, M, variation=True, fullRange=True)
    elif(rp_gen_algorithm == 'sparse2_csc'):
      self.P = generate_sparse_rp(N, M, variation=True, fullRange=False, sparse_type='csc')
    elif(rp_gen_algorithm == 'ortho'):
      self.P = generate_gaussian_rp(N, M, True).T
    else:
      raise Exception("Supported random projections are: gaussian, sparse, sparse_coo, " +
                      "sparse_csr, sparse_csc, sparse_original, sparse2, sparse2full, ortho")

    self.padding = padding
    self.stride = stride

  def forward(self, xb):
    if (self.rp_gen_algorithm == 'sparse_coo'): # trick for mv multiplication pytorch support (vm mult doesn't work)
      kernel = self.theta_0 +  self.P @ self.theta_d
    else:
      kernel = self.theta_0 + self.theta_d @ self.P
    kernel_params = kernel[:self.n_kernel_params].reshape(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)

    bias = kernel[self.n_kernel_params:] if self.bias else None
    
    res = torch.nn.functional.conv2d(xb, kernel_params, bias, stride=self.stride, padding=self.padding)
    return res

In [None]:
class LeNet(BaseModel):
    
    def __init__(self, in_channels, output_size, projected=True, d_units=[4000, ], rp_gen_algorithm='gaussian'):
        super().__init__()

        self.activation = nn.ReLU()
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.task = "multiclass"
        self.output_size = output_size
        if(projected):
            multiple_values = len(d_units) > 1

            print("#layer_fe_0")
            if (multiple_values):
                self.model = nn.Sequential(Conv2dProj(in_channels, 5, 6, d_units[0], padding=2, stride=1, rp_gen_algorithm=rp_gen_algorithm), self.activation, self.pool)

                print("#layer_fe_1")
                self.model.append(Conv2dProj(6, 5, 16, d_units[1], rp_gen_algorithm=rp_gen_algorithm, padding=0, stride=1))
                self.model.append(self.activation)
                self.model.append(self.pool)
                self.model.append(nn.Flatten())

                print("#layer_classifier_0")
                self.model.append(Linear_Projected(400, 120, d_units[2], rp_gen_algorithm=rp_gen_algorithm))
                self.model.append(self.activation)

                print("#layer_classifier_1")
                self.model.append(Linear_Projected(120, 84, d_units[3], rp_gen_algorithm=rp_gen_algorithm))
                self.model.append(self.activation)

                print("#layer_classifier_final")
                self.model.append(Linear_Projected(84, output_size, d_units[4], rp_gen_algorithm=rp_gen_algorithm))
            else:
                self.model = nn.Sequential(Conv2dProj(in_channels, 5, 6, d_units[0], padding=2, stride=1, rp_gen_algorithm=rp_gen_algorithm), self.activation, self.pool)

                print("#layer_fe_1")
                self.model.append(Conv2dProj(6, 5, 16, d_units[0], rp_gen_algorithm=rp_gen_algorithm, padding=0, stride=1))
                self.model.append(self.activation)
                self.model.append(self.pool)
                self.model.append(nn.Flatten())

                print("#layer_classifier_0")
                self.model.append(Linear_Projected(400, 120, d_units[0], rp_gen_algorithm=rp_gen_algorithm))
                self.model.append(self.activation)

                print("#layer_classifier_1")
                self.model.append(Linear_Projected(120, 84, d_units[0], rp_gen_algorithm=rp_gen_algorithm))
                self.model.append(self.activation)

                print("#layer_classifier_final")
                self.model.append(Linear_Projected(84, output_size, d_units[0], rp_gen_algorithm=rp_gen_algorithm))
        else:
            self.model = nn.Sequential(nn.Conv2d(in_channels, 6, 5, padding=2, stride=1), self.activation, self.pool)

            self.model.append(nn.Conv2d(6, 16, 5, padding=0, stride=1))
            self.model.append(self.activation)
            self.model.append(self.pool)
            self.model.append(nn.Flatten())

            self.model.append(nn.Linear(400, 120))
            self.model.append(self.activation)

            self.model.append(nn.Linear(120, 84))
            self.model.append(self.activation)

            self.model.append(nn.Linear(84, output_size))
      
        self.save_hyperparameters()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), 1e-3)
        #optimizer = torch.optim.SGD(self.parameters(), 1e-2, momentum=0.9)
        return optimizer

In [None]:
class CNN(BaseModel):
    def __init__(self, in_channels, output_size):
        super().__init__()

        self.activation = nn.ReLU()
        self.pool = nn.AvgPool2d(kernel_size=2, stride=1)
        self.task = "multiclass"
        self.output_size = output_size

        self.model = nn.Sequential(nn.Conv2d(in_channels, 12, 5, padding=2, stride=2), self.activation, self.pool)
        self.model.append(nn.Conv2d(12, 7, 7, padding=0, stride=2))
        self.model.append(self.activation)
        self.model.append(self.pool)
        self.model.append(nn.Flatten())

        self.model.append(nn.Linear(63, output_size))

        self.n_parameters = sum(p.numel() for p in self.model.parameters())

        print("Num. parameters: ", self.n_parameters)
      
        self.save_hyperparameters()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), 1e-3)
        #optimizer = torch.optim.SGD(self.parameters(), 1e-2, momentum=0.9)
        return optimizer

### Run

In [None]:
if("cnn" in PROJECT_NAME.lower()):
  model = LeNet(1 if "mnist" in PROJECT_NAME.lower() else 3, 10, True, [200,500,2000,500,300], rp_gen_algorithm="gaussian")
  #model = CNN(1 if "mnist" in PROJECT_NAME.lower() else 3, 10)
  wandb_logger.watch(model, log="all")

  t = torch.cuda.get_device_properties(0).total_memory
  r = torch.cuda.memory_reserved(0)
  a = torch.cuda.memory_allocated(0)

  print('Total VRAM:', t)
  print('Reserved VRAM:', r)
  print('Allocated VRAM:', a)
  print('Free VRAM:', r-a)

  trainer = pl.Trainer(max_epochs=100, accelerator="gpu", logger=wandb_logger,
                      callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5),
                                  LogPredictionsCallback(), checkpoint_callback])
  trainer.fit(model, train_dataloaders=train_set, val_dataloaders=val_set)

  wandb.finish()