In [None]:
import argparse, datetime, sys, os
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default = 1e-4)
parser.add_argument('--bs', type=int, default = 2)
parser.add_argument('--acc_grad', type=int, default = 4)
parser.add_argument('--fold', type=int, default = 0)
parser.add_argument('--kfolds', type=int, default = 5)
parser.add_argument('--drop_last_layers', type=int, default = None)
parser.add_argument('--freeze_layers', type=int, default = None)
parser.add_argument('--freeze_embeddings', type=bool, default = False)
parser.add_argument('--enc_depth', type=int, default = 6)
parser.add_argument('--dec_depth', type=int, default = 6)
parser.add_argument('--dim', type=int, default = 768)
parser.add_argument('--aug_pct', type=float, default = 0.5)
parser.add_argument('--aug_dir', type=str, default = "tokens")
parser.add_argument('--data_dir', type=str, default = "data")
parser.add_argument('--filter', action='store_true')
parser.add_argument('--singlefold', action='store_true') #run 1 fold only (for quick testing)
parser.add_argument('--precision', type=int, default = 16)
parser.add_argument('--epochs', type=int, default = 20)
parser.add_argument('--max_tokens', type=int, default = 224)
parser.add_argument('--max_cells', type=int, default = 128)
parser.add_argument('--graph', type=bool, default = False)
parser.add_argument('--device',type=int, default = 0)
parser.add_argument('--workers', type=int, default = 4) # lower to 2 or 1 (or use --use_cache) if OOM (CPU)
parser.add_argument('--subset', type=int, default = None) #-1 all dataset, otherwise the number of samples
parser.add_argument('--test', type=str, default = None) # path to the .csv file for running inference
parser.add_argument('--load', type=str, default = None) 
parser.add_argument('--resume', action='store_true')
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--experiment', type=str, default = datetime.datetime.now().strftime("%Y%m%d%H%M") )
parser.add_argument('--results_dir', type=str, default = 'results')
parser.add_argument('--use_cache', action='store_true') # cache dataset to disk
parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")

args= parser.parse_args()
print(args)
original_cmd_line = list(sys.argv)

In [None]:
#noexport
cmdline = ['--results_dir', 'pruebas', '--singlefold','--kfolds','1000']
sys.argv = ['_'] + cmdline
args= parser.parse_args(cmdline)
print(args)

In [None]:
#noexport
os.environ['CUDA_VISIBLE_DEVICES']=''

In [None]:
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from tqdm import tqdm
from x_transformers import Decoder, Encoder
import fastcore.parallel
from pathlib import Path
from sklearn.model_selection import GroupShuffleSplit
from functools import partial
from methodtools import lru_cache
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar, ProgressBar

# Setting the seed
pl.seed_everything(42, workers = True)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

n_gpus = torch.cuda.device_count()

print(f"Number of workers: {args.workers}")
print(f"Number of gpus: {n_gpus}")

In [None]:
#noexport
## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.set()

In [None]:
from transformers import (
    AdamW,
    get_linear_schedule_with_warmup,
    RobertaTokenizer,
    RobertaModel,
)

In [None]:
p_data = Path(args.data_dir)
p_dest_folder = Path(args.results_dir) / f'{args.experiment}'
p_models_folder = p_dest_folder

In [None]:
MAX_CELLS  = args.max_cells
MAX_TOKENS = args.max_tokens
pad_token_id = 1 # CHANGE if no microsoft/codebert-base
codebert_dim = 768

In [None]:
import json

In [None]:
train_path = p_data / 'train'
def read_notebook(path):
    return (
        pd.read_json(
            path,
            dtype={'cell_type': 'category', 'source': 'str'})
        .assign(path.stem)
        .rename_axis('cell_id')
    )
def read_notebook_and_empty(path):
    with open(path) as json_file:
        data = json.load(json_file)
    d =  pd.DataFrame(data).assign(id=path.stem).rename_axis('cell_id')
    d['source']=""
    d['cell_type']=d['cell_type'].astype('category')
    return d

In [None]:
p_notebooks = Path('cache/notebooks-empty.feather')
if p_notebooks.exists():
    df = pd.read_feather(p_notebooks)
