In [1]:
import importlib
import labs

importlib.reload(labs)
from labs import *

# INIT. Processor 

In [2]:
from transformers import LayoutLMv3Processor

In [3]:
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

# DEF. Dataset and DataModule 
- TODO) D1~3 동일 구조로 변환한 json만들고 이를 사용하기

In [4]:
class D4Dataset(Dataset):
    def __init__(self, image_paths, label_map, processor):
        self.label2id = label_map
        self.processor = processor
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        encoding = prepare_example(image_path, self.processor)
        label = self.label2id[encoding['doc_class_str']]

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "bbox": encoding["bbox"].squeeze(0),
            "pixel_values": encoding["pixel_values"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

In [5]:
class D4DataModule(LightningDataModule):
    def __init__(
        self,
        train_paths,
        valid_paths,
        trial_paths,
        label_map,
        processor,
        batch_size=32,
        num_workers=4,
    ):
        super().__init__()
        self.train_paths = train_paths
        self.valid_paths = valid_paths
        self.trial_paths = trial_paths
        self.label_map = label_map
        self.processor = processor
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        if stage == "fit":
            self.train_ds = D4Dataset(self.train_paths, self.label_map, self.processor)
            self.valid_ds = D4Dataset(self.valid_paths, self.label_map, self.processor)
        if stage == "test" or stage is None:
            self.trial_ds = D4Dataset(self.trial_paths, self.label_map, self.processor)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=default_data_collator
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=default_data_collator 
        )

    def test_dataloader(self):
        return DataLoader(
            self.trial_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=default_data_collator 
        )

# INIT. DM

In [6]:
image_paths = grep_files("/data/ephemeral/home/dataset/dataset-d3p/", exts=['jpg'])
label_path = "/data/ephemeral/home/dataset/dataset-d3p/doc_classes.json"
label2id, id2label = make_doc_class_mapper(label_path)

0it [00:00, ?it/s]

In [7]:
train_images, valid_images, trial_images = split_ds(image_paths)

data_module = D4DataModule(
    train_paths=train_images,
    valid_paths=valid_images,
    trial_paths=trial_images,
    label_map=label2id,
    processor=processor,
    batch_size=16,
    num_workers=8
)

In [8]:
# s0 = data_module.train_ds[0]
# print(s0.keys())
# print(s0['input_ids'].shape, s0['attention_mask'].shape, s0['bbox'].shape, s0['pixel_values'].shape, s0['labels'].shape)

# DEF) Model

In [9]:
from transformers import LayoutLMv3ForSequenceClassification as LyLmv3, LayoutLMv3Processor

In [10]:
class Lym(pl.LightningModule):
    def __init__(self, label2id, id2label):
        super().__init__()
        num_classes = len(label2id)
        self.model = LyLmv3.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_classes)
        self.model.config.label2id = label2id
        self.model.config.id2label = id2label

        metrics = {
            "accuracy": Accuracy(task="multiclass", num_classes=num_classes),
            # "top-3 accuracy" : MulticlassAccuracy(num_classes=10, top_k=3),
            "roc_auc": AUROC(task="multiclass", num_classes=num_classes),
            "precision": Precision(task="multiclass", num_classes=num_classes, average="macro"),
            "recall": Recall(task="multiclass", num_classes=num_classes, average="macro"),
            "F1": F1Score(task="multiclass", num_classes=num_classes, average="macro"),
        }

        self.train_metrics = MetricCollection(metrics, prefix="train_")
        self.valid_metrics = MetricCollection(metrics, prefix="valid_")

    def forward(self, input_ids, attention_mask, bbox, pixel_values, labels=None):
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            bbox=bbox,
            pixel_values=pixel_values,
            labels=labels
        )

    def feed(self, batch):
        return self(
            batch["input_ids"],
            batch["attention_mask"],
            batch["bbox"],
            batch["pixel_values"],
            batch["labels"]
        )

    def training_step(self, batch, batch_idx):
        labels = batch["labels"]
        outputs = self.feed(batch)
    
        self.train_metrics.update(outputs.logits, labels)
        
        self.log("train_loss", outputs.loss)
        for name, metric in self.train_metrics.items():
            self.log(name, metric.compute(), prog_bar=True)
        
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        labels = batch["labels"]
        outputs = self.feed(batch)

        self.valid_metrics.update(outputs.logits, labels)
        
        self.log("valid_loss", outputs.loss)
        for name, metric in self.valid_metrics.items():
            self.log(name, metric.compute(), prog_bar=True)
        return outputs.loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=1e-5)
        
    def on_train_epoch_start(self):
        self.train_metrics.reset()

    def on_train_epoch_end(self):
        metrics = self.train_metrics.compute()
        for name, value in metrics.items():
            self.log(name, value)

    def on_validation_epoch_start(self):
        self.valid_metrics.reset()
    
    def on_validation_epoch_end(self):
        try:
            metrics = self.valid_metrics.compute()
            for k, v in metrics.items():
                self.log(k, v)
        except Exception as e:
            print(f"Metric compute error: {e}")

# RUN. Train

In [11]:
early_stopping = EarlyStopping(monitor='valid_loss', patience=5, mode='min')
model_checkpoint = ModelCheckpoint(monitor="valid_loss", mode="min", save_top_k=3)

trainer = pl.Trainer(
    accelerator="gpu",
    precision=16,
    max_epochs=20,
    callbacks=[model_checkpoint, early_stopping]
)

model = Lym(label2id, id2label)
trainer.fit(model, datamodule=data_module)

/data/ephemeral/home/.pyenv/versions/catch/lib/python3.12/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Some weights of LayoutLMv3ForSequenceClassification were not initialized from the model checkpoint at microsoft/layoutlmv3-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [12]:
# 13'15 ~ 14:00 15epoch