In [1]:
import pandas as pd
import numpy as np
import torch

from transformers import AutoTokenizer, BertTokenizer
from torch.utils.data import DataLoader

In [2]:
from tqdm import tqdm
from utils import kendall_tau
from functools import cmp_to_key


class Tester:
    def __init__(self, model, device):
        self.model = model
        self.device = device

    def test(self, test_dataloader):
        self.model.to(self.device)
        self.model.eval()
        true_order = []
        predicted_order = []

        with torch.no_grad():
            for cells, correct_order in tqdm(test_dataloader):
                sorted_cells = sorted(cells, key=cmp_to_key(self._custom_compare))
                sorted_order = [cell[0] for cell in sorted_cells]
                true_order.append(correct_order)
                predicted_order.append(sorted_order)

        return kendall_tau(true_order, predicted_order)

    def _custom_compare(self, cell1, cell2):
        result = self.model(
            cell1[1].squeeze(0).to(self.device),
            cell1[2].squeeze(0).to(self.device),
            cell1[3].squeeze(0).to(self.device),
            cell2[1].squeeze(0).to(self.device),
            cell2[2].squeeze(0).to(self.device),
            cell2[3].squeeze(0).to(self.device),
        )

        if result.item() <= 0.5:
            return -1
        else:
            return 1
import torch

from tqdm import tqdm
from utils import kendall_tau
from functools import cmp_to_key


class Tester:
    def __init__(self, model, device):
        self.model = model
        self.device = device

    def test(self, test_dataloader):
        self.model.to(self.device)
        self.model.eval()
        true_order = []
        predicted_order = []

        with torch.no_grad():
            for cells, correct_order in tqdm(test_dataloader):
                sorted_cells = sorted(cells, key=cmp_to_key(self._custom_compare))
                sorted_order = [cell[0] for cell in sorted_cells]
                true_order.append(correct_order)
                predicted_order.append(sorted_order)

        return kendall_tau(true_order, predicted_order)

    def _custom_compare(self, cell1, cell2):
        result = self.model(
            cell1[1].squeeze(0).to(self.device),
            cell1[2].squeeze(0).to(self.device),
            cell1[3].squeeze(0).to(self.device),
            cell2[1].squeeze(0).to(self.device),
            cell2[2].squeeze(0).to(self.device),
            cell2[3].squeeze(0).to(self.device),
        )

        if result.item() <= 0.5:
            return -1
        else:
            return 1

import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from time import time


class Trainer:
    def __init__(
        self,
        model,
        train_dataloader,
        valid_dataloader,
        savedir,
        device,
        epochs=10,
        early_stopping=5,
        saving_freq=5,
        lr=1e-4,
    ):

        self.device = device

        self.model = model.to(device)
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.criterion = nn.BCELoss()
        self.optimizer = optim.NAdam(self.model.parameters(), lr=lr)
        self.epochs = epochs
        self.early_stopping = early_stopping

        self.best_score = -float("inf")
        self.best_model = None

        self.savedir = savedir
        self.saving_freq = saving_freq

    def train(self):
        early_stopping_remaining = self.early_stopping
        print("*" * 80)
        print(f"Train model")

        for epoch in range(1, self.epochs + 1):
            print("*" * 80)
            print(f"Epoch {epoch}/{self.epochs}")
            start_time = time()
            train_loss = self._train_one_epoch()
            valid_score = self._validate()

            print(f"Train loss: {train_loss:.4f}, Valid accuracy: {valid_score:.4f}")
            print(f"Epoch execution time: {time() - start_time:.2f} seconds")

            if valid_score > self.best_score:
                early_stopping_remaining = self.early_stopping
                self.best_score = valid_score
                self.best_model = {k: v.cpu() for k, v in self.model.state_dict().items()}
                print(f"New best model saved with valid accuracy: {valid_score:.4f}")
            else:
                early_stopping_remaining -= 1

            if epoch % self.saving_freq == 0:
                self._save_checkpoint(epoch, train_loss)

            if not early_stopping_remaining:
                print(f"Training stopped at {epoch} epoch")
                break

        if self.best_model:
            torch.save(self.best_model, f"{self.savedir}best_model.pt")
            print("Best model saved as 'best_model.pt'.")

    def _train_one_epoch(self):
        self.model.train()
        train_loss = 0
        n_batches = 0

        for (first_cell, second_cell), train_label in tqdm(self.train_dataloader):
            self.optimizer.zero_grad()
            output = self.model(
                first_cell[0].squeeze(1).to(self.device),
                first_cell[1].squeeze(1).to(self.device),
                first_cell[2].squeeze(1).to(self.device),
                second_cell[0].squeeze(1).to(self.device),
                second_cell[1].squeeze(1).to(self.device),
                second_cell[2].squeeze(1).to(self.device),
            )
            loss = self.criterion(output, train_label.float().to(self.device))
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()
            n_batches += 1

        return train_loss / n_batches

    def _validate(self):
        self.model.eval()
        score = 0
        n_batches = 0

        with torch.no_grad():
            for (first_cell, second_cell), correct_order in tqdm(self.valid_dataloader):
                n_batches += 1
                output = self.model(
                    first_cell[0].squeeze(1).to(self.device),
                    first_cell[1].squeeze(1).to(self.device),
                    first_cell[2].squeeze(1).to(self.device),
                    second_cell[0].squeeze(1).to(self.device),
                    second_cell[1].squeeze(1).to(self.device),
                    second_cell[2].squeeze(1).to(self.device),
                )

                output += 0.5
                order = output.to(dtype=torch.int32).cpu()
                score += sum(order == correct_order).sum() / correct_order.shape[0]

        score /= n_batches

        return score

    def _save_checkpoint(self, epoch, train_loss):
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": {k: v.cpu() for k, v in self.model.state_dict().items()},
            "optimizer_state_dict": self.optimizer.state_dict(),
            "train_loss": train_loss,
        }
        checkpoint_path = f"{self.savedir}checkpoint_epoch_{epoch}.pt"
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}.")

