In [None]:
GRAPH_MODEL_NAME = "../input/ai4code-train-ds/graph-model-100k-ncs2.bin"
UNIX_MODEL_NAME = "../input/ai4code-train-ds/model-epoch1.5.bin"
GRAPH_WEIGHT = 0.4
MODEL_MAX_LEN = 512

In [None]:
# common.py:

from dataclasses import dataclass
import time
from pathlib import Path
from tqdm import tqdm
import os
import pandas as pd  # data processing, CSV file I/O (e.g. pd.read_csv)
import numpy as np  # linear algebra
import torch
import math

pd.options.display.width = 180
pd.options.display.max_colwidth = 120


def is_interactive_mode():
    return os.environ.get('KAGGLE_KERNEL_RUN_TYPE', 'Interactive') == 'Interactive'


def read_notebook(path):
    return (
        pd.read_json(
            path,
            dtype={'cell_type': 'category', 'source': 'str'})
        .assign(id=path.stem)
        .rename_axis('cell_id')
    )


def save_model(model, suffix):
    output_dir = Path(".")
    model_to_save = model.encoder.model
    output_dir = os.path.join(output_dir, 'model-{}.bin'.format(suffix))
    torch.save(model_to_save.state_dict(), output_dir)
    print("Saved model to {}".format(output_dir))


def save_roberta_model(model, suffix):
    output_dir = Path(".")
    output_dir = os.path.join(
        output_dir, 'roberta-model-{}.bin'.format(suffix))
    torch.save(model.state_dict(), output_dir)
    print("Saved model to {}".format(output_dir))


def get_code_cells(nb):
    return nb[nb['cell_type'] == 'code'].index


def get_markdown_cells(nb):
    return nb[nb['cell_type'] == 'markdown'].index


def split_into_batches(lst, batch_size):
    num_chunks = (len(lst) + batch_size - 1) // batch_size
    return list(np.array_split(lst, num_chunks))


def sim(emb1, emb2):
    return torch.einsum("i,i->", emb1, emb2).detach().numpy()


def get_probs_by_embeddings(embeddings, m_cell_id, code_cell_ids, coef_mul):
    markdown_emb = embeddings[m_cell_id]
    sims = [sim(markdown_emb, embeddings[c]) for c in code_cell_ids]
    max_sim = max(sims)
    sims_probs = list(map(lambda x: math.exp((x-max_sim) * coef_mul), sims))
    sum_probs = sum(sims_probs)
    sims_probs = list(map(lambda x: x/sum_probs, sims_probs))
    return sims_probs


def get_best_pos_by_probs(probs):
    scores = [0.0] * len(probs)
    for i in range(len(probs)):
        for j in range(len(probs)):
            scores[j] += abs(i - j) * probs[i]
    return scores.index(min(scores))


@dataclass
class OneCell:
    score: float
    cell_id: str
    cell_type: str

end_token = 'END'

In [None]:
# config.py:

from dataclasses import dataclass
from pathlib import Path
import os


@dataclass
class Config:
    data_dir: Path
    unixcoder_model_path: str
    wandb_key: str
    batch_size: int
    batch_size_graph2: int
    cosine_minibatch_size: int
    cosine_batch_size: int
    use_simple_ensemble_model = True


def get_local_config():
    return Config(data_dir=Path('/home/borys/AI4Code/input/AI4Code'),
                  unixcoder_model_path='/home/borys/AI4Code/input/unixcoderbase',
                  wandb_key='/home/borys/wandb_key',
                  batch_size=2,
                  batch_size_graph2=2,
                  cosine_minibatch_size=2,
                  cosine_batch_size=4
                  )


def get_jarvis_config():
    return Config(data_dir=Path('/home/input/AI4Code'),
                  unixcoder_model_path='/home/unixcoderbase',
                  wandb_key='/home/wandb_key',
                  batch_size=60,
                  batch_size_graph2=30,
                  cosine_minibatch_size=8,
                  cosine_batch_size=60
                  )

def get_kaggle_config():
    return Config(data_dir=Path('../input/AI4Code'),
                  unixcoder_model_path='../input/unixcoderbase',
                  wandb_key='',
                  batch_size=30,
                  batch_size_graph2=30,
                  cosine_minibatch_size=8,
                  cosine_batch_size=30
                  )



def get_default_config():
    if os.getenv('LOGNAME') == 'borys':
        print('Get local config')
        return get_local_config()
    else:
        print('Get jarvis config')
        return get_jarvis_config()


In [None]:
from dataclasses import dataclass
import pandas as pd
from tqdm import tqdm
import torch


