In [None]:
# Imports
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import AutoTokenizer, BertTokenizer, AutoModel, BertModel
from tqdm import tqdm
import os
import json
from bisect import bisect
from time import localtime, strftime
import matplotlib.pyplot as plt
import seaborn as sns

CONFIG = {
    "data_path": "../AI4Code_data/",
    "train_orders_path": "../AI4Code_data/train_orders.csv",
    "train_notebooks_path": "../AI4Code_data/train/",
    
    "random_seed": 42,
    "train_size": 0.7,
    "valid_size": 0.2,
    "test_size": 0.1,
    
    "debug_mode": True,
    "train_sample_size": 40,
    "valid_sample_size": 10,
    "test_sample_size": 10,
    
    "hidden_dim": 128,
    "dropout_prob": 0.1,
    "max_length": 128,
    
    "learning_rate": 1e-4,
    "epochs": 5,
    "batch_size": 1,
    
    "savedir_name": "checkpoints_listwise"
}


In [None]:
def prepare_folders(savedir_name="checkpoints_listwise"):
    """Prepare folders for saving model checkpoints."""
    current_time = strftime("%d.%m.%Y-%H.%M", localtime())
    savedir = f"./{savedir_name}/{current_time}/"

    if not os.path.exists(f"./{savedir_name}"):
        os.mkdir(f"./{savedir_name}/")
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    else:
        for root, dirs, files in os.walk(savedir, topdown=False):
            for name in files:
                os.remove(os.path.join(root, name))
            for name in dirs:
                os.rmdir(os.path.join(root, name))

    return savedir

def get_device():
    """Get appropriate device for training."""
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    return device

def count_inversions(a):
    """Count the number of inversions in array a."""
    inversions = 0
    sorted_so_far = []
    for i, u in enumerate(a):
        j = bisect(sorted_so_far, u)
        inversions += i - j
        sorted_so_far.insert(j, u)
    return inversions

def kendall_tau(ground_truth, predictions):
    """Calculate the Kendall Tau correlation metric."""
    total_inversions = 0
    total_2max = 0
    for gt, pred in zip(ground_truth, predictions):
        ranks = [gt.index(x) for x in pred]
        total_inversions += count_inversions(ranks)
        n = len(gt)
        total_2max += n * (n - 1)
    return 1 - 4 * total_inversions / total_2max


In [None]:
class ListWiseCellDataset(Dataset):
    def __init__(self, path, data, code_tokenizer, text_tokenizer, max_length=128):
        super().__init__()
        self.path = path
        self.data = data
        self.notebook_ids = list(data.index)
        self.code_tokenizer = code_tokenizer
        self.text_tokenizer = text_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.notebook_ids)

    def __getitem__(self, idx):
        notebook_id = self.notebook_ids[idx]
        cell_order = self.data.loc[notebook_id, "cell_order"]
        
        if isinstance(cell_order, str):
            cell_order = cell_order.split()
            
        with open(f"{self.path}{notebook_id}.json", "r") as f:
            nb_json = json.load(f)

        input_ids_list = []
        attn_mask_list = []
        cell_type_list = []

        for cell_id in cell_order:
            ctype = nb_json["cell_type"][cell_id]
            csource = nb_json["source"][cell_id]

            if ctype == "code":
                tok = self.code_tokenizer(
                    csource,
                    max_length=self.max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )
                cell_type_list.append(1)
            else:  # markdown
                tok = self.text_tokenizer(
                    csource,
                    max_length=self.max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )
                cell_type_list.append(0)

            input_ids_list.append(tok["input_ids"].squeeze(0))
            attn_mask_list.append(tok["attention_mask"].squeeze(0))

        input_ids_tensor = torch.stack(input_ids_list, dim=0)
        attn_mask_tensor = torch.stack(attn_mask_list, dim=0)
        cell_type_tensor = torch.tensor(cell_type_list, dtype=torch.long)

        return {
            "id_notebook": notebook_id,
            "cell_ids": cell_order,
            "input_ids": input_ids_tensor,
            "attention_mask": attn_mask_tensor,
            "cell_types": cell_type_tensor
        }