import os
import torch

from time import localtime, strftime
from bisect import bisect


def prepare_folders():
    current_time = strftime("%d.%m.%Y-%H.%M", localtime())
    savedir = f"./checkpoints/{current_time}/"

    if not os.path.exists("./checkpoints"):
        os.mkdir("./checkpoints/")
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    else:
        for root, dirs, files in os.walk(savedir, topdown=False):
            for name in files:
                os.remove(os.path.join(root, name))
            for name in dirs:
                os.rmdir(os.path.join(root, name))

    return savedir


def get_device():
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
    return device


def count_inversions(a):
    inversions = 0
    sorted_so_far = []
    for i, u in enumerate(a):
        j = bisect(sorted_so_far, u)
        inversions += i - j
        sorted_so_far.insert(j, u)
    return inversions


def kendall_tau(ground_truth, predictions):
    total_inversions = 0
    total_2max = 0
    for gt, pred in zip(ground_truth, predictions):
        ranks = [gt.index(x) for x in pred]
        total_inversions += count_inversions(ranks)
        n = len(gt)
        total_2max += n * (n - 1)
    return 1 - 4 * total_inversions / total_2max

import torch
import torch.nn as nn

from transformers import BertModel, AutoModel

import numpy as np

from torch.utils.data import Sampler


class CellSampler(Sampler):
    def __init__(self, data, seed=None):
        self.data = data
        self.seed = seed
        n_pair = 0
        for row_index in self.data.index:
            n_pair += len(self.data.loc[row_index, "cell_order"]) - 1
        self.n_pair = n_pair

    def __len__(self):
        return self.n_pair

    def __iter__(self):
        pairs = []
        for row_index in self.data.index:
            cells = self.data.loc[row_index, "cell_order"].copy()
            if self.seed:
                rng = np.random.default_rng(self.seed)
                rng.shuffle(cells)
            else:
                np.random.shuffle(cells)
            for cell_index in range(len(cells) - 1):
                pairs.append([row_index, cells[cell_index], cells[cell_index + 1]])

        for pair in pairs:
            yield pair


