In [1]:
import pandas as pd

from zoobot.data_utils import image_datasets, create_shards

In [2]:
labelled_catalog = (
    pd.read_csv('C:/Users/oryan/Documents/zoobot_new/data/extracted-coords-zoobot.csv')
    .rename(columns={'merging_merger':'merger'})
)

In [3]:
labelled_catalog_float = (
    labelled_catalog
    .assign(merger_float = labelled_catalog.merger.astype('float64'))
    .assign(id_str_str = labelled_catalog.id_str.astype('str'))
    .drop(columns=['merger','id_str'])
    .rename(columns={'merger_float':'merger','id_str_str':'id_str'})
)

In [4]:
labelled_catalog_float

Unnamed: 0.1,Unnamed: 0,RA,DEC,file_loc,merger,id_str
0,0,01:30:37.76,+13:12:52.0,C:\Users\oryan\Documents\zoobot\dataset-creati...,1.0,587724197207212176
1,1,01:58:16.35,-00:31:18.8,C:\Users\oryan\Documents\zoobot\dataset-creati...,1.0,587731512073650262
2,2,07:34:03.45,+43:32:41.3,C:\Users\oryan\Documents\zoobot\dataset-creati...,1.0,587738195036799112
3,3,07:57:33.42,+25:34:36.5,C:\Users\oryan\Documents\zoobot\dataset-creati...,1.0,587732156315402630
4,4,08:04:22.68,+40:38:55.9,C:\Users\oryan\Documents\zoobot\dataset-creati...,1.0,587728669878845746
...,...,...,...,...,...,...
995,995,09:40:10.78,+26:31:41.2,C:\Users\oryan\Documents\zoobot\dataset-creati...,0.0,587741392109895883
996,996,08:02:24.10,+36:08:53.6,C:\Users\oryan\Documents\zoobot\dataset-creati...,0.0,587728905564389735
997,997,12:43:48.13,+06:14:30.1,C:\Users\oryan\Documents\zoobot\dataset-creati...,0.0,588017723866218632
998,998,10:06:46.33,+08:36:23.5,C:\Users\oryan\Documents\zoobot\dataset-creati...,0.0,587734863219130540


In [5]:
shard_config = create_shards.ShardConfig(shard_dir = 'C:\\Users\\oryan\\Documents\\zoobot_new\\tfrecords\\',size=300)

In [6]:
shard_config.prepare_shards(
    labelled_catalog_float,
    unlabelled_catalog = None,
    test_fraction=0.20,
    val_fraction=0.10,
    labelled_columns_to_save=['id_str','file_loc','merger']
)

100%|██████████| 1000/1000 [00:00<00:00, 14918.65it/s]
  0%|          | 0/1 [00:00<?, ?shards/s]

Checking no missing files


100%|██████████| 1/1 [00:03<00:00,  3.87s/shards]
100%|██████████| 1/1 [00:00<00:00,  1.85shards/s]
100%|██████████| 1/1 [00:01<00:00,  1.04s/shards]


In [7]:
from zoobot.data_utils import tfrecord_datasets

In [8]:
train_records = 'C:\\Users\\oryan\\Documents\\zoobot_new\\tfrecords\\train_shards\\s300_shard_0.tfrecord'
test_records = 'C:\\Users\\oryan\\Documents\\zoobot_new\\tfrecords\\test_shards\\s300_shard_0.tfrecord'
val_records = 'C:\\Users\\oryan\\Documents\\zoobot_new\\tfrecords\\val_shards\\s300_shard_0.tfrecord'

In [9]:
columns_to_save = ['merger']
batch_size = 64
raw_train_dataset = tfrecord_datasets.get_tfrecord_dataset(train_records,columns_to_save,batch_size,shuffle=True)
raw_test_dataset = tfrecord_datasets.get_tfrecord_dataset(test_records,columns_to_save,batch_size,shuffle=False)
raw_val_dataset = tfrecord_datasets.get_tfrecord_dataset(val_records,columns_to_save,batch_size,shuffle=False)



In [10]:
raw_train_dataset

<PrefetchDataset element_spec={'id_str': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'matrix': TensorSpec(shape=(None, None), dtype=tf.float32, name=None), 'merger': TensorSpec(shape=(None,), dtype=tf.float32, name=None)}>

In [11]:
from zoobot.estimators import preprocess

In [12]:
preprocess_config = preprocess.PreprocessingConfig(
    label_cols = ['merger'],
    input_size = 300,
    normalise_from_uint8=True,
    make_greyscale=True,
    permute_channels=False
)

In [13]:
train_dataset = preprocess.preprocess_dataset(raw_train_dataset,preprocess_config)
test_dataset = preprocess.preprocess_dataset(raw_test_dataset,preprocess_config)
val_dataset = preprocess.preprocess_dataset(raw_val_dataset,preprocess_config)

In [14]:
from zoobot.estimators import define_model

In [15]:
model = define_model.get_model(
    output_dim = 2,
    input_size = 300,
    crop_size = int(300 * 0.75),
    resize_size = 224,
    channels=1
)



In [16]:
from zoobot import schemas

In [17]:
schema = schemas.Schema({'merger':['merger_merger','merger_not_merger']}, {'merger'})

{merger, indices 0 to 1, asked after None: (0, 1)}


In [18]:
schema.question_index_groups

[(0, 1)]

In [19]:
from zoobot.training import losses

In [20]:
multiquestion_loss = losses.get_multiquestion_loss(schema.question_index_groups)
loss = lambda x, y: multiquestion_loss(x,y)/batch_size

In [21]:
import tensorflow as tf

In [22]:
model.compile(
    loss=loss,
    optimizer=tf.keras.optimizers.Adam()
)

In [23]:
model

<keras.engine.sequential.Sequential at 0x2896d88a508>

In [24]:
from zoobot.training import training_config

In [25]:
train_config = training_config.TrainConfig(
    log_dir = 'C:\\Users\\oryan\\Documents\\zoobot_tests\\model\\',
    epochs=50,
    patience=10
)

In [26]:
training_config.train_estimator(
    model,
    train_config,
    train_dataset,
    test_dataset,
    val_dataset,
    eager=True
)



Epoch 1/50

 Ending step:  0.0
11/11 - 156s - loss: 0.1834 - val_loss: 0.1907 - 156s/epoch - 14s/step
Epoch 2/50


KeyboardInterrupt: 