In [None]:
offline_packages = True # for packages not included in Kaggle env as of 14.04.2023
offline_huggingface = True

if offline_packages:
    !pip install faiss-gpu --no-index --find-links=file:///kaggle/input/faissgpu/
    !pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric --no-index --find-links=file:///kaggle/input/torch-geometric/
else:
    !pip install faiss-gpu torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric
if offline_huggingface:
    pre_trained_models_dir = r'/kaggle/input/sentence-transformers/minilm-l6-v2/'
else:
    pre_trained_models_dir = r'sentence-transformers/'
kaggle_dir = '/kaggle/'
%env TOKENIZERS_PARALLELISM=true

In [5]:
import os
import sys
import gc
import shutil
import time
import logging
import inspect
import psutil
import pickle
import numpy as np
import pandas as pd
import faiss
import torch
import torch.nn.functional as F
from torch import nn, Tensor, tensor
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import NeighborLoader, NeighborSampler
from torch_geometric.nn import SAGEConv
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, AutoConfig, PreTrainedTokenizerFast, BatchEncoding, default_data_collator
from sklearn.model_selection import GroupKFold
from pathlib import Path
from typing import List
from dataclasses import dataclass, field
from itertools import chain, combinations
from functools import wraps
from datetime import datetime
from collections import defaultdict

In [None]:
# Config
@dataclass    
class Config:
    # General
    seed: int =                                17 # seed for all random algorithms
    eps: float =                            1e-12 # clamp zeros to eps where infinities can arise
    device: str =                           "cuda" if torch.cuda.is_available() else "cpu"
    home_dir: str =                         Path(kaggle_dir)
    input_dir: str =                        home_dir/'input/learning-equality-curriculum-recommendations'
    output_dir: str =                       home_dir/'working/output'
    model_outputs_dir: str =                output_dir/'model_outputs'
    checkpoints_dir: str =                  output_dir/'checkpoints'
    tb_logs_dir: str =                      output_dir/'tb_logs'
    logfile: str =                          output_dir/'logfile.log'
    
    # Model architecture
    encoder_backbone: str =                  pre_trained_models_dir + 'all-MiniLM-L6-v2'
    neighborhood_sizes: List[int] =          field(default_factory=lambda: [6, 4, 2]) # no. neighboring topics to sample at each graph conv layer -- list length determines no. layers
    rerank_threshold: float =                0.55 # threshold for binarizing reranker classification scores
        
    # Topic/content representation
    max_seq_length: int =                      64 # truncation length for tokenized text
    use_topic_title: bool =                  True # use 'title' field in topic representation
    use_topic_descr: bool =                  True # use 'description' field in topic representation
    use_topic_level: bool =                  True # etc.
    use_content_title: bool =                True
    use_content_descr: bool =                True
    use_content_text: bool =                 True
    field_sep_token: str =                "[FLD]" # special token used in combining text fields
    tokenizer: PreTrainedTokenizerFast =   AutoTokenizer.from_pretrained(encoder_backbone, additional_special_tokens=[field_sep_token], use_fast=True)
        
    # Training
    max_topics: int or bool =               False # max no. topics to load (for unit testing) -- set to False to disable sampling limit.
    retriever_batch_size: int =               256 # no. topics in biencoder batch
    reranker_batch_size: int =                128 # no. pairs in cross encoder batch
    content_batch_size_ratio: int =             1 # no. contents per topic in batch (currently deprecated)
    grad_accumulation_steps: int =              8 # no. batches to average before taking optimization step -- set to >1 to simulate larger batch size
    retriever_epochs: int =                    20 # no. training epochs for retriever (stage 1)
    reranker_epochs: int =                     10 # no. training epochs for reranker (stage 2)
    k_folds: int =                              4 # no. folds for GroupKFold cross-validation
    max_grad_norm: float =                    1.0 # clip gradient norms -- set to 0 to disable clipping
    learning_rate: float =                   5e-5 # max learning rate
    lr_decay_factor: float =                  0.5 # decay factor used by lr scheduler
    lr_decay_patience: int =                    3 # no. epochs of no improvement before reducing lr
    source_fold: int =                         99 # dummy fold to assign to topics with 'category'=='source' (must be larger than 'k_folds')
    shuffle: bool =                          True # shuffle samples in dataloaders
    use_amp: bool =                          True # use automatic mixed precision (if GPU available)
    cast_dtype: torch.dtype =                torch.float16 if torch.cuda.is_available() else torch.bfloat16 # cast data type used with AMP
    gradient_checkpointing: bool =           True # use gradient checkpointing
        
    # Indexing embeddings
    top_k: int =                               50 # no. candidate content recommendations to output at stage 1
    index_nlinks: int =                        32 # no. neighbors in HNSW quantizer graph
    index_efConstruction: int =                64 # depth of HNSW exploration at construction time
    index_efSearch: int =                      32 # depth of HNSW exploration at search time
    index_nclusters: int =                   1024 # no. k-means clusters in IVF index
    index_nprobe: int =                       128 # no. IVF clusters to probe at search time
    
    def DEV_MODE(self):
        self.max_seq_length = 128
        self.max_topics = 3200
        self.topic_batch_size = 128
        self.content_batch_size_ratio = 1
        self.top_k = 5
        self.grad_accumulation_steps = 2
        self.retriever_epochs = 1
        self.reranker_epochs = 1
        self.k_folds = 2
        self.index_nlinks: int = 16
        self.index_efConstruction: int = 16
        self.index_efSearch: int = 8
        self.index_nclusters: int = 16
        self.index_nprobe: int = 16

# Topic & content class definitions from competition hosts
class Topic:
    def __init__(self, topic_id):
        self.id = topic_id
    @property
    def parent(self):
        parent_id = topics_df.loc[self.id].parent
        if pd.isna(parent_id):
            return None
        else:
            return Topic(parent_id)
    @property
    def ancestors(self):
        ancestors = []
        parent = self.parent
        while parent is not None:
            ancestors.append(parent)
            parent = parent.parent
        return ancestors
    @property
    def siblings(self):
        if not self.parent:
            return []
        else:
            return [topic for topic in self.parent.children if topic != self]
    def get_breadcrumbs(self, separator=" >> ", include_self=True, include_root=True):
        ancestors = self.ancestors
        if include_self:
            ancestors = [self] + ancestors
        if not include_root:
            ancestors = ancestors[:-1]
        return separator.join(reversed([a.title for a in ancestors]))
    @property
    def children(self):
        return [Topic(child_id) for child_id in topics_df[topics_df.parent == self.id].index]
    def subtree_markdown(self, depth=0):
        markdown = "  " * depth + "- " + self.title + "\n"
        for child in self.children:
            markdown += child.subtree_markdown(depth=depth + 1)
        for content in self.content:
            markdown += ("  " * (depth + 1) + "- " + "[" + content.kind.title() + "] " + content.title) + "\n"
        return markdown
    def __eq__(self, other):
        if not isinstance(other, Topic):
            return False
        return self.id == other.id
    def __getattr__(self, name):
        return topics_df.loc[self.id][name]
    def __str__(self):
        return self.title
    def __repr__(self):
        return f"<Topic(id={self.id}, title=\"{self.title}\")>"
    # NEW
    # Get all content items in the subtree of this topic
    @property
    def subtree_content(self):
        content = self.content
        for child in self.children:
            content += child.subtree_content
        return content
    # Get root topic of this topic
    @property
    def root(self):
        if self.parent is None:
            return self
        else:
            return self.ancestors[-1]
    # EDITED
    # Originally 'content' property (made an instance attribute), now with option to get content ids only
    def get_content(self, ids=False):
        if self.id in correlations_df.index:
            if ids: return correlations_df.loc[self.id].content_ids
            return [ContentItem(content_id) for content_id in correlations_df.loc[self.id].content_ids]
        else:
            return tuple([]) if self.has_content else []
    @property
    def content(self):
        return self.get_content()
    @property
    def content_ids(self):
        return self.get_content(ids=True)
    