class OrderPredictionModel(nn.Module):
    def __init__(self, hidden_dim, dropout_prob=0.1):
        super(OrderPredictionModel, self).__init__()

        self.bert_text = BertModel.from_pretrained("bert-base-multilingual-uncased")
        self.codebert = AutoModel.from_pretrained("microsoft/codebert-base")

        self.type_embedding = nn.Embedding(2, 8)
        self.fc1 = nn.Linear(768 * 2 + 8 * 2, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout_prob)
        self.fc2 = nn.Linear(hidden_dim, 1)

    # def forward(self, input_ids1, att_mask1, cell_type1, input_ids2, att_mask2, cell_type2):
    #     with torch.no_grad():
    #         embedding1 = self.bert(input_ids1, attention_mask=att_mask1).pooler_output
    #         embedding2 = self.bert(input_ids2, attention_mask=att_mask2).pooler_output

    #     type_emb1 = self.type_embedding(cell_type1)
    #     type_emb2 = self.type_embedding(cell_type2)

    #     combined = torch.cat([embedding1, type_emb1, embedding2, type_emb2], dim=1)
    #     x = torch.relu(self.bn1(self.fc1(combined)))
    #     x = self.dropout(x)
    #     output = torch.sigmoid(self.fc2(x))
    #     return output.squeeze(1)

    @staticmethod
    def _get_batch_embeddings(input_ids, attention_mask, cell_type, code_model, text_model):

        device = input_ids.device
        batch_size = input_ids.size(0)

        hidden_size = code_model.config.hidden_size

        embeddings = torch.zeros(batch_size, hidden_size, device=device, dtype=torch.float32)

        code_mask = (cell_type == 1)
        text_mask = (cell_type == 0)

        if code_mask.any():
            code_indices = code_mask.nonzero(as_tuple=True)[0]
            code_input_ids = input_ids[code_indices]
            code_attention_mask = attention_mask[code_indices]

            out_code = code_model(code_input_ids, attention_mask=code_attention_mask).pooler_output
            embeddings[code_indices] = out_code

        if text_mask.any():
            text_indices = text_mask.nonzero(as_tuple=True)[0]
            text_input_ids = input_ids[text_indices]
            text_attention_mask = attention_mask[text_indices]

            out_text = text_model(text_input_ids, attention_mask=text_attention_mask).pooler_output
            embeddings[text_indices] = out_text

        return embeddings

    def forward(self, input_ids1, att_mask1, cell_type1, input_ids2, att_mask2, cell_type2):

        embedding1 = self._get_batch_embeddings(input_ids1, att_mask1, cell_type1, 
                                                code_model=self.codebert,text_model=self.bert_text)

        embedding2 = self._get_batch_embeddings(input_ids2, att_mask2, cell_type2, code_model=self.codebert, text_model=self.bert_text)

        type_emb1 = self.type_embedding(cell_type1)
        type_emb2 = self.type_embedding(cell_type2)

        combined = torch.cat([embedding1, type_emb1, embedding2, type_emb2], dim=1)
        x = torch.relu(self.bn1(self.fc1(combined)))
        x = self.dropout(x)
        output = torch.sigmoid(self.fc2(x))

        return output.squeeze(1)

import numpy as np

from Datasets.cell_dataset import CellDataset


class TestCellDataset(CellDataset):
    def __init__(self, path, data, code_tokenizer, text_tokenizer, max_length):
        super().__init__(path, data, code_tokenizer, text_tokenizer, max_length)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        file_id = self.data.iloc[idx].name
        correct_order = self.data.iloc[idx].item()
        random_order = correct_order.copy()
        np.random.shuffle(random_order)

        cells = []
        for index in random_order:
            input_ids, att_mask, cell_type = self.files[file_id][index].get()
            cells.append([index, input_ids, att_mask, cell_type])

        return cells, correct_order
from Datasets.cell_dataset import CellDataset


class TrainValCellDataset(CellDataset):
    def __init__(self, path, data, code_tokenizer, text_tokenizer, max_length):
        super().__init__(path, data, code_tokenizer, text_tokenizer, max_length)

        n_pair = 0
        for row_index in self.data.index:
            n_pair += len(self.data.loc[row_index, "cell_order"]) - 1
        self.n_pair = n_pair

    def __len__(self):
        return self.n_pair

    def __getitem__(self, idx):
        filename = idx[0]
        first_cell_id = idx[1]
        second_cell_id = idx[2]

        first_position = self.data.loc[filename, "cell_order"].index(first_cell_id)
        second_position = self.data.loc[filename, "cell_order"].index(second_cell_id)
        order = 0 if first_position < second_position else 1

        return ((self.files[filename][first_cell_id].get(), self.files[filename][second_cell_id].get()), order)

import numpy as np

from Datasets.cell_dataset import CellDataset


class TestCellDataset(CellDataset):
    def __init__(self, path, data, code_tokenizer, text_tokenizer, max_length):
        super().__init__(path, data, code_tokenizer, text_tokenizer, max_length)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        file_id = self.data.iloc[idx].name
        correct_order = self.data.iloc[idx].item()
        random_order = correct_order.copy()
        np.random.shuffle(random_order)

        cells = []
        for index in random_order:
            input_ids, att_mask, cell_type = self.files[file_id][index].get()
            cells.append([index, input_ids, att_mask, cell_type])

        return cells, correct_order

