In [1]:

import torch
from torch import nn
from torch.func import functional_call, grad
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import *
from torchvision import transforms
import pytorch_lightning as pl
import os
import numpy as np
from threading import Thread

from gc_module import ContNet

print(torch.cuda.is_available())
torch.zeros(1).cuda()
torch.set_float32_matmul_precision('high')

True


In [2]:

class MNISTClass(ContNet):
    def __init__(self, loglambda0: float, cont_lr: float, cont_reg: float, warmup_epochs: int):
        super().__init__(loglambda0, cont_lr, cont_reg, warmup_epochs)
        self.lossfunc = F.cross_entropy

        self.dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
        size_train = int(len(self.dataset)*0.9)
        self.data_train, self.data_val = random_split(self.dataset, [size_train, len(self.dataset) - size_train])

        activ = nn.ReLU()
        self.net = nn.Sequential(nn.Flatten(start_dim=1),
                                nn.Linear(784, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 10),
                                nn.Softmax(dim=1)
                                )
        
        for layer in self.net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight)
                #nn.init.zeros_(layer.weight)
                nn.init.zeros_(layer.bias)

    def forward(self, x):
        ypred = self.net(x)
        return ypred
    
    def configure_optimizers(self):
        # include logcontvar in optimizer
        optimizer = torch.optim.Adam([{'params': self.net.parameters()},
                                      {'params': (self.logcontvar,), 'lr': self.cont_lr}],
                                     lr=1e-4)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        self.log('param_norm', sum(p.pow(2.0).sum() for p in self.parameters()))
        
        opt = self.optimizers()
        opt.zero_grad()

        # add gaussian noise to parameters
        rand_samp, ref_params = self.perturb_params()
        
        # compute loss
        x, y = train_batch
        ypred = self.forward(x)
        loss = self.lossfunc(ypred, y)
        self.manual_backward(loss)
        
        # calculate accuracy
        match = torch.eq(y, torch.argmax(ypred, dim=1))
        acc = torch.sum(match.detach())/y.shape[0]

        # compute contvar gradient
        self.contvar_grad(rand_samp, loss)
        
        # reload reference parameters
        self.load_state_dict(ref_params)

        opt.step()
        self.log('train_loss', loss.detach())
        self.log('train_acc', acc)

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        ypred = self.net(x)

        loss = self.lossfunc(ypred, y)
        self.log('val_loss', loss)
        
        match = torch.eq(y, torch.argmax(ypred, dim=1))
        acc = torch.sum(match)/y.shape[0]
        self.log('val_acc', acc)
    
    def train_dataloader(self):
        return DataLoader(self.data_train, batch_size=250, shuffle=True, num_workers=16, persistent_workers=True)
        
    def val_dataloader(self):
        return DataLoader(self.data_val, batch_size=250, num_workers=16, persistent_workers=True)


In [3]:
def run_case(hyperparams: list) -> None:
    mymodel = MNISTClass(*hyperparams)

    # training
    epochs = 200
    logger = pl.loggers.tensorboard.TensorBoardLogger('.', name=f'l30_dcont_mnist_{epochs}_' + '_'.join([f'{p}' for p in hyperparams]))
    trainer = pl.Trainer(max_epochs=epochs, accelerator='gpu', logger=logger)
    trainer.fit(mymodel)

def sweep_hyperparams(hyperparam_list: list, n_runs: int, hyperparams: list = []) -> None:
    if len(hyperparams) == len(hyperparam_list):
        for i in range(n_runs):
            print('-'*80)
            print(f'Running case with hyperparams {hyperparams}')
            print('-'*80)

            run_case(hyperparams)

    else:
        for hyperparam_i in hyperparam_list[len(hyperparams)]:
            new_hyperparams = hyperparams + [hyperparam_i]
            sweep_hyperparams(hyperparam_list, n_runs, new_hyperparams)
    

In [4]:
def main() -> None:
    # model
    loglambda0 = [np.log(1e-8)]
    cont_lr = [1e-2]
    cont_reg = [0.0]
    warmup_epochs = [-1]
    n_runs = 5

    hyperparam_list = [loglambda0, cont_lr, cont_reg, warmup_epochs]
    sweep_hyperparams(hyperparam_list, n_runs)

if __name__ == "__main__":
    main()

--------------------------------------------------------------------------------
Running case with hyperparams [-18.420680743952367, 0.01, 0.0, -1]
--------------------------------------------------------------------------------


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: ./l30_dcont_mnist_200_-18.420680743952367_0.01_0.0_-1
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params
------------------------------------
0 | net  | Sequential | 111 K 
------------------------------------
111 K     Trainable params
0         Non-trainable params
111 K     Total params
0.445     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]torch.Size([250, 10]) torch.Size([250])
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  8.59it/s]torch.Size([250, 10]) torch.Size([250])
Epoch 0:  33%|███▎      | 71/216 [00:03<00:06, 21.91it/s, v_num=0]         