In [None]:
class ListWiseOrderPredictionModel(nn.Module):
    def __init__(self, hidden_dim=768, dropout_prob=0.1):
        super().__init__()
        self.codebert = AutoModel.from_pretrained("microsoft/codebert-base")
        self.bert_text = BertModel.from_pretrained("bert-base-multilingual-uncased")
        self.type_embedding = nn.Embedding(2, 8)

        self.proj = nn.Linear(768 + 8, hidden_dim)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(hidden_dim, 1)

    def forward(self, input_ids, attention_mask, cell_types):
        device = input_ids.device
        N = input_ids.size(0)

        code_mask = (cell_types == 1)
        text_mask = (cell_types == 0)

        embeddings = torch.zeros(N, 768, device=device, dtype=torch.float32)

        if code_mask.any():
            code_idx = code_mask.nonzero(as_tuple=True)[0]
            out_code = self.codebert(
                input_ids[code_idx],
                attention_mask=attention_mask[code_idx]
            ).pooler_output
            embeddings[code_idx] = out_code

        if text_mask.any():
            text_idx = text_mask.nonzero(as_tuple=True)[0]
            out_text = self.bert_text(
                input_ids[text_idx],
                attention_mask=attention_mask[text_idx]
            ).pooler_output
            embeddings[text_idx] = out_text

        type_emb = self.type_embedding(cell_types)

        x = torch.cat([embeddings, type_emb], dim=1)
        x = self.dropout(self.act(self.proj(x)))
        scores = self.classifier(x)
        return scores.squeeze(-1)


In [None]:
def listnet_loss(scores, n_items):
    rank_tensor = torch.arange(n_items, device=scores.device, dtype=torch.float32)
    q_unnorm = torch.exp(-rank_tensor)
    q = q_unnorm / q_unnorm.sum()

    p = F.softmax(scores, dim=0)
    eps = 1e-10
    loss = - (q * torch.log(p + eps)).sum()
    return loss

def custom_collate_fn(batch):
    ids_notebook = [item["id_notebook"] for item in batch]
    
    cell_ids_lists = [item["cell_ids"] for item in batch]
    
    input_ids = [item["input_ids"] for item in batch]
    attention_masks = [item["attention_mask"] for item in batch]
    cell_types = [item["cell_types"] for item in batch]
    
    return {
        "id_notebook": ids_notebook,
        "cell_ids": cell_ids_lists,
        "input_ids": input_ids,
        "attention_mask": attention_masks,
        "cell_types": cell_types
    }

class ListWiseTrainer:
    def __init__(self, model, train_dataset, valid_dataset, device, save_dir, lr=1e-4, epochs=5):
        self.model = model
        self.device = device
        self.epochs = epochs
        self.save_dir = save_dir

        self.train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
        self.valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)

        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
        self.best_kendall = -999.0
        self.best_model_state = None
        
        self.train_losses = []
        self.valid_kendalls = []

    def train(self):
        self.model.to(self.device)
        for epoch in range(1, self.epochs+1):
            print(f"\nEpoch [{epoch}/{self.epochs}]")
            train_loss = self._train_one_epoch()
            val_kendall = self._validate()
            
            self.train_losses.append(train_loss)
            self.valid_kendalls.append(val_kendall)
            
            print(f"Train loss: {train_loss:.4f}, Valid Kendall Tau: {val_kendall:.4f}")

            if val_kendall > self.best_kendall:
                self.best_kendall = val_kendall

                self.best_model_state = {
                    k: v.cpu() for k, v in self.model.state_dict().items()
                }
                print("New best model saved.")

        if self.best_model_state is not None:
            torch.save(self.best_model_state, os.path.join(self.save_dir, "best_model.pt"))
            print(f"Best model with Kendall Tau={self.best_kendall:.4f} saved.")
            
        return self.train_losses, self.valid_kendalls

    def _train_one_epoch(self):
        self.model.train()
        total_loss = 0.0
        n_batches = 0

        for batch in tqdm(self.train_loader, desc="Training"):
            self.optimizer.zero_grad()
            input_ids = batch["input_ids"][0].to(self.device)
            att_mask  = batch["attention_mask"][0].to(self.device)
            cell_types= batch["cell_types"][0].to(self.device)
            N = input_ids.size(0)

            scores = self.model(input_ids, att_mask, cell_types)
            loss = listnet_loss(scores, N)
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            n_batches += 1

        return total_loss / max(1, n_batches)

    def _validate(self):
        self.model.eval()
        all_gt_orders = []
        all_pred_orders = []

        with torch.no_grad():
            for batch in self.valid_loader:
                notebook_id = batch["id_notebook"][0]
                
                if isinstance(batch["cell_ids"], list):
                    cell_ids = batch["cell_ids"][0]
                else:
                    cell_ids = list(batch["cell_ids"][0])
                
                input_ids = batch["input_ids"][0].to(self.device)
                att_mask = batch["attention_mask"][0].to(self.device)
                cell_types = batch["cell_types"][0].to(self.device)

                scores = self.model(input_ids, att_mask, cell_types)
                scores_cpu = scores.cpu().numpy()

                idx_sorted = np.argsort(-scores_cpu)
                idx_sorted = idx_sorted[:len(cell_ids)]
                
                predicted_ids = [cell_ids[i] for i in idx_sorted]

                all_gt_orders.append(list(cell_ids))
                all_pred_orders.append(predicted_ids)

        ktau = kendall_tau(all_gt_orders, all_pred_orders)
        return ktau