class ContentItem:
    def __init__(self, content_id):
        self.id = content_id
    def __getattr__(self, name):
        return content_df.loc[self.id][name]
    def __str__(self):
        return self.title
    def __repr__(self):
        return f"<ContentItem(id={self.id}, title=\"{self.title}\")>"
    def __eq__(self, other):
        if not isinstance(other, ContentItem):
            return False
        return self.id == other.id
    def get_all_breadcrumbs(self, separator=" >> ", include_root=True):
        breadcrumbs = []
        for topic in self.topics:
            new_breadcrumb = topic.get_breadcrumbs(separator=separator, include_root=include_root)
            if new_breadcrumb:
                new_breadcrumb = new_breadcrumb + separator + self.title
            else:
                new_breadcrumb = self.title
            breadcrumbs.append(new_breadcrumb)
        return breadcrumbs
    # EDITED
    @property
    def topics(self):
        return [Topic(topic_id) for topic_id in correlations_df.mask(~correlations_df.applymap(lambda x: self.id in x)).dropna().index]

# Global scope helper functions
def set_seed():
    global g
    g_pt = torch.Generator()
    g_pt.manual_seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    g_np = np.random.default_rng(cfg.seed)
    g = {'pt': g_pt, 'np': g_np}
    
def setup_output(clear_model_outputs=False, clear_checkpoints=False, clear_working=False):
    # Clear model outputs directory
    if clear_model_outputs is True and cfg.model_outputs_dir.is_dir():
        shutil.rmtree(cfg.model_outputs_dir)
    cfg.model_outputs_dir.mkdir(parents=True, exist_ok=True)
    if clear_checkpoints and cfg.checkpoints_dir.is_dir():
        shutil.rmtree(cfg.checkpoints_dir)
    cfg.checkpoints_dir.mkdir(parents=True, exist_ok=True)
    # Clear working directory
    if clear_working:
        for path in Path('/kaggle/working').iterdir():
            if path.is_file():
                os.remove(path)
            elif path.is_dir():
                shutil.rmtree(path)

def setup_logger():
    # Init global logger
    global logger
    if cfg.logfile.is_file(): os.remove(cfg.logfile)
    cfg.logfile.parents[0].mkdir(parents=True, exist_ok=True)
    logger = logging.getLogger(__name__)
    logger.handlers.clear()
    logger.setLevel(logging.DEBUG)
    # Create handlers
    c_handler = logging.StreamHandler()
    f_handler = logging.FileHandler(cfg.logfile)
    c_handler.setLevel(logging.DEBUG)
    f_handler.setLevel(logging.DEBUG)
    # Format handlers
    c_format = logging.Formatter('%(levelname)s - %(message)s')
    f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    c_handler.setFormatter(c_format)
    f_handler.setFormatter(f_format)
    # Add handlers to logger
    logger.addHandler(c_handler)
    logger.addHandler(f_handler)

def print_log():
    with open(cfg.logfile) as f:
        for line in f.readlines():
            print(line)
        
