In [2]:
from transformers import ViTFeatureExtractor, ViTModel, ViTForImageClassification, ViTConfig
from PIL import Image
import requests

In [3]:
import os
import re
import numpy as np
import pandas as pd
import cv2

# pytorch related imports
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
#from torchvision.datasets import CIFAR10
#from torchvision import transforms

# lightning related imports
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# sklearn related imports
from sklearn.metrics import precision_recall_curve
from sklearn.preprocessing import label_binarize

# Creating Dataset class

In [21]:
class firesmoke_image_dataset(torch.utils.data.Dataset):
    def __init__(self, csv_file_path, image_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file_path)
        self.image_dir = image_dir
        self.transform = transform
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
        
    def __len__(self):
        return len(self.data_frame)
        pass
    
    def __getitem__(self, index):
        
        row = self.data_frame.iloc[index]
        image_path = self.image_dir+"/"+row['image_name'].strip()
        
        img = cv2.imread(image_path)
       
        #img = np.einsum("ijk->kij", img)
        img = self.feature_extractor(img)['pixel_values'][0] # Don't forget this to apply while inference
        #img = image = Image.open(image_path)
        label = row[1]#.to_numpy(dtype=np.float32)
        if self.transform:
            img = transform(img)
            #img = feature_extractor(img)['pixel_values'][0]
            
        return [img, label]
    

## test

In [22]:
dataset = firesmoke_image_dataset("./raw_data/image_label.csv",
                                  "./raw_data/images/" )
dl = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

b = None
for batch in dl:
    b = batch
    break

#op = model(pixel_values=b[0])

In [23]:
#op.pooler_output.shape

## creating pl datamodule

In [24]:
class FireSmokeDataModule(pl.LightningDataModule):
    def __init__(self, csv_file_path, data_dir: str = './', batch_size=5):
        super().__init__()
        self.csv_file_path = csv_file_path
        self.data_dir = data_dir
        self.batch_size = batch_size

#         self.transform = transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
#         ])
        
#         self.dims = (3, 32, 32)
#         self.num_classes = 10

    def prepare_data(self):
        # download 
        pass

    def setup(self, stage=None):
        self.firesmoke_ds = firesmoke_image_dataset(self.csv_file_path, self.data_dir)
            

    def train_dataloader(self):
        return DataLoader(self.firesmoke_ds, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.firesmoke_ds, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.firesmoke_ds, batch_size=self.batch_size)

# Creating pl lightningModule for vision transformer

In [25]:
class ViTfinetune(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=2e-4):
        super().__init__()
        self.save_hyperparameters()
        pooler_dims = 768
        self.vit_back_bone = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.fc  = nn.Sequential(
                              nn.Dropout(0.5),
                              nn.Linear(768, num_classes)
                             )
        self.learning_rate = learning_rate
    # will be used during inference
    def forward(self, x):
        
        x = self.vit_back_bone(pixel_values=x)
        x = x.pooler_output
        x = self.fc(x)
        x = F.log_softmax(x, dim=1)
        return x

    # logic for a single training step
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss

    # logic for a single validation step
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    # logic for a single testing step
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [26]:
ft = ViTfinetune(num_classes=5)
lsm = ft(b[0])

In [27]:
torch.argmax(lsm , dim=1)

tensor([1, 2])

In [28]:
lsm

tensor([[-1.3756, -1.3221, -1.7968, -1.8398, -1.8574],
        [-1.5885, -1.4880, -1.4185, -1.8343, -1.7829]],
       grad_fn=<LogSoftmaxBackward>)

# Creating callbacks

In [39]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    dirpath=MODEL_CKPT_PATH,
    monitor='val_loss',
    filename=MODEL_CKPT ,
    save_top_k=3,
    mode='min')

# Lightning Trainer 

In [32]:
CSV_FILE="./raw_data/image_label.csv"
IMAGE_DIR = "./raw_data/images/"
# Init our data pipeline
dm = FireSmokeDataModule(csv_file_path=CSV_FILE, data_dir=IMAGE_DIR, batch_size=32)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape





(torch.Size([32, 3, 224, 224]), torch.Size([32]))

In [37]:
model=None

In [None]:
# Init our model
model = ViTfinetune(num_classes=4)
tb_logger = TensorBoardLogger("logs/")

# Initialize a trainer
trainer = pl.Trainer(max_epochs=50,
                     progress_bar_refresh_rate=20, 
                     gpus=1, 
                     logger=tb_logger,
                     callbacks=[early_stop_callback, checkpoint_callback])

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held-out test set ⚡⚡
trainer.test()


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params
---------------------------------------------
0 | vit_back_bone | ViTModel   | 86.4 M
1 | fc            | Sequential | 3.1 K 
---------------------------------------------
86.4 M    Trainable params
0         Non-trainable params
86.4 M    Total params
345.569   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [54]:
def load_model(checkpoint_file_path=None):
    model = ViTfinetune.load_from_checkpoint(checkpoint_file_path)
    #model.load_from_checkpoint(checkpoint_file_path)
    return model
    pass

In [55]:
model = load_model("~/ViT_clf/model/model-epoch=47-val_loss=0.00-v2.ckpt")

In [113]:
def inference(model, image: np.array, transformer_feature_extractor, class_list):
    result = []
    model_input = transformer_feature_extractor(image, return_tensors='pt')['pixel_values']
    log_prob = model(model_input)
    prob=torch.exp(op)
    for each_img_in_batch in range(prob.shape[0]):
        tmp_prob = prob[each_img_in_batch].tolist()
        tmp_out_dict = dict(zip(class_list, tmp_prob))
        result.append(tmp_out_dict)
    return result
        
        
    
    
    pass
    

## Inference 

In [103]:
image_path = "./raw_data/images/airplane_00001.jpg"
img = cv2.imread(image_path)
ip = [img, img, img]
feature_extractor  = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

In [114]:
inference(model, ip, feature_extractor, ['a', 'b', 'c', 'd'])

[{'a': 0.9979692101478577,
  'b': 0.0009292723843827844,
  'c': 0.0006120629259385169,
  'd': 0.0004893728182651103},
 {'a': 0.9981659054756165,
  'b': 0.000683989783283323,
  'c': 0.0008381748339161277,
  'd': 0.0003120519104413688},
 {'a': 0.999303936958313,
  'b': 0.00018437736434862018,
  'c': 0.0003101338807027787,
  'd': 0.0002016006037592888}]