In [1]:
import torch, argparse, gzip, os, warnings, copy, time, mlflow, pickle
import numpy as np, pytorch_lightning as pl
from tqdm.notebook import tqdm
from pytorch_lightning.loggers import MLFlowLogger
from mlflow.tracking.artifact_utils import get_artifact_uri, _get_root_uri_and_artifact_path
from data_loader import load_test_data
from models import ConvNet
from mlflow_helper import init_mlf_logger

In [2]:
TEST_ROT = True

In [3]:
if TEST_ROT:
    TEST_PATH = "s2_mnist_cs1.gz"
else:
    raise NotImplementedError('A non-rotated test set does not exist yet.')

if torch.cuda.is_available():
    print('GPU available: ' + torch.cuda.get_device_name())
else:
    raise RuntimeError('No GPU found.')

GPU available: NVIDIA GeForce RTX 2070 SUPER


In [4]:
test_data = load_test_data(TEST_PATH)

print("Total test examples: {}".format(len(test_data)))

hparams = argparse.Namespace()

hparams.name = 'test_model'
hparams.test_batch_size = 32
hparams.num_workers = 0

hparams.channels = [13, 15, 22, 31, 141]
hparams.kernels = [5, 3, 9, 7, 3]
hparams.strides = [1, 1, 1, 1, 2]
hparams.activation_fn = 'ReLU'
hparams.batch_norm = True
hparams.nodes = [64, 32]

hparams.lr = 1e-3
hparams.weight_decay = 0

Total test examples: 10000


In [5]:
tracking_uri='sqlite:///mlruns/database.db'

tag_dict = {"mlflow.runName": round(time.time()),
           "mlflow.user": "dschuh"}

mlf_logger, artifact_path = init_mlf_logger(experiment_name='model_training', tracking_uri=tracking_uri, tags=tag_dict, verbose=True)

run ID directory created
artifact directory created


In [6]:
model = ConvNet(hparams, None, test_data)
mlf_logger.experiment.set_tag(run_id=mlf_logger.run_id, key="model", value=model.__class__.__name__+'_u')

print(f"Number of total parameters: {model.count_parameters()}")

monitor = 'val_acc'
mode = 'max'
checkpoint = pl.callbacks.model_checkpoint.ModelCheckpoint(filepath=artifact_path, monitor=monitor, mode=mode)


trainer = pl.Trainer(gpus=1, logger=mlf_logger, checkpoint_callback=checkpoint)

checkpoint_dict = {'state_dict': copy.deepcopy(model.state_dict()),
                  'hyper_parameters': pl.utilities.parsing.AttributeDict(vars(hparams))
                  }
    
assert not os.path.isfile(os.path.join(artifact_path, 'untrained.ckpt'))
torch.save(checkpoint_dict, os.path.join(artifact_path, 'untrained.ckpt'))




GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


Number of total parameters: 113761


In [7]:
model.eval()
test_results = trainer.test(model)

AttributeError: Missing attribute "lr"

In [None]:
filename = 'test_results.pickle'

if os.path.isfile(os.path.join(artifact_path, filename)):
    filename = str(time.time()) + filename
    print('File already existed, timestamp was prepended to filename.')
    
with open(os.path.join(artifact_path, filename), 'wb') as file:
    pickle.dump(test_results, file)

In [None]:
# dummy_checkpoint_path = 'mlruns/2/ff4a235add69448d9dca68d961c84955/artifacts/epoch=44.ckpt'
# dummy_checkpoint = torch.load(dummy_checkpoint_path)
# print(dummy_checkpoint.keys())
# print(type(dummy_checkpoint))

# my_dict = {'state_dict': copy.deepcopy(model.state_dict()),
#           'hyper_parameters': pl.utilities.parsing.AttributeDict(vars(hparams))
#           }

# print(my_dict.keys())
# print(type(my_dict['state_dict']))

# print(type(my_dict['hyper_parameters']))
# print(my_dict['hyper_parameters'])

# print(dummy_checkpoint['hyper_parameters'])
# print(type(dummy_checkpoint['hyper_parameters']))

# print(pl.utilities.parsing.AttributeDict(vars(hparams)))
# print(type(dummy_checkpoint['state_dict']))