def timeit(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        t0 = time.perf_counter()
        out = func(*args, **kwargs)
        t1 = time.perf_counter()
        logger.debug(f'{func.__name__} took {t1 - t0:.6f} s to complete.')
        return out
    return wrapper

def id_to_int(item_ids, item_type='topic'):
    if type(item_ids) == str:
        item_ids = [item_ids]
    if item_type == 'topic':
        return topics_df.loc[item_ids, 'num'].values
    elif item_type == 'content':  
        return content_df.loc[item_ids, 'num'].values
    else:
        raise Exception(f'item_type: {item_type} not implemented')
        
def int_to_id(item_nums, item_type='topic'):
    if type(item_nums) == int:
        item_nums = [item_nums]
    if item_type == 'topic':
        return topics_df.index[item_nums]
    elif item_type == 'content':  
        return content_df.index[item_nums]
    else:
        raise Exception(f'item_type: {item_type} not implemented')

def tensor_dict_to(tensor_dict, pin_memory=False, device=None):
    out = {}
    for k, t in tensor_dict.items():
        if pin_memory and (device is not None):
            out[k] = t.pin_memory().to(device, non_blocking=True)
        elif pin_memory:
            out[k] = t.pin_memory()
        elif device is not None:
            out[k] = t.to(device)
        else:
            out[k] = t
    return out
        
def get_ids(item_list):
    item_ids = [item.id for item in item_list]
    assert len(item_list) == len(item_ids), "Input and output lengths do not match"
    return item_ids
        
@timeit
def prepare_data(train=True):
    global topics_df, topic_encodings, content_df, content_encodings, correlations_df, sample_submission_df, tc_edge_index, tt_edge_index, neighbor_sampler, tc_graph
    
    # Load data
    logger.info('Loading data.')
    topics_df = pd.read_csv(cfg.input_dir/'topics.csv').rename(columns={'id': 'topic_id'}).set_index('topic_id').fillna({'title': '', 'description': ''})
    topics_df['level'] = topics_df.level.apply(lambda x: f'Level {x}')
    content_df = pd.read_csv(cfg.input_dir/'content.csv').rename(columns={'id': 'content_id'}).set_index('content_id').fillna('')
    sample_submission_df = pd.read_csv(cfg.input_dir/'sample_submission.csv')
    if train:
        correlations_df = pd.read_csv(cfg.input_dir/'correlations.csv').set_index('topic_id')
        correlations_df.content_ids = correlations_df.content_ids.str.split(' ')
    
        # Optional subsampling
        if cfg.max_topics: 
            topics_df = topics_df.sample(cfg.max_topics) # TODO: sample by channel to make compatible with building tt-graph (need all parents)
            correlations_df = correlations_df.loc[topics_df.loc[topics_df.has_content].index]
            content_ids = correlations_df.explode('content_ids').content_ids.unique()
            content_df = content_df.loc[content_ids]
        
        # Make GroupKFold
        logger.info('Making CV splits.')
        source_topics = topics_df.loc[topics_df.category == 'source'].index
        non_source_topics = topics_df.loc[topics_df.category != 'source'].index
        topics_df.loc[source_topics, 'fold'] = cfg.source_fold
        group_kfold = GroupKFold(cfg.k_folds)
        for fold, (_, val_inds) in enumerate(group_kfold.split(X=topics_df.loc[non_source_topics], groups=topics_df.loc[non_source_topics, 'channel'])):
            topics_df.loc[non_source_topics[val_inds], 'fold'] = fold
        del group_kfold
        gc.collect()
        
    # Make topic/content text representations
    def make_repr(df, use_title=True, use_descr=False, use_text=False, use_level=False):
        fields = []
        if use_title: fields.append('title')
        if use_descr: fields.append('description')
        if use_text: fields.append('text')
        if use_level: fields.append('level')
        text = [df[field].to_list() for field in fields]
        text = [f' {cfg.field_sep_token} '.join([f for f in t if f != '']) for t in zip(*text)]
        return text
    
    # Tokenize text (in chunks) and map to disk (int16 ok since vocab_size < 32768)
    def memmap_encodings(df, path, **make_repr_args):
        def encode_chunk(offset=0, chunk_size=16384):
            text = make_repr(df.iloc[offset:offset+chunk_size], **make_repr_args)
            encodings = cfg.tokenizer(text, padding='max_length', truncation=True, max_length=cfg.max_seq_length, return_tensors='np')
            return np.array(list(encodings.values()), dtype=np.int16)

        values = encode_chunk()
        while values.shape[1] < len(df):
            values = np.concatenate((values, encode_chunk(offset=values.shape[1])), axis=1)
        encodings = np.memmap(path, mode='w+', shape=values.shape, dtype=np.int16)
        encodings[:] = values[:]
        encodings.flush()
        encodings.setflags(write=False)
        return encodings
    logger.info(f'Encoding topics ({len(topics_df)}).')
    topic_encodings = memmap_encodings(topics_df, cfg.output_dir/'topic_enc.bin', use_title=cfg.use_topic_title, use_descr=cfg.use_topic_descr, use_level=cfg.use_topic_level)
    logger.info(f'Encoding contents ({len(content_df)}).')
    content_encodings = memmap_encodings(content_df, cfg.output_dir/'content_enc.bin', use_title=cfg.use_content_title, use_descr=cfg.use_content_descr, use_text=cfg.use_content_text)
    
    if train:
        # Keep training columns
        topics_df = topics_df.loc[:, ['fold', 'parent', 'language', 'has_content', 'title']]
        content_df = content_df.loc[:, ['kind', 'language', 'title']]
    else:
        # Keep inference columns
        topics_df = topics_df.loc[:, ['parent', 'language', 'title']]
        content_df = content_df.loc[:, ['kind', 'language', 'title']]
    
    # Build edge index for topic-content graph
    topics_df['num'] = range(len(topics_df))
    content_df['num'] = range(len(content_df))
    tc_edge_index = []
    tc_edge_index.append(id_to_int(correlations_df.explode('content_ids').index, item_type='topic'))
    tc_edge_index.append(id_to_int(correlations_df.explode('content_ids').content_ids, item_type='content'))
    tc_edge_index = tensor(np.vstack(tc_edge_index), dtype=torch.int)
    _, sorted_idx = tc_edge_index[0].sort()
    tc_edge_index = tc_edge_index.gather(1, sorted_idx.expand(tc_edge_index.size()))

    tc_graph = HeteroData()
    tc_graph['topic'].x = torch.tensor(topics_df.num.values, dtype=torch.long).view(-1, 1)
    tc_graph['content'].x = torch.tensor(content_df.num.values, dtype=torch.long).view(-1, 1)
    tc_graph['topic', 'content'].edge_index = tc_edge_index.type(torch.long)
    
    # Build edge index and neighbor sampler for topic-topic graph (global ok since CV split by 'channel', i.e. train and val tt-graphs are disjoint)
    topic_parents = topics_df.loc[~topics_df['parent'].isna(), 'parent']
    tt_edge_index = [id_to_int(topic_parents.index), id_to_int(topic_parents.values)]
    tt_edge_index = tensor(np.vstack(tt_edge_index), dtype=torch.int)
    _, sorted_idx = tt_edge_index[0].sort()
    tt_edge_index = tt_edge_index.gather(1, sorted_idx.expand(tt_edge_index.size()))
    
    undirected_tt_edge_index = torch.cat((tt_edge_index, tt_edge_index.flip(0)), 1).type(torch.long)
    neighbor_sampler = NeighborSampler(edge_index=undirected_tt_edge_index, sizes=cfg.neighborhood_sizes, shuffle=True, return_e_id=False)
    
    logger.info(f"Successfully loaded and processed data.")

# Main
class MeanPooling(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, token_embeddings, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=cfg.eps)

class SentenceEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        config = AutoConfig.from_pretrained(cfg.encoder_backbone)
        if cfg.gradient_checkpointing: config.gradient_checkpointing = True
        backbone = AutoModel.from_pretrained(cfg.encoder_backbone, config=config)
        backbone.resize_token_embeddings(len(cfg.tokenizer))
        self.config = config
        self.backbone = backbone
        self.pool = MeanPooling()
    def forward(self, encodings):
        embeddings = self.backbone(**encodings)
        embeddings = self.pool(embeddings.last_hidden_state, encodings.attention_mask)
        return embeddings

class SAGE(torch.nn.Module):
    def __init__(self, in_out_channels, hidden_channels, num_layers=3, aggr='mean'):
        super().__init__()
        self.num_layers = num_layers
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_out_channels, hidden_channels, aggr))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels, aggr))
        self.convs.append(SAGEConv(hidden_channels, in_out_channels, aggr))
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            if i != self.num_layers - 1:
                x = F.gelu(x)
                x = F.dropout(x, p=0.1, training=self.training)
        return x
    
class TopicEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.sentence_encoder = SentenceEncoder()
        D = self.sentence_encoder.config.hidden_size
        self.sage_conv = SAGE(D, hidden_channels=D, num_layers=len(cfg.neighborhood_sizes), aggr='mean')
    def forward(self, encodings, adjs):
        embeddings = self.sentence_encoder(encodings)
        embeddings = self.sage_conv(embeddings, adjs)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        return dict(topic_emb=embeddings)
    
class ContentEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.sentence_encoder = SentenceEncoder()
        D = self.sentence_encoder.config.hidden_size
        dense = nn.Linear(D, D, bias=True)
        nn.init.eye_(dense.weight)
        nn.init.zeros_(dense.bias)
        self.dense = dense
    def forward(self, encodings):
        embeddings = self.sentence_encoder(encodings)
        embeddings = self.dense(embeddings)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        return dict(content_emb=embeddings)

class BiEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.topic_encoder = TopicEncoder()
        self.content_encoder = ContentEncoder()
    def forward(self, topic_enc, adjs, content_enc):
        topic_emb = self.topic_encoder(topic_enc, adjs)['topic_emb']
        content_emb = self.content_encoder(content_enc)['content_emb']
        return dict(topic_emb=topic_emb, content_emb=content_emb)
    
class CrossEncoderClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        config = AutoConfig.from_pretrained(cfg.encoder_backbone)
        config.num_labels = 1
        if cfg.gradient_checkpointing: config.gradient_checkpointing = True
        backbone = AutoModelForSequenceClassification.from_pretrained(cfg.encoder_backbone, config=config)
        backbone.resize_token_embeddings(len(cfg.tokenizer))
        self.config = config
        self.backbone = backbone
    def forward(self, cross_enc):
        logits = self.backbone(**cross_enc).logits
        return dict(logits=logits.view(-1))

class RetrieverTrainingSet(Dataset):
    def __init__(self, topic_index=None):
        self.topic_index = topic_index if (topic_index is not None) else topics_df.num.values
    def __len__(self):
        return len(self.topic_index)
    def __getitem__(self, item_num):
        return self.topic_index[item_num]

class RerankerTrainingSet(Dataset):
    def __init__(self, pairs, labels):
        assert pairs is not None, "Cannot create reranker train set without an index of topic-content pairs."
        self.pairs = pairs
        self.labels = labels
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, item_num):
        return (self.pairs[item_num], self.labels[item_num])
    
class RetrieverTestSet(Dataset):
    def __init__(self, item_index):
        assert item_index is not None, "Cannot create retriever test set without an index of topics or contents."
        self.item_index = item_index
    def __len__(self):
        return len(self.item_index)
    def __getitem__(self, item_num):
        return self.item_index[item_num]
            
class RerankerTestSet(Dataset):
    def __init__(self, pairs):
        assert pairs is not None, "Cannot create reranker test set without an index of topic-content pairs."
        self.pairs = pairs
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, item_num):
        return self.pairs[item_num]

def get_pairs(content_matrix, topic_index=None):
    topic_index = topic_index if (topic_index is not None) else topics_df.num.values
    if not isinstance(topic_index, Tensor): topic_index = tensor(topic_index)
    if not isinstance(content_matrix, Tensor): content_matrix = tensor(content_matrix)
    topic_index, content_matrix = topic_index.type(torch.int32), content_matrix.type(torch.int32)
    pairs = torch.cat((topic_index.view(-1, 1, 1).expand(-1, content_matrix.size(1), 1), content_matrix.unsqueeze(-1)), dim=-1)
    pairs = pairs.reshape(content_matrix.numel(), 2)
    return pairs
    
def label_pairs(pairs):
    _, sorted_idx = pairs.T[0].sort()
    pairs = pairs[sorted_idx]
    def label_chunk(offset=0, chunk_size=8192):
        sub_edge_index = get_subgraph_edge_index(pairs[offset:offset+chunk_size].T[0].unique(), tc_edge_index) # search only in subspace of edges containing query topics
        return (pairs[offset:offset+chunk_size].unsqueeze(1) == sub_edge_index.T).all(-1).any(1).type(torch.int8)
    labels = label_chunk()
    while labels.size(0) < pairs.size(0):
        labels = torch.cat((labels, label_chunk(offset=labels.size(0))), axis=0)
    unsort_labels = torch.zeros(labels.size())
    unsort_labels[sorted_idx] = labels.type(torch.float32)
    return unsort_labels
    
def get_enc(item_nums, item_type='topic', out_type='tensor'):
    if item_type == 'topic':
        enc_vals = tensor(topic_encodings[:, item_nums, :], dtype=torch.int)
    elif item_type == 'content':  
        enc_vals = tensor(content_encodings[:, item_nums, :], dtype=torch.int)
    else:
        raise Exception(f'item_type: {item_type} not implemented')
    if out_type == 'dict':
        enc_keys = ['input_ids', 'token_type_ids', 'attention_mask']
        return BatchEncoding(data=dict(zip(enc_keys, enc_vals)))
    return enc_vals

def get_cross_enc(topic_nums, content_nums, out_type='tensor'):
    assert len(content_nums) == len(topic_nums), 'No. topics and no. contents must match for cross encoding of topic-content pairs'
    t_enc = get_enc(topic_nums, 'topic', out_type='tensor')
    c_enc = get_enc(content_nums, 'content', out_type='tensor')
    c_enc[1] = c_enc[2] # token type 1 for content
    cross_enc = torch.cat((t_enc, c_enc), dim=2)
    _, indices = cross_enc[2].sort(dim=1, descending=True, stable=True)
    cross_enc = cross_enc.gather(2, indices.expand(cross_enc.size()))
    if out_type == 'dict':
        enc_keys = ['input_ids', 'token_type_ids', 'attention_mask']
        return BatchEncoding(data=dict(zip(enc_keys, cross_enc)))
    return cross_enc
    
def get_subgraph_edge_index(sub_nodes, edge_index, node_type='source', return_mask=False):
    if node_type == 'source':
        node_type = 0
    elif node_type == 'target':
        node_type = 1
    else:
        raise Exception(f'node_type: {node_type} not implemented.')
    if not isinstance(sub_nodes, Tensor):
        sub_nodes = tensor(sub_nodes)
    idx_mask = (((edge_index[node_type].unsqueeze(-1) - sub_nodes) == 0).sum(-1) == 1)
    if return_mask:
        return edge_index[:, idx_mask], idx_mask
    return edge_index[:, idx_mask]

def sample_edges(edge_index):
    # Sample one random edge per source node
    idx = torch.randperm(edge_index[0].nelement())
    edge_perm = edge_index[:, idx]
    unique, inv_idx, counts = edge_perm[0].unique(sorted=True, return_inverse=True, return_counts=True)
    _, ind_sorted = inv_idx.sort(stable=True)
    cum_sum = counts.cumsum(0)
    cum_sum = torch.cat((tensor([0]), cum_sum[:-1]))
    return edge_perm[:, ind_sorted[cum_sum]]

def prepare_retriever_batch(topic_nums):
    if not isinstance(topic_nums, Tensor):
        topic_nums = tensor(topic_nums, dtype=torch.long)
    else:
        topic_nums = topic_nums.type(torch.long)

    # Make in-batch positive and negative labels
    sub_edge_index = tc_graph.subgraph({'topic': topic_nums,
                                        'content': tc_graph['content'].x.squeeze()})\
                                            ['topic', 'content'].edge_index
    sub_edge_index[0] = topic_nums.sort()[0][sub_edge_index[0]]
    # sub_edge_index2 = get_subgraph_edge_index(topic_nums, tc_edge_index) # edges /\ topic subset
    # print(torch.equal(sub_edge_index,sub_edge_index2))

    # Either:
    topics_w, content_nums = sample_edges(sub_edge_index).type(torch.long) # one edge per topic with content
    topics_wo = sorted(list(set(topic_nums.tolist()) - set(topics_w.tolist()))) # topics w/o content
    topic_nums = torch.cat((topics_w, tensor(topics_wo, dtype=torch.long))) # topics w/ content + topics w/o content
    # Or: (assumes all topics have content and performs content sampling randomly)