@dataclass
class State:
    # TODO: wrong types. How does it work? :)
    df_orders: list
    test_df: list
    df_ancestors: list
    all_train_nb: list
    all_validate_nb: list
    cur_train_nbs: list
    config: Config
    device: str

    def __init__(self, config: Config):
        self.config = config
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.load_df_orders()
        self.load_df_ancestors()

    def load_df_orders(self):
        self.df_orders = pd.read_csv(
            self.config.data_dir / 'train_orders.csv',
            index_col='id',
        ).squeeze("columns").str.split()  # Split the string representation of cell_ids into a list

    def load_test_nbs(self):
        paths_test = list((self.config.data_dir / 'test').glob('*.json'))
        notebooks_test = [
            read_notebook(path) for path in tqdm(paths_test, desc='Test NBs')
        ]
        self.test_df = (
            pd.concat(notebooks_test)
            .set_index('id', append=True)
            .swaplevel()
            .sort_index(level='id', sort_remaining=False)
        )

    def load_df_ancestors(self):
        self.df_ancestors = pd.read_csv(
            self.config.data_dir / 'train_ancestors.csv', index_col='id')

        # TODO: rewrite this to use the dataframe
        cnt_by_group = {}
        for id, row in tqdm(self.df_ancestors.iterrows()):
            cnt_by_group[row['ancestor_id']] = cnt_by_group.get(
                row['ancestor_id'], 0) + 1

        cnt = pd.Series(cnt_by_group)
        print('only one:', cnt[cnt == 1].count())
        cnt.plot.hist(grid=True, bins=20, rwidth=0.9,
                      color='#607c8e')
        cnt

        good_notebooks = []
        for id, row in tqdm(self.df_ancestors.iterrows()):
            if row['parent_id'] != None and cnt_by_group[row['ancestor_id']] == 1:
                good_notebooks.append(id)

        good_notebooks = pd.Series(good_notebooks)
        print('good notebooks', len(good_notebooks))

        self.all_train_nb = good_notebooks.sample(
            frac=0.9, random_state=787788)
        self.all_validate_nb = good_notebooks.drop(self.all_train_nb.index)

    def load_train_nbs_helper(self, ids):
        paths_train = [self.config.data_dir / 'train' /
                       '{}.json'.format(id) for id in ids]
        notebooks_train = [
            read_notebook(path) for path in tqdm(paths_train, desc='Train NBs')
        ]
        self.cur_train_nbs = pd.concat(notebooks_train).set_index(
            'id', append=True).swaplevel().sort_index(level='id', sort_remaining=False)

    def load_train_nbs(self, num: int):
        self.load_train_nbs_helper(self.all_train_nb.head(num))
    
    def load_train_nbs_range(self, from_: int, to_: int):
        self.load_train_nbs_helper(self.all_train_nb[from_:to_])

    def load_train_nbs_tail(self, num: int):
        self.load_train_nbs_helper(self.all_train_nb.tail(num))



In [None]:
# metric.py:

from dataclasses import dataclass
from bisect import bisect


# Actually O(N^2), but fast in practice for our data
def count_inversions(a):
    inversions = 0
    sorted_so_far = []
    for i, u in enumerate(a):  # O(N)
        j = bisect(sorted_so_far, u)  # O(log N)
        inversions += i - j
        sorted_so_far.insert(j, u)  # O(N)
    return inversions


@dataclass
class Score:
    cur_score: float
    total_inversions: int
    total_pairs: int

    def __init__(self, total_inversions, total_pairs):
        self.total_inversions = total_inversions
        self.total_pairs = total_pairs
        if total_pairs == 0:
            self.cur_score = 0.0
        else:
            self.cur_score = 1 - 4 * total_inversions / total_pairs

    def merge(a, b):
        return Score(a.total_inversions + b.total_inversions, a.total_pairs + b.total_pairs)


def kendall_tau_typed(ground_truth, predictions):
    total_inversions = 0  # total inversions in predicted ranks across all instances
    total_2max = 0  # maximum possible inversions across all instances
    for gt, pred in zip(ground_truth, predictions):
        # rank predicted order in terms of ground truth
        ranks = [gt.index(x) for x in pred]
        total_inversions += count_inversions(ranks)
        n = len(gt)
        total_2max += n * (n - 1)
    return Score(total_inversions=total_inversions, total_pairs=total_2max)


def kendall_tau(ground_truth, predictions):
    score = kendall_tau_typed(ground_truth, predictions)
    return [score.cur_score, score.total_inversions, score.total_pairs]


def sum_scores(a, b):
    total_inversions = a[1] + b[1]
    total_2max = a[2] + b[2]
    return [1 - 4 * total_inversions / total_2max, total_inversions, total_2max]


def calc_nb_score(my_order, correct_order):
    ground_truth = [correct_order]
    predictions = [my_order]

    return kendall_tau_typed(ground_truth, predictions)


In [None]:
# unixcoder.py:

# Copied from: https://github.com/microsoft/CodeBERT/blob/master/UniXcoder/unixcoder.py

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import numpy as np
import torch
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig
import os
from pathlib import Path


