In [1]:
import importlib
import ds
import augs
import labs

importlib.reload(ds)
importlib.reload(labs)
importlib.reload(augs)

from ds import *
from labs import *
from augs import *

In [2]:
pl.seed_everything(777)

Seed set to 777


777

# ConvNextV2
- https://huggingface.co/facebook/convnextv2-large-22k-384

In [3]:
model_name = 'facebook/convnextv2-large-22k-384'

# DEF. Dataset and DataModule 

In [4]:
def prepare_example(image_path, processor, transform):
    # load image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # load metas
    json_path = Path(image_path).with_suffix(".json")
    meta = load_json(json_path)

    augmented = transform(image=image)
    image = augmented['image']
    
    return processor(image, return_tensors="pt")

In [5]:
class D4Dataset(Dataset):
    def __init__(self, image_paths, targets, processor, transform=None):
        self.targets = targets
        self.processor = processor
        self.transform = transform
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        encoding = prepare_example(image_path, self.processor, self.transform)
        target = int(self.targets[os.path.basename(image_path)])

        return {
            "pixel_values": encoding["pixel_values"].squeeze(0),
            "labels": torch.tensor(target, dtype=torch.long)
        }

In [6]:
class D4DataModule(LightningDataModule):
    def __init__(
        self,
        train_paths,
        valid_paths,
        trial_paths,
        target_dict,
        processor,
        batch_size=24,
        num_workers=6,
    ):
        super().__init__()
        self.train_paths = train_paths
        self.valid_paths = valid_paths
        self.trial_paths = trial_paths
        self.targets = target_dict
        self.processor = processor
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transforms = Transforms(target_size=384)

    def setup(self, stage=None):
        if stage == "fit":
            self.train_ds = D4Dataset(self.train_paths, 
                                      self.targets, 
                                      self.processor,
                                      self.transforms.make(100))
            self.valid_ds = D4Dataset(self.valid_paths, 
                                      self.targets, 
                                      self.processor,
                                      self.transforms.make(100))
        if stage == "test" or stage is None:
            self.trial_ds = D4Dataset(self.trial_paths, self.targets, self.processor)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=default_data_collator
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=default_data_collator 
        )

    def test_dataloader(self):
        return DataLoader(
            self.trial_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=default_data_collator 
        )

# INIT. DM

In [7]:
image_paths = grep_files("/data/ephemeral/home/dataset/dtc/train", exts=['jpg'])
target_dict = load_csv_targets("/data/ephemeral/home/dataset/dtc/train.csv")
label_path = "/data/ephemeral/home/dataset/dtc/doc_classes.json"
label2id, id2label = make_doc_class_mapper(label_path)

0it [00:00, ?it/s]

In [8]:
from transformers import AutoImageProcessor
processor = AutoImageProcessor.from_pretrained('facebook/convnextv2-large-22k-384')

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [9]:
train_images, valid_images, trial_images = split_ds(image_paths,  train_ratio=0.90,  valid_ratio=0.10, trial_ratio=0, seed=777)

data_module = D4DataModule(
    train_paths=train_images,
    valid_paths=valid_images,
    trial_paths=trial_images,
    target_dict=target_dict,
    processor=processor,
    batch_size=16,
    num_workers=8
)

# DEF) Model

In [11]:
from transformers import ConvNextV2ForImageClassification

In [12]:
class CNN(pl.LightningModule):
    def __init__(self, label2id, id2label):
        super().__init__()
        num_classes = len(label2id)
        self.model = ConvNextV2ForImageClassification.from_pretrained('facebook/convnextv2-large-22k-384', num_labels=num_classes, ignore_mismatched_sizes=True)
        self.model.train()
        self.model.config.label2id = label2id
        self.model.config.id2label = id2label

        self.loss_fn = DiceLoss()
        self.dice_weight = 0.4

        metrics = {
            "accuracy": Accuracy(task="multiclass", num_classes=num_classes),
            "per-class-accuracy" : MulticlassAccuracy(num_classes=num_classes, average=None),
            "roc_auc": AUROC(task="multiclass", num_classes=num_classes),
            "precision": Precision(task="multiclass", num_classes=num_classes, average="macro"),
            "recall": Recall(task="multiclass", num_classes=num_classes, average="macro"),
            "F1": F1Score(task="multiclass", num_classes=num_classes, average="macro"),
        }

        self.train_metrics = MetricCollection(metrics, prefix="train_")
        self.valid_metrics = MetricCollection(metrics, prefix="valid_")

    def forward(self, pixel_values, labels=None):
        return self.model(pixel_values=pixel_values, labels=labels)
        
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5, weight_decay=1e-4)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.25, patience=3)

        scheduler_delay = ConstantLR(optimizer, factor=1.0, total_iters=30)
        scheduler_warmup = LinearLR(optimizer, start_factor=0.01, total_iters=10)
        total_training_steps = 1000 # 총 학습 스텝 수
        scheduler_main = CosineAnnealingLR(optimizer, T_max=total_training_steps - 40, eta_min=1e-6)
        final_scheduler = SequentialLR(
            optimizer,
            schedulers=[scheduler_delay, scheduler_warmup, scheduler_main],
            milestones=[40, 50]
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": final_scheduler,
                "interval": "step",
                "monitor": "valid_loss"
            }
        }
     
    def feed(self, batch):
        return self(batch["pixel_values"], batch["labels"])
        
    def training_step(self, batch, batch_idx):
        labels = batch["labels"]
        outputs = self.feed(batch)
    
        self.train_metrics.update(outputs.logits, labels)
        
        self.log("train_loss", outputs.loss, prog_bar=True)
        for name, metric in self.train_metrics.items():
            if name == 'train_per-class-accuracy':
                continue
            self.log(name, metric.compute(), prog_bar=True)

        # lr 기록
        optimizer = self.optimizers().optimizer
        current_lr = optimizer.param_groups[0]['lr']
        self.log('lr', current_lr)

        # custom loss
        loss = outputs.loss + self.loss_fn(outputs.logits, labels) * self.dice_weight
        
        return loss

    def validation_step(self, batch, batch_idx):
        labels = batch["labels"]
        outputs = self.feed(batch)

        self.valid_metrics.update(outputs.logits, labels)
        self.log("valid_loss", outputs.loss, prog_bar=True)
        for name, metric in self.valid_metrics.items():
            if name == 'valid_per-class-accuracy':
                continue
            self.log(name, metric.compute(), prog_bar=True)
            
        return outputs.loss
   
    def on_train_epoch_start(self):
        self.train_metrics.reset()
        
    def on_validation_epoch_start(self):
        self.valid_metrics.reset()

    def on_train_epoch_end(self):
        metrics = self.train_metrics.compute()
        for name, value in metrics.items():
            if name == 'train_per-class-accuracy':
                continue
            self.log(name, value)
                
        per_class_acc = metrics['train_per-class-accuracy']
        for i, acc in enumerate(per_class_acc):
            label_name = self.model.config.id2label[i]
            self.log(f'train_acc_class_{label_name}', acc)
    
    def on_validation_epoch_end(self):
        try:
            metrics = self.valid_metrics.compute()
            for name, value in metrics.items():
                if name == 'valid_per-class-accuracy':
                    continue
                self.log(name, value)
                
            per_class_acc = metrics['valid_per-class-accuracy']
            for i, acc in enumerate(per_class_acc):
                label_name = self.model.config.id2label[i]
                self.log(f'valid_acc_class_{label_name}', acc)
        except Exception as e:
            print(f"Metric compute error: {e}")

