In [None]:
import os
from pathlib import Path
import zipfile

zip_path = Path(os.getcwd(), 'drive', 'MyDrive', 'Colab Notebooks', 'OHT', '4_classes.zip')
data_path = Path(os.getcwd(), 'data')
data_path.mkdir(parents=True, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  zip_ref.extractall(data_path)

In [None]:
!pip install lightning
!pip install torch
!pip install transformers

Collecting lightning
  Downloading lightning-2.1.0-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)
  Downloading lightning_utilities-0.9.0-py3-none-any.whl (23 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.1.0-py3-none-any.whl (774 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m774.6/774.6 kB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: lightning-utilities, torchmetrics, pytorch-lightning, lightning
Successfully installed lightning-2.1.0 lightning-utilities-0.9.0 pytorch-lightning-2.1.0 torchmetrics-1.2.0
Collecti

In [None]:
from typing import Union
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import os
from PIL import Image
import lightning.pytorch as pl

class OHTDataset(Dataset):
    def __init__(self, root_dir, processor):
        super().__init__()
        self.dataset = []
        self.processor = processor
        classes = sorted(os.listdir(path=root_dir))
        self.config = {
            'id2label': {k:v for k, v in enumerate(classes)},
            'label2id': {k:v for v, k in enumerate(classes)}
        }
        for n_class in classes:
            for image_path in list(Path(root_dir, n_class).glob("*.jpg")):
                label = self.config['label2id'][n_class]
                self.dataset.append((image_path, label))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        image_path, label = self.dataset[index]
        image = Image.open(image_path).convert("RGB")
        inputs = processor(images=image, return_tensors="pt")
        inputs['pixel_values'] = inputs['pixel_values'].squeeze()
        return inputs, torch.tensor(label)

class OHTDataModule(pl.LightningDataModule):
    def __init__(self, root_dir, batch_size, processor, num_workers):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.processor = processor

    def prepare_data(self):
        self.train_dir = Path(self.root_dir, 'train')
        self.val_dir = Path(self.root_dir, 'val')
        # self.test_dir = Path(self.root_dir, 'test')

    def setup(self, stage: str):
        self.train_ds = OHTDataset(root_dir=self.train_dir, processor=self.processor)
        self.val_ds = OHTDataset(root_dir=self.val_dir, processor=self.processor)

    def train_dataloader(self):
        return  DataLoader(dataset=self.train_ds,
                           batch_size=self.batch_size,
                           shuffle=True,
                           num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_ds,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=self.num_workers)


    # def test_dataloader(self):
    #     return DataLoader(dataset=self.val_ds,
    #                       batch_size=self.batch_size,
    #                       shuffle=False,
    #                       num_workers=self.num_workers)

In [None]:
import lightning.pytorch as pl
from transformers import AutoModelForImageClassification
from torch.optim import Adam
import torch.nn.functional as F
import torch.nn as nn
import torchmetrics
from torchmetrics import Metric


class DiTModel(pl.LightningModule):
    def __init__(self, n_classes):
        super().__init__()
        self.model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
        top_layer = list(self.model.children())[-1]
        self.model.classifier = nn.Linear(in_features=top_layer.in_features,
                                          out_features=n_classes,
                                          bias=True)
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=n_classes)
        self.f1_score = torchmetrics.F1Score(task='multiclass', num_classes=n_classes)

    def forward(self, inputs):
        outputs = self.model(**inputs)
        return outputs

    def training_step(self, batch, batch_idx):
        loss, outputs, labels = self._common_step(batch=batch, batch_idx=batch_idx)
        train_accuracy = self.accuracy(outputs, labels)
        train_f1_score = self.f1_score(outputs, labels)
        self.log_dict({'train_accuracy': train_accuracy, 'train_f1_score': train_f1_score, 'train_loss': loss}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, outputs, labels = self._common_step(batch=batch, batch_idx=batch_idx)
        val_accuracy = self.accuracy(outputs, labels)
        val_f1_score = self.f1_score(outputs, labels)
        self.log_dict({'val_accuracy': val_accuracy, 'val_f1_score': val_f1_score, 'val_loss': loss}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss, outputs, labels = self._common_step(batch=batch, batch_idx=batch_idx)
        test_accuracy = self.accuracy(outputs, labels)
        test_f1_score = self.f1_score(outputs, labels)
        self.log_dict({'test_accuracy': test_accuracy, 'test_f1_score': test_f1_score, 'test_loss': loss}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def _common_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self.forward(inputs=inputs)
        logits = outputs.logits
        loss = F.cross_entropy(logits, labels)
        return loss, logits, labels

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

In [None]:
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader
import tqdm as notebook_tqdm
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

root_dir = Path(os.getcwd(), 'data')
processor = AutoImageProcessor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
data_module = OHTDataModule(root_dir=root_dir, batch_size=2, processor=processor, num_workers=1)
train_dir = Path(root_dir, 'train')
n_class = os.listdir(train_dir)
pl_model = DiTModel(len(n_class))
model_chk_dir = Path(os.getcwd(), 'model')
model_chk_dir.mkdir(parents=True, exist_ok=True)
callbacks = [
    EarlyStopping(monitor='val_loss', patience=10),
    ModelCheckpoint(dirpath=model_chk_dir,
                    monitor='val_loss',
                    save_top_k=3,
                    filename='model-{val_accuracy:.2f}-{val_loss:.2f}',
                    save_on_train_epoch_end=True)
]
trainer = pl.Trainer(accelerator='auto', devices='auto', min_epochs=1, max_epochs=50, callbacks=callbacks)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
d = OHTDataset(Path(train_dir), processor)
d.config

{'id2label': {0: 'A', 1: 'B', 2: 'C', 3: 'O'},
 'label2id': {'A': 0, 'B': 1, 'C': 2, 'O': 3}}

In [None]:
trainer.fit(pl_model, data_module)

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /content/model exists and is not empty.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name     | Type                       | Params
--------------------------------------------------------
0 | model    | BeitForImageClassification | 85.8 M
1 | accuracy | MulticlassAccuracy         | 0     
2 | f1_score | MulticlassF1Score          | 0     
--------------------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
343.244   Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name     | Type                       | Params
--------------------------------------------------------
0 | model    | BeitForImageClassification | 85.8 M
1 | accuracy | MulticlassAccuracy         | 0     
2 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
trainer.validate(pl_model, data_module)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_accuracy': 0.904411792755127,
  'val_f1_score': 0.904411792755127,
  'val_loss': 0.3112561106681824}]

In [None]:
trainer.test(test_model, data_module)

In [None]:
version = 'v1'

pl_model.model.save_pretrained(Path(model_chk_dir, f'model_{version}'))
torch.save(pl_model.state_dict(), Path(model_chk_dir, f'state_dict{version}'))

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# !cp model/model-epoch\=21-val_accuracy\=0.99-val_loss\=0.07.ckpt /content/drive/MyDrive/Colab\ Notebooks/OHT
# !cp model/model-epoch\=25-val_accuracy\=0.99-val_loss\=0.07.ckpt /content/drive/MyDrive/Colab\ Notebooks/OHT
!cp model/model-epoch\=30-val_accuracy\=0.99-val_loss\=0.07.ckpt /content/drive/MyDrive/Colab\ Notebooks/OHT

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
