In [None]:
import numpy as np
import nltk
from tqdm.notebook import tqdm
from glob import glob
import fasttext
from navec import Navec

from transformers import AutoTokenizer, AutoModel, BertTokenizer, BertModel
from sklearn.model_selection import train_test_split, KFold
import torch
from torch.utils.data import Dataset, DataLoader
import torch.functional as F
from torch import nn
import torchmetrics
import pytorch_lightning as pl

from warnings import filterwarnings
filterwarnings("ignore")

## Data

In [None]:
from model import CustomDataset, FinalModel

In [None]:
# embeddings
navec_model = Navec.load("data/navec_hudlit_v1_12B_500K_300d_100q.tar")
fasttext_model = fasttext.load_model("data/cc.ru.300.bin")

# Load data
# train_data = glob("data/augmentations/train2/*.npy") + glob("data/augmentations/test_pseudo/*.npy")
train_data = glob("data/augmentations/train/*.npy")

In [None]:
kfold = KFold(n_splits=5, random_state=42, shuffle=True)
sent_size = 112
batch_size = 128
    
for idx, (train_path, val_path) in enumerate(kfold.split(train_data)):  
    
    # split
    train_files = [train_data[i] for i in range(len(train_data)) if i in train_path]
    val_files = [train_data[i] for i in range(len(train_data)) if i in val_path]
    
    # data
    dataset_train = CustomDataset(train_files, sent_size, True, navec_model, fasttext_model)
    dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    dataset_val = CustomDataset(val_files, sent_size, True, navec_model, fasttext_model)
    dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)
    
    # model
    params = {"lr": 0.00068437, "weight_decay": 0.052502, 
              "hidden_size": 303,  "bidirectional": False,
              "drop_lstm":0.0048089, "drop_linear":0.35191,
              "linear1_size":723, "linear2_size":494}
    model = FinalModel(**params)
    
    # utils
    lr_monitoring = pl.callbacks.LearningRateMonitor(logging_interval="epoch")
    early_stop_callback = pl.callbacks.early_stopping.EarlyStopping(monitor="val_f1", min_delta=0.0001,
                                                                    patience=3, verbose=False, mode="max")
    logger = pl.loggers.TensorBoardLogger(save_dir="logs", name="final_model", version=f"fold_{idx}")
    
    # train
    trainer = pl.Trainer(gpus=1, max_epochs=15, logger=logger,
                         callbacks=[lr_monitoring, early_stop_callback], weights_summary=None)
    trainer.fit(model, dataloader_train, dataloader_val)
    
    # save model
    trainer.save_checkpoint(f"data/models/final_model_{idx}.ckpt", weights_only=True)
    
    break