# Init Dashboard

In [13]:
exp_name = 'exp-convnext-large-384-easy-rotator-dice-loss-add'
wandb.init(project='docsy', name=exp_name)
wandb_logger = WandbLogger()

[34m[1mwandb[0m: Currently logged in as: [33mcatchy[0m ([33mcat2oon[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# RUN. Train

In [14]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torchmetrics import MetricCollection, Accuracy, F1Score, Precision, Recall, AUROC
from torchmetrics.classification import MulticlassAccuracy
from torch.optim.lr_scheduler import ReduceLROnPlateau, ConstantLR, LinearLR, CosineAnnealingLR, CosineAnnealingWarmRestarts, SequentialLR

In [15]:
early_stopping = EarlyStopping(monitor='valid_loss', patience=10, mode='min')
model_checkpoint = ModelCheckpoint(monitor="valid_loss", mode="min", save_top_k=2)

trainer = pl.Trainer(
    accelerator="gpu",
    precision="16-mixed",
    max_epochs=100,
    logger=wandb_logger,
    reload_dataloaders_every_n_epochs=1, 
    callbacks=[model_checkpoint]
    # callbacks=[model_checkpoint, early_stopping]
)

model = CNN(label2id, id2label)
trainer.fit(model, datamodule=data_module)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Some weights of ConvNextV2ForImageClassification were not initialized from the model checkpoint at facebook/convnextv2-large-22k-384 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([17]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 1536]) in the checkpoint and torch.Size([17, 1536]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/tor

Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [16]:
trainer.save_checkpoint(f"./{exp_name}-last_epoch.ckpt")

In [17]:
wandb.finish()

0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇████
lr,█▇▄▁▂▄▆▇▇█▅▄▃▆██▇▄▃▁▁▃▅██▆▄▄▃▂▂▁▁▁▄▇█▆▃▁
train_F1,▁▁▃▃▆▅▂▇▆▇▄▇▇▄█▅▇▇▄█▂▇▇▃▇▄▄█▄▃█▅▇█▇█▄███
train_acc_class_account_number,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc_class_application_for_payment_of_pregnancy_medical_expenses,▁▆▆▆▆▆▆▆▆████████████████▆████████████▆█
train_acc_class_car_dashboard,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc_class_confirmation_of_admission_and_discharge,▅▃▇▆▆▅▆▁▅▃▆▆█▆▆█▂▆▇▅▇▆▇▇▆▃▅██████▅▇▇▇▇▇█
train_acc_class_diagnosis,▅▁▅▅▅▁▁▁▁▅▁▅▁▅█▁▅▁▁▁▁▁▅▁▁▁▁▅▁▁▅▅▅▅▁▁▁▁▁▁
train_acc_class_driver_lisence,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc_class_medical_bill_receipts,▁████████████████████▁██▁███████████████

0,1
epoch,99.0
lr,0.0
train_F1,0.91176
train_acc_class_account_number,1.0
train_acc_class_application_for_payment_of_pregnancy_medical_expenses,0.83333
train_acc_class_car_dashboard,1.0
train_acc_class_confirmation_of_admission_and_discharge,0.875
train_acc_class_diagnosis,0.625
train_acc_class_driver_lisence,1.0
train_acc_class_medical_bill_receipts,1.0


In [18]:
raise Exception("STOP HERE")

Exception: STOP HERE

---