#     unique_contents = sub_edge_index[1].unique()
#     rand_nums = torch.randperm(unique_contents.size(0))[:int(topic_nums.size(0)*cfg.content_batch_size_ratio)] # sample contents randomly (can control total number)
#     content_nums = unique_contents[rand_nums]
    #

    sub_edge_index = tc_graph.subgraph({'topic': topic_nums,
                                        'content': content_nums})\
                                            ['topic', 'content'].edge_index
    sub_edge_index[0] = topic_nums.sort()[0][sub_edge_index[0]]
    sub_edge_index[1] = content_nums.sort()[0][sub_edge_index[1]]
    # sub_edge_index2 = get_subgraph_edge_index(content_nums, sub_edge_index2, 'target') # edges /\ content subset
    # print(torch.equal(sub_edge_index,sub_edge_index2))
    
    dummy_t_idx = -torch.ones(len(topics_df), dtype=torch.long)
    dummy_c_idx = -torch.ones(len(content_df), dtype=torch.long)
    dummy_t_idx[topic_nums.type(torch.long)] = torch.arange(0, end=topic_nums.size(0), dtype=torch.long)
    dummy_c_idx[content_nums.type(torch.long)] = torch.arange(0, end=content_nums.size(0), dtype=torch.long)
    sub_edge_index[0] = dummy_t_idx[sub_edge_index[0].type(torch.long)]
    sub_edge_index[1] = dummy_c_idx[sub_edge_index[1].type(torch.long)] # w/ dummy indices up to (topic_nums.size(0), content_nums.size(0)) s.t. sparse matrix is not of size(len(topics_df), len(content_df))
    labels = torch.sparse_coo_tensor(sub_edge_index, torch.ones(sub_edge_index.size(1)), (topic_nums.size(0), content_nums.size(0))).to_dense()
    
    _, n_id, adjs = neighbor_sampler.sample(topic_nums.type(torch.long)) # n_id: topic_nums (first) + neighbors at all depths required for graph conv (after)
    topic_enc = get_enc(n_id, item_type='topic', out_type='dict')
    content_enc = get_enc(content_nums, item_type='content', out_type='dict')
    
    topic_enc.data = tensor_dict_to(topic_enc, pin_memory=(cfg.device == 'cuda'), device=cfg.device)
    content_enc.data = tensor_dict_to(content_enc, pin_memory=(cfg.device == 'cuda'), device=cfg.device)     
    if cfg.device == 'cuda':
        labels = labels.to_dense().pin_memory().to('cuda', non_blocking=True)
    else:
        labels = labels.to(cfg.device) 
    adjs = [adj.to(cfg.device) for adj in adjs]
    return dict(inputs=(topic_enc, adjs, content_enc), labels=labels)

def prepare_reranker_batch(tuples):
    # input list of tuples [([t_num, c_num], label), (), ...]
    pairs, labels = zip(*tuples)
    pairs = torch.stack(pairs)
    cross_enc = get_cross_enc(*pairs.T, out_type='dict')
    cross_enc.data = tensor_dict_to(cross_enc, pin_memory=(cfg.device == 'cuda'), device=cfg.device)
    labels = torch.stack(labels).type(torch.float32)
    if cfg.device == 'cuda':
        labels = labels.pin_memory().to(cfg.device, non_blocking=True)
    else:
        labels = labels.to(cfg.device)
    return dict(inputs=(cross_enc,), labels=labels)

def prepare_retriever_test_topic_batch(topic_nums):
    if not isinstance(topic_nums, Tensor):
        topic_nums = tensor(topic_nums, dtype=torch.int32)
    topic_enc = get_enc(topic_nums, item_type='topic', out_type='dict')
    topic_enc.data = tensor_dict_to(topic_enc, pin_memory=(cfg.device == 'cuda'), device=cfg.device)
    return dict(inputs=(topic_enc,))

def prepare_retriever_test_content_batch(content_nums):
    if not isinstance(content_nums, Tensor):
        content_nums = tensor(content_nums, dtype=torch.int32)
    content_enc = get_enc(content_nums, item_type='content', out_type='dict')
    content_enc.data = tensor_dict_to(content_enc, pin_memory=(cfg.device == 'cuda'), device=cfg.device)
    return dict(inputs=(content_enc,))

def prepare_reranker_test_batch(pairs):
    pairs = torch.stack(pairs)
    cross_enc = get_cross_enc(*pairs.T, out_type='dict')
    cross_enc.data = tensor_dict_to(cross_enc, pin_memory=(cfg.device == 'cuda'), device=cfg.device)
    return dict(inputs=(cross_enc,))

@torch.no_grad()
def infer_pred(loader, model):
    N = len(loader.dataset)
    n_batches, loader = len(loader), iter(loader)
    pred = defaultdict(list)
    i = 1
    model.eval()
    while i <= n_batches:
        batch = next(loader)
        with torch.autocast(device_type=cfg.device, dtype=cfg.cast_dtype):
            outputs = model(*batch['inputs'])
        for k, v in outputs.items():
            pred[k].append(v)
        i += 1
    for k, v in pred.items():
        pred[k] = torch.cat(v).cpu().numpy()
    return pred

def similarity_nll(topic_emb, content_emb, labels):
    T, C, D = (topic_emb.size()[0], *content_emb.size())
    assert topic_emb.size()[1] == D, "Topic and content embedding dimensionalities do not match"
    similarity = torch.exp(topic_emb @ content_emb.T)
    likelihood = ((similarity * labels).sum(1) / similarity.sum(1)).clamp(cfg.eps)
    return -torch.log(likelihood).mean()

def logits_bce(logits, labels):
    return nn.BCEWithLogitsLoss(pos_weight=tensor(2))(logits, labels)

def retrieve_top_k_contents(topic_emb, content_emb):
    def index_vectors(vectors):
        d = vectors.shape[-1]
        quantizer = faiss.IndexHNSWFlat(d, cfg.index_nlinks, faiss.METRIC_INNER_PRODUCT)
        quantizer.hnsw.efConstruction = cfg.index_efConstruction
        quantizer.hnsw.efSearch = cfg.index_efSearch
        index = faiss.IndexIVFFlat(quantizer, d, cfg.index_nclusters, faiss.METRIC_INNER_PRODUCT)
        index.train(vectors)
        index.add(vectors)
        index.nprobe = cfg.index_nprobe
        return index
            
    index = index_vectors(content_emb)
    _, top_k_idx = index.search(topic_emb, cfg.top_k)
    return top_k_idx

