In [3]:
# from transformers import ViTFeatureExtractor, ViTModel, ViTForImageClassification, ViTConfig
# from PIL import Image
# import requests

# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw)

# feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
# model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

# inputs = feature_extractor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# last_hidden_states = outputs.last_hidden_state

In [23]:
#!pip install jupyter-lab
#!pip install numpy==1.19.5
#!pip install pandas==1.1.3
#!pip install opencv-python==4.2.0.34
#!pip install torch==1.8.1
#!pip install pytorch-lightning==1.4.6
#!pip install transformers==4.10.2
#!pip install scikit-learn==0.24.2
#!pip install Pillow==8.3.2

Collecting Pillow
[?25l  Downloading https://files.pythonhosted.org/packages/0e/8f/b435e010927ab2e8e7708464e5f47f233f10d8d71d73a3d5c7c456346a4f/Pillow-8.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (3.0MB)
[K     |████████████████████████████████| 3.0MB 2.7MB/s eta 0:00:01
[?25hInstalling collected packages: Pillow
Successfully installed Pillow-8.3.2
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [24]:
from transformers import ViTFeatureExtractor, ViTModel, ViTForImageClassification, ViTConfig
from PIL import Image
import requests

In [20]:
import os
import re
import numpy as np
import pandas as pd
import cv2

# pytorch related imports
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
#from torchvision.datasets import CIFAR10
#from torchvision import transforms

# lightning related imports
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# sklearn related imports
from sklearn.metrics import precision_recall_curve
from sklearn.preprocessing import label_binarize

# Creating Dataset class

In [3]:
class firesmoke_image_dataset(torch.utils.data.Dataset):
    def __init__(self, csv_file_path, image_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file_path)
        self.image_dir = image_dir
        self.transform = transform
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
        
    def __len__(self):
        return len(self.data_frame)
        pass
    
    def __getitem__(self, index):
        
        row = self.data_frame.iloc[index]
        image_path = self.image_dir+"/"+row['image_name'].strip()
        
        img = cv2.imread(image_path)
       
        #img = np.einsum("ijk->kij", img)
        img = self.feature_extractor(img)['pixel_values'][0] # Don't forget this to apply while inference
        #img = image = Image.open(image_path)
        label = row[1]#.to_numpy(dtype=np.float32)
        if self.transform:
            img = transform(img)
            #img = feature_extractor(img)['pixel_values'][0]
            
        return [img, label]
    

## test

In [5]:
dataset = firesmoke_image_dataset("/home/anish/multilabel_classification/sample_data/image_label.csv",
                                  "/home/anish/multilabel_classification/sample_data/images/" )
dl = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

b = None
for batch in dl:
    b = batch
    break

#op = model(pixel_values=b[0])

In [6]:
op.pooler_output.shape

torch.Size([2, 768])

## creating pl datamodule

In [6]:
class FireSmokeDataModule(pl.LightningDataModule):
    def __init__(self, csv_file_path, data_dir: str = './', batch_size=5):
        super().__init__()
        self.csv_file_path = csv_file_path
        self.data_dir = data_dir
        self.batch_size = batch_size

#         self.transform = transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
#         ])
        
#         self.dims = (3, 32, 32)
#         self.num_classes = 10

    def prepare_data(self):
        # download 
        pass

    def setup(self, stage=None):
        self.firesmoke_ds = firesmoke_image_dataset(self.csv_file_path, self.data_dir)
            

    def train_dataloader(self):
        return DataLoader(self.firesmoke_ds, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.firesmoke_ds, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.firesmoke_ds, batch_size=self.batch_size)

# Creating pl lightningModule for vision transformer

In [15]:
class ViTfinetune(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=2e-4):
        super().__init__()
        pooler_dims = 768
        self.vit_back_bone = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.fc  = nn.Sequential(
                              nn.Dropout(0.5),
                              nn.Linear(768, num_classes)
                             )
        self.learning_rate = learning_rate
    # will be used during inference
    def forward(self, x):
        
        x = self.vit_back_bone(pixel_values=x)
        x = x.pooler_output
        x = self.fc(x)
        x = F.log_softmax(x, dim=1)
        return x

    # logic for a single training step
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss

    # logic for a single validation step
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    # logic for a single testing step
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [8]:
ft = ViTfinetune(num_classes=5)
lsm = ft(b[0])

In [9]:
torch.argmax(lsm , dim=1)

tensor([3, 0])

In [10]:
lsm

tensor([[-2.1577, -1.5392, -1.9584, -1.0667, -1.6895],
        [-1.4025, -1.6001, -1.7671, -1.9649, -1.4224]],
       grad_fn=<LogSoftmaxBackward>)

# Creating callbacks

In [11]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    dirpath=MODEL_CKPT_PATH,
    monitor='val_loss',
    filename=MODEL_CKPT ,
    save_top_k=3,
    mode='min')

# Lightning Trainer 

In [12]:
CSV_FILE="/home/anish/multilabel_classification/sample_data/image_label.csv"
IMAGE_DIR = "/home/anish/multilabel_classification/sample_data/images/"
# Init our data pipeline
dm = FireSmokeDataModule(csv_file_path=CSV_FILE, data_dir=IMAGE_DIR, batch_size=2)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape





(torch.Size([2, 3, 224, 224]), torch.Size([2]))

In [16]:
# Init our model
model = ViTfinetune(num_classes=4)
tb_logger = TensorBoardLogger("logs/")

# Initialize a trainer
trainer = pl.Trainer(max_epochs=50,
                     progress_bar_refresh_rate=20, 
                     gpus=0, 
                     logger=tb_logger,
                     callbacks=[early_stop_callback, checkpoint_callback])

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held-out test set ⚡⚡
trainer.test()


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(

  | Name          | Type       | Params
---------------------------------------------
0 | vit_back_bone | ViTModel   | 86.4 M
1 | fc            | Sequential | 3.1 K 
---------------------------------------------
86.4 M    Trainable params
0         Non-trainable params
86.4 M    Total params
345.569   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  rank_zero_warn(
  stream(template_mgs % msg_args)
  rank_zero_warn(
  rank_zero_warn(


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
  rank_zero_warn(


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 1.0, 'test_loss': 0.0026426888071000576}
--------------------------------------------------------------------------------



[{'test_loss': 0.0026426888071000576, 'test_acc': 1.0}]

In [20]:
lst = [m for m in model.modules()]

In [45]:
full_model = nn.Sequential(*lst[2:])

In [2]:
#full_model(b[0])

In [49]:
b[0].shape

torch.Size([2, 3, 224, 224])