In [None]:
# Imports
from importlib import reload
import os
import shutil
from pathlib import Path
import numpy as np
import torch
from datasets import load_from_disk
from transformers import AutoTokenizer
from transformers import AutoConfig
from transformers import OPTForSequenceClassification
from transformers import DataCollatorWithPadding
from transformers import TrainingArguments
from transformers import Trainer

import utils.preprocess_data
reload(utils.preprocess_data)
from utils.preprocess_data import preprocess_orig_data

#### Setup

In [None]:
# Set up DL framework and device
dl_framework = 'pt'

is_gpu = torch.cuda.is_available()
if is_gpu:
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Set up file and dir paths
clean = True # Whether to clean the raw data directory
data_type = 'applications' # 'is_experimental' or 'applications'

path_to_galactica_folder = Path(r'../galactica')

if data_type == 'applications':
    path_to_raw = Path(r'./data/raw_applications.json')
elif data_type == 'is_experimental':
    path_to_raw = Path(r'./data/raw_is_experimental.json')
path_to_data = Path(path_to_galactica_folder, path_to_raw)

path_to_state_dict = "./state_dict/model_state_dict.pt"
path_to_state_dict = Path(path_to_state_dict)
dir_to_state_dict = path_to_state_dict.parent
shutil.rmtree(dir_to_state_dict, ignore_errors=True)
os.mkdir(dir_to_state_dict)

# Set up checkpoint
checkpoint = "facebook/galactica-125m"
print("\nSet-up completed")

#### Preprocess the data to Datasets

In [None]:
# Preprocess the data to get raw data into /galactica/data
if data_type == 'applications':
    path_to_orig = Path(r'./data/orig_applications.json')
elif data_type == 'is_experimental':
    path_to_orig = Path(r'./data/orig_is_experimental.json')

preprocess_orig_data(str(path_to_galactica_folder), str(path_to_orig), clean=clean)

#### Load the data, model, tokenizer and tokenize

In [None]:
# Load the data
raw_dataset = load_from_disk(str(path_to_data))
print("\nDataset loaded: ", raw_dataset)

# Get the number of labels
num_labels = len(raw_dataset['train'].features['label'].names)
print("\nNumber of labels: ", num_labels)

# Load the Model
#TODO: understand how to properly instantiate the model when device.cuda == 'cuda'
model = OPTForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)
state_dict = model.state_dict()
torch.save(state_dict, str(path_to_state_dict))

if device.type == 'cuda':
    from accelerate import init_empty_weights, load_checkpoint_and_dispatch
    config = AutoConfig.from_pretrained(checkpoint, num_labels=num_labels)
    with init_empty_weights():
        model = OPTForSequenceClassification._from_config(config)
    model.tie_weights()
    no_split_module_classes = None #List of modules with any residual connection of some kind
    model = load_checkpoint_and_dispatch(model, "./model/model_state_dict.pt", device_map="auto", no_split_module_classes=no_split_module_classes)
print("\nModel instantiated")

# Freeze the model but the last layer
# TODO: how to programmatically get the name of the last layer
for param in model.named_parameters():
    if param[0] != 'score.weight':
        param[1].requires_grad = False

# TODO: This is probably wrong, check what is the correct max_length
max_length = model.config.word_embed_proj_dim
print("\nMax length = ", max_length)

# Load the Tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True) # use_fast is recommended to be False, set to True for testing purposes
# use_fast argument; check https://huggingface.co/docs/transformers/model_doc/opt#overview
id2label = {i: label for label, i in tokenizer.vocab.items()}
pad_token_id = model.config.pad_token_id
tokenizer.add_special_tokens({'pad_token': id2label[pad_token_id]})
print("\nTokenizer instantiated")

def tokenize_function(sequences):
    return tokenizer(sequences['text'], max_length=max_length, truncation=True)

#TODO: save the tokenized datasets to disk
tokenized_datasets = raw_dataset.map(tokenize_function , batched=True)
print("\nTokenized datasets: ", tokenized_datasets)

#### Define the training parameters

In [None]:
# Define the data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
print("\nDataCollator instantiated")

# Define the TrainingArguments and Trainer
training_args = TrainingArguments(
    output_dir = "test-trainer",
    overwrite_output_dir = True, # False
    num_train_epochs = 3, # 3
    per_device_train_batch_size = 1, # 8
    per_device_eval_batch_size = 1, # 8
    gradient_accumulation_steps = 1, # 1
    learning_rate = 5e-5, # 5e-5
    weight_decay = 0, # 0
    warmup_steps = 0, # 0
    evaluation_strategy = 'no', # 'no'
    )
print("\nTrainingArguments instantiated")

# Cleaning
import gc
gc.collect()

import torch
torch.cuda.empty_cache()
print("\nCleaning done")

# Define the Trainer
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)
print("\nTrainer instantiated")

#### Train

In [None]:
# Quick training
print("\nStarting training")
trainer.train()