def rerank_top_k_contents(logits):
    if not isinstance(logits, Tensor):
        logits = tensor(logits)
    reco_mask = (nn.Sigmoid()(logits) > cfg.rerank_threshold).cpu().numpy()
    return reco_mask

def reco_pairs_to_series(pairs):
    topic_ids, content_ids = int_to_id(pairs[:, 0], 'topic'), int_to_id(pairs[:, 1], 'content')
    recos = pd.Series(content_ids, index=topic_ids, name='content_ids').groupby(level=0).agg(lambda x: [idx for idx in x])        
    topics_wo = topics_df.index[~topics_df.index.isin(topic_ids)].values
    wo_recos = pd.DataFrame({'topic_id': topics_wo, 'content_ids': ''}).set_index('topic_id').squeeze().apply(lambda x: [])
    reco_series = pd.concat([recos, wo_recos], axis=0)
    return reco_series

def evaluate_recommendations(recommendations: pd.Series):
    T = len(recommendations)
    recommended_and_relevant = [
        set(Topic(recommendations.index[i]).content_ids).intersection(set(recommendations.iloc[i])) # relevant /\ recommended
        for i in range(T)
    ]
    recall = np.array([
        (float(len(recommended_and_relevant[i])) + cfg.eps)/(len(Topic(recommendations.index[i]).content_ids) + cfg.eps) # relevant /\ recommended / relevant
        for i in range(T)
    ])
    precision = np.array([
        float(len(recommended_and_relevant[i]))/(len(recommendations.iloc[i]) + cfg.eps) # relevant /\ recommended / recommended
        for i in range(T)
    ])
    f2score = 5 * precision * recall / (4 * precision + recall)
    return dict(recall=recall.mean(), precision=precision.mean(), f2score=f2score.mean())

class StagedModel:
    def __init__(self, stage):
        if stage == 1:
            self.name = 'retriever'
            self.model = BiEncoder().to(cfg.device)
            self.loss_fn = similarity_nll
        elif stage == 2: # Reranker
            self.name = 'reranker'
            self.model = CrossEncoderClassifier().to(cfg.device)
            self.loss_fn = logits_bce
        else:
            raise Exception(f'Stage {stage} not implemented.')
        self.optimizer = torch.optim.Adam(self.model.parameters(), cfg.learning_rate)
        self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp & (cfg.device=="cuda"))
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=cfg.lr_decay_factor, patience=cfg.lr_decay_patience)
        self.stage_num = stage
    def load_checkpoint(self, checkpoint):
        self.model.load_state_dict(checkpoint['model'])
        self.model.to(cfg.device)
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
    def load_model_outputs(self, path):
        with open(path, 'rb') as f:
            data = pickle.load(f)
            length = {k: len(data[k]) for k in data}
            while 1:
                try: next_batch = pickle.load(f)
                except EOFError: break
                for k in data.keys():
                    data[k] = np.concatenate((data[k], next_batch[k]), axis=0)
                    length[k] = len(data[k])
        if self.name == 'retriever':
            assert length['topic_emb'] == length['topic_ind'], f"Number of topic embeddings does not match size of topic index.\nLengths: {length}.\nData: {data}"
            assert length['content_emb'] == length['content_ind'], f"Number of content embeddings does not match size of content index.\nLengths: {length}.\nData: {data}"
        elif self.name == 'reranker':
            assert length['logits'] == length['topic_ind'] == length['content_ind'], f"Number of logits does not match size of topic-content pairs index.\nLengths: {length}.\nData: {data}"
        return data
    def save_model_outputs(self, path, **data):
        data_dump = {}
        for k, v in data.items():
            try:
                data_dump[k] = v.detach().cpu().numpy()
            except AttributeError:
                first = v[0]
                if self.name == 'retriever':
                    data_dump[k] = np.array(get_ids(v)) 
                elif self.name == 'reranker':
                    data_dump['topic_ind'], data_dump['content_ind'] = [np.array(get_ids(items)) for items in zip(*v)]
        with open(path, mode='ab') as f:
            pickle.dump(data_dump, f)
    def make_recommendations(self, model_outputs, pairs=None):
        if self.name == 'retriever':
            top_k_content_matrix = retrieve_top_k_contents(**model_outputs)
            return get_pairs(top_k_content_matrix)
        elif self.name == 'reranker':
            assert pairs is not None, "Missing candidate pairs to rerank."
            reco_mask = rerank_top_k_contents(**model_outputs)
            return pairs[reco_mask]
        
class StagedTrainer:
    def __init__(self, fold=None):
        self.fold = fold
        self.epoch = 0
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        # self.train_ids = topics_df.loc[topics_df.fold != fold].num.values
        self.train_ids = topics_df.loc[(topics_df.has_content==True) & (topics_df.fold != fold)].num.values # train only with topics which have content (more +ive samples)
        self.val_ids = topics_df.loc[topics_df.fold == fold].num.values
        self.loader_args = dict(
            shuffle=cfg.shuffle,
            generator=g['pt'],
            # num_workers=4,
        )
        self.history = dict(retriever=defaultdict(list), reranker=defaultdict(list))
        self.history_idx_offset = 0
        self.recos = defaultdict(None)
        self.best_model = defaultdict(lambda: float('inf'))
        self.tracked = 'val_loss'
        self.log_precision = 5
        self.keep_best_only = True
        self.stage_epochs = {1: cfg.retriever_epochs, 
                             2: cfg.reranker_epochs}
        setup_output(clear_model_outputs=True)
        
    def setup_stage(self, stage, from_ckpt=None):
        # Prepare data
        if stage == 1:
            self.train_set = RetrieverTrainingSet(self.train_ids)
            self.val_set = RetrieverTrainingSet(self.val_ids)
            self.loader_args.update(batch_size=cfg.retriever_batch_size, collate_fn=prepare_retriever_batch)
        elif stage == 2:
            # Get reco pairs from stage 1
            pairs = self.recos.get('stage_1_pairs')
            labels = self.recos.get('stage_1_labels')
            assert pairs is not None, "Attempting stage 2 with no results from stage 1. Run stage 1 first."
            _, train_mask = get_subgraph_edge_index(self.train_ids, pairs.T, return_mask=True)
            train_pairs, train_labels = pairs[train_mask], labels[train_mask]
            val_pairs, val_labels = pairs[~train_mask], labels[~train_mask]
            # Add any missing true pairs to train subset (s.t. no. positive samples independent of stage 1 recall)
            true_train_edge_index = get_subgraph_edge_index(self.train_ids, tc_edge_index)
            train_pairs = torch.cat((train_pairs, true_train_edge_index.T), dim=0).unique(dim=0).type(train_pairs.dtype)
            train_labels = torch.cat((train_labels, torch.ones(train_pairs.size(0) - train_labels.size(0))), dim=0).type(train_labels.dtype)
            # Create train and val sets
            self.train_set = RerankerTrainingSet(train_pairs, train_labels)
            self.val_set = RerankerTrainingSet(val_pairs, val_labels)
            self.loader_args.update(batch_size=cfg.reranker_batch_size, collate_fn=prepare_reranker_batch)  
        self.train_loader = DataLoader(self.train_set, **self.loader_args)
        self.val_loader = DataLoader(self.val_set, **self.loader_args)