class ListWiseTester:
    def __init__(self, model, device):
        self.model = model
        self.device = device
    
    def test(self, test_dataset):
        loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)
        all_gt_orders = []
        all_pred_orders = []

        self.model.eval()
        with torch.no_grad():
            for batch in loader:
                if isinstance(batch["cell_ids"], list):
                    cell_ids = batch["cell_ids"][0]
                else:
                    cell_ids = list(batch["cell_ids"][0])
                    
                input_ids = batch["input_ids"][0].to(self.device)
                att_mask = batch["attention_mask"][0].to(self.device)
                cell_types = batch["cell_types"][0].to(self.device)

                scores = self.model(input_ids, att_mask, cell_types)
                scores_cpu = scores.cpu().numpy()
                
                idx_sorted = np.argsort(-scores_cpu)
                idx_sorted = idx_sorted[:len(cell_ids)]
                
                predicted_ids = [cell_ids[i] for i in idx_sorted]

                all_gt_orders.append(list(cell_ids))
                all_pred_orders.append(predicted_ids)

        ktau = kendall_tau(all_gt_orders, all_pred_orders)
        return ktau


In [None]:
code_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
text_tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-uncased")

print("*"*80)
print("Reading data")

# Load and split data
info = pd.read_csv(CONFIG["train_orders_path"], index_col="id")
info["cell_order"] = info["cell_order"].apply(lambda x: x.split())
indeces = list(info.index)

rng = np.random.default_rng(CONFIG["random_seed"])
rng.shuffle(indeces)

train_border = int(CONFIG["train_size"] * len(indeces))
valid_border = int((CONFIG["train_size"] + CONFIG["valid_size"]) * len(indeces))

train_data = info.loc[indeces[:train_border]]
valid_data = info.loc[indeces[train_border:valid_border]]
test_data  = info.loc[indeces[valid_border:]]

if CONFIG["debug_mode"]:
    train_data = train_data.iloc[:CONFIG["train_sample_size"]]
    valid_data = valid_data.iloc[:CONFIG["valid_sample_size"]]
    test_data  = test_data.iloc[:CONFIG["test_sample_size"]]

print(f"Train size: {len(train_data)}")
print(f"Valid size: {len(valid_data)}")
print(f"Test size: {len(test_data)}")

train_dataset = ListWiseCellDataset(
    path=CONFIG["train_notebooks_path"],
    data=train_data,
    code_tokenizer=code_tokenizer,
    text_tokenizer=text_tokenizer,
    max_length=CONFIG["max_length"]
)

valid_dataset = ListWiseCellDataset(
    path=CONFIG["train_notebooks_path"],
    data=valid_data,
    code_tokenizer=code_tokenizer,
    text_tokenizer=text_tokenizer,
    max_length=CONFIG["max_length"]
)

test_dataset = ListWiseCellDataset(
    path=CONFIG["train_notebooks_path"],
    data=test_data,
    code_tokenizer=code_tokenizer,
    text_tokenizer=text_tokenizer,
    max_length=CONFIG["max_length"]
)

model = ListWiseOrderPredictionModel(
    hidden_dim=CONFIG["hidden_dim"], 
    dropout_prob=CONFIG["dropout_prob"]
)

savedir = prepare_folders(savedir_name=CONFIG["savedir_name"])
device = get_device()
print(f"Using device: {device}")



trainer = ListWiseTrainer(
    model=model,
    train_dataset=train_dataset,
    valid_dataset=valid_dataset,
    device=device,
    save_dir=savedir,
    lr=CONFIG["learning_rate"],
    epochs=CONFIG["epochs"]
)
train_losses, valid_kendalls = trainer.train()

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
plt.plot(valid_kendalls)
plt.title('Validation Kendall Tau')
plt.xlabel('Epoch')
plt.ylabel('Kendall Tau')

plt.tight_layout()
plt.savefig(os.path.join(savedir, 'training_curves.png'))
plt.show()

best_model_path = os.path.join(savedir, "best_model.pt")
best_weights = torch.load(best_model_path, map_location="cpu")
model.load_state_dict(best_weights)
model.to(device)

tester = ListWiseTester(model, device)
result = tester.test(test_dataset)
print("*"*80)
print(f"Test Kendall Tau score: {result:.4f}")
