In [28]:
import pickle
import functools
from absl import flags
from absl import logging
import numpy as np

import data_util
import tensorflow.compat.v2 as tf

In [2]:
with open('spectrograms/entire_data.pickle','rb') as f:
      ds=pickle.load(f)

In [3]:
def get_preprocess_fn(is_training, is_pretrain):
      return functools.partial(
          data_util.preprocess_image,
          height=480,
          width=640,
          is_training=True,
          color_jitter_strength=1.0,
          test_crop=True)

In [4]:
preprocess_fn_pretrain = get_preprocess_fn(True, is_pretrain=True)

In [5]:
def map_fn(image, label):
    xs = []
    for _ in range(2):  # Two transformations
        xs.append(preprocess_fn_pretrain(image))
    image = tf.concat(xs, -1)
    label = tf.one_hot(label, 3)
    return image, label

In [6]:
augmented_list=[]

In [7]:
for i in range(len(ds)):
    augmented_list.append(map_fn(ds[i][0],ds[i][1]))   

In [8]:
augmented_list[0][0].shape

TensorShape([480, 640, 6])

In [13]:
augmented_list[0][0].dtype

tf.float32

In [9]:
augmented_list[0][1].shape

TensorShape([3])

In [18]:
len(augmented_list)

55

In [26]:
batch_size=4

In [30]:
augmented_list_batch_wise=[]
curr_batch_data=[]
curr_batch_labels=[]
for sample_idx in range(len(augmented_list)):
    if(sample_idx%batch_size==0 and len(curr_batch_data)>0):
        augmented_list_batch_wise.append((np.stack(curr_batch_data,axis=0),np.stack(curr_batch_labels,axis=0)))
        curr_batch_data=[]
        curr_batch_labels=[]
    curr_batch_data.append(augmented_list[sample_idx][0])
    curr_batch_labels.append(augmented_list[sample_idx][1]) 

In [31]:
len(augmented_list_batch_wise)

13

In [35]:
with open('spectrograms/augmented_list_batch_wise.pickle','wb')as f:
    pickle.dump(augmented_list_batch_wise,f)

In [50]:
with open('spectrograms/entire_data.pickle','rb') as f:
      ds=pickle.load(f)

In [51]:
ds[0][0].shape

(480, 640, 3)

In [52]:
def get_preprocess_fn(is_training, is_pretrain):
      return functools.partial(
          data_util.preprocess_image,
          height=480,
          width=640,
          is_training=True,
          color_jitter_strength=1.0,
          test_crop=True)

In [53]:
preprocess_fn_finetune = get_preprocess_fn(False, is_pretrain=False)

In [54]:
def map_fn(image, label):
    image = preprocess_fn_finetune(image)
    label = tf.one_hot(label, 3)
    return image, label

In [55]:
fine_tune_list=[]

In [56]:
for i in range(len(ds)):
    fine_tune_list.append(map_fn(ds[i][0],ds[i][1]))   

In [57]:
batch_size=4

In [58]:
fine_tune_list_batch_wise=[]

In [59]:
fine_tune_list_batch_wise=[]
curr_batch_data=[]
curr_batch_labels=[]
for sample_idx in range(len(fine_tune_list)):
    if(sample_idx%batch_size==0 and len(curr_batch_data)>0):
        fine_tune_list_batch_wise.append((np.stack(curr_batch_data,axis=0),np.stack(curr_batch_labels,axis=0)))
        curr_batch_data=[]
        curr_batch_labels=[]
    curr_batch_data.append(fine_tune_list[sample_idx][0])
    curr_batch_labels.append(fine_tune_list[sample_idx][1]) 

In [60]:
fine_tune_list_batch_wise[0][0].shape

(4, 480, 640, 3)

In [61]:
with open('spectrograms/fine_tune_list_batch_wise.pickle','wb')as f:
    pickle.dump(fine_tune_list_batch_wise,f)