class CellSampler(Sampler):
    def __init__(self, data, seed=None):
        self.data = data
        self.seed = seed
        n_pair = 0
        for row_index in self.data.index:
            n_pair += len(self.data.loc[row_index, "cell_order"]) - 1
        self.n_pair = n_pair

    def __len__(self):
        return self.n_pair

    def __iter__(self):
        pairs = []
        for row_index in self.data.index:
            cells = self.data.loc[row_index, "cell_order"].copy()
            if self.seed:
                rng = np.random.default_rng(self.seed)
                rng.shuffle(cells)
            else:
                np.random.shuffle(cells)
            for cell_index in range(len(cells) - 1):
                pairs.append([row_index, cells[cell_index], cells[cell_index + 1]])

        for pair in pairs:
            yield pair

class Cell:
    def __init__(self, input_ids, att_mask, cell_type):
        self.input_ids = input_ids
        self.att_mask = att_mask
        self.cell_type = cell_type

    def get(self):
        return (self.input_ids, self.att_mask, self.cell_type)
import torch
import json

from tqdm import tqdm
from torch.utils.data import Dataset

from Datasets.cell import Cell


class CellDataset(Dataset):
    def __init__(self, path, data, code_tokenizer, text_tokenizer, max_length):
        self.data = data
        self.code_tokenizer = code_tokenizer
        self.text_tokenizer = text_tokenizer
        self.max_length = max_length
        self.files = {}

        for filename in tqdm(self.data.index):
            cells_dict = {}
            cells = self.data.loc[filename, "cell_order"]
            with open(f"{path}{filename}.json") as file:
                json_code = json.load(file)
            for cell in cells:
                input_ids, att_mask, cell_type = self.prepare_data(
                    json_code["cell_type"][cell], json_code["source"][cell]
                )
                cells_dict[cell] = Cell(input_ids, att_mask, cell_type)
            self.files[filename] = cells_dict

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

    def prepare_data(self, cell_type, cell_content):

        if cell_type == "code":
            tokenizer = self.code_tokenizer
            type_label = 1
        else:
            tokenizer = self.text_tokenizer
            type_label = 0

        tokens = tokenizer(
            cell_content,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
        )
    
        type_tensor = torch.tensor([type_label], dtype=torch.long)

        return (tokens["input_ids"], tokens["attention_mask"], type_tensor)

    

In [5]:
code_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
text_tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-uncased")
np.random.seed(42)

print("*" * 80)
print("Reading data")
info = pd.read_csv("/home/drkocharyan/ai4code/AI4Code/train_orders.csv", index_col="id")
info["cell_order"] = info["cell_order"].apply(lambda x: x.split())
indeces = list(info.index)

rng = np.random.default_rng(42)
rng.shuffle(indeces)

train_size = 0.7
valid_size = 0.2
test_size = 0.1

train_border = int(train_size * len(indeces))
valid_border = int((train_size + valid_size) * len(indeces))

train_data = info.loc[indeces[:train_border]]
valid_data = info.loc[indeces[train_border:valid_border]]
test_data = info.loc[indeces[valid_border:]]

train_data_short = train_data
valid_data_short = valid_data
test_data_short = test_data




********************************************************************************
Reading data


In [None]:
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor

class TrainValCellDataset(CellDataset):
    def __init__(self, path, data, code_tokenizer, text_tokenizer, max_length):
        super().__init__(path, data, code_tokenizer, text_tokenizer, max_length)

        n_pair = 0
        for row_index in self.data.index:
            n_pair += len(self.data.loc[row_index, "cell_order"]) - 1
        self.n_pair = n_pair

        # Создаем пул процессов
        self.executor = ProcessPoolExecutor(max_workers=mp.cpu_count())

    def __len__(self):
        return self.n_pair

    def __getitem__(self, idx):
        filename = idx[0]
        first_cell_id = idx[1]
        second_cell_id = idx[2]

        # Параллельно вычисляем позиции ячеек
        future1 = self.executor.submit(self.get_position, filename, first_cell_id)
        future2 = self.executor.submit(self.get_position, filename, second_cell_id)
        first_position, second_position = future1.result(), future2.result()

        order = 0 if first_position < second_position else 1

        return ((self.files[filename][first_cell_id].get(), self.files[filename][second_cell_id].get()), order)

    def get_position(self, filename, cell_id):
        return self.data.loc[filename, "cell_order"].index(cell_id)

    def __del__(self):
        # Закрываем пул процессов
        self.executor.shutdown(wait=True)