class UniXcoder(nn.Module):
    def __init__(self, model_name, state_dict=None):
        """
            Build UniXcoder.
            Parameters:
            * `model_name`- huggingface model card name. e.g. microsoft/unixcoder-base
        """
        super(UniXcoder, self).__init__()
        self.tokenizer = RobertaTokenizer.from_pretrained(
            model_name, use_fast=True)
        self.config = RobertaConfig.from_pretrained(model_name)
        self.config.is_decoder = True
        self.model = RobertaModel.from_pretrained(
            model_name, config=self.config)

        if state_dict is not None:
            self.model.load_state_dict(torch.load(state_dict))

        self.register_buffer("bias", torch.tril(torch.ones(
            (1024, 1024), dtype=torch.uint8)).view(1, 1024, 1024))
        self.lm_head = nn.Linear(
            self.config.hidden_size, self.config.vocab_size, bias=False)
        self.lm_head.weight = self.model.embeddings.word_embeddings.weight
        self.lsm = nn.LogSoftmax(dim=-1)

        self.tokenizer.add_tokens(["<mask0>"], special_tokens=True)
        #self.tokenizer.add_tokens(["<END>"], special_tokens=True)

    def tokenize(self, inputs, mode="<encoder-only>", max_length=512, padding=False):
        """ 
        Convert string to token ids 

        Parameters:
        * `inputs`- list of input strings.
        * `max_length`- The maximum total source sequence length after tokenization.
        * `padding`- whether to pad source sequence length to max_length. 
        * `mode`- which mode the sequence will use. i.e. <encoder-only>, <decoder-only>, <encoder-decoder>
        """
        assert mode in ["<encoder-only>",
                        "<decoder-only>", "<encoder-decoder>"]

        tokenizer = self.tokenizer

        tokens_ids = []
        for x in inputs:
            tokens = tokenizer.tokenize(x)
            if mode == "<encoder-only>":
                tokens = tokens[:max_length-4]
                tokens = [tokenizer.cls_token, mode,
                          tokenizer.sep_token] + tokens + [tokenizer.sep_token]
            elif mode == "<decoder-only>":
                tokens = tokens[-(max_length-3):]
                tokens = [tokenizer.cls_token, mode,
                          tokenizer.sep_token] + tokens
            else:
                tokens = tokens[:max_length-5]
                tokens = [tokenizer.cls_token, mode,
                          tokenizer.sep_token] + tokens + [tokenizer.sep_token]

            tokens_id = tokenizer.convert_tokens_to_ids(tokens)
            tokens_ids.append(tokens_id)

        if padding:
            cur_max_length = len(max(tokens_ids, key=len))
            tokens_ids = list(map(
                lambda l: l + [self.config.pad_token_id] * (cur_max_length-len(l)), tokens_ids))
        return tokens_ids

    def decode(self, source_ids):
        """ Convert token ids to string """
        predictions = []
        for x in source_ids:
            prediction = []
            for y in x:
                t = y.cpu().numpy()
                t = list(t)
                if 0 in t:
                    t = t[:t.index(0)]
                text = self.tokenizer.decode(
                    t, clean_up_tokenization_spaces=False)
                prediction.append(text)
            predictions.append(prediction)
        return predictions

    def forward(self, source_ids):
        """ Obtain token embeddings and sentence embeddings """
        mask = source_ids.ne(self.config.pad_token_id)
        token_embeddings = self.model(
            source_ids, attention_mask=mask.unsqueeze(1) * mask.unsqueeze(2))[0]
        sentence_embeddings = (
            token_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1)
        return token_embeddings, sentence_embeddings

    def generate(self, source_ids, decoder_only=True, eos_id=None, beam_size=5, max_length=64):
        """ Generate sequence given context (source_ids) """

        # Set encoder mask attention matrix: bidirectional for <encoder-decoder>, unirectional for <decoder-only>
        if decoder_only:
            mask = self.bias[:, :source_ids.size(-1), :source_ids.size(-1)]
        else:
            mask = source_ids.ne(self.config.pad_token_id)
            mask = mask.unsqueeze(1) * mask.unsqueeze(2)

        if eos_id is None:
            eos_id = self.config.eos_token_id

        device = source_ids.device

        # Decoding using beam search
        preds = []
        zero = torch.LongTensor(1).fill_(0).to(device)
        source_len = list(source_ids.ne(1).sum(-1).cpu().numpy())
        length = source_ids.size(-1)
        encoder_output = self.model(source_ids, attention_mask=mask)
        for i in range(source_ids.shape[0]):
            context = [[x[i:i+1, :, :source_len[i]].repeat(beam_size, 1, 1, 1) for x in y]
                       for y in encoder_output.past_key_values]
            beam = Beam(beam_size, eos_id, device)
            input_ids = beam.getCurrentState().clone()
            context_ids = source_ids[i:i+1,
                                     :source_len[i]].repeat(beam_size, 1)
            out = encoder_output.last_hidden_state[i:i +
                                                   1, :source_len[i]].repeat(beam_size, 1, 1)
            for _ in range(max_length):
                if beam.done():
                    break
                if _ == 0:
                    hidden_states = out[:, -1, :]
                    out = self.lsm(self.lm_head(hidden_states)).data
                    beam.advance(out)
                    input_ids.data.copy_(input_ids.data.index_select(
                        0, beam.getCurrentOrigin()))
                    input_ids = beam.getCurrentState().clone()
                else:
                    length = context_ids.size(-1)+input_ids.size(-1)
                    out = self.model(input_ids, attention_mask=self.bias[:, context_ids.size(-1):length, :length],
                                     past_key_values=context).last_hidden_state
                    hidden_states = out[:, -1, :]
                    out = self.lsm(self.lm_head(hidden_states)).data
                    beam.advance(out)
                    input_ids.data.copy_(input_ids.data.index_select(
                        0, beam.getCurrentOrigin()))
                    input_ids = torch.cat(
                        (input_ids, beam.getCurrentState().clone()), -1)
            hyp = beam.getHyp(beam.getFinal())
            pred = beam.buildTargetTokens(hyp)[:beam_size]
            pred = [torch.cat([x.view(-1) for x in p]+[zero]
                              * (max_length-len(p))).view(1, -1) for p in pred]
            preds.append(torch.cat(pred, 0).unsqueeze(0))

        preds = torch.cat(preds, 0)

        return preds


