- datamodule은 lightning transformers를 사용하자.
- 모델은 그냥 pytorch lightning의 예제를 활용하자. (https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/text-transformers.html)
    - 이게 좀 더 직관적이고 pytorch lightning을 이해하기 좋음.
- 기본 pytorch랑 같이 살펴보자..

---

## 1. From `nn.Module` to `pl.LightningModule`

### pytorch의 기본 모델링
- NN은 comprise of layers/modules that perform operations on data임.
    - torhc.nn은 NN을 조립하기 위한 building block을 제공함.
- 모든 pytorch의 모듈은 nn.Module을 subclassing하여 정의됨.
    - init에는 building block으로 사용할 layer들을 initialize함.
    - forward에는 operation을 정의함.
- forward에 정의한 값들은 그대로 callable하게 사용할 수 있음.

```python
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
        
```

### pl.LightningModule
- pl.LightningModule을 상속하되 `__init__`과 `forward`를 그대로 갖고 오자.

```python

from torch import nn
import pytorch_lightning as pl

class NeuralNetwork(pl.LightningModule):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
        
```
- `pl.LightningModule`은 위의 `nn.Module`의 기본 구조에 학습과 관련된 추가적인 기능들을 메서드로 추가했다.
    - training_step
    - validation_step
    - test_step
    - configure_optimizers
    - and more...    

---
pl.LightningModule의 주요 메서드들은 다양한 리턴 형태를 갖는다. 이를 이해하는 것이 pytorch lightning 코드를 이해하는데 도움이 될것이다.

### ```configure_optimizers```
- configure_optimizers 메서드에 대한 리턴은 총 **6가지 방식**이 가능함.
    - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
- 여기서는 Two List 방식의 return을 사용함.
    - 첫 번째 list는 optimizer에 대한 configure를 갖고 있는 dictionary임. 
    - 두 번째 list는 scheduler에 대한 configure를 갖고 있는 dictioanry다.
    
### ```training_step```
- loss tensor 리턴을 하거나 loss를 key 값으로 하는 dictioanry를 리턴함.
- https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#training-step


## 2. Feature-based Approach with Pretrained Language Model
- HuggingFace AutoTokenizer + AutoModl + AutoConfig 기반의 Pytorch Lightgning 모델 사용

### define config arguments

**data_args**

In [1]:
from omegaconf import OmegaConf
from lightning_transformers.core.nlp import HFTransformerDataConfig

# load data_args configure
args = OmegaConf.load('dm_config/ynat_base.yaml')
data_args = HFTransformerDataConfig(batch_size=args.batch_size)
data_args = OmegaConf.create(vars(data_args))
data_args = OmegaConf.create(data_args)
data_args = OmegaConf.merge(data_args, args)

  '"sox" backend is being deprecated. '


**model_args and training_args**
- huggingface Trainer's TrainingArguments and ModelArguments

In [2]:
from typing import Optional
from dataclasses import dataclass, field
from transformers import TrainingArguments

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=False,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )
               
model_args = ModelArguments(model_name_or_path='klue/roberta-small')
training_args = TrainingArguments(
    output_dir='ckpt/ynat',
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    do_predict=False,
    evaluation_strategy='steps',
    logging_strategy='steps',
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    seed=42,
    metric_for_best_model='macro-f1',
    greater_is_better=True,
    report_to="none"
)

**Load DataModules**

In [3]:
from transformers import AutoTokenizer
from src.datamodules.task.nlp import TextClassificationDataModule

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
dm = TextClassificationDataModule(tokenizer, data_args)

**Define PytorchLightning Model**

In [4]:
from typing import Dict
from datetime import datetime

import numpy as np
import torch
from torch import nn

import pytorch_lightning as pl
from datasets import load_metric
from transformers import (
    AutoConfig,
    AutoModel,
    AdamW,
    get_linear_schedule_with_warmup
)
from transformers.modeling_outputs import SequenceClassifierOutput

