In [1]:
"""
Wrapper used to train VAE model for WfieldMouse data and plotting latent space
UMAPPING for it.
We can add further functionality here as needed.
"""
import os, sys
import argparse
import numpy as np
import random
import torch
import time
import import_ipynb
import DataClass as data
import vae as vae
import glob
from wfield import *
#import tooltip_plot as tooltip

importing Jupyter notebook from DataClass.ipynb
importing Jupyter notebook from vae.ipynb


In [None]:
#def function to convert bool flag/arg
def str2bool(v):
    """
    Str to Bool converter for wrapper script.
    This is used both for --from_ckpt flag, which
    is False by default but can be turned on either by listing the flag (without args)
    or by listing with an appropriate arg (which can be converted to a corresponding boolean)
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

#now set up args for model
parser = argparse.ArgumentParser(description='user args for vae model')

parser.add_argument('--data_dir', type=str, metavar='N', default='', \
help='Path to directory where data lives.')
parser.add_argument('--save_dir', type=str, metavar='N', default='', \
help='Dir where model params, latent projection maps and TB logs are saved to. Default is to save files to current dir.')
parser.add_argument('--batch-size', type=int, default=32, metavar='N', \
help='Input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',\
help='Number of epochs to train (default: 100)')
parser.add_argument('--seed', type=int, default=1, metavar='S', \
help='Random seed (default: 1)')
parser.add_argument('--save_freq', type=int, default=1, metavar='N', \
help='How many epochs to wait before saving training status.')
parser.add_argument('--test_freq', type=int, default=2, metavar='N', \
help='How many epochs to wait before testing.')
parser.add_argument('--from_ckpt', type=str2bool, nargs='?', const=True, default=False, \
help='Boolean flag indicating if training and/or reconstruction should be carried using a pre-trained model state.')
parser.add_argument('--ckpt_path', type=str, metavar='N', default='', \
help='Path to ckpt with saved model state to be loaded. Only effective if --from_ckpt == True.')

args = parser.parse_args("")
torch.manual_seed(args.seed)

#set up saving directory
if args.save_dir =='':
	args.save_dir = os.getcwd()
if args.save_dir != '' and not os.path.exists(args.save_dir):
	os.makedirs(args.save_dir)
else:
	pass

if __name__ == "__main__":
    args.data_dir = "/hdd/achint_files/wfield_data/frames_2_540_640_uint16.dat"
    print('hello1')
    args.save_dir = "/home/achint/Practice_code/Synthetic_dataset/Daniela_POISE_VAE/wfield"
    #data_path = glob(pjoin(args.data_dir,'*.dat'))[0] # do '*.dat' for the 50 GB file
    print('hello2')
    main_start = time.time()
    print('hello3')
    loaders_dict = data.setup_data_loaders(batch_size=args.batch_size, file_path=args.data_dir)
    model = vae.VAE(save_dir = args.save_dir)
    print('hello4')
    if args.from_ckpt == True:
        assert os.path.exists(args.ckpt_path), 'Oops, looks like ckpt file given does NOT exist!'
        print('='*40)
        print('Loading model state from: {}'.format(args.ckpt_path))
        model.load_state(filename = args.ckpt_path)
    model.train_loop(loaders_dict, epochs = args.epochs, test_freq = args.test_freq, save_freq = args.save_freq)
    model.get_latent_umap(loaders_dict, save_dir=args.save_dir, title = "Latent Space plot")
    #recons = model.get_recons(loaders_dict['dset'], [0, len(loaders_dict['train'].dataset)]) #am reconstructing entire dset here
    #n=len(loaders_dict['train'].dataset), img_range=img_range)
    main_end = time.time()
    print('Total model runtime (seconds): {}'.format(main_end - main_start))

hello1
hello2
hello3
Dataclass_here3
Dataclass_here
Dataclass_here1 (39123, 2, 540, 640)
Dataclass_here1 (100, 2, 540, 640)


In [3]:
data

memmap([[[[1491, 1509, 1478, ..., 1459, 1471, 1440],
          [1415, 1333, 1379, ..., 1414, 1348, 1342],
          [1417, 1355, 1366, ..., 1404, 1340, 1406],
          ...,
          [1500, 1533, 1505, ..., 1514, 1525, 1445],
          [1490, 1573, 1498, ..., 1482, 1477, 1480],
          [1663, 1678, 1665, ..., 1615, 1548, 1566]],

         [[1494, 1516, 1468, ..., 1459, 1444, 1433],
          [1404, 1362, 1378, ..., 1392, 1350, 1373],
          [1369, 1390, 1365, ..., 1440, 1419, 1399],
          ...,
          [1834, 1763, 1798, ..., 1594, 1595, 1525],
          [1792, 1806, 1800, ..., 1562, 1530, 1529],
          [1892, 1895, 1910, ..., 1658, 1628, 1612]]],


        [[[1486, 1503, 1453, ..., 1434, 1445, 1429],
          [1417, 1370, 1336, ..., 1415, 1388, 1327],
          [1411, 1420, 1400, ..., 1375, 1388, 1380],
          ...,
          [1499, 1558, 1565, ..., 1503, 1532, 1463],
          [1493, 1530, 1549, ..., 1465, 1460, 1460],
          [1725, 1595, 1631, ..., 1616, 1609, 15

In [None]:
2* 540* 640