class Beam(object):
    def __init__(self, size, eos, device):
        self.size = size
        self.device = device
        # The score for each translation on the beam.
        self.scores = torch.FloatTensor(size).zero_().to(device)
        # The backpointers at each time-step.
        self.prevKs = []
        # The outputs at each time-step.
        self.nextYs = [torch.LongTensor(size).fill_(0).to(device)]
        # Has EOS topped the beam yet.
        self._eos = eos
        self.eosTop = False
        # Time and k pair for finished.
        self.finished = []

    def getCurrentState(self):
        "Get the outputs for the current timestep."
        batch = self.nextYs[-1].view(-1, 1)
        return batch

    def getCurrentOrigin(self):
        "Get the backpointers for the current timestep."
        return self.prevKs[-1]

    def advance(self, wordLk):
        """
        Given prob over words for every last beam `wordLk` and attention
        `attnOut`: Compute and update the beam search.
        Parameters:
        * `wordLk`- probs of advancing from the last step (K x words)
        * `attnOut`- attention at the last step
        Returns: True if beam search is complete.
        """
        numWords = wordLk.size(1)

        # Sum the previous scores.
        if len(self.prevKs) > 0:
            beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)

            # Don't let EOS have children.
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] == self._eos:
                    beamLk[i] = -1e20
        else:
            beamLk = wordLk[0]
        flatBeamLk = beamLk.view(-1)
        bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)

        self.scores = bestScores

        # bestScoresId is flattened beam x word array, so calculate which
        # word and beam each score came from
        prevK = bestScoresId // numWords
        self.prevKs.append(prevK)
        self.nextYs.append((bestScoresId - prevK * numWords))

        for i in range(self.nextYs[-1].size(0)):
            if self.nextYs[-1][i] == self._eos:
                s = self.scores[i]
                self.finished.append((s, len(self.nextYs) - 1, i))

        # End condition is when top-of-beam is EOS and no global score.
        if self.nextYs[-1][0] == self._eos:
            self.eosTop = True

    def done(self):
        return self.eosTop and len(self.finished) >= self.size

    def getFinal(self):
        if len(self.finished) == 0:
            self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
        self.finished.sort(key=lambda a: -a[0])
        if len(self.finished) != self.size:
            unfinished = []
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] != self._eos:
                    s = self.scores[i]
                    unfinished.append((s, len(self.nextYs) - 1, i))
            unfinished.sort(key=lambda a: -a[0])
            self.finished += unfinished[:self.size-len(self.finished)]
        return self.finished[:self.size]

    def getHyp(self, beam_res):
        """
        Walk back to construct the full hypothesis.
        """
        hyps = []
        for _, timestep, k in beam_res:
            hyp = []
            for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
                hyp.append(self.nextYs[j+1][k])
                k = self.prevKs[j][k]
            hyps.append(hyp[::-1])
        return hyps

    def buildTargetTokens(self, preds):
        sentence = []
        for pred in preds:
            tokens = []
            for tok in pred:
                if tok == self._eos:
                    break
                tokens.append(tok)
            sentence.append(tokens)
        return sentence


# Partially coied from: https://github.com/microsoft/CodeBERT/blob/567dd49a4b916835f93fb95709de714b8772fea2/UniXcoder/downstream-tasks/code-search/model.py

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


class Model(nn.Module):
    def __init__(self, encoder):
        super(Model, self).__init__()
        self.encoder = encoder

    def forward(self, inputs):
        outputs = self.encoder(inputs)[1]
        return torch.nn.functional.normalize(outputs, p=2, dim=1)


def reload_model(state: State, state_dict):
    unixcoder_model = UniXcoder(
        model_name=state.config.unixcoder_model_path, state_dict=state_dict)
    unixcoder_model.to(state.device)
    return unixcoder_model


def get_text_tokens(state: State, model, text):
    tokens = model.tokenize(
        [text], max_length=512, mode="<encoder-only>")
    return torch.tensor(tokens).to(state.device)


def get_text_embedding(state: State, model, text):
    source_ids = get_text_tokens(state, model, text)
    _, embeddings = model(source_ids)
    return torch.nn.functional.normalize(embeddings, p=2, dim=1).cpu()[0]


