In [1]:
!wandb login f90bc50f940247a6f1e5b7574e124c960c4cece9

'wandb' is not recognized as an internal or external command,
operable program or batch file.


# Boilerplate

## Import Statements

In [2]:
import wandb
import numpy as np
import pandas as pd
import os
import torch
from PIL import Image
import matplotlib.pyplot as plt

#torch related imports
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

# lightning related imports
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# sklearn related imports
from sklearn.metrics import precision_recall_curve
from sklearn.preprocessing import label_binarize

## Config Class

In [3]:
class config:
    NUM_EPOCHS = 10
    LR = 0.00001
    TRAIN_CNN = False
    BATCH_SIZE = 32
    PIN_MEMORY = True
    TRAIN_BATCH_SIZE = 32
    VALID_BATCH_SIZE = 8

# Data Preparation

## Preprocessing Function

In [4]:
classes = {'0': 'letter',
 '1': 'form',
 '10': 'budget',
 '11': 'invoice',
 '12': 'presentation',
 '13': 'questionnaire',
 '14': 'resume',
 '15': 'memo',
 '2': 'email',
 '3': 'handwritten',
 '4': 'advertisement',
 '5': 'scientific report',
 '6': 'scientific publication',
 '7': 'specification',
 '8': 'file folder',
 '9': 'news article'}

In [5]:
def get_df(fname, options=None):
    df = pd.read_csv(fname, sep=" ", header=None)
    df.columns = ["img_name", "label"]
    df["img_name"] = 'images/' + df["img_name"]
    if options:
        df = df[df["label"].isin(options)]
    
    return df



## Dataset Class

In [6]:
class DocDataset(Dataset):
    def __init__(self, annotation_df, transform=None):
        self.annotations = annotation_df
        self.transform = transform  

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations.iloc[index, 0]
        img = Image.open(img_id).convert("RGB")
        y_label = torch.tensor(self.annotations.iloc[index, 1])

        if self.transform is not None:
            img = self.transform(img)

        return (img, y_label)

## Dataloader Class

In [7]:
class DocDataModule(pl.LightningDataModule):

    def __init__(self, train_df, val_df, test_df, train_batch, val_batch, pin_memory, options=None, transform=None):
        super().__init__()
        self.train_df = train_df
        self.val_df = val_df
        self.test_df = test_df
        self.train_batch=train_batch
        self.val_batch=val_batch
        self.pin_memory = pin_memory
        self.transform = transform
        if options:
            self.num_classes = len(options)
        else:
            self.num_classes = 16
    
    def setup(self):
        self.train_dataset = DocDataset(self.train_df, self.transform)
        self.val_dataset = DocDataset(self.val_df, self.transform)
        self.test_dataset = DocDataset(self.test_df, self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch,
            shuffle=True,
            num_workers=8,
            pin_memory = self.pin_memory
        )
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch,
            num_workers=8,
            pin_memory = self.pin_memory
        )
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.val_batch,
            num_workers=8,
            pin_memory = self.pin_memory
        )

## DataPrep Statements

In [8]:
# options = [0,1,2,3,4,5,11,13]
options = None
train_df = get_df(r'labels/train.txt', options)
val_df = get_df(r'labels/val.txt', options)
test_df = get_df(r'labels/test.txt', options)

transform = transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

data_module = DocDataModule(train_df, val_df, test_df, config.TRAIN_BATCH_SIZE, config.VALID_BATCH_SIZE, config.PIN_MEMORY, options, transform)
data_module.setup()


# Modelling

## Modelling Class

In [9]:
class DocModel(pl.LightningModule):

    def __init__(self, num_classes, learning_rate, train_CNN=False):
        super().__init__()

        self.num_classes = num_classes
        self.train_CNN = train_CNN
        self.learning_rate = learning_rate

        self.inception = models.inception_v3(pretrained=True, aux_logits=False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.softmax = nn.LogSoftmax(dim=1)
        self.loss = nn.NLLLoss()

    def forward(self, images):
        features = self.inception(images)
        logits = self.softmax(self.dropout(self.relu(features))).squeeze(1)
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch

        logits = self(x)
        loss = self.loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        for name, param in self.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = self.train_CNN
        return optimizer

## Callback class

In [10]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

class ImagePredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
        
    def on_validation_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # Get model prediction
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # Log the images as wandb Image
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            })

In [11]:
MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model/model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filepath=MODEL_CKPT ,
    save_top_k=3,
    mode='min')

## Training Statements

In [12]:
val_samples = next(iter(data_module.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape


In [None]:
# Init our model
model = DocModel(data_module.num_classes, config.LR)

# Initialize wandb logger
wandb_logger = WandbLogger(project='DocClassifier', job_type='train')

# Initialize a trainer
trainer = pl.Trainer(max_epochs=config.NUM_EPOCHS,
                     progress_bar_refresh_rate=20, 
                     logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples)],
                     checkpoint_callback=checkpoint_callback)

# Train the model ⚡🚅⚡
trainer.fit(model, data_module)

# Evaluate the model on the held out test set ⚡⚡
trainer.test()

# Prediction

## Prediction Statements