#         self.recos.update(**{f'stage{stage}_train': None, f'stage{stage}_val': None})
        # Load model
        self.stage = StagedModel(stage)
        if from_ckpt is not None:
            checkpoint = torch.load(from_ckpt)
            [self.history[self.stage.name][m[5:]].append(checkpoint[m]) for m in checkpoint if m.startswith('curr_')]
            self.history_idx_offset = 1
            self.update_best_model()
            self.best_model[f'{self.stage.name}_ckpt_path'] = from_ckpt # override update method
            self.best_model[f'{self.stage.name}_prev_ckpt_path'] = from_ckpt # override update method
            best_tracked_val = round(self.best_model[f'{self.stage.name}_{self.tracked}'], self.log_precision)
            logger.info(f'Resuming from checkpoint: {from_ckpt} ({self.tracked}: {best_tracked_val})')
            self.make_recos()
        else:
            self.history_idx_offset = 0
        # Logs
        self.tb_writer = SummaryWriter(self.tblg_path)
        logger.info(f'Stage {stage} ({self.stage.name}) setup complete. Training size: {len(self.train_set)}. Validation size: {len(self.val_set)}. Batch size: {self.loader_args["batch_size"]}.')
    
    @property
    def mode(self):
        return 'train' if self.stage.model.training else 'val'
    
    @property
    def mout_path(self):
        return cfg.model_outputs_dir/f'{self.stage.name}_outs_{self.timestamp}_{self.epoch}_{self.mode}.pkl'
    
    @property
    def tblg_path(self):
        return cfg.tb_logs_dir/f'{self.stage.name}_trainer_fold{self.fold}_{self.timestamp}'
    
    def ckpt_path(self, stage=None, epoch=None):
        if stage is None:
            stage = self.stage.name
        if epoch is None:
            epoch = self.epoch
        return cfg.checkpoints_dir/f'{stage}_ckpt_{self.timestamp}_{epoch}.pt'
    
    # @timeit
    def run_epoch(self):
        # Initialize
        loader = self.train_loader if self.mode == 'train' else self.val_loader
        n_batches, loader = len(loader), iter(loader)
        i, batch = 1, next(loader)
        total_loss = 0.
        
        while i <= n_batches:
            # Forward-backward update with gradient accumulation
            for _ in range(cfg.grad_accumulation_steps):
                if i > n_batches: break # catch overflow in last accumulation round
                    
                # Forward pass
                with torch.autocast(device_type=cfg.device, dtype=cfg.cast_dtype):
                    outputs = self.stage.model(*batch['inputs'])
                    loss = self.stage.loss_fn(**outputs, labels=batch['labels'])
                    total_loss += loss.detach()

                # Prefetch next batch
                if i < n_batches: batch = next(loader)
            
                # Backward pass
                if self.mode == 'train':
                    self.stage.scaler.scale(loss).backward()
                
                i += 1
            
            # Optimizer step
            if self.mode == 'train':
                if cfg.max_grad_norm != 0:
                    self.stage.scaler.unscale_(self.stage.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.stage.model.parameters(), cfg.max_grad_norm)
                self.stage.scaler.step(self.stage.optimizer)
                self.stage.scaler.update()
                self.stage.optimizer.zero_grad(set_to_none=True)
            
        # Self and global log
        epoch_loss = total_loss.item()/n_batches
        self.history[self.stage.name][f'{self.mode}_loss'].append(epoch_loss)
        return epoch_loss
    
    def metrics(self, split=None):
        all_metrics = list(self.history[self.stage.name].keys())
        if split is None:
            return all_metrics
        if split in ('train', 'val'):
            return [m for m in all_metrics if f'{split}_' in m]
        if split == 'name':
            return [m.split('_')[-1] for m in all_metrics]
        else:
            return [m.split(split) for m in all_metrics]
    
    def get_metric(self, metric):
        return self.history[self.stage.name].get(metric)[self.epoch - 1 + self.history_idx_offset]
    
    def update_best_model(self):
        if self.get_metric(self.tracked) < self.best_model[f'{self.stage.name}_{self.tracked}']:
            self.best_model[f'{self.stage.name}_{self.tracked}'] = self.get_metric(self.tracked)
            self.best_model[f'{self.stage.name}_prev_ckpt_path'] = self.best_model[f'{self.stage.name}_ckpt_path']
            self.best_model[f'{self.stage.name}_ckpt_path'] = self.ckpt_path()
    
    def save_checkpoint(self):
        if self.keep_best_only:
            if self.best_model[f'{self.stage.name}_{self.tracked}'] != self.get_metric(self.tracked):
                return
            prev_best = self.best_model[f'{self.stage.name}_prev_ckpt_path']
            if isinstance(prev_best, Path) and prev_best.is_file():
                os.remove(prev_best)
        checkpoint = {
            'model': self.stage.model.state_dict(),
            'optimizer': self.stage.optimizer.state_dict(),
            'scheduler': self.stage.scheduler.state_dict(),
            'fold': self.fold,
            'stage': self.stage.name,
            'epoch': self.epoch,
            f'best_{self.tracked}': self.best_model.get(f'{self.stage.name}_{self.tracked}'),
            **{f'curr_{m}': self.get_metric(m) for m in self.metrics()},
            'config': cfg,
        }
        torch.save(checkpoint, self.ckpt_path())
        logger.info(f'Checkpoint saved to: {self.ckpt_path()}')
        return
    
    def write_tb_log(self):
        metric_logs = [
            (f"Epochs/{self.stage.name}_{m}", {'train': self.get_metric(f'train_{m}'),
                                              'val':   self.get_metric(f'val_{m}')}) 
                for m in self.metrics('name')
        ]
        for named_dict in metric_logs: self.tb_writer.add_scalars(*named_dict, self.epoch)
        self.tb_writer.flush()
    
    def make_recos(self):
        # Make recommendations based on best model predictions
        best_model_ckpt = torch.load(self.best_model[f'{self.stage.name}_ckpt_path'])
        self.stage.load_checkpoint(best_model_ckpt)
        if self.stage.name == 'retriever':
            pairs = None
            test_set = RetrieverTrainingSet(topics_df.num)
        elif self.stage.name == 'reranker':
            pairs = self.recos['stage_1_pairs']
            labels = self.recos['stage_1_labels']
            test_set = RerankerTrainingSet(pairs, labels)
        test_loader_args = self.loader_args
        test_loader_args.update(shuffle=False)
        test_loader = DataLoader(test_set, **test_loader_args)
        
        logger.info(f'Inferring {self.stage.name} predictions for all {len(test_set)} samples.')
        model_outputs = infer_pred(test_loader, self.stage.model)
        reco_pairs = self.stage.make_recommendations(model_outputs, pairs)
        self.recos[f'stage_{self.stage.stage_num}_pairs'] = reco_pairs
        self.recos[f'stage_{self.stage.stage_num}_labels'] = label_pairs(reco_pairs)
    def eval_recos(self):
        # Evaluate recos
        reco_pairs = self.recos[f'stage_{self.stage.stage_num}_pairs']
        recos = reco_pairs_to_series(reco_pairs)
        eval_metrics = evaluate_recommendations(recos) # double (and possibly slower//TESTED: no bottleneck here) computation of precision (can get from label_pairs())
        to_logger = []
        for m, v in eval_metrics.items():
            self.best_model[f'{self.stage.name}_{m}'] = v
            to_logger.append(f'{m}: {round(v, self.log_precision)}.')
        logger.info(' '.join(to_logger))
    
    def train(self):
        logger.info(f'Begin training. Fold: {self.fold}. Top-k: {cfg.top_k}. Gradient checkpointing: {cfg.gradient_checkpointing}. Automatic mixed precision: {cfg.use_amp}. GPU available: {cfg.device=="cuda"}.')       
        for stage, n_epochs in self.stage_epochs.items():
            if n_epochs == 0: continue
            # Setup stage (if not already done manually)
            try:
                assert self.stage.stage_num == stage
            except:
                self.setup_stage(stage)
            # Epoch loop
            for epoch in np.arange(1, 1 + n_epochs):
                self.epoch = epoch
                # Train
                self.stage.model.train(True)
                train_loss = self.run_epoch()
                # Validate
                self.stage.model.train(False)
                with torch.no_grad():
                    val_loss = self.run_epoch()
                logger.info(f'Epoch {self.epoch} train loss: {round(train_loss, self.log_precision)}, val loss: {round(val_loss, self.log_precision)}')
                self.update_best_model()
                # TB log
                self.write_tb_log()
                # Save checkpoint
                self.save_checkpoint()
                # Scheduler step
                self.stage.scheduler.step(self.get_metric(self.tracked))
                # Clear memory
                if cfg.device=='cuda': torch.cuda.empty_cache()
            # Make recommendations for next stage and evaluate current results
            self.make_recos()
            self.eval_recos()
        self.epoch = 0
            