@torch.no_grad()
def get_unix_nb_embeddings(state: State, model, nb):
    res = {}

    batch_size = state.config.batch_size
    n_chunks = len(nb) / min(len(nb), batch_size)

    nb = nb.sort_values(by="source", key=lambda x: x.str.len())
    for nb in np.array_split(nb, n_chunks):
        texts = nb['source'].to_numpy()

        tokens = model.tokenize(texts, max_length=MODEL_MAX_LEN,
                                mode="<encoder-only>", padding=True)
        source_ids = torch.tensor(tokens).to(state.device)
        _, embeddings = model(source_ids)
        normalized = torch.nn.functional.normalize(
            embeddings, p=2, dim=1).cpu()

        for key, val in zip(nb['source'].index, normalized):
            res[key] = val

    res['END'] = get_text_embedding(state, model, 'END')

    return res


class EnsembleModel(nn.Module):
    def __init__(self, state, state_dict=None):
        super(EnsembleModel, self).__init__()
        self.encoder = reload_model(state, state_dict=None)
        self.top = nn.Linear(768 + 6, 2)
        self.softmax = nn.Softmax(dim=1)
        self.name = ""
        if state_dict is not None:
            self.name = state_dict
            self.load_state_dict(torch.load(
                state_dict, map_location=state.device))
        self.to(state.device)

    def forward(self, inputs, additional_features, device):
        outputs = self.encoder(inputs)[1]
        joined = torch.cat((outputs, additional_features), 1).to(device)
        per_model = self.top(joined)
        return self.softmax(per_model)

    def save(self, suffix):
        output_dir = Path(".")
        output_path = os.path.join(
            output_dir, 'ensemble-model-{}.bin'.format(suffix))
        torch.save(self.state_dict(), output_path)
        print("Saved model to {}".format(output_path))


In [None]:
# graph_model.py:
import math
import torch.nn.functional as F
import os
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoModel, AutoTokenizer
from torch.optim import AdamW
import wandb
from tqdm import tqdm
from dataclasses import dataclass

import torch

from dataclasses import dataclass
import itertools

from torch.optim.lr_scheduler import CosineAnnealingLR


@dataclass
class Sample:
    markdown: str
    code: str


@dataclass
class SampleWithLabel:
    sample: Sample
    label: float


@dataclass
class TwoSamples:
    markdown: str
    correct_code: str
    wrong_code: str

    def max_len(self):
        l1 = len(self.markdown) + len(self.correct_code)
        l2 = len(self.markdown) + len(self.wrong_code)
        return max(l1, l2)


max_tokenizer_len = 256


class MyGraphModel(nn.Module):
    def __init__(self, state: State, preload_state=None, next_code_cells=1, coef_mul=25):
        super(MyGraphModel, self).__init__()
        self.graph = AutoModel.from_pretrained(
            '../input/graphcodebert-base/graphcodebert-base')
        self.tokenizer = AutoTokenizer.from_pretrained(
            '../input/graphcodebert-base/graphcodebert-base')
        self.top = nn.Linear(768, 1)
        self.dropout = nn.Dropout(0.2)
        self.next_code_cells = next_code_cells
        self.coef_mul = coef_mul
        if preload_state is not None:
            print('Preloading state:', preload_state)
            #state = torch.load(preload_state, map_location=state.device)
            cur_state = torch.load(preload_state)
            if 'state_dict' in cur_state:
                self.load_state_dict(cur_state['state_dict'])
                self.next_code_cells = cur_state['next_code_cells']
                self.coef_mul = cur_state['coef_mul']
            else:
                self.load_state_dict(cur_state)
        self.name = preload_state if preload_state is not None else "0"
        self.name += ";ncs=" + str(self.next_code_cells)
        self.to(state.device)

    def forward(self, input_ids, attention_mask, use_sigmoid, return_scalar):
        x = self.graph(input_ids=input_ids, attention_mask=attention_mask)[0]
        x = x[:, 0, :]
        x = self.dropout(x)
        if return_scalar:
            x = self.top(x)
        else:
            x = torch.nn.functional.normalize(x, p=2, dim=1)
        if use_sigmoid:
            x = torch.sigmoid(x)
        return x

    def encode_sample(self, sample):
        max_length = (max_tokenizer_len // 2 - 3)
        code_tokens = self.tokenizer.tokenize(
            sample.code, max_length=max_length, truncation=True)
        markdown_tokens = self.tokenizer.tokenize(
            sample.markdown, max_length=max_length, truncation=True)
        tokens = [self.tokenizer.cls_token] + markdown_tokens + \
            [self.tokenizer.sep_token] + code_tokens + [self.tokenizer.sep_token]
        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        return {'input_ids': token_ids, 'attention_mask': [1] * len(token_ids)}

    def encode(self, state: State, samples):
        input_ids = []
        attention_mask = []

        for sample in samples:
            encoded = self.encode_sample(sample)
            input_ids.append(encoded['input_ids'])
            attention_mask.append(encoded['attention_mask'])

        max_len = max(map(lambda x: len(x), input_ids))
        for i in range(len(input_ids)):
            more = max_len - len(input_ids[i])
            input_ids[i] += [self.tokenizer.pad_token_id] * more
            attention_mask[i] += [0] * more

        return {'input_ids': torch.LongTensor(input_ids).to(state.device),
                'attention_mask': torch.LongTensor(attention_mask).to(state.device)
                }

    def encode_texts(self, state: State, texts):
        input_ids = []
        attention_mask = []

        for text in texts:
            tokens = [self.tokenizer.cls_token_id] + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(
                text, max_length=MODEL_MAX_LEN-2, truncation=True)) + [self.tokenizer.sep_token_id]
            input_ids.append(tokens)
            attention_mask.append([1] * len(tokens))

        max_len = max(map(lambda x: len(x), input_ids))
        for i in range(len(input_ids)):
            more = max_len - len(input_ids[i])
            input_ids[i] += [self.tokenizer.pad_token_id] * more
            attention_mask[i] += [0] * more

        return {'input_ids': torch.LongTensor(input_ids).to(state.device),
                'attention_mask': torch.LongTensor(attention_mask).to(state.device)
                }

    @torch.no_grad()
    def predict(self, state: State, samples, use_sigmoid):
        result = []
        batches = split_into_batches(samples, state.config.batch_size)
        for batch in batches:
            encoded = self.encode(state, batch)
            pred = self(encoded['input_ids'],
                        encoded['attention_mask'], use_sigmoid)
            result += [x[0].item() for x in pred]
        return result

    def save(self, suffix, optimizer=None):
        output_dir = Path(".")
        output_path = os.path.join(
            output_dir, 'graph-model-{}.bin'.format(suffix))
        torch.save({'state_dict':self.state_dict(), 'next_code_cells':self.next_code_cells, 'coef_mul':self.coef_mul}, output_path)
        print("Saved model to {}".format(output_path))
        if optimizer is not None:
            output_path = os.path.join(
                output_dir, 'graph-model-{}.opt.bin'.format(suffix))
            torch.save(optimizer.state_dict(), output_path)


