In [None]:
!pip install git+https://github.com/deepmind/dm-haiku

Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-1v_kml_9
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-1v_kml_9


In [None]:
pip install optax



In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
import optax
import sys
from math import ceil

In [None]:
sys.path.insert(0,'/content/gdrive/MyDrive/Colab Notebooks/Capstone Project/Dependencies')

In [None]:
import model_fashion

In [None]:
from typing import Iterator, Mapping
import json
import os
import numpy as np
import haiku as hk
import jax
import jax.numpy as jnp

In [None]:
def batching(tensor, batch_size):
    batch,sequence,embedding = tensor.shape
    n = int(ceil(batch/batch_size))
    m = (tensor[batch_size*i:batch_size*(i+1),:,:] for i in range(n))
    return m

LOAD TRAINING DATA

In [None]:
path_T = '/content/gdrive/MyDrive/Colab Notebooks/Capstone Project/Outfit_Processing/Archive and Analysis/pkl files/TrainingDataset'

In [None]:
image_np_T = np.load(os.path.join(path_T, 'outfitSequencesImage_50.npy'), allow_pickle=True)
caption_np_T = np.load(os.path.join(path_T, 'outfitSequencesCaption_50.npy'), allow_pickle=True)

In [None]:
image_np_T.shape, caption_np_T.shape

((50, 8, 512), (50, 8, 2756))

In [None]:
image_itr_T = batching(image_np_T, 5)
caption_itr_T = batching(caption_np_T, 5)

LOAD VALIDATION DATA

In [None]:
path_V = '/content/gdrive/MyDrive/Colab Notebooks/Capstone Project/Outfit_Processing/Archive and Analysis/pkl files/ValidationDataset'

In [None]:
image_np_V = np.load(os.path.join(path_V, 'outfitSequencesImage_validation.npy'), allow_pickle=True)
caption_np_V = np.load(os.path.join(path_V, 'outfitSequencesCaption_validation.npy'), allow_pickle=True)

In [None]:
image_np_V.shape, caption_np_V.shape

((1497, 8, 512), (1497, 8, 2756))

In [None]:
image_itr_V = batching(image_np_V[:20], 5)
caption_itr_V = batching(caption_np_V[:20], 5)

TRAINING

In [None]:
model = hk.transform(model_fashion.total_loss)

In [None]:
def train_model(train_ds, valid_ds) -> hk.Params:
  """Initializes and trains a model on train_ds, returning the final params."""
  rng = jax.random.PRNGKey(428)
  opt = optax.adam(1e-3)

  image_itr_T, caption_itr_T = train_ds
  image_itr_V, caption_itr_V = valid_ds

  @jax.jit
  def loss(params, x):
    pred = model.apply(params, None, x)
    return pred

  @jax.jit
  def update(step, params, opt_state, x):
    l, grads = jax.value_and_grad(loss)(params, x)
    grads, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, grads)
    return l, params, opt_state

  # Initialize state.
  try:
    sample_x = {'outfitSequencesImage': next(image_itr_T), 'outfitSequencesCaption': next(caption_itr_T)}
  except:
    return
  
  params = model.init(rng, sample_x)
  opt_state = opt.init(params)

  steps = range(200)  # this value have to be changed and should be more than total number of batches
  for step in steps:
    if step % 3 == 0:
      try:
        x_V = {'outfitSequencesImage': next(image_itr_V), 'outfitSequencesCaption': next(caption_itr_V)}
        print("Step {}: valid loss {}".format(step, loss(params, x_V))) # Include epoch information in print statement
      except:
        pass
        
    try:
      x_T = {'outfitSequencesImage': next(image_itr_T), 'outfitSequencesCaption': next(caption_itr_T)}
    except:
      return params
    train_loss, params, opt_state = update(step, params, opt_state, x_T)
    print("Step {}: train loss {}".format(step, train_loss))  # Include epoch information in print statement
    
  return params

In [None]:
params = train_model((image_itr_T, caption_itr_T), (image_itr_V, caption_itr_V))
params.keys()



Step 0: valid loss 15.300223350524902
Step 0: train loss 18.58026695251465
Step 1: train loss 42.18431854248047
Step 2: train loss 43.62986373901367
Step 3: valid loss 41.84905242919922
Step 3: train loss 38.552059173583984
Step 4: train loss 42.10193634033203
Step 5: train loss 48.29399108886719
Step 6: valid loss 44.854408264160156
Step 6: train loss 44.37632751464844
Step 7: train loss 48.59300994873047
Step 8: train loss 35.855892181396484
Step 9: valid loss 36.57114791870117


dict_keys(['linear', 'lstm/linear', 'lstm_1/linear', 'mlp/~/linear_0', 'mlp_1/~/linear_0', 'visual_semantic'])

SAVE PARAMETERS

In [None]:
import pickle

param_path = '/content/gdrive/MyDrive/Colab Notebooks/Capstone Project/params.p'

pickle.dump( params, open( param_path, "wb" ) )