In [5]:
import os
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from pathlib import Path
from x_transformers import Decoder, Encoder

In [6]:
from transformers import (
    RobertaTokenizer,
    RobertaModel,
    RobertaConfig
    
)

In [7]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [8]:

MAX_TOKENS = 256
pad_token_id = 1 

In [9]:
test_path  = Path('../input/AI4Code/test')
test_nb = sorted(os.listdir(test_path))

In [10]:
def read_notebook_inference(nb):
    return (
        pd.read_json(
            test_path / nb,
            dtype={'cell_type': 'category', 'source': 'str'})
        .assign(id=nb[:-5])
        .rename_axis('cell_id')
    )

In [11]:
notebooks = [read_notebook_inference(nb) for nb in test_nb]
df = pd.concat(notebooks).reset_index()
df=df.set_index(['id','cell_id'])
total_cells_by_notebook = df['cell_type'].groupby(by='id').count()
code_cells_by_notebook = df[df['cell_type']=='code']['cell_type'].groupby(by='id').count()
md_cells_by_notebook=total_cells_by_notebook-code_cells_by_notebook

In [12]:
ids_test = df.index.unique('id')

In [13]:
def get_ids(notebook_id):
    data = df.loc[notebook_id]       
    return data.index[data['cell_type']=='code'].values,data.index[data['cell_type']=='markdown'].values



In [14]:
class TokensDataset(Dataset):
    def __init__(self, ids__):
        self.tokenizer = RobertaTokenizer.from_pretrained('../input/codebert-b')
        self.ids__ = ids__
    def __len__(self):
        return len(self.ids__)
            
    def __getitem__(self, idx):
        
        notebook_id = self.ids__[idx]
        
        
        data = df.loc[notebook_id] 
        cells_ids = data.index[data['cell_type']=='code'].values
        code_cells = data.source[cells_ids].tolist()
        md_ids = data.index[data['cell_type']=='markdown'].values
        mark_cells = data.source[md_ids].tolist()

        _code_tokens = self.tokenizer(code_cells, padding='max_length',truncation=True, 
                                                       max_length = MAX_TOKENS, return_tensors='pt')['input_ids']
        _md_tokens = self.tokenizer(mark_cells, padding='max_length',truncation=True, 
                                                       max_length = MAX_TOKENS, return_tensors='pt')['input_ids']
        n_c = len(_code_tokens)
        
        n_md = len(_md_tokens)
        MAX_CELLS = max(n_c,n_md)
        code_tokens = torch.full((MAX_CELLS,MAX_TOKENS),pad_token_id,dtype=torch.int64)
        md_tokens   = torch.full((MAX_CELLS,MAX_TOKENS),pad_token_id,dtype=torch.int64)
        
        code_pos = torch.arange(n_c) 
        _code_tokens = _code_tokens[code_pos]
        
        code_pos_pct = torch.full((MAX_CELLS,),-1.)
        code_pos_pct[:n_c] = (code_pos.float()/(n_c-1)).nan_to_num(posinf=0) 
        

        md_pos = torch.arange(n_md) 
        _md_tokens = _md_tokens[md_pos]
        
        code_ids, md_ids = get_ids(notebook_id)
        code_ids = code_ids[code_pos.tolist()]
        md_ids   = md_ids[md_pos.tolist()]
        
        code_tokens[:n_c,:_code_tokens.shape[1]] = _code_tokens
        md_tokens  [:n_md,:_md_tokens.shape[1]]     = _md_tokens
            
        return code_pos_pct, code_tokens, md_tokens, notebook_id
    
        
            

In [15]:
test_ds = TokensDataset(ids_test)
test_dl = DataLoader(test_ds, 1, shuffle=False, num_workers=0)

In [16]:
test_ds[0]

(tensor([0.0000, 0.1667, 0.3333, 0.5000, 0.6667, 0.8333, 1.0000]),
 tensor([[    0, 41975,   295,  ...,     1,     1,     1],
         [    0, 36807,  5457,  ...,     1,     1,     1],
         [    0,   282, 26229,  ...,     1,     1,     1],
         ...,
         [    0, 10431,  3037,  ...,  1297,    22,     2],
         [    0, 10431,  2741,  ...,     1,     1,     1],
         [    0, 10431,   155,  ...,     1,     1,     1]]),
 tensor([[    0, 10431,  2741,  ...,     1,     1,     1],
         [    0, 48342, 25980,  ...,     1,     1,     1],
         [    0, 48342, 39154,  ...,     1,     1,     1],
         ...,
         [    0, 10431, 20891,  ...,     1,     1,     1],
         [    0, 10431,  4913,  ...,     1,     1,     1],
         [    1,     1,     1,  ...,     1,     1,     1]]),
 '0009d135ece78d')

In [17]:
code_bert=RobertaModel.from_pretrained('../input/codebert-b')

In [18]:
import math
def positionalencoding1d(positions_pct, d_model):
    """
    :param d_model: dimension of the model
    :param length: length of positions
    :return: length*d_model position matrix
    """
    device,dtype = positions_pct.device, positions_pct.dtype
    if d_model % 2 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dim (got dim={:d})".format(d_model))
    pe = torch.empty((*positions_pct.shape[:2], d_model),device=device,dtype=dtype)
    position = (positions_pct * 50.).flatten(0).unsqueeze(1)
    div_term = torch.exp((torch.arange(0, d_model, 2, device=device,dtype=dtype) *
                         -(math.log(10000.0) / d_model)))
    pe[..., 0::2] = torch.sin(position * div_term).view(*positions_pct.shape[:2],-1)
    pe[..., 1::2] = torch.cos(position * div_term).view(*positions_pct.shape[:2],-1)

    return pe