# Создаем экземпляр датасета
train_dataset = TrainValCellDataset('/home/drkocharyan/ai4code/AI4Code/train/', train_data_short, code_tokenizer, text_tokenizer, 128)

In [None]:
import multiprocessing as mp

class TrainValCellDataset(CellDataset):
    def __init__(self, path, data, code_tokenizer, text_tokenizer, max_length):
        super().__init__(path, data, code_tokenizer, text_tokenizer, max_length)

        n_pair = 0
        for row_index in self.data.index:
            n_pair += len(self.data.loc[row_index, "cell_order"]) - 1
        self.n_pair = n_pair

        # Создаем пул процессов с использованием 7 ядер
        self.pool = mp.Pool(processes=12)

        # Кэш для хранения уже вычисленных позиций
        self.position_cache = {}

    def __len__(self):
        return self.n_pair

    def __getitem__(self, idx):
        filename = idx[0]
        first_cell_id = idx[1]
        second_cell_id = idx[2]

        # Получаем позиции из кэша или вычисляем их
        first_position = self.get_cached_position(filename, first_cell_id)
        second_position = self.get_cached_position(filename, second_cell_id)

        order = 0 if first_position < second_position else 1

        return ((self.files[filename][first_cell_id].get(), self.files[filename][second_cell_id].get()), order)

    def get_cached_position(self, filename, cell_id):
        # Проверяем, есть ли результат в кэше
        if (filename, cell_id) in self.position_cache:
            return self.position_cache[(filename, cell_id)]
        
        # Если нет, вычисляем и сохраняем в кэш
        position = self.data.loc[filename, "cell_order"].index(cell_id)
        self.position_cache[(filename, cell_id)] = position
        return position

# Создаем экземпляр датасета
train_dataset = TrainValCellDataset('/home/drkocharyan/ai4code/AI4Code/train/', train_data_short, code_tokenizer, text_tokenizer, 128)

In [None]:
import os
os.cpu_count()
num_workers=12

In [None]:
# train_dataset = TrainValCellDataset('/home/drkocharyan/ai4code/AI4Code/train/', train_data_short, code_tokenizer, text_tokenizer, 128)


In [None]:

train_dataloader = DataLoader(train_dataset, batch_size=64, drop_last=True, sampler=train_sampler, num_workers=num_workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, drop_last=True, sampler=valid_sampler, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=num_workers)

In [6]:
train_dataset = TrainValCellDataset('/home/drkocharyan/ai4code/AI4Code/train/', train_data_short, code_tokenizer, text_tokenizer, 128)
train_sampler = CellSampler(train_data_short)
train_dataloader = DataLoader(train_dataset, 64, drop_last=True, sampler=train_sampler)

valid_dataset = TrainValCellDataset("/home/drkocharyan/ai4code/AI4Code/train/", valid_data_short, code_tokenizer, text_tokenizer, 128)
valid_sampler = CellSampler(valid_data_short, 42)
valid_dataloader = DataLoader(valid_dataset, 64, drop_last=True, sampler=valid_sampler)

test_dataset = TestCellDataset("/home/drkocharyan/ai4code/AI4Code/train/", test_data_short, code_tokenizer, text_tokenizer, 128)
test_dataloader = DataLoader(test_dataset, 1, shuffle=False)


100%|██████████| 97479/97479 [1:48:13<00:00, 15.01it/s]   
100%|██████████| 27851/27851 [33:04<00:00, 14.04it/s]  
100%|██████████| 13926/13926 [16:24<00:00, 14.14it/s] 


In [None]:
import pickle
from pathlib import Path

# Функция для сохранения токенизированных данных
def save_tokenized_data(dataset, filename):
    with open(filename, 'wb') as f:
        pickle.dump(dataset, f)

# Сохранение тренировочных и валидационных данных
save_tokenized_data(train_dataset, 'train_tokenized_full.pkl')
save_tokenized_data(valid_dataset, 'valid_tokenized_full.pkl')
save_tokenized_data(test_dataset, 'test_tokenized_full.pkl')