In [None]:
from jax import numpy as jnp
import jax
import transformers
import optax
from flax.training import train_state
import flax.linen as nn

import data
import modeling_bart
import arguments
import datasets

import os

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

In [None]:
TRAIN_DIR = './training'
TRAIN_DATA = './training/train_dataset'
DEV_DATA = './training/dev_dataset'
RANK_SCORE_PATH = './training'

PATH_TO_TSV = './training/json_files/sample_train.json'

data_args = arguments.DataArguments(train_dir=TRAIN_DIR,train_path=TRAIN_DATA,dev_path=DEV_DATA,rank_score_path=RANK_SCORE_PATH,max_len=512)
reranker_args = arguments.RerankerTrainingArguments(output_dir=os.path.join(TRAIN_DIR,'output'))

In [None]:
# config = transformers.BartConfig()
tokenizer = transformers.BartTokenizer.from_pretrained("facebook/bart-base")
# model = modeling_bart.FlaxBartMoresRanker(config=config)
query_model = modeling_bart.FlaxBartMoresRanker.from_pretrained('facebook/bart-base')
document_model = modeling_bart.FlaxBartMoresRanker.from_pretrained('facebook/bart-base')
# model = modeling_bart.FlaxBartMoresRanker.from_pretrained('facebook/bart-base')

In [None]:
train_dataset = data.GroupedTrainDataset(args=data_args,path_to_tsv=PATH_TO_TSV,tokenizer=tokenizer,train_args=reranker_args)

In [None]:
# @jax.jit
def compute_lce(params, input_ids, pos):
    logits = IB.apply({'params':params},input_ids)
    if pos:
        loss = optax.softmax_cross_entropy(logits=logits,labels=jnp.array([1,0],dtype=jnp.float32)).mean()
    else:
        loss = optax.softmax_cross_entropy(logits=logits,labels=jnp.array([0,1],dtype=jnp.float32)).mean()
    return loss

def compute_lce_2(logits, pos):
    if pos:
        loss = optax.softmax_cross_entropy(logits=logits,labels=jnp.array([1,0],dtype=jnp.float32)).mean()
    else:
        loss = optax.softmax_cross_entropy(logits=logits,labels=jnp.array([0,1],dtype=jnp.float32)).mean()
    return loss

In [None]:
class InteractionBlock(nn.Module):
    @nn.compact
    def __call__(self,x):
        # print(x.shape)
        # Q,D = x
        Q = x[:,0]
        D = x[:,1:].reshape(1,3*512,768)
        # print(Q.shape)
        # print(D.shape)
        x = nn.MultiHeadDotProductAttention(num_heads=6)(inputs_q=Q,inputs_kv=D)
        x = nn.LayerNorm()(x+Q)
        x = nn.LayerNorm()(nn.SelfAttention(num_heads=6)(x)) + x
        x = nn.LayerNorm()(nn.Dense(features=Q.shape[-1])(x) + x)
        return x

class IB2(nn.Module):
    @nn.compact
    def __call__(self,x):
        # _, D = x
        D = x[:,1:]
        x = InteractionBlock()(x)
        x = InteractionBlock()(jnp.concatenate((jnp.expand_dims(x,axis=0),D),axis=1).reshape(1,4,512,768))
        cls_tok = x[:,0]
        x = nn.Dense(features=2)(cls_tok)
        return x


In [None]:
key = jax.random.PRNGKey(0)
IB = IB2()
params = IB.init(key, jnp.ones([1, 4, 512, 768]))

In [None]:
def compute_metrics(logits):
  pos, neg = logits
  loss = compute_lce_2(logits=pos, pos=1)
  loss += compute_lce_2(logits=neg, pos=0)
  pred = jnp.array([jnp.argmax(pos,-1), jnp.argmax(neg,-1)])
  accuracy = jnp.mean(pred.flatten() == jnp.array([1,0]))
  # print(f"accuracy is {accuracy}\tloss is {loss}")
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }

  return metrics

# @jax.jit
def eval_step(params, batch):
  pos, neg = batch
  pos_logits = IB2().apply({'params': params}, pos)
  neg_logits = IB2().apply({'params': params}, neg)

  return compute_metrics(logits=(pos_logits,neg_logits))
  
def eval_model(params, batch):
  metrics = eval_step(params, batch)
  # metrics = jax.device_get(metrics)
  # summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
  return metrics

In [None]:
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  # query_enc = modeling_bart.FlaxBartMoresRanker.from_pretrained('facebook/bart-base')
  # document_enc = modeling_bart.FlaxBartMoresRanker.from_pretrained('facebook/bart-base')
  IB = IB2()
  params = IB.init(rng, jnp.ones([1, 4, 512, 768]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=IB.apply, params=params, tx=tx)

In [None]:
key = jax.random.PRNGKey(0)
state = create_train_state(key, 2e-4, 0.9)
del key

In [None]:
# @jax.jit
def train_step(state, batch, label):
  """Train for a single step."""
  grad_fn = jax.grad(compute_lce, has_aux=False)
  grads = grad_fn(state.params, batch, label)
  state = state.apply_gradients(grads=grads)
  return state

def train_epoch(state, batch):
  """Train for a single epoch."""
  # compute mean of metrics across each batch in epoch.
  pos,neg = batch
  # pos train
  state = train_step(state, pos, label=1)
  state = train_step(state, neg ,label=0)
   
  return state

In [None]:
num_epoch = 10
for epoch in range(1, num_epoch + 1):
    accuracy = []
    loss = 0
    acc = 0
    for pos,neg in train_dataset:
        query = query_model(jnp.expand_dims(pos[0],axis=0))[1]
        pos_emb = document_model(pos[1])[1]
        neg_emb = document_model(neg[1])[1]
        pos_emb = jnp.expand_dims(jnp.concatenate((query, pos_emb),axis=0),axis=0)
        neg_emb = jnp.expand_dims(jnp.concatenate((query, neg_emb),axis=0),axis=0)

        state = train_epoch(state, (pos_emb,neg_emb))
        metrics = eval_model(state.params, (pos_emb,neg_emb))
        acc += metrics['accuracy'].item()
        loss += metrics['loss'].item()

    print(f' test epoch: {epoch}, loss: {loss/len(train_dataset)}, accuracy: {acc/len(train_dataset) * 100}')