In [None]:
import torch
import os
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from tqdm import tqdm

from torch.utils.data import DataLoader
from torch import nn
import torch
tqdm.pandas()

In [None]:
import typing as tp
import contextlib
import pathlib
import json

class Logger:
    def __init__(self, logs_path: tp.Union[str, os.PathLike]):
        self.path = pathlib.Path(logs_path)

        records = []
        for root, dirs, files in os.walk(self.path):
            for file in files:
                if file.lower().endswith('.json'):
                    uuid = os.path.splitext(file)[0]
                    with open(os.path.join(root, file), 'r') as f:
                        try:
                            logged_data = json.load(f)
                            records.append(
                                {
                                    'id': uuid,
                                    **logged_data
                                }
                            )
                        except json.JSONDecodeError:
                            pass
        if records:
            self.leaderboard = pd.DataFrame.from_records(records, index='id')
        else:
            self.leaderboard = pd.DataFrame(index=pd.Index([], name='id'))

        self._current_run = None

    class Run:

        def __init__(self, name, storage, path):
            self.name = name
            self._storage = storage
            self._path = path
            self._storage.append(pd.Series(name=name))

        def log(self, key, value):
            self._storage.loc[self.name, key] = value

        def log_values(self, log_values: tp.Dict[str, tp.Any]):
            for key, value in log_values.items():
                self.log(key, value)

        def save_logs(self):
            with open(self._path / f'{self.name}.json', 'w+') as f:
                json.dump(self._storage.loc[self.name].to_dict(), f)

        def log_artifact(self, fname: str, writer: tp.Callable):
            with open(self._path / fname, 'wb+') as f:
                writer(f)

    @contextlib.contextmanager
    def run(self, name: tp.Optional[str] = None):
        if name is None:
            name = str(uuid.uuid4())
        elif name in self.leaderboard.index:
            raise NameError("Run with given name already exists, name should be unique")
        else:
            name = name.replace(' ', '_')
        self._current_run = Logger.Run(name, self.leaderboard, self.path / name)
        os.makedirs(self.path / name, exist_ok=True)
        try:
            yield self._current_run
        finally:
            self._current_run.save_logs()
logger = Logger('./logs')

In [None]:
device = torch.device("cuda:4")

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
import transformers

In [None]:
model_name = "cointegrated/rubert-tiny2"

In [None]:
from transformers import BertTokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False)

In [None]:
class classificationDataset(Dataset):
    def __init__(self, data_path, tokenizer, MAX_LEN):
        # вбрасываю сюда всю предобработку
        super(classificationDataset).__init__()
        self.df = pd.read_csv(data_path)
        self.labels =self.df.new_target
        self.token = self.df.text.progress_apply(lambda x: tokenizer.encode_plus(str(x), 
                        add_special_tokens=True,
                        truncation=True,
                        padding='max_length',
                        max_length=MAX_LEN,
                        return_attention_mask=True,   
                        return_tensors='pt'    
                   ))
        
    def __getitem__(self, index):
        input_ids = torch.squeeze(self.token[index]['input_ids'])
        token_type_ids = torch.squeeze(self.token[index]['token_type_ids'])
        attention_mask = torch.squeeze(self.token[index]['attention_mask'])
        return input_ids, token_type_ids, attention_mask, self.labels[index]
    
    def __len__(self):
        return len(self.df)

In [None]:
train = classificationDataset("labeled_train.csv", tokenizer, 128)
val = classificationDataset("labeled_val.csv", tokenizer, 128)

In [None]:
batch_size = 32

dataloader_train = DataLoader(
    dataset=train,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

dataloader_val = DataLoader(
    dataset=val,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
)

In [None]:
from transformers import BertForSequenceClassification, AdamW, BertConfig, get_linear_schedule_with_warmup


model = BertForSequenceClassification.from_pretrained(
    model_name,   
    output_attentions = False,
    output_hidden_states = False
)

for param in model.bert.parameters():
    param.requires_grad = False


optimizer = AdamW(model.parameters(),
                  lr = 2e-5, 
                  eps = 1e-8 
                )


epochs = 4

total_steps = len(dataloader_train) * epochs


scheduler = transformers.get_linear_schedule_with_warmup(                
                optimizer = optimizer,
                num_warmup_steps = 0,
                num_training_steps = total_steps
)

# Running the model on GPU.
model.to(device)

In [None]:
import random
from torchmetrics import Accuracy, Precision, Recall
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

def flat_accuracy(logits, labels):
    return (logits == labels).sum() / len(labels)

def flat_recall(logits, labels):
    pred_flat = logits.argmax(axis=1)
    return Recall('binary', pred_flat, labels)

def flat_precision(logits, labels):
    pred_flat = logits.argmax(axis=1)
    return Precision(pred_flat, labels)


In [None]:
from collections import defaultdict

def train(model, opt, loader, scheduler):
    model.train()
    losses_tr = []
    for input_ids, attention_masks, token_type_ids, labels in tqdm(loader):
        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, 
                    token_type_ids=token_type_ids, 
                    attention_mask=attention_masks, 
                    labels=labels)
        
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        scheduler.step()
        losses_tr.append(loss.item())
    
    return model, optimizer, np.mean(losses_tr)