def train(state, model, dataset, save_to_wandb=False, optimizer_state=None):
    print('start training...')
    np.random.seed(123)
    learning_rate = 5e-5
    optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
    if optimizer_state is not None:
        print('loading optimizer state...')
        optimizer.load_state_dict(torch.load(optimizer_state))
    scheduler = CosineAnnealingLR(optimizer, T_max=len(dataset))
    # scheduler = get_linear_schedule_with_warmup(
    #    optimizer, num_warmup_steps=0.05*len(dataset), num_training_steps=len(dataset))
    model.train()
    print('training... num batches:', len(dataset))
    if save_to_wandb:
        init_wandb(name='graph-training')

    criterion = torch.nn.BCELoss()
    for b_id, batch in enumerate(tqdm(dataset)):
        samples = list(map(lambda x: x.sample, batch))
        encoded = model.encode(state, samples)
        input_ids = encoded['input_ids']
        attention_mask = encoded['attention_mask']
        target = list(map(lambda x: [x.label], batch))
        target = torch.FloatTensor(target).to(state.device)

        optimizer.zero_grad()
        pred = model(input_ids, attention_mask)

        loss = criterion(pred, target)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        if save_to_wandb:
            wandb.log({'graph_loss': loss.item()})

        if (b_id % 1000 == 999):
            print('Saving model after', b_id)
            model.save('step-batch-' + str(b_id), optimizer=optimizer)

    if save_to_wandb:
        wandb.finish()

    model.save('cur-final', optimizer=optimizer)


def train2(state, model, dataset, save_to_wandb=False, optimizer_state=None):
    print('start training...')
    np.random.seed(123)
    learning_rate = 3e-5
    optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
    #optimizer = bnb.optim.Adam8bit(model.parameters(), lr=learning_rate, betas=(0.9, 0.995))
    if optimizer_state is not None:
        print('loading optimizer state...')
        optimizer.load_state_dict(torch.load(optimizer_state))
    # scheduler = get_linear_schedule_with_warmup(
    #     optimizer, num_warmup_steps=0.05*len(dataset), num_training_steps=len(dataset))
    model.train()
    print('training... num batches:', len(dataset))
    if save_to_wandb:
        init_wandb(name='graph2-training')

    scaler = torch.cuda.amp.GradScaler()
    accumulation_steps = 1
    scheduler = CosineAnnealingLR(
        optimizer, T_max=len(dataset)/accumulation_steps)

    for b_id, batch in enumerate(tqdm(dataset)):
        samples = list(map(lambda x: [Sample(markdown=x.markdown, code=x.correct_code), Sample(
            markdown=x.markdown, code=x.wrong_code)], batch))
        samples = list(itertools.chain(*samples))
        encoded = model.encode(state, samples)
        input_ids = encoded['input_ids']
        attention_mask = encoded['attention_mask']

        with torch.cuda.amp.autocast():
            pred = model(input_ids, attention_mask, use_sigmoid=False)

            losses = []
            for i in range(len(batch)):
                sm = F.softmax(pred[i*2:i*2+2] * 10.0, dim=0)
                losses.append(sm[1])

            total_loss = sum(losses) / len(batch)
        scaler.scale(total_loss).backward()
        if b_id % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()

        #

        if save_to_wandb:
            wandb.log({'graph2_loss': total_loss.item()})

        if (b_id % 5000 == 4999):
            print('Saving model after', b_id)
            model.save('2-step-batch-' + str(b_id), optimizer=optimizer)

    if save_to_wandb:
        wandb.finish()

    model.save('2-cur-final', optimizer=optimizer)


