In [1]:
import torch, argparse, gzip, os, warnings, copy, time, mlflow
import numpy as np, pytorch_lightning as pl
from tqdm.notebook import tqdm
from pytorch_lightning.loggers import MLFlowLogger
from mlflow.tracking.artifact_utils import get_artifact_uri, _get_root_uri_and_artifact_path
from data_loader import load_train_data, load_test_data
from models import S2ConvNet
# from s2cnn import s2_near_identity_grid, so3_near_identity_grid, SO3Convolution, S2Convolution, so3_integrate
from mlflow_helper import init_mlf_logger

In [2]:
TRAIN_SAMPLES = 10000
TRAIN_ROT = True
TEST_ROT = True

# MAX_EPOCHS = 20
MAX_EPOCHS = 3

In [3]:
if TRAIN_ROT:
    TRAIN_PATH = "s2_mnist_train_dwr_" + str(TRAIN_SAMPLES) + ".gz"
else:
    raise NotImplementedError('A non-rotated training set does not exist yet.')
    
if TEST_ROT:
    TEST_PATH = "s2_mnist_cs1.gz"
else:
    raise NotImplementedError('A non-rotated test set does not exist yet.')

if torch.cuda.is_available():
    print('GPU available: ' + torch.cuda.get_device_name())
else:
    raise RuntimeError('No GPU found.')

GPU available: NVIDIA GeForce RTX 2070 SUPER


In [4]:
train_data, test_data = load_train_data(TRAIN_PATH), load_test_data(TEST_PATH)

print("Total training examples: {}".format(len(train_data)))
print("Total test examples: {}".format(len(test_data)))

hparams = argparse.Namespace()

hparams.name = 'test_model'
hparams.train_batch_size = 32
hparams.test_batch_size = 32
hparams.num_workers = 0
hparams.lr = 1e-4
hparams.weight_decay = 0.

hparams.channels = [8, 16, 16, 24, 24, 32, 64]
hparams.bandlimit = [30, 15, 15, 8, 8, 4, 2]
hparams.kernel_max_beta = [0.0625, 0.0625, 0.125, 0.125, 0.25, 0.25, 0.5]
hparams.activation_fn = 'ReLU'
hparams.batch_norm = True
hparams.nodes = [64, 32]

Total training examples: 10000
Total test examples: 10000


In [5]:
tracking_uri='sqlite:///mlruns/database.db'
tag_dict = {"mlflow.runName": round(time.time()),
           "mlflow.user": "dschuh"}

mlf_logger, artifact_path = init_mlf_logger(experiment_name='test_log', tracking_uri=tracking_uri, tags=tag_dict, verbose=True)

run ID directory created
artifact directory created


In [6]:
model = S2ConvNet(hparams, train_data, test_data)
mlf_logger.experiment.set_tag(run_id=mlf_logger.run_id, key="model", value=model.__class__.__name__)

print(f"Number of trainable / total parameters: {model.count_trainable_parameters(), model.count_parameters()}")

monitor = 'val_acc'
mode = 'max'
early_stopping = pl.callbacks.EarlyStopping(monitor=monitor, min_delta=0., patience=10, mode=mode)
checkpoint = pl.callbacks.model_checkpoint.ModelCheckpoint(filepath=artifact_path, monitor=monitor, mode=mode)

log_dict = {'es_min_delta': early_stopping.min_delta,
           'es_mode': early_stopping.mode,
           'es_monitor': early_stopping.monitor,
           'es_patience': early_stopping.patience,
           'max_epochs': MAX_EPOCHS,
           'train_samples': len(train_data),
           'train_rot': TRAIN_ROT,
           'test_rot': TEST_ROT}

mlf_logger.log_hyperparams(log_dict)

trainer = pl.Trainer(gpus=1, max_epochs=MAX_EPOCHS, logger=mlf_logger, early_stop_callback=early_stopping, checkpoint_callback=checkpoint)

trainer.fit(model)

mlf_logger.experiment.log_param(run_id=mlf_logger.run_id, key='es_stopped_epoch', value=early_stopping.stopped_epoch)

best_model = torch.load(checkpoint.best_model_path)
model.load_state_dict(best_model['state_dict'])
model.eval()
test_results = trainer.test(model)

Number of trainable / total parameters: (156882, 156882)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | loss_function | CrossEntropyLoss | 0     
1 | conv          | Sequential       | 149 K 
2 | dense         | Sequential       | 6 K   


Validation sanity check: 0it [00:00, ?it/s]

load 0.pkl.gz... done
load 0.pkl.gz... done
load 1.pkl.gz... done
load 2.pkl.gz... done
load 1.pkl.gz... done
load 3.pkl.gz... done
load 4.pkl.gz... done
load 2.pkl.gz... done
load 14.pkl.gz... done
load 14.pkl.gz... done
load 15.pkl.gz... done
load 16.pkl.gz... done
load 15.pkl.gz... done
load 17.pkl.gz... done
load 16.pkl.gz... done
load 18.pkl.gz... done
load 19.pkl.gz... done
load 17.pkl.gz... done
load 20.pkl.gz... done




Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
TEST RESULTS
{'test_acc': tensor(0.8193), 'test_loss': 0.902073860168457}
--------------------------------------------------------------------------------


{'test_loss': 0.902073860168457, 'test_acc': 0.8192999958992004}

In [None]:
filename = 'test_results.pickle'

if os.path.isfile(os.path.join(artifact_path, filename)):
    filename = str(time.time()) + filename
    print('File already existed, timestamp was prepended to filename.')
    
with open(os.path.join(artifact_path, filename), 'wb') as file:
    pickle.dump(test_results, file)