In [2]:
# 2024-09-08
# Wonseok Hwang
# CC-BY-NC 4.0 International Licens

!pip install --quiet datasets
!pip install --quiet transformers
!pip install --quiet lightning

import datasets
import torch
from transformers import AutoModel, AutoTokenizer

dataset_card = "lbox/lbox_open"
task = "casename_classification"
data = datasets.load_dataset(dataset_card, task)
data_t = data["train"]
CASENAMES = tuple(sorted(set(data_t["casename"])))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading data:   0%|          | 0.00/5.66M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/739k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/743k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/890k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test2 split:   0%|          | 0/1294 [00:00<?, ? examples/s]

In [6]:
import pytorch_lightning as pl


class CasenameDataModule(pl.LightningDataModule):
  def __init__(self, casenames, tokenizer, data, batch_size=36, max_input_len=512):
    super().__init__()
    self.casenames = casenames
    self.tokenizer = tokenizer
    self.data = data
    self.batch_size = batch_size
    self.max_input_len = max_input_len


  def setup(self, stage):
    pass

  def train_dataloader(self):
    return torch.utils.data.DataLoader(
        self.data["train"],
        batch_size=self.batch_size,
        shuffle=True,
        collate_fn=self._collate_fn
    )

  def val_dataloader(self):
    return torch.utils.data.DataLoader(
        self.data["validation"],
        batch_size=self.batch_size * 2,
        shuffle=False,
        collate_fn=self._collate_fn
    )

  def test_dataloader(self):
    return torch.utils.data.DataLoader(
        self.data["test"],
        batch_size=self.batch_size * 2,
        shuffle=False,
        collate_fn=self._collate_fn
    )

  # _collate_fn : batch로 묶을 때 필요한 전처리를 수행하게 해주는 함수
  def _collate_fn(self, batch):
    xs = [x["facts"] for x in batch]
    ys = [x["casename"] for x in batch]
    ys = [CASENAMES.index(y) for y in ys]


    input_features = self.tokenizer(xs, padding=True, truncation=True, return_tensors="pt")
    input_ids = input_features["input_ids"]
    attention_mask = input_features["attention_mask"]
    label = torch.tensor(ys)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "label": label,
    }

In [7]:
import numpy as np
from functools import reduce
import random

class CasenameClassifier(pl.LightningModule):
  def __init__(self, casenames, backbone, learning_rate=0.0001):
    super().__init__()
    self.casenames = casenames
    self.backbone = backbone
    self.pooler = torch.nn.Linear(768, 100)
    self.learning_rate = learning_rate
    self.loss_fn = torch.nn.CrossEntropyLoss()

    self.val_outputs = []
    self.test_outputs = []

  def forward(self, batch):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]

    encoded = self.backbone(input_ids, attention_mask)
    last = encoded.last_hidden_state # [batch, 768]
    logit = self.pooler(last[:, 0]) # [batch, 100]

    return logit

  def training_step(self, batch, batch_idx):
    logit = self.forward(batch)

    label = batch["label"]
    loss = self.loss_fn(logit, label)
    return {"loss": loss}

  def validation_step(self, batch, batch_idx):
    out = self._evaluation_step(batch)
    self.val_outputs.append(out)

  def on_validation_epoch_end(self):
    outputs = self.val_outputs
    self._evaluation_epoch_end(outputs)
    self.val_outputs = []

  def test_step(self, batch, batch_idx):
    out = self._evaluation_step(batch)
    self.test_outputs.append(out)

  def on_test_epoch_end(self):
    outputs = self.test_outputs
    self._evaluation_epoch_end(outputs)
    self.test_outputs = []

  def _evaluation_step(self, batch):
    logit = self.forward(batch)
    label = batch["label"]
    loss = self.loss_fn(logit, label)
    gts = [self.casenames[label_id] for label_id in label] # groun truth
    pr_label_ids = logit.argmax(-1)
    prs = [self.casenames[label_id] for label_id in pr_label_ids] # predict

    return {
        "loss": loss.item(),
        "gts": gts,
        "prs": prs,
    }

  def _evaluation_epoch_end(self, outputs):
    ave_loss = np.mean([x["loss"] for x in outputs])
    gts = reduce(lambda x,y: x + y, [x["gts"] for x in outputs], [])
    prs = reduce(lambda x,y: x + y, [x["prs"] for x in outputs], [])
    acc = sum([gt==pr for gt, pr in zip(gts, prs)]) / len(gts)

    target_ids = random.sample(range(len(gts)), 5)
    # target_ids = [0,1,2,3,4]
    print("="*50)
    print(f"ave_loss: {ave_loss}")
    print(f"ACC: {acc}")
    print("GT" + "-"*40)
    for target_id in target_ids:
      print(f"GT: {gts[target_id]}\t\t\tPR: {prs[target_id]}")


  def configure_optimizers(self):
      grouped_params = [
          {
              "params": list(filter(lambda p: p.requires_grad, self.parameters())),
              "lr": self.learning_rate,
          },

      ]

      optimizer = torch.optim.AdamW(
          grouped_params,
          lr=self.learning_rate,
      )
      return {"optimizer": optimizer}