@dataclass
class Embedding:
    cell_id: str
    text: str

@torch.no_grad()
def get_graph_nb_embeddings(state: State, model, nb):
    def get_code(cell_id):
        if cell_id == end_token:
            return end_token
        return nb.loc[cell_id]['source']

    code_cells = get_code_cells(nb).tolist()
    code_cells.append(end_token)

    to_convert = [Embedding(cell_id=end_token, text=end_token)]

    for cell_id in nb.index:
        text=get_code(cell_id)
        if cell_id in code_cells:
            idx = code_cells.index(cell_id)
            more_code_cells = code_cells[idx+1:idx+model.next_code_cells]
            for next_id in more_code_cells:
                text += model.tokenizer.sep_token + get_code(next_id)
        to_convert.append(
            Embedding(cell_id=cell_id, text=text))
    to_convert.sort(key=lambda x: len(x.text))

    num_chunks = (len(to_convert) + state.config.batch_size -
                  1) // state.config.batch_size

    result = {}
    for batch in np.array_split(to_convert, num_chunks):
        all_texts = list(map(lambda x: x.text, batch))
        encoded = model.encode_texts(state, all_texts)
        embeddings = model(
            input_ids=encoded['input_ids'], attention_mask=encoded['attention_mask'], use_sigmoid=False, return_scalar=False)
        for i in range(len(batch)):
            result[batch[i].cell_id] = embeddings[i].cpu()
    return result


In [None]:
# ensembles.py:

from dataclasses import dataclass
import torch

@dataclass
class Sample:
    text: str
    md_cell_id: str
    graph3_pos: float
    unix_pos: float
    total_cells: int
    md_cells: int
    code_cells: int
    part_code_cells: float
    target_pos: float


def gen_nb_samples(nb, graph3_embeddings, unix_embeddings, correct_order):
    code_cells = nb[nb['cell_type'] == 'code'].reset_index(level='cell_id')
    markdown_cells = nb[nb['cell_type'] != 'code'].reset_index(level='cell_id')

    code_cell_ids = code_cells['cell_id'].values.tolist()
    code_cell_ids.append('END')

    samples = []

    md_cells = len(markdown_cells)
    code_cells = len(code_cells)
    total_cells = md_cells + code_cells
    part_code_cells = code_cells / total_cells

    for m_cell_id in markdown_cells['cell_id'].values:
        text = nb.loc[m_cell_id]['source']
        graph_sims_probs = get_probs_by_embeddings(
            graph3_embeddings, m_cell_id, code_cell_ids, 25.0)
        unix_sims_probs = get_probs_by_embeddings(
            unix_embeddings, m_cell_id, code_cell_ids, 1000.0)

        graph3_pos = get_best_pos_by_probs(graph_sims_probs)
        unix_pos = get_best_pos_by_probs(unix_sims_probs)

        best_coef = 0

        if correct_order is not None:
            idx = correct_order.index(m_cell_id)
            next_code_cell = 'END'
            for i in range(idx+1, len(correct_order)):
                if correct_order[i] in code_cell_ids:
                    next_code_cell = correct_order[i]
                    break
            target_score = code_cell_ids.index(next_code_cell)
            OPTIONS = 20
            best_diff = 123.45
            sum_best_coefs = 0.0
            cnt_best_coefs = 0.0
            all_possible_positions = []
            for o in range(OPTIONS+1):
                coef = o/(OPTIONS)
                sim_probs = [graph_sims_probs[i] * coef + unix_sims_probs[i] * (1 - coef) for i in range(len(graph_sims_probs))]
                                
                pos = get_best_pos_by_probs(sim_probs)
                all_possible_positions.append(pos)
                diff = abs(pos - target_score)
                if diff < best_diff:
                    best_diff = diff
                    sum_best_coefs = 0.0
                    cnt_best_coefs = 0.0
                if diff == best_diff:
                    sum_best_coefs += coef
                    cnt_best_coefs += 1.0
            # all_possible_positions.sort()
            if all_possible_positions[0] == all_possible_positions[-1]:
                continue
            # print('target score:', target_score)
            # print(all_possible_positions)
            best_coef = sum_best_coefs / cnt_best_coefs
            # print('best coef:', best_coef)
        samples.append(Sample(md_cell_id=m_cell_id, text=text, graph3_pos=graph3_pos, unix_pos=unix_pos, total_cells=total_cells,
                       md_cells=md_cells, code_cells=code_cells, part_code_cells=part_code_cells, target_pos=best_coef))

    return samples