In [19]:
class Orderer(nn.Module):
    def __init__(self):
        super().__init__()
                
        encoder_dim = codebert_dim = 768 
        
        self.codebert_proj = nn.Linear(codebert_dim,encoder_dim,bias=False) \
            if codebert_dim!=encoder_dim else nn.Identity()
        
        self.encoder = Encoder(
            dim = encoder_dim,
            depth = 6,
            heads = 8,
            attn_num_mem_kv = 16,
            rotary_pos_emb = False,
            #ff_glu = True,
        )
        
        self.decoder = Encoder(
            dim = encoder_dim,
            depth = 8,
            heads = 8,
            attn_num_mem_kv = 16,
            rotary_pos_emb = False,
            cross_attend = True,
            #ff_glu = True,
        )

        #self.code_regressor = nn.Sequential(nn.Linear(encoder_dim, encoder_dim, bias = True), nn.GLU(),
        #                                    nn.Linear(encoder_dim//2, 2, bias = True), nn.GLU())
        self.md_regressor   = nn.Sequential(nn.Linear(encoder_dim, 2, bias = True), nn.GLU())
        
        self.code_bert = code_bert
        
    def forward(self, code_pos_pct, code_tokens, md_tokens):
        bs = code_tokens.shape[0]
        device = code_tokens.device
        
        code_mask = code_tokens != pad_token_id
        x_code = self.code_bert(code_tokens.flatten(0,1),
                                attention_mask=code_mask.flatten(0,1).to(device))\
                .last_hidden_state.view(*code_tokens.shape[:3],-1)[:,:,0,:]
        x_code = self.codebert_proj(x_code)
        
        md_mask = md_tokens != pad_token_id
        x_md = self.code_bert(md_tokens.flatten(0,1),
                                attention_mask=md_mask.flatten(0,1).to(device))\
                .last_hidden_state.view(*md_tokens.shape[:3],-1)[:,:,0,:]
        x_md = self.codebert_proj(x_md)
        
        m_code = ~(code_tokens == pad_token_id).all(dim=-1)
        m_md   = ~(md_tokens == pad_token_id).all(dim=-1)
                
        nc = m_code.sum(dim=1)
        nm = m_md.sum(dim=1)
        
        #print(code_pos_pct)
                
        x_code += positionalencoding1d(code_pos_pct,x_code.shape[-1])
        
        x_code = self.encoder(x_code, mask = m_code)
        x_md   = self.decoder(x_md, context = x_code, mask = m_md, context_mask = m_code)
                
        encoded_all = (nc+nm).view(bs,1,1) * torch.cat((
            code_pos_pct.unsqueeze(1),
         #   self.code_regressor(x_code).sigmoid().view(bs,1,-1),
            -0.2 + 1.4 * self.md_regressor(x_md).sigmoid().view(bs,1,-1)
        ),1)
                
        return encoded_all, torch.stack((m_code,m_md),1)

In [20]:
def preds_to_ids(preds,masks,notebook_ids):
    preds[~masks]=np.inf
    orders=preds.detach().flatten(1).argsort().argsort().view(*preds.shape).cpu().numpy()
    predictions = {}
    for notebook_id,order, mask in zip(notebook_ids,orders,masks):
        code_ids,md_ids=get_ids(notebook_id)
        nc,nm=len(code_ids),len(md_ids)
        code_order,md_order=order[0],order[1]
        pred=np.full((nc+nm),None,dtype=object)
        pred[code_order[:nc]] =  code_ids
        pred[md_order[:nm]] =  md_ids
        predictions[notebook_id] = ' '.join(pred.tolist())
    return predictions

In [21]:
model = Orderer() 
model.load_state_dict({k[6:]:v for k,v in torch.load(ModelPath, map_location='cpu')['state_dict'].items()})

<All keys matched successfully>

In [22]:
d ={}
model = model.to(device)
model.eval()
#model.half()
with torch.no_grad():
    for code_pos_pct, code_tokens, md_tokens, notebook_id in test_dl:
        code_pos_pct, code_tokens, md_tokens = code_pos_pct.to(device), code_tokens.to(device), md_tokens.to(device)
        #code_pos_pct, code_tokens, md_tokens = code_pos_pct.half().to(device), code_tokens.to(device), md_tokens.to(device)
        output,masking = model(code_pos_pct, code_tokens, md_tokens)
        d.update(preds_to_ids(output, masking, notebook_id))

In [23]:
final =pd.DataFrame.from_dict(d, orient='index').reset_index().rename(columns={'index':'id', 0:'cell_order'})
final.to_csv('submission.csv', index=False)
final

Unnamed: 0,id,cell_order
0,0009d135ece78d,0a226b6a ddfd239c 8cb8d28a c6cd22db 1372ae9b e...
1,0010483c12ba9b,7f270e34 54c7cab3 fe66203e 7844d5f8 5ce8863c 4...
2,0010a919d60e4f,23607d04 b7578789 bbff12d4 aafc3d23 89b1fdd2 8...
3,0028856e09c5b7,eb293dfc 012c9d02 d22526d1 3ae7ece3
