In [1]:

import torch
from torch import nn
from torch.func import functional_call, grad
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import *
from torchvision import transforms
import pytorch_lightning as pl
import os
from threading import Thread

from mcdiff_module import MCDiffNet

print(torch.cuda.is_available())
torch.zeros(1).cuda()
torch.set_float32_matmul_precision('high')

True


In [2]:

class MNISTClass(MCDiffNet):
    def __init__(self, sigma, n_mc):
        super().__init__(sigma, n_mc)

        activ = nn.LeakyReLU()
        self.net = nn.Sequential(nn.Flatten(start_dim=1),
                                nn.Linear(784, 50), activ,
                                nn.Linear(50, 50), activ,
                                nn.Linear(50, 10),
                                nn.Softmax(dim=1)
                                )
        
        for layer in self.net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.zeros_(layer.bias)

        self.lossfunc = F.cross_entropy

    def forward(self, x):
        ypred = self.net(x)
        return ypred
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer
    
    # training_step already defined

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        ypred = self.net(x)

        loss = self.lossfunc(ypred, y)
        self.log('val_loss', loss)
        
        match = torch.eq(y, torch.argmax(ypred, dim=1))
        acc = torch.sum(match)/y.shape[0]
        self.log('val_acc', acc)


In [3]:
def run_case(hyperparams: list, train_loader: DataLoader, val_loader: DataLoader) -> None:
    mymodel = MNISTClass(*hyperparams).cuda()

    # training
    logger = pl.loggers.tensorboard.TensorBoardLogger('.', name='mcdiff_mnist_' + '_'.join([f'{p}' for p in hyperparams]))
    trainer = pl.Trainer(max_epochs=25, accelerator='auto', logger=logger)#gpus=4, num_nodes=8, precision=16, limit_train_batches=0.5)
    trainer.fit(mymodel, train_loader, val_loader)

def sweep_hyperparams(hyperparam_list: list, train_loader: DataLoader, val_loader: DataLoader, n_runs: int, hyperparams: list = []) -> None:
    if len(hyperparams) == len(hyperparam_list):
        for i in range(n_runs):
            print('-'*80)
            print(f'Running case with hyperparams {hyperparams}')
            print('-'*80)

            run_case(hyperparams, train_loader, val_loader)

    else:
        for hyperparam_i in hyperparam_list[len(hyperparams)]:
            new_hyperparams = hyperparams + [hyperparam_i]
            sweep_hyperparams(hyperparam_list, train_loader, val_loader, n_runs, new_hyperparams)
    

In [4]:
def main() -> None:
    # data
    dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
    size_train = int(len(dataset)*0.9)
    data_train, data_val = random_split(dataset, [size_train, len(dataset) - size_train])

    train_loader = DataLoader(data_train, batch_size=150, num_workers=48)
    val_loader = DataLoader(data_val, batch_size=150, num_workers=48)

    # model
    lambda0 = [1e-4]
    n_mc = [10000]
    n_runs = 1

    hyperparam_list = [lambda0, n_mc]
    sweep_hyperparams(hyperparam_list, train_loader, val_loader, n_runs)

if __name__ == "__main__":
    main()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


--------------------------------------------------------------------------------
Running case with hyperparams [0.0001, 10000]
--------------------------------------------------------------------------------


Missing logger folder: ./mcdiff_mnist_0.0001_10000
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params
------------------------------------
0 | net  | Sequential | 42.3 K
------------------------------------
42.3 K    Trainable params
0         Non-trainable params
42.3 K    Total params
0.169     Total estimated model params size (MB)


Epoch 3:  96%|█████████▌| 344/360 [2:19:51<06:30, 24.39s/it, v_num=0]      

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