@torch.no_grad()
def gen_samples(state: State, nb, graph3_model: MyGraphModel, unixcoder_model, correct_order):
    graph3_embeddings = get_graph_nb_embeddings(state, graph3_model, nb)
    unix_embeddings = get_unix_nb_embeddings(state, unixcoder_model, nb)
    return gen_nb_samples(nb, graph3_embeddings, unix_embeddings, correct_order)


def predict(state: State, ensemble_model, samples):
    texts = list(map(lambda s: s.text, samples))
    additional_features = list(map(lambda s: torch.FloatTensor(
        [s.graph3_pos, s.unix_pos, s.total_cells, s.md_cells, s.code_cells, s.part_code_cells]), samples))
    additional_features = torch.stack(additional_features).to(state.device)

    to_mul = list(map(lambda s: torch.FloatTensor(
        [s.graph3_pos, s.unix_pos]), samples))
    to_mul = torch.stack(to_mul).to(state.device)

    coefs = None
    if state.config.use_simple_ensemble_model:
        coefs = ensemble_model(additional_features)
        #coefs = [torch.FloatTensor([0.4, 0.6]) for _ in range(len(samples))]
        #coefs = torch.stack(coefs).to(state.device)
    else:
        text_tokens = ensemble_model.encoder.tokenize(
            texts, max_length=512, mode="<encoder-only>", padding=True)
        text_tokens = torch.tensor(text_tokens).to(state.device)
        coefs = ensemble_model(text_tokens, additional_features, state.device)
    preds = torch.einsum("ab,ab->a", coefs, to_mul)

    return {'coefs': coefs, 'preds': preds}


In [None]:
config = get_kaggle_config()
state = State(config)


graph3_model = MyGraphModel(state, preload_state=GRAPH_MODEL_NAME)
graph3_model.to(state.device)
print('Graph3 model loaded')

unixcoder_model = reload_model(state, UNIX_MODEL_NAME)
print('Unixcoder model loaded')

from tqdm import tqdm
from dataclasses import dataclass
import wandb
import numpy as np
import torch
import math

def get_probs_by_embeddings(embeddings, m_cell_id, code_cell_ids, coef_mul):
    markdown_emb = embeddings[m_cell_id]
    sims = [sim(markdown_emb, embeddings[c]) for c in code_cell_ids]
    max_sim = max(sims)
    sims_probs = list(map(lambda x:math.exp((x-max_sim) * coef_mul), sims))
    sum_probs = sum(sims_probs)
    sims_probs = list(map(lambda x:x/sum_probs, sims_probs))
    return sims_probs    


@torch.no_grad()
def predict_order(state: State, nb, graph3_model: MyGraphModel, unixcoder_model, graph3_embeddings, unix_embeddings, graph_weight):
    code_cells = nb[nb['cell_type'] == 'code'].reset_index(level='cell_id')
    
    code_cell_ids = code_cells['cell_id'].values.tolist()
    code_cell_ids.append('END')
    
    cells = []
    for pos, cell_id in enumerate(get_code_cells(nb)):
        cells.append(OneCell(score=pos+0.5, cell_id=cell_id, cell_type="code"))

    markdown_cells = get_markdown_cells(nb)

    for cell_id in markdown_cells:            
        coef = graph_weight

        graph_sims_probs = get_probs_by_embeddings(graph3_embeddings, cell_id, code_cell_ids, graph3_model.coef_mul)
        unix_sims_probs = get_probs_by_embeddings(unix_embeddings, cell_id, code_cell_ids, 1000.0)
        sims_probs = [a*coef + b*(1 - coef) for (a, b) in zip(graph_sims_probs, unix_sims_probs)]
        scores = [0.0] * len(sims_probs)
        for i in range(len(sims_probs)):
            for j in range(len(sims_probs)):
                scores[j] += abs(i - j) * sims_probs[i]
        best_pos = scores.index(min(scores))


        cells.append(OneCell(score=best_pos, cell_id=cell_id, cell_type="markdown"))

    cells.sort(key=lambda x:x.score)
    return list(map(lambda c:c.cell_id, cells))


In [None]:
state.load_test_nbs()
# state.load_train_nbs_tail(100)

In [None]:
@torch.no_grad()
def save_results(state, graph3_model, unixcoder_model):
    res = []
    graph3_model.eval()
    unixcoder_model.eval()
     
    print('Start using the model:')

    df = state.test_df
    # df = state.cur_train_nbs # TODO: change!
    all = df.index.get_level_values(0).unique()

    for cnt, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]
        graph3_embeddings = get_graph_nb_embeddings(state, graph3_model, nb)
        unix_embeddings = get_unix_nb_embeddings(state, unixcoder_model, nb)

        my_order = predict_order(state, nb, graph3_model, unixcoder_model, graph3_embeddings, unix_embeddings, graph_weight=GRAPH_WEIGHT)
        
        my_order = " ".join(my_order)
        res.append({'id' : nb_id, 'cell_order' : my_order})
        
    
    res = pd.DataFrame(res)
    res.to_csv('submission.csv', index=False)

    display(res.head())
    
save_results(state, graph3_model, unixcoder_model)