In [None]:
import gc
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm.auto import tqdm
import logging
import preprocessing
import base_model
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)

In [None]:
TRAINING_MODE = True
TEST_CHUNK_SIZE = 100_000
FREEZE_LAYERS = True

In [None]:
if TRAINING_MODE:
    train = pd.read_csv('../data/train_data.csv')

In [None]:
#model_name = 'yangheng/RNA-RoBERTa-v0.1'
model_name = 'zhihan1996/DNABERT-2-117M'
model_name_end = model_name.split('/')[1]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=457*2)

In [None]:
preprocessor = preprocessing.Preprocessor()

In [None]:
if TRAINING_MODE:
    train_sets, preprocessing_config = preprocessor.prepare_xy_split(train, 
                              categorical=lambda x: tokenizer(x, truncation=True, padding='max_length', max_length=457, return_tensors='pt'), 
                              shuffle=True, validation_split=None, 
                              batch_size=16, filter_noise=True, dual_model=False, k_fold=5, 
                              structure=False, clip=True, weighted_loss=None, additive_weight=False)

In [None]:
if TRAINING_MODE:
    print(train_sets[0][1][0].dataset[0][0].shape, train_sets[0][1][0].dataset[0][1].shape, train_sets[0][1][0].dataset[0][2].shape)

In [None]:
if TRAINING_MODE and FREEZE_LAYERS:
    # Freeze all
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace last layer
    model.classifier.dense = nn.Linear(in_features=model.config.hidden_size, out_features=model.config.hidden_size)
    model.classifier.out_proj = nn.Linear(in_features=model.config.hidden_size, out_features=457*2)
    
    # Require last layer gradients
    for param in model.classifier.parameters():
        param.requires_grad = True

In [None]:
LR = 0.001
NUM_EPOCHS = 3
optimizer = AdamW(model.parameters(), lr=LR)
scheduler = CyclicLR(optimizer, base_lr=LR/10, max_lr=LR, cycle_momentum=False, mode='triangular2')
model = base_model.BaseModel(optimizer, model, f'PRETRAINED-{model_name_end}.pth', 
                            scheduler=scheduler, enable_wandb=True, pretrained=True)

In [None]:
if TRAINING_MODE:
    experiment_type, train_data_loader, validation_data_loader = train_sets[0]
    print(f'Model fit {experiment_type}')

In [None]:
if TRAINING_MODE:
    tags = ['Transferlearning']
    if FREEZE_LAYERS:
        tags += ['Frozen']
        
    model.fit(
        train_data_loader,
        validation_data_loader,
        experiment_type=experiment_type,
        epochs=NUM_EPOCHS,
        verbose=True,
        preprocessing_config=preprocessing_config,
        tags=tags)

In [None]:
if not TRAINING_MODE:
    test = pd.read_csv('../data/test_sequences.csv')
    #test = test.iloc[:10]

In [None]:
if not TRAINING_MODE:
    model.load_model()

In [None]:
if not TRAINING_MODE:
    final_outputs = pd.DataFrame()
    final_outputs.index.name = 'id'
    experiment_types = ['DMS_AND_2A3_MaP']
    
    for experiment_type in experiment_types:
        print(f'Model prediction {experiment_type}')
        
        new_init = True
        all_predictions = []
        test_size = test.shape[0]
        for start_chunk in tqdm(range(0, test_size, TEST_CHUNK_SIZE)):
            s_index, e_index = start_chunk, start_chunk + TEST_CHUNK_SIZE
            if e_index > test_size:
                e_index = test_size
                
            finish_wandb = False
            if e_index == test_size:
                finish_wandb = True
                
            test_set = preprocessor.prepare_prediction_dataset(test.iloc[s_index:e_index], 
                                batch_size=256,
                                categorical=lambda x: tokenizer(x, truncation=True, padding='max_length', 
                                                                max_length=457, return_tensors='pt'), 
                                structure=False, verbose=False)
        
            final_predictions = model.predict(test_set, single_model_mode=True, new_init=new_init, finish_wandb=finish_wandb)
            del test_set
            new_init = False

            final_predictions = final_predictions.cpu().numpy()
            all_predictions.append(final_predictions)
            del final_predictions
            gc.collect()
        
        final_predictions = np.vstack(all_predictions)

        final_outputs[f'reactivity_DMS_MaP'] = final_predictions[:,0]
        final_outputs[f'reactivity_2A3_MaP'] = final_predictions[:,1]
        del final_predictions
        del model
        
        gc.collect()
    final_outputs.clip(0.0, 1.0, inplace=True)
    final_outputs.to_csv(f'PRETRAINING-{model_name_end}.csv')

from importlib import reload
reload(base_model)