else:
    iterable = list(train_path.glob("*.json"))
    notebooks = fastcore.parallel.parallel(
        read_notebook_and_empty, iterable,n_workers=63,chunksize=len(iterable)//63,progress=True)
    df = pd.concat(notebooks).reset_index()
    df.to_feather(p_notebooks)

In [None]:
df=df.set_index(['id','cell_id'])

In [None]:
total_cells_by_notebook = df['cell_type'].groupby(by='id').count()
code_cells_by_notebook = df[df['cell_type']=='code']['cell_type'].groupby(by='id').count()
md_cells_by_notebook=total_cells_by_notebook-code_cells_by_notebook

In [None]:
args.kfolds=1000 #use as much training data as possible without changing code for final submission

In [None]:
df_ancestors = pd.read_csv(p_data /'train_ancestors.csv',  index_col='id')
splitter = GroupShuffleSplit(n_splits=args.kfolds, random_state=0,test_size=1/args.kfolds)
ids = df_ancestors.index.unique('id')
ancestors = df_ancestors.loc[ids, 'ancestor_id']
for _ in range(1+args.fold):
    ids_train, ids_valid = next(splitter.split(ids, groups=ancestors))
ids_train, ids_valid = ids[ids_train], ids[ids_valid]
print(f"Valid items: {len(ids_valid)} {len(ids_valid)/(len(ids_train)+len(ids_valid))*100:0.3f}% of total")

In [None]:
# filter notebooks MAX_CELLS
if args.filter:
    filter_cells = lambda xids:xids[(code_cells_by_notebook[xids.values]<=MAX_CELLS).combine(
        md_cells_by_notebook[xids.values]<=MAX_CELLS,lambda x,y:x and y)]
    ids_train=filter_cells(ids_train)
    ids_valid=filter_cells(ids_valid)

In [None]:
def read_df_order(path):
    return pd.read_csv(path,index_col='id',).squeeze("columns").str.split()  

In [None]:
p_aug = Path(args.aug_dir)

In [None]:
p_tokens = Path(f'microsoft_graphcodebert-base_tokens')

In [None]:
df_orders_aug = [read_df_order(p) for p in sorted(p_aug.glob("train_orders*.csv"))] if args.aug_pct else []

In [None]:
p_notebooks_aug = [Path(f'cache/notebooks{i}.feather') for i in range(len(df_orders_aug))] if args.aug_pct else []

In [None]:
#cache (locally)
df_aug = []
for i,_p_notebooks_aug in tqdm(enumerate(p_notebooks_aug)):
    if not _p_notebooks_aug.exists():
        print(f"Caching {_p_notebooks_aug}")
        tp = p_data / f'mut{i}' / 'train'
        iterable = list(tp.glob("*.json"))
        notebooks_aug = fastcore.parallel.parallel(
            read_notebook_and_empty, iterable,n_workers=63,chunksize=len(iterable)//63,progress=True)
        _df_aug = pd.concat(notebooks_aug).reset_index()
        _df_aug.to_feather(_p_notebooks_aug)
    else:
        print(f"Reading {_p_notebooks_aug}")
        _df_aug = pd.read_feather(_p_notebooks_aug)
    _df_aug=_df_aug.set_index(['id','cell_id'])
    df_aug.append(_df_aug)

In [None]:
p_tokens_aug = sorted(p_aug.glob("microsoft_graphcodebert-base_tokens*")) if args.aug_pct else []

In [None]:
if args.aug_pct:
    assert len(p_tokens_aug) == len(df_orders_aug)

In [None]:
if args.test is None:
    df_orders = read_df_order(p_data / 'train_orders.csv')
else: 
    df_orders = None

In [None]:
def filter_df(_df_orders):
    return _df_orders[_df_orders.index.isin(np.concatenate((ids_train.values,ids_valid.values)))]

In [None]:
df_orders=filter_df(df_orders)

In [None]:
if args.aug_pct:
    df_orders_aug = [filter_df(_d) for _d in df_orders_aug]

In [None]:
#noexport
df.head()

In [None]:
ids_to_ints = np.vectorize(partial(int,base=16))
p_tokens = Path(f'microsoft_graphcodebert-base_tokens')

def get_ids(notebook_id,df):
    data = df.loc[notebook_id]       
    return data.index[data['cell_type']=='code'].values,data.index[data['cell_type']=='markdown'].values

def get_index_gt(gt, cc, md):
    xc= torch.full((MAX_CELLS,), -1)
    xm= torch.full((MAX_CELLS,), -1)
    d= {v:i for i, v in enumerate(gt)}
    c= torch.tensor([d[c] for c in cc])
    m= torch.tensor([d[m] for m in md])
        
    xc[:min(MAX_CELLS, len(c))]=c[:min(MAX_CELLS, len(c))]
    xm[:min(MAX_CELLS, len(m))]=m[:min(MAX_CELLS, len(m))]
        
    return torch.stack((xc,  xm))

class TokensDataset(Dataset):
    def __init__(self,idx,aug_pct):
        self.df_orders = df_orders.loc[idx]
        self.df_orders_aug = [_d.loc[idx] for _d in df_orders_aug]
        self.aug_pct = aug_pct

    def __len__(self):
        return len(self.df_orders)
    
    def __getitem__(self, idx):
        
        if torch.rand((1,)).item() < self.aug_pct:
            # augment
            aug_idx = torch.randint(len(self.df_orders_aug),(1,)).item()
            df_orders = self.df_orders_aug[aug_idx]
            pp_tokens = p_tokens_aug[aug_idx]
            dff = df_aug[aug_idx]
        else:
            # not augment
            df_orders = self.df_orders
            pp_tokens = p_tokens
            dff = df

        notebook_id = df_orders.index[idx]
    
        code_tokens = torch.full((MAX_CELLS,MAX_TOKENS),pad_token_id,dtype=torch.int64)
        md_tokens   = torch.full((MAX_CELLS,MAX_TOKENS),pad_token_id,dtype=torch.int64)
        
        _code_tokens = torch.from_numpy(np.load(str(pp_tokens / notebook_id)+"_code.npy"))[:,:MAX_TOKENS]
        
        n_c = len(_code_tokens)
        code_pos = torch.arange(n_c) if n_c <= MAX_CELLS else torch.concat((torch.IntTensor([0]),
                                 (torch.randperm(n_c-2)+1)[:MAX_CELLS-2],
                                 torch.IntTensor([n_c-1])))
        
        #to make sure no cheating
        code_pos = code_pos[torch.randperm(len(code_pos))]

        _code_tokens = _code_tokens[code_pos]
        
        code_pos_pct = torch.full((MAX_CELLS,),-1.)
        code_pos_pct[:n_c] = (code_pos.float()/(n_c-1)).nan_to_num(posinf=0) # set n_c = 1 with 0.
        
        # TODO: Deal with _md_tokens > MAX_CELLS in inference
        _md_tokens = torch.from_numpy(np.load(str(pp_tokens / notebook_id)+"_md.npy"))[:,:MAX_TOKENS]
        n_md = len(_md_tokens)
        md_pos = torch.randperm(n_md)[:MAX_CELLS]
        _md_tokens = _md_tokens[md_pos]
        
        
        # filter ids by sampled positions
        gt_ids = df_orders[notebook_id]
        code_ids, md_ids = get_ids(notebook_id,dff)
        code_ids = code_ids[code_pos.tolist()]
        md_ids   = md_ids[md_pos.tolist()]
        gt_ids = [i for i in gt_ids if (i in code_ids) or (i in md_ids)]
        index_gt = get_index_gt(gt_ids, code_ids, md_ids)
        
        code_tokens[:len(_code_tokens),:_code_tokens.shape[1]] = _code_tokens
        md_tokens  [:len(_md_tokens),:_md_tokens.shape[1]]     = _md_tokens
            
        return code_pos_pct, code_tokens, md_tokens, index_gt, notebook_id
    
train_ds,valid_ds = TokensDataset(ids_train,args.aug_pct),TokensDataset(ids_valid,0.)

In [None]:
#noexport
len(train_ds)

In [None]:
#noexport
train_dl = DataLoader(train_ds,args.bs,   shuffle=True,  num_workers=args.workers, pin_memory=True, drop_last=True)
valid_dl = DataLoader(valid_ds,args.bs*2, shuffle=False, num_workers=args.workers, pin_memory=True)

In [None]:
#noexport
code_pos_pct, code_tokens, md_tokens, index_gt, notebook_ids = next(iter(valid_dl))

In [None]:
code_bert=RobertaModel.from_pretrained("microsoft/graphcodebert-base")

In [None]:
assert pad_token_id == code_bert.config.pad_token_id

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self,*args):
        return args

In [None]:
if args.drop_last_layers:
    for l in range(len(code_bert.encoder.layer)-args.drop_last_layers,len(code_bert.encoder.layer)):
        code_bert.encoder.layer[l]=Identity()
        
modules = []
if args.freeze_layers:
    modules.append(code_bert.encoder.layer[:args.freeze_layers])
if args.freeze_embeddings:
    modules.append(code_bert.embeddings.word_embeddings)
for module in modules:
    for param in module.parameters():
        param.requires_grad =False
code_bert

In [None]:
import math
def positionalencoding1d(positions_pct, d_model):
    """
    :param d_model: dimension of the model
    :param length: length of positions
    :return: length*d_model position matrix
    """
    device,dtype = positions_pct.device, positions_pct.dtype
    if d_model % 2 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dim (got dim={:d})".format(d_model))
    pe = torch.empty((*positions_pct.shape[:2], d_model),device=device,dtype=dtype)
    position = (positions_pct * 50.).flatten(0).unsqueeze(1)
    div_term = torch.exp((torch.arange(0, d_model, 2, device=device,dtype=dtype) *
                         -(math.log(10000.0) / d_model)))
    pe[..., 0::2] = torch.sin(position * div_term).view(*positions_pct.shape[:2],-1)
    pe[..., 1::2] = torch.cos(position * div_term).view(*positions_pct.shape[:2],-1)

    return pe

In [None]:
class Orderer(nn.Module):
    def __init__(self):
        super().__init__()
                
        encoder_dim = args.dim 
        
        self.codebert_proj = nn.Linear(codebert_dim,encoder_dim,bias=False) \
            if codebert_dim!=encoder_dim else nn.Identity()
        
        self.encoder = Encoder(
            dim = encoder_dim,
            depth = args.enc_depth,
            heads = 8,
            attn_num_mem_kv = 16,
            rotary_pos_emb = False,
        )
        
        self.decoder = Encoder(
            dim = encoder_dim,
            depth = args.dec_depth,
            heads = 8,
            attn_num_mem_kv = 16,
            rotary_pos_emb = False,
            cross_attend = True,
        )

       
        self.md_regressor   = nn.Sequential(nn.Linear(encoder_dim, 2, bias = True), nn.GLU())
        
        self.code_bert = code_bert
        
    def forward(self, code_pos_pct, code_tokens, md_tokens):
        bs = code_tokens.shape[0]
        device = code_tokens.device
        
        code_mask = code_tokens != pad_token_id
        x_code = self.code_bert(code_tokens.flatten(0,1),
                                attention_mask=code_mask.flatten(0,1).to(device))\
                .last_hidden_state.view(*code_tokens.shape[:3],-1)[:,:,0,:]
        x_code = self.codebert_proj(x_code)
        
        md_mask = md_tokens != pad_token_id
        x_md = self.code_bert(md_tokens.flatten(0,1),
                                attention_mask=md_mask.flatten(0,1).to(device))\
                .last_hidden_state.view(*md_tokens.shape[:3],-1)[:,:,0,:]
        x_md = self.codebert_proj(x_md)
        
        m_code = ~(code_tokens == pad_token_id).all(dim=-1)
        m_md   = ~(md_tokens == pad_token_id).all(dim=-1)
                
        nc = m_code.sum(dim=1)
        nm = m_md.sum(dim=1)
        
                
        x_code += positionalencoding1d(code_pos_pct,x_code.shape[-1])
        
        x_code = self.encoder(x_code, mask = m_code)
        x_md   = self.decoder(x_md, context = x_code, mask = m_md, context_mask = m_code)
                
        encoded_all = (nc+nm).view(bs,1,1) * torch.cat((
            code_pos_pct.unsqueeze(1),
            -0.2 + 1.4 * self.md_regressor(x_md).sigmoid().view(bs,1,-1)
        ),1)
                
        return encoded_all

In [None]:
def preds_to_ids_tensor(preds,masks,notebook_ids):
    device = preds.device
    preds=preds.clone()
    preds[~masks]=np.inf
    orders=preds.detach().flatten(1).argsort().argsort().view(*preds.shape).cpu().numpy()
    print(f"preds_to_ids_tensor {orders}")
    predictions,ground_truth=[],[]
    for notebook_id,order,mask in zip(notebook_ids,orders,masks):
        code_ids,md_ids=get_ids(notebook_id)
        nc,nm=len(code_ids),len(md_ids)
        gt = df_orders[notebook_id] if notebook_id in df_orders else None
        assert (nc+nm) == len(gt)
        code_order,md_order=order[0],order[1]
        pred=torch.full((len(gt),),-1,dtype=torch.int64,device=device)
        pred[code_order[:nc]] =  torch.from_numpy(ids_to_ints(code_ids)).to(device)
        pred[md_order[:nm]] =  torch.from_numpy(ids_to_ints(md_ids)).to(device)
        predictions.append(pred)
        if gt: ground_truth.append(torch.from_numpy(ids_to_ints(gt)).to(device))
    return predictions,ground_truth

In [None]:
#noexport
m=Orderer()
preds = m(code_pos_pct, code_tokens, md_tokens)
print(preds[0],md_tokens[0],index_gt[0])


In [None]:
from torchsort import soft_rank

In [None]:
from bisect import bisect

def count_inversions(a):
    inversions = 0
    sorted_so_far = []
    for i, u in enumerate(a):
        j = bisect(sorted_so_far, u)
        inversions += i - j
        sorted_so_far.insert(j, u)
    return inversions

def kendall_tau(ground_truth, predictions):
    total_inversions = 0
    total_2max = 0  # twice the maximum possible inversions across all instances
    for gt, pred in zip(ground_truth, predictions):
        ranks = [gt.index(x) for x in pred]  # rank predicted order in terms of ground truth
        print(ranks)
        total_inversions += count_inversions_slowly(ranks)
        n = len(gt)
        total_2max += n * (n - 1)
    return 1 - 4 * total_inversions / total_2max

count_inversions_tensor = lambda r:(r.unsqueeze(0).t()>r).triu().sum()

def count_inversions_slowly(ranks):
    inversions = 0
    size = len(ranks)
    for i in range(size):
        for j in range(i+1, size):
            if ranks[i] > ranks[j]:
                inversions += 1
                
    return inversions


def kendall_tau_tensor(ground_truth, predictions):
    #print(len(ground_truth), len(predictions))
    total_inversions = 0
    total_2max = 0  # twice the maximum possible inversions across all instances
    for gt, pred in zip(ground_truth, predictions):
        ranks = torch.nonzero(pred[..., None] == gt)[:,1]
        total_inversions += count_inversions_tensor(ranks)
        n = len(gt)
        total_2max += n * (n - 1)
    return 1 - 4 * total_inversions / total_2max

def soft_kendall_tau_tensor(ground_truth, predictions):
    #print(len(ground_truth), len(predictions))
    total_inversions = 0
    total_2max = 0  # twice the maximum possible inversions across all instances
    for gt, pred in zip(ground_truth, predictions):
        rank_gt = gt.argsort().argsort()
        ranks = soft_rank(pred.unsqueeze(0),regularization_strength=1.)[0][rank_gt]-1
        total_inversions += (ranks.unsqueeze(0).t()-ranks).triu().sum()
        print(total_inversions)
        n = len(gt)
        total_2max += n * (n - 1)
    return 4 * total_inversions / total_2max

In [None]:
import torchmetrics
from typing import Any, List, Optional

class KendallTau(torchmetrics.Metric):
    # Set to True if the metric during 'update' requires access to the global metric
    # state for its calculations. If not, setting this to False indicates that all
    # batch states are independent and we will optimize the runtime of 'forward'
    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False
    inversions: torch.LongTensor
    total_2max: torch.LongTensor
        
    def __init__(self):
        super().__init__()
        self.add_state("inversions", default=torch.LongTensor([0]), dist_reduce_fx="sum")
        self.add_state("total_2max", default=torch.LongTensor([0]), dist_reduce_fx="sum")

    def update_slowly(self, preds: list, target: list):
        for gt, pred in zip(target, preds):
            ranks = torch.nonzero(pred[..., None] == gt)[:,1]
            #print(f"update: {ranks}")
            self.inversions += count_inversions_slowly(ranks)
            n = len(gt)
            self.total_2max += n * (n-1)
                    
    def update(self,preds,index_gt):
        masks = index_gt != -1
        _preds=preds.clone()
        _index_gt = index_gt.clone()
        m = masks.flatten(1)
        _index_gt[~masks] = MAX_CELLS*2
        _preds[~masks] = np.inf
        pred_ranks = _preds.flatten(1).argsort(dim=1)
        ranks=torch.gather(_index_gt.flatten(1),1,pred_ranks)
        self.inversions += torch.triu((ranks.unsqueeze(1).permute(0,2,1)>ranks.unsqueeze(1))).sum()
        len_gt = masks.sum(dim=(1,2))
        self.total_2max +=((len_gt-1)*len_gt).sum()

    def compute(self):
        return  1. - 4. * self.inversions /  self.total_2max

In [None]:
train_dl = DataLoader(train_ds,args.bs, shuffle=True,  num_workers=args.workers, pin_memory=True, drop_last=True)
valid_dl = DataLoader(valid_ds,args.bs, shuffle=False, num_workers=args.workers, pin_memory=True)

In [None]:
#noexport
example = next(iter(train_dl))
example

In [None]:
def preds_to_ids(preds,masks,notebook_ids):
    preds[~masks]=np.inf
    orders=preds.detach().flatten(1).argsort().argsort().view(*preds.shape).cpu().numpy()
    predictions,ground_truth=[],[]
    for notebook_id,order,mask in zip(notebook_ids,orders,masks):
        code_ids,md_ids=get_ids(notebook_id)
        nc,nm=len(code_ids),len(md_ids)
        gt = df_orders[notebook_id] if notebook_id in df_orders else None
        assert (nc+nm) == len(gt)
        code_order,md_order=order[0],order[1]
        pred=np.full((len(gt)),None,dtype=object)
        pred[code_order[:nc]] =  code_ids
        pred[md_order[:nm]] =  md_ids
        predictions.append(pred.tolist())
        if gt: ground_truth.append(gt)
    return predictions,ground_truth

In [None]:
#noexport
code_pos_pct, code_tokens, md_tokens,index_gt,notebook_ids = example
m=Orderer()
preds = m(code_pos_pct, code_tokens, md_tokens)
print(preds,index_gt[0])


In [None]:
#noexport
k=KendallTau()


In [None]:
def nannorm(t,**kw): return (t*t).nansum(**kw).sqrt()

def spearman_rho(pred, target, **kw):
    pred = soft_rank(pred, **kw)
    target = soft_rank(target, **kw)
    pred = pred - pred.mean()
    pred = pred / pred.norm()
    target = target - target.mean()
    target = target / target.norm()
    return (pred * target).sum()

def spearman_rho_loss(pred, target, **kw):
    loss = 0.
    for p,t in zip(pred,target):
        mask = t == -1
        loss += spearmanr_(p[~mask].unsqueeze(0), t[~mask].unsqueeze(0), **kw)
    return 1. - loss/pred.shape[0]

max_non_inversions = \
    [(torch.arange(n).unsqueeze(0).t()-torch.arange(n)).triu().sum().item() for n in range(2*MAX_CELLS)]

# https://en.wikipedia.org/wiki/Rank_correlation#Spearman%E2%80%99s_%CF%81_as_a_particular_case
def spearman_p_loss(pred, target_ranks, **kw):
    d2 = 0.
    n = 0.
    for p,t in zip(pred.flatten(1),target_ranks.flatten(1)):
        mask = t != -1
        rp = soft_rank(p[mask].unsqueeze(0), **kw) - 1.
        rt = t[mask].unsqueeze(0)
        d2 += (rt-rp).square().sum()
        n  += mask.sum()
    return 4. * d2/ ((n*(n-1)))

In [None]:
class OrdererModule(pl.LightningModule):
    def __init__(self,
                 lr: float = 2e-5,
                 adam_epsilon: float = 1e-8,
                 weight_decay: float = 0.01,
                ):

        super().__init__()
        # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
        self.save_hyperparameters()
        
        # Create model
        self.model = Orderer()
        # Create loss module
        self.loss_module = spearman_p_loss
        
        #self.mae = torchmetrics.MeanAbsoluteError()
        self.kendall_tau = KendallTau()
        
        self.batch_size = args.bs
        self.automatic_optimization = True        
        
    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() 
                           if ('code_bert' in n) and not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
                "lr" : self.hparams.lr,
            },
            {
                "params": [p for n, p in model.named_parameters() 
                           if ('code_bert' in n) and  any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "lr" : self.hparams.lr,

            },
            {
                "params": [p for n, p in model.named_parameters() 
                           if ('code_bert' not in n) and not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
                "lr" : 1. * self.hparams.lr,

            },
            {
                "params": [p for n, p in model.named_parameters() 
                           if ('code_bert' not in n) and  any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "lr" : 1. * self.hparams.lr,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.lr, eps=self.hparams.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0.05 * self.trainer.estimated_stepping_batches,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        
        print(optimizer,scheduler)
        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]    
    
    def forward(self, *imgs):
        # Forward function that is run when visualizing the graph
        #print(imgs.shape)
        return self.model(*imgs)    
           
    def training_step(self, batch, batch_idx):
        # "batch" is the output of the training data loader.
        code_pos_pct, code_tokens, md_tokens,index_gt, notebook_ids = batch
        
        
        preds = self.model(code_pos_pct, code_tokens, md_tokens)
        loss = self.loss_module(preds, index_gt)
        
        self.log('train_loss', loss.detach(), on_step=True, on_epoch=True, sync_dist=True)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        code_pos_pct, code_tokens, md_tokens, index_gt, notebook_ids = batch
        
        preds = self.model(code_pos_pct, code_tokens, md_tokens)
        loss = self.loss_module(preds, index_gt)
        
        self.kendall_tau.update(preds.detach(),index_gt)
        
        self.log('val_loss', loss.detach(), on_step=True, on_epoch=True, sync_dist=True)

        
    def validation_epoch_end(self, outputs):
        kendall_tau=self.kendall_tau.compute()
        print(kendall_tau)
        self.log('val_kendall_tau', kendall_tau,sync_dist=True,prog_bar=True)
        self.kendall_tau.reset()
        
    def test_step(self, batch, batch_idx):
        code_pos_pct, code_tokens, md_tokens,index_gt,notebook_ids = batch        
        
        preds = self.model(code_pos_pct, code_tokens, md_tokens)
                
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        code_pos_pct, code_tokens, md_tokens,index_gt,notebook_ids = batch
        
        mask = labels != -1
        preds = self.model(code_pos_pct, code_tokens, md_tokens)
        
        return  preds
    
    def training_epoch_end(self, outputs):
        sch = self.lr_schedulers()

        # If the selected scheduler is a ReduceLROnPlateau scheduler.
        if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
            sch.step(self.trainer.callback_metrics["train_loss"])
        
        

In [None]:
class LitProgressBar(TQDMProgressBar):

    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.set_description('running validation ...')
        return bar

bar = LitProgressBar()

In [None]:
from pytorch_lightning.profiler import AdvancedProfiler
ap = AdvancedProfiler(filename = 'advanced_profiler')

In [None]:
if args.load and not args.resume:
    module = OrdererModule.load_from_checkpoint(args.load,lr=args.lr)
else:
    module = OrdererModule(lr=args.lr)

In [None]:
checkpointfilename = f'{args.experiment}-{args.fold+1}of{args.kfolds}-' + \
    '{epoch}-{train_loss:.2f}-{val_loss:.2f}-{val_kendall_tau:.4f}'

modelcheckpoint = ModelCheckpoint( save_weights_only=False, mode="max", monitor="val_kendall_tau", 
                                  save_last=True,save_top_k=-1,
                                  filename= checkpointfilename,verbose=True)

modelcheckpoint.CHECKPOINT_NAME_LAST=checkpointfilename

In [None]:
trainer = pl.Trainer(logger=None,
    precision=args.precision,
    accelerator = 'gpu',
    strategy=pl.strategies.DDPStrategy(find_unused_parameters=True) if n_gpus >1 else None,
    gpus = n_gpus,
    resume_from_checkpoint=args.load if args.resume else None,
    max_epochs= args.epochs, 
    deterministic = False, 
    callbacks=[modelcheckpoint,LearningRateMonitor("step"),],
    accumulate_grad_batches=args.acc_grad,
    )             

In [None]:
trainer.fit(module, train_dataloaders=train_dl, val_dataloaders=valid_dl)

In [None]:
#noexport
sys.exit()