def val(model, loader, metric_names=None):
    model.eval()
    losses_val = []
    if metric_names is not None:
        metrics = defaultdict(list)
        
    
    with torch.no_grad():
        for input_ids, attention_masks, token_type_ids, labels in tqdm(loader):
            input_ids = input_ids.to(device)
            attention_masks = attention_masks.to(device)
            token_type_ids = token_type_ids.to(device)
            labels = labels.to(device)
            
            outputs = model(input_ids, 
                    token_type_ids=token_type_ids, 
                    attention_mask=attention_masks, 
                    labels=labels)
            
            labels = labels.to(device)
            losses_val.append(outputs.loss.item())

            if metric_names is not None:
                pred_flat = outputs.logits.argmax(axis=1)
                pred_flat.to(device)
                if 'accuracy' in metric_names:
                    accuracy = Accuracy(task="binary")
                    accuracy.to(device)
                    metrics["accuracy"].append(flat_accuracy(pred_flat, labels).item())
                if 'precision' in metric_names:
                    precision = Precision(task="binary")
                    precision.to(device)
                    metrics["precision"].append(precision(pred_flat, labels).item())
                if 'recall' in metric_names:
                    recall = Recall(task="binary")
                    recall.to(device)
                    metrics["recall"].append(recall(pred_flat, labels).item())


        if metric_names is not None:
            for name in metrics:
                metrics[name] = np.mean(metrics[name])
    
    return np.mean(losses_val), metrics if metric_names else None

In [None]:
metrics = {
    "Accuracy": Accuracy("binary"),
    "Precision": Precision("binary"),
    "Recall": Recall("binary")
}

In [None]:
metrics_names = {
        "accuracy": {"plot_id": 1},
        "precision": {"plot_id": 2},
        "recall": {"plot_id": 3}
    }

In [None]:
from IPython.display import clear_output
import warnings

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def learning_loop(
    model,
    optimizer,
    train_loader,
    val_loader,
    scheduler=None,
    min_lr=None,
    epochs=10,
    val_every=1,
    draw_every=1,
    separate_show=False,
    model_name=None,
    metrics_names=None,
):
     
    losses = {'train': [], 'val': []}
    lrs = []
    best_val_loss = np.Inf
    if metrics_names is not None:
        metrics = defaultdict(list)

    for epoch in range(1, epochs+1):
        if epoch == 2:
            for param in model.bert.parameters():
                param.requires_grad = True
        print(f'#{epoch}/{epochs}:')

        lrs.append(get_lr(optimizer))
        
        model, optimizer, loss = train(model, optimizer, train_loader, scheduler)
        losses['train'].append(loss)

        if not (epoch % val_every):
            loss, metrics_ = val(model, val_loader, metric_names=metrics_names)
            losses['val'].append(loss)
            if metrics_ is not None:
                for name, value in metrics_.items():
                    metrics[name].append(value)
            
            if scheduler:
                try:
                    scheduler.step()
                except:
                    scheduler.step(loss)

        if not (epoch % draw_every):
            clear_output(True)
            ww = 3 if separate_show else 2
            ww_metrics = 0
            if metrics_names is not None:
                plot_ids_ = [
                    [key, metric_meta.get("plot id", 1)]
                    for key, metric_meta
                    in metrics_names.items()
                ]
                ww_metrics = len(set(el[1] for el in plot_ids_))
                assert all(el[1] <= ww_metrics for el in plot_ids_)
                
                plot_ids = defaultdict(list)
                for el in plot_ids_:
                    plot_ids[el[1]].append(el[0])
                
            fig, ax = plt.subplots(1, ww + ww_metrics, figsize=(20, 10))
            fig.suptitle(f'#{epoch}/{epochs}:')

            plt.subplot(1, ww + ww_metrics, 1)
            plt.plot(losses['train'], 'r.-', label='train')
            if separate_show:
                plt.title('loss on train')
                plt.legend()
            plt.grid()

            if separate_show:
                plt.subplot(1, ww + ww_metrics, 2)
                plt.title('loss on validation')
                plt.grid()
            else:
                plt.title('losses')
            plt.plot(losses['val'], 'g.-', label='val')
            plt.legend()
            
            plt.subplot(1, ww + ww_metrics, ww)
            plt.title('learning rate')
            plt.plot(lrs, 'g.-', label='lr')
            plt.legend()
            
            plt.grid()
            
            if metrics_names is not None:
                for plot_id, keys in plot_ids.items():
                    for key in keys:
                        plt.subplot(1, ww + ww_metrics, ww + plot_id)
                        plt.title(f'additional metrics #{plot_id}')
                        for name in metrics:
                            if key in name:
                                plt.plot(metrics[name], '.-', label=name)
                        plt.legend()
                        plt.grid()
            
            plt.show()
        
        if min_lr and get_lr(optimizer) <= min_lr:
            print(f'Learning process ended with early stop after epoch {epoch}')
            break
    
    return model, optimizer, losses

In [None]:
run_name = model.name_or_path.replace("/", "-") + f"-{len(dataloader_train)}" 

model, optimizer, losses = learning_loop(
    model = model,
    optimizer = optimizer,
    train_loader = dataloader_train,
    val_loader = dataloader_val,
    scheduler = scheduler,
    epochs = epochs,
    min_lr = 1e-7,
    val_every = 1,
    draw_every = 1,
    separate_show = False,
    metrics_names = metrics_names
)

In [None]:
torch.save(model, "small_model.bin")

In [None]:
run_name = model.name_or_path.replace("/", "-") + f"-{len(dataloader_train)}" 

In [None]:
with logger.run(name=run_name) as run:
            run.log_values(result_metrics)

In [None]:
logger.leaderboard