# Plant Classification
> Group: YoHRa  
> Members: ZHENG Yannan, DAI Changjun, HUANG Yueqiao

![avatar](https://pbs.twimg.com/profile_images/960531652119285760/vtcdJZw5_400x400.jpg)

**What we used:**

In [None]:
!nvidia-smi -L

In [None]:
pip install --upgrade efficientnet-pytorch

In [None]:
pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110

In [None]:
import os
import re
import csv

import torch
import torch.nn as nn
import torchmetrics
import torchvision
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from efficientnet_pytorch import EfficientNet
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

**Global variables:**

In [None]:
NUM_CLASSES = 153
BATCH_SIZE = 40
NUM_THREADS = 8
NUM_EPOCHS = 14
VAL_SIZE = .1
NET_NAME = 'efficientnet-b3'

TRAIN_DIR = '../input/polytech-nice-data-science-course-2021/polytech/train'
TEST_DIR = '../input/polytech-nice-data-science-course-2021/polytech/test'

# Training set and validation set

In [None]:
IMAGES = []
IMAGES_TARGETS = []

for dirname, _, filenames in os.walk(TRAIN_DIR):
    m = re.search(r'(?<=/)\d+', dirname)         #get labels with regular expression
    for filename in filenames:
        if filename.endswith('.jpg'):
            IMAGES.append((os.path.join(dirname, filename), int(m.group(0))-1))
            IMAGES_TARGETS.append(int(m.group(0))-1)
            
TRAIN_IMAGES, VAL_IMAGES = train_test_split(IMAGES,test_size=VAL_SIZE,
                                            shuffle=True,
                                            stratify=IMAGES_TARGETS)

train_dirs, train_labels = zip(*TRAIN_IMAGES)
val_dirs, val_labels = zip(*VAL_IMAGES)

# DALI from NVIDIA

In [None]:
class MyTrainPipe(Pipeline):
    def __init__(self, files, labels, batch_size, num_threads, device_id):
        super(MyTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12)
        self.input = ops.readers.File(files=files, labels=labels, random_shuffle=True)
        self.decode = ops.decoders.Image(device="mixed", output_type=types.RGB)
        self.res = ops.RandomResizedCrop(device="gpu", size=(300, 300), random_area=[0.08, 1.25])
        self.cmn = ops.CropMirrorNormalize(device="gpu",
                                            dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                            std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
        self.coin = ops.random.CoinFlip(probability=0.5)

    def define_graph(self):
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmn(images, mirror=self.coin())
        return output, self.labels


class MyValPipe(Pipeline):
    def __init__(self, files, labels, batch_size, num_threads, device_id):
        super(MyValPipe, self).__init__(batch_size, num_threads, device_id, seed=12)
        self.input = ops.readers.File(files=files, labels=labels)
        self.decode = ops.decoders.Image(device="mixed", output_type=types.RGB)
        self.res = ops.Resize(device="gpu", resize_shorter=324, interp_type=types.INTERP_TRIANGULAR)
        self.cmn = ops.CropMirrorNormalize(device="gpu",
                                            dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(300, 300),
                                            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                            std=[0.229 * 255, 0.224 * 255, 0.225 * 255])

    def define_graph(self):
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmn(images)
        return output, self.labels


def _make_net_iterator(type, files, labels, batch_size, num_threads, device_id=0):
    if type == 'train':
        ppl = MyTrainPipe(files=files, labels=labels, batch_size=batch_size, 
                                num_threads=num_threads, device_id=device_id)
        ppl.build()
        net_iter = DALIClassificationIterator(ppl, 
                                                     reader_name="Reader",
                                                     last_batch_policy=LastBatchPolicy.PARTIAL,
                                                     auto_reset=True)
        return net_iter
    
    elif type == 'val':
        ppl = MyValPipe(files=files, labels=labels, batch_size=batch_size, 
                            num_threads=num_threads, device_id=device_id)
        ppl.build()
        net_iter = DALIClassificationIterator(ppl, reader_name="Reader",
                                                   last_batch_policy=LastBatchPolicy.PARTIAL,
                                                   auto_reset=True)
        return net_iter

# Pytorch Lightning

In [None]:
class MyNetwork(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.net = EfficientNet.from_pretrained(NET_NAME) 
        self.criterion = torch.nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy()
        self.lr = 1e-3
        
    def forward(self, x):
        return self.net(x)
    
    def training_step(self, batch, batch_idx):
        x = batch[0]["data"].squeeze().float()
        y = batch[0]["label"].squeeze().long()
        y_hat = self.net(x)
        acc = self.accuracy(y_hat, y)
        self.log('train_acc', acc, prog_bar=True)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss)
        return loss        
   
    def validation_step(self, batch, batch_idx):
        x = batch[0]["data"].squeeze().float()
        y = batch[0]["label"].squeeze().long()
        y_hat = self.net(x)
        acc = self.accuracy(y_hat, y)
        self.log('val_acc', acc, prog_bar=True, logger=True)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss, prog_bar=True, logger=True)

    def configure_optimizers(self):
        return torch.optim.SGD(self.net.parameters(), lr=self.lr, momentum=.9)


# O Captain! My Captain! Our wonderful trip begins!

In [None]:
train_dl = _make_net_iterator(type='train', 
                                  files=train_dirs, 
                                  labels=train_labels,
                                  batch_size=BATCH_SIZE, 
                                  num_threads=NUM_THREADS)

val_dl = _make_net_iterator(type='val', 
                                files=val_dirs, 
                                labels=val_labels,
                                batch_size=BATCH_SIZE, 
                                num_threads=NUM_THREADS)

net = MyNetwork()

net_checkpoint = ModelCheckpoint(monitor = "val_acc",
                                 mode='max',
                                 verbose=True)

trainer = pl.Trainer(gpus=1,
                     max_epochs=NUM_EPOCHS,
                     accelerator='dp',
                     callbacks=[net_checkpoint],
                     checkpoint_callback=True)
trainer.tune(net, train_dl, val_dl)
trainer.fit(net, train_dl, val_dl)

trainer.save_checkpoint("best_net.ckpt")

Finally, classify the test set.

In [None]:
test_transform = transforms.Compose([transforms.Resize(324),
                                     transforms.CenterCrop(300),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                    ])

net = MyNetwork.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
net.eval()
net.to(device)

headers = ['image_name', 'class']
results = []

for filename in tqdm(os.listdir(TEST_DIR)):
    if filename.endswith(".jpg"):
        img = Image.open(TEST_DIR+ '/' + filename).convert("RGB")
        data = test_transform(img)
        data = data.to(device)
        data = data.float()
        data = data.unsqueeze(0)
        pred = net(data)
        output = pred.max(1)[1]
        results.append({'image_name':filename, 'class':output.item()+1})

In [None]:
with open('submission.csv','w',newline='') as f:
    f_csv = csv.DictWriter(f, headers)
    f_csv.writeheader()
    f_csv.writerows(results)