In [1]:
%cd ../..

/home/shapkin/effective-inference


## Import libs

In [2]:
import os

import yaml
import h5py
import numpy as np
import torch
import torch.nn as nn
import json
import seaborn as sns
import matplotlib.pyplot as plt

from utils.dataset_cache import cache_embeddings, get_dataset_for_regression, build_dataset_from_cached, load_cached_dataset
from utils.dataset_cache import build_dict_dataset_from_cached
from utils.prepare_dataset import load_datasets, cut_datasets
from utils.config import ConfigWrapper
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
from typing import Tuple, List, Dict, Optional, Union
from numpy.random import shuffle
from sklearn.metrics import r2_score

from IPython.display import clear_output

In [3]:
from utils.attentions.bert.linear import BertWrapperLin, LinearClassifierBertAttention, LinearAttention
from utils.dataset_utils import get_dict_batch, prepare_batches
from utils.train_linear_utils import train_epoch, eval_epoch, plot_history

## Project configuration

In [37]:
config_path = 'config.yaml'

with open(config_path, "r") as f:
    config = ConfigWrapper(yaml.load(f, Loader=yaml.FullLoader))

In [35]:
def train_linear_model(X_train, X_test, y_train, y_test, config, save_pattern='', 
                       use_plots=False, save_final_results=False, 
                       verbose=False, use_pbars=False, save_model=False):
    
    add_ = 0 if len(X_train) % config.attention_config.train_batch_size == 0 else 1
    total_len = (len(X_train) // config.attention_config.train_batch_size) + add_
    
    model = LinearAttention(config.attention_config).to(config.general.device)

    for param_name, param in model.named_parameters():
        print(param_name, param)
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=1)
    
    # Learning rate scheduler
    scheduler = None #torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=total_len, epochs=config.general.num_epochs)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    train_log = []
    val_log = []
    for epoch in range(config.general.num_epochs):
        if use_plots:
            clear_output()
            
        train_loss, _ = train_epoch(model, optimizer, criterion, X_train, y_train, config, scheduler=scheduler, use_pbar=use_pbars)
        val_loss, val_preds = eval_epoch(model, criterion, X_test, y_test, config, use_pbar=use_pbars)
        train_log.extend(train_loss)
        steps = len(train_loss)
        val_log.append((steps * (epoch + 1), np.mean(val_loss)))
        
        if use_plots:
            print(f'{epoch} -- VAL R2 score:', r2_score(y_test, val_preds))
            plot_history(train_log, val_log)
        elif verbose:
            print(f'{epoch} -- Mean train loss:', np.mean(train_loss))
            print(f'{epoch} -- Mean val loss:', np.mean(val_loss))
            print(f'{epoch} -- VAL R2 score:', r2_score(y_test, val_preds))
            print()

        if epoch + 1 == config.general.num_epochs and save_final_results and save_pattern != '':
            if not os.path.exists(f'{config.data.data_path}/linear_models'):
                os.makedirs(f'{config.data.data_path}/linear_models')
            if not os.path.exists(f'{config.data.data_path}/linear_models/{save_pattern}'):
                os.makedirs(f'{config.data.data_path}/linear_models/{save_pattern}')
            with open(f'{config.data.data_path}/linear_models/{save_pattern}/preds.json', 'wb') as f:
                np.save(f, val_preds) # json.dump(val_preds, f)
            with open(f'{config.data.data_path}/linear_models/{save_pattern}/true.json', 'wb') as f:
                np.save(f, y_test) # json.dump(y_test, f)

    if save_model:
        if not os.path.exists(f'{config.data.data_path}/linear_models'):
            os.makedirs(f'{config.data.data_path}/linear_models')
        if not os.path.exists(f'{config.data.data_path}/linear_models/{save_pattern}'):
            os.makedirs(f'{config.data.data_path}/linear_models/{save_pattern}')
        model.to('cpu')
        torch.save(model.state_dict(), f'{config.data.data_path}/linear_models/{save_pattern}/model.pth')

    if epoch + 1 == config.general.num_epochs and not verbose and not use_plots:
        print(f'Final val loss:', np.mean(val_loss))
        print(f'Final val R2 score:', r2_score(y_test, val_preds))
    return model
        

In [6]:
tokenizer = AutoTokenizer.from_pretrained(config.model.model_name, max_length=config.general.max_len)
initial_model = AutoModel.from_pretrained(config.model.model_name).to(config.general.device)

In [7]:
train_datasets = load_datasets(config.data.train_datasets, config.data.cut_size)
train_datasets

{'imdb': DatasetDict({
     train: Dataset({
         features: ['text', 'label'],
         num_rows: 25000
     })
     test: Dataset({
         features: ['text', 'label'],
         num_rows: 25000
     })
     unsupervised: Dataset({
         features: ['text', 'label'],
         num_rows: 50000
     })
 })}

## Load data

In [38]:
if config.attention_config.split_heads or config.attention_config.model_for_each_head:
    pbar = tqdm(total=len(config.attention_config.layers_to_train) * len(config.attention_config.heads_to_train), position=0, leave=True)
    for layer_N in config.attention_config.layers_to_train:
        for head_N in config.attention_config.heads_to_train:
            print(f'Training {layer_N} layer, {head_N} head')
            X_train, y_train, X_test, y_test = build_dict_dataset_from_cached(config, train_datasets, layer=layer_N, heads=[head_N], 
                                                                      features=config.attention_config.features, 
                                                                      split_hidden=config.attention_config.split_heads_in_data)
            print('Train size:', len(X_train))
            print(X_train[10]['hidden_to'].shape)
            #train_linear_model(X_train, X_test, y_train, y_test, config, save_pattern=f'{config.data.model_save_pattern}_{layer_N}_{head_N}', 
            #                   use_plots=False, save_final_results=True, 
            #                   verbose=True, use_pbars=False, save_model=True)
            pbar.update(1)

else:
    pbar = tqdm(total=12, position=0, leave=True)
    for layer_N in range(12):
        X_train, y_train, X_test, y_test = build_dict_dataset_from_cached(config, train_datasets, layer=layer_N, heads=[0, 1, 2], 
                                                                      features=config.attention_config.features, 
                                                                      split_hidden=False)
        print('Train size:', len(X_train))
        train_linear_model(X_train, X_test, y_train, y_test, config, save_pattern=f'{config.data.model_save_pattern}_{layer_N}', 
                           use_plots=False, save_final_results=True, 
                           verbose=True, use_pbars=False, save_model=True)
        pbar.update(1)

  0%|          | 0/12 [00:00<?, ?it/s]

KeyboardInterrupt: 