In [8]:
backbone_card = "distilbert-base-multilingual-cased"
backbone = AutoModel.from_pretrained(backbone_card)
tokenizer = AutoTokenizer.from_pretrained(backbone_card)

data_module = CasenameDataModule(CASENAMES, tokenizer, data, batch_size=24, max_input_len=256)
model = CasenameClassifier(CASENAMES, backbone, learning_rate=0.0001)

# trainer
n_gpus = torch.cuda.device_count()
trainer = pl.Trainer(
    max_epochs=3,
    devices=n_gpus,
    fast_dev_run=not True,
    limit_train_batches=0.2,
    limit_val_batches=0.2,
)



config.json:   0%|          | 0.00/466 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/542M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [9]:
n_gpus

1

In [10]:
trainer.fit(model, data_module)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type             | Params | Mode 
------------------------------------------------------
0 | backbone | DistilBertModel  | 134 M  | eval 
1 | pooler   | Linear           | 76.9 K | train
2 | loss_fn  | CrossEntropyLoss | 0      | train
------------------------------------------------------
134 M     Trainable params
0         Non-trainable params
134 M     Total params
539.244   Total estimated model params size (MB)
2         Modules in train mode
92        Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

ave_loss: 4.635679721832275
ACC: 0.0
GT----------------------------------------
GT: 공연음란			PR: 매매대금
GT: 건물명도(인도)			PR: 도로교통법위반(무면허운전), 도로교통법위반(음주운전)
GT: 건물명도(인도)			PR: 도로교통법위반(무면허운전), 도로교통법위반(음주운전)
GT: 공무집행방해, 업무방해			PR: 매매대금
GT: 공사대금			PR: 교통사고처리특례법위반(치상)


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

ave_loss: 3.6904362440109253
ACC: 0.11458333333333333
GT----------------------------------------
GT: 근저당권말소			PR: 매매대금반환
GT: 게임산업진흥에관한법률위반			PR: 도로교통법위반(사고후미조치), 특정범죄가중처벌등에관한법률위반(도주치상)
GT: 공무집행방해			PR: 도로교통법위반(사고후미조치), 특정범죄가중처벌등에관한법률위반(도주치상)
GT: 교통사고처리특례법위반(치상), 도로교통법위반(음주운전)			PR: 교통사고처리특례법위반(치상)
GT: 근로기준법위반			PR: 매매대금반환


Validation: |          | 0/? [00:00<?, ?it/s]

ave_loss: 2.409738838672638
ACC: 0.3645833333333333
GT----------------------------------------
GT: 교통사고처리특례법위반(치상)			PR: 교통사고처리특례법위반(치상)
GT: 건물명도(인도)			PR: 소유권말소등기
GT: 공무집행방해			PR: 공무집행방해
GT: 공무집행방해, 업무방해			PR: 공무집행방해, 업무방해
GT: 건물인도			PR: 소유권말소등기


Validation: |          | 0/? [00:00<?, ?it/s]

ave_loss: 1.7840320467948914
ACC: 0.5
GT----------------------------------------
GT: 공무집행방해, 업무방해			PR: 공무집행방해, 업무방해
GT: 근로기준법위반, 근로자퇴직급여보장법위반			PR: 근로기준법위반, 근로자퇴직급여보장법위반
GT: 교통사고처리특례법위반(치상)			PR: 교통사고처리특례법위반(치상)
GT: 강제추행			PR: 강제추행
GT: 근로기준법위반, 근로자퇴직급여보장법위반			PR: 근로기준법위반, 근로자퇴직급여보장법위반


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.