In [5]:
class FeatureBasedSequneceClassification(pl.LightningModule):
    """
    Inspired by BERT paper's feature-based approach.
    The API is built on top of AutoModel and AutoConfig, provided by HuggingFace.
    
    see: https://arxiv.org/pdf/1810.04805.pdf
    
    Args:
    
    
    
    """
    def __init__(
        self, 
        model_args, 
        training_args,
        id2label: Dict,
        task_name: str,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # init model
        self.config = AutoConfig.from_pretrained(
            self.hparams.model_args.model_name_or_path, 
            num_labels=len(self.hparams.id2label),
            id2label=self.hparams.id2label,
            label2id={l:i for i, l in self.hparams.id2label.items()},
            output_hidden_states=True # get all hidden states
        )
        self.plm = AutoModel.from_pretrained(
            self.hparams.model_args.model_name_or_path, 
            config=self.config,
            add_pooling_layer=False # drop the pooling layer
        )
        
        self.num_labels = self.config.num_labels
        for param in self.plm.parameters(): # freeze all pretrained layers.
            param.requires_grad = False
        
        self.half_num_encoder = len(self.plm.encoder.layer) // 2
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.half_num_encoder * self.config.hidden_size, self.num_labels) # concat the last 1/2 layers
        self.plm.init_weights()

        # init metric
        self.metric = load_metric('f1', self.hparams.task_name, experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"))

    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        outputs = self.plm(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        concatenated_hidden_states = torch.cat(outputs.hidden_states[-self.half_num_encoder:], -1)
        first_token_tensor = concatenated_hidden_states[:,0]
        logits = self.classifier(first_token_tensor)
        
        loss = None
        if labels is not None:            
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
    def _step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss
        logits = outputs.logits
        
        preds = logits.argmax(dim=-1)
        labels = batch['labels']
        
        return {
            "loss": loss,
            "y_true": labels,
            "y_pred": preds
        }
    
    def training_step(self, batch, batch_idx):
        return self._step(batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        return self._step(batch, batch_idx)

#     def training_step_end(self, batch_parts):
#         losses = torch.stack(batch_parts['loss']).mean()
#         self.log('tr_loss', losses, on_step=True, prog_bar=True)
#         return losses 
    
    def training_epoch_end(self, outputs):
        loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('tr_avg_loss', loss, on_epoch=True, prog_bar=True)
    
    def validation_epoch_end(self, outputs):
        y_true = torch.cat([x['y_true'] for x in outputs]).detach().cpu().numpy()
        y_pred = torch.cat([x['y_pred'] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x['loss'] for x in outputs]).mean()
        
        self.log('val_avg_loss', loss, on_epoch=True, prog_bar=True)
        self.log_dict(self.metric.compute(predictions=y_pred, references=y_true, average='macro'), on_epoch=True, prog_bar=True)
        return loss
    
    def setup(self, stage=None) -> None:
        if stage == 'fit':
            # Get dataloader by calling it - train_dataloader() is called after setup() by default
            train_loader = self.train_dataloader()

            # Calculate total steps
            tb_size = self.hparams.training_args.train_batch_size * max(1, self.trainer.gpus)
            ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs)
            self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size
    
    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.training_args.weight_decay,
            },
            {
                "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters, lr=self.hparams.training_args.learning_rate, eps=self.hparams.training_args.adam_epsilon
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.hparams.training_args.warmup_steps, num_training_steps=self.total_steps
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step', 'frequency': 1}
        return [optimizer], [scheduler]


In [6]:
from pytorch_lightning import seed_everything
seed_everything(training_args.seed)

Global seed set to 42


42

In [8]:
dm.setup(stage='fit')
model = FeatureBasedSequneceClassification(model_args, training_args, dm.id2label, data_args.finetuning_task)

Using custom data configuration default-483d06c09187902b
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-483d06c09187902b/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0)


HBox(children=(FloatProgress(value=0.0, max=46.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




Some weights of the model checkpoint at klue/roberta-small were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
from pytorch_lightning import Trainer

In [10]:
trainer = Trainer(max_epochs=3, gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [11]:
trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name       | Type         | Params
--------------------------------------------
0 | plm        | RobertaModel | 67.5 M
1 | dropout    | Dropout      | 0     
2 | classifier | Linear       | 16.1 K
--------------------------------------------
16.1 K    Trainable params
67.5 M    Non-trainable params
67.5 M    Total params
270.060   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Global seed set to 42


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




