Install Required Packages

In [None]:
!pip install -q transformers pytorch-lightning

In [None]:
import torch
from torch import nn
import torchvision
import numpy as np
import timm
import os
import matplotlib.pyplot as plt
from datetime import datetime
import pandas as pd
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import operator
from torch.utils.data.sampler import SubsetRandomSampler
from PIL import Image
import torchvision.transforms as transforms 
import seaborn as sns

from sklearn.metrics import accuracy_score,confusion_matrix
from sklearn.metrics import classification_report
import matplotlib.ticker as ticker
import itertools

In [None]:
from transformers import ViTFeatureExtractor

In [None]:
train_dir = '/kaggle/input/plant-seedlings-classification/train'
test_dir = '/kaggle/input/plant-seedlings-classification/test'

In [None]:
class_map = {
 'Black-grass': 0,
 'Charlock': 1,
 'Cleavers': 2,
 'Common Chickweed': 3,
 'Common wheat': 4,
 'Fat Hen': 5,
 'Loose Silky-bent': 6,
 'Maize': 7,
 'Scentless Mayweed': 8,
 'Shepherds Purse': 9,
 'Small-flowered Cranesbill': 10,
 'Sugar beet': 11
}

id_to_class = {
 0: 'Black-grass',
 1: 'Charlock',
 2: 'Cleavers',
 3: 'Common Chickweed',
 4: 'Common wheat',
 5: 'Fat Hen',
 6: 'Loose Silky-bent',
 7: 'Maize',
 8: 'Scentless Mayweed',
 9: 'Shepherds Purse',
 10: 'Small-flowered Cranesbill',
 11: 'Sugar beet'
}

In [None]:
batch_size = 16
epochs = 50
CHECKPOINT_PATH = ""

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

In [None]:
normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
train_transforms = transforms.Compose(
        [
            transforms.RandomResizedCrop(feature_extractor.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
    )

val_transforms = transforms.Compose(
        [
            transforms.Resize(feature_extractor.size),
            transforms.CenterCrop(feature_extractor.size),
            transforms.ToTensor(),
            normalize,
        ]
    )

In [None]:
# Dataset

train_dataset = ImageFolder(train_dir, transform = train_transforms)
valid_size = 0.10

# Train-Valid split
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
val_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_sampler)
val_dataloader.dataset.transforms = val_transforms
test_dataloader = val_dataloader

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

In [None]:
batch = next(iter(train_dataloader))
print(batch[0].shape)

In [None]:
import pytorch_lightning as pl
from transformers import ViTForImageClassification, AdamW
import torch.nn as nn

class ViTLightningModule(pl.LightningModule):
    def __init__(self, num_labels=10):
        super(ViTLightningModule, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                             ignore_mismatched_sizes=True,
                                                              num_labels=12,
                                                              id2label=id_to_class,
                                                              label2id=class_map)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits
        
    def common_step(self, batch, batch_idx):
        pixel_values, labels = batch
        logits = self(pixel_values)

        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/pixel_values.shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

    def test_dataloader(self):
        return test_dataloader

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=20,
    strict=False,
    verbose=True,
    mode='min'
)

model = ViTLightningModule()
trainer = Trainer(gpus=1, callbacks=[EarlyStopping(monitor='validation_loss')])
trainer.fit(model)

In [None]:
def create_directory_dataFrame(basedir):
    df = pd.DataFrame(columns=['Location'])
    # basedir
    for location in os.listdir(basedir+'/'):
        df = df.append({'Location':basedir+'/'+location},ignore_index=True)
    return df

In [None]:
def pred_class(img):
    # transform images
    img_tens = val_transforms(img)
    img_im = img_tens.unsqueeze(0).cuda() 
    uinput = Variable(img_im)
    uinput = uinput.to(device)
    out = model(uinput)
    # convert image to numpy format in cpu and snatching max prediction score class index
    index = out.data.cpu().numpy().argmax()    
    return index

In [None]:
model.eval()
model.cuda()

test_df = create_directory_dataFrame(test_dir)

submission = pd.DataFrame(columns=['file', 'species'])

for i, image in enumerate(test_df['Location']):
    img = Image.open(image)
    image = image.split('/')[-1]
    index = pred_class(img)
    pred = id_to_class[index]
    
    submission = submission.append({'file': image, 'species': pred}, ignore_index=True)
    
submission.to_csv('vit-base-patch16-384_submission.csv', index=False)

In [None]:
torch.save({
            'epoch': epochs,
            'model_state_dict': model.state_dict(),
            }, 'vit-base-patch16-224-in21k.pth')