if __name__ == '__main__':
    # Init config
    cfg = Config()
    set_seed()
    setup_logger()
    profiler = False
    tensorboard = False
    
    # Quick overrides
#     cfg.DEV_MODE()
#     cfg.max_topics = False
#     prepare_data(train=True)
    
    # Sessions
    training = False
    inference = True
    
    logger.info('START MAIN')
        
    # Training
    if training:
        setup_output(
            clear_model_outputs=False,
            clear_checkpoints=False,
            clear_working=False, # NOTE: moved output/ into working/
        )
        prepare_data(train=True)
        if tensorboard:
            %reload_ext tensorboard
            %tensorboard --logdir cfg.tb_logs_dir
        START_TIME = datetime.now()
        logger.info(f'START TRAINING - {START_TIME}')
        # K-folds loop
        for fold in range(cfg.k_folds):
            fold = 1 # manual fold
            trainer = StagedTrainer(fold)
            # Trainer overrides for resuming training
#             trainer.setup_stage(1, cfg.checkpoints_dir/'retriever_ckpt_20230314_192550_1.pt')
#             trainer.setup_stage(2)
            trainer.stage_epochs = {
                1: 0, # no. epochs for stage 1
                2: 1, # no. epochs for stage 2 
            }
            if profiler:
                %reload_ext line_profiler
                %lprun -f prepare_retriever_batch -f prepare_reranker_batch -f trainer.run_epoch -f trainer.make_eval_recos -f trainer.train trainer.train()
            else:
                trainer.train()
            best_model_str = "\n".join([f"    {k}: {v}" for k,v in trainer.best_model.items()])
            logger.info(f"################ Fold {fold} best performance: ################\n{best_model_str}")
            break # one fold only
        logger.info('Checkpoints stored:\n' + '\n'.join(['    ' + p for p in sorted(os.listdir(cfg.checkpoints_dir))]))
        
        END_TIME = datetime.now()
        logger.info(f'END TRAINING  - {END_TIME} (+{END_TIME - START_TIME})')
    
    # Inference
    if inference:
        # Setup
        setup_output()
        prepare_data(train=False)
        retriever_ckpt = torch.load(Path(r'/kaggle/input/lecr-tuned-models/...BEST_STAGE_2_CKPT...'))
        reranker_ckpt = torch.load(Path(r'/kaggle/input/lecr-tuned-models/...BEST_STAGE_2_CKPT...'))
        
        # Stage 1
        retriever_model = BiEncoder()
#         retriever_model.load_state_dict(retriever_ckpt['model'])
        retriever_model.to(cfg.device)

        topic_set = RetrieverTestSet(topics_df.num.values)
        content_set = RetrieverTestSet(content_df.num.values)
        topic_loader = DataLoader(topic_set, shuffle=False, batch_size=cfg.retriever_batch_size, collate_fn=prepare_retriever_test_topic_batch)
        content_loader = DataLoader(content_set, shuffle=False, batch_size=cfg.retriever_batch_size, collate_fn=prepare_retriever_test_content_batch)

        topic_emb = infer_pred(topic_loader, retriever_model.topic_encoder)['topic_emb']
        content_emb = infer_pred(content_loader, retriever_model.content_encoder)['content_emb']
        top_k_content_matrix = retrieve_top_k_contents(topic_emb, content_emb)
        stage_1_pairs = get_pairs(top_k_content_matrix, topic_set.item_index)

        # Stage 2
        reranker_model = CrossEncoderClassifier()
#         reranker_model.load_state_dict(reranker_ckpt['model'])
        reranker_model.to(cfg.device)

        pair_set = RerankerTestSet(stage_1_pairs)
        pair_loader = DataLoader(topic_set, shuffle=False, batch_size=cfg.reranker_batch_size, collate_fn=prepare_reranker_test_batch)

        logits = infer_pred(pair_loader, reranker_model)['logits']
        reco_mask = rerank_top_k_contents(logits)
        stage_2_pairs = stage_1_pairs[reco_mask]

        # Final output
        recos = reco_pairs_to_series(stage_2_pairs).apply(lambda x: ' '.join(x))
        recos = pd.DataFrame({'topic_id': recos.index, 'content_ids': recos.values})
        recos = recos.loc[recos.topic_id.isin(sample_submission_df.topic_id)].reset_index(drop=True)
        if 1: # Submission file
            recos.to_csv(output_dir/'submission.csv', index=False)
        else:
            recos
    
    logger.info('END MAIN')

In [None]:
# !zip -r checkpoints.zip cfg.checkpoints_dir