# Install all you need for train

In [None]:
!nvidia-smi

In [None]:
# Before installing those libraries, you need to install build-essential, libopenmpi-dev and pytorch with CUDA support.
%pip install transformers
%pip install mpi4py
%pip install transformers[deepspeed]
%pip install deepspeed
%pip install wandb

In [None]:
%%writefile train.py

import os
import pickle


"""
MASTER port should be open if train with ddp
RAnk - main gpu

"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '9994'
os.environ['RANK'] = "0"
os.environ['LOCAL_RANK'] = "0"# for ddp
os.environ['WORLD_SIZE'] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "true" #uncoment for large files

import pandas as pd
import torch
from torch.utils.data import Dataset, random_split
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM
from transformers import TextDataset,DataCollatorForLanguageModeling
from transformers import GPT2TokenizerFast
from transformers import GPT2Tokenizer, TrainingArguments, Trainer, GPT2LMHeadModel, PreTrainedTokenizerFast
from multiprocessing import Pool
from tqdm import tqdm
import gc
torch.manual_seed(42)
MODEL_DIR =  '/workspace/models/mailqa_large'

device = 'cuda:0'
backbone = 'sberbank-ai/rugpt3large_based_on_gpt2'

tokenizer = GPT2TokenizerFast.from_pretrained(backbone, use_fast=True)

train_path = '/workspace/qna.txt'

def tokenize(text):
    print(f'Tokenizing text length {len(text)}')
    return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))

# Custom dataset loader using multiprocessing to parallelize tokenization
class MailRuDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizerFast, file_path: str, block_size: int = 1024):
        if not os.path.exists('./cached_dataset'):
            self.samples = []

            print('Reading dataset file')
            with open(file_path, encoding='utf-8', errors='ignore') as data_file:
                data = data_file.read()

            print(f'Data size: {len(data)}')

            print('Chunking dataset file')
            data_chunks = [data[i:i+2097152] for i in tqdm(range(0, len(data), 2097152))]

            del data
            gc.collect()

            print(f'Chunk count: {len(data_chunks)}')

            print('Starting tokenization')
            with Pool(8) as p:
                tokenized_text = [token for tokens in p.map(tokenize, data_chunks) for token in tokens]
                p.close()
                p.join()

            del data_chunks
            gc.collect()

            print(f'Tokenized text size: {len(tokenized_text)}')

            print('Splitting by block size, ignoring last sample')
            self.samples = [tokenized_text[i:i+block_size] for i in range(0, len(tokenized_text) - block_size + 1, block_size)]

            print(f'Sample count: {len(self.samples)}')

            del tokenized_text
            gc.collect()

            pickle.dump(self.samples, open("./cached_dataset", "wb"))

            print('Dataset loaded and cached to disk')
        else:
            print('Loading cached dataset')
            self.samples = pickle.load(open('./cached_dataset', 'rb'))

    def __getitem__(self, idx):
        return torch.tensor(self.samples[idx])

    def __len__(self):
        return len(self.samples)


def load_dataset(train_path, tokenizer):
    train_dataset = MailRuDataset(
          tokenizer=tokenizer,
          file_path=train_path,
          block_size=1024)
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False,
    )

    return train_dataset, data_collator

if __name__ == "__main__":
    train_dataset, data_collator = load_dataset(train_path, tokenizer)
    
    model = GPT2LMHeadModel.from_pretrained(backbone).to(device)

    training_args = TrainingArguments(output_dir=MODEL_DIR,
                                      num_train_epochs=1, 
                                      logging_steps=50, 
                                      save_steps=2000,
                                      per_device_train_batch_size=1,
                                      per_device_eval_batch_size=1,
                                      warmup_steps=100,
                                      weight_decay=0.01, 
                                      fp16=True,
                                      #warmup_steps=10,
                                      #weight_decay=0.01,  
                                      #fp16=True, 
                                      #fp16_opt_level='O1', not useful beacuse deepspeed
                                      report_to="wandb",
                                      save_total_limit=5)
    trainer = Trainer(model=model, args=training_args, 
            data_collator=data_collator,
            train_dataset=train_dataset,
                      
    )
    trainer.train()

In [None]:
!python3 train.py