In [None]:
%%capture
!pip install deeplake
!pip install lightning

In [None]:
import deeplake
'''
This cell loads plant village dataset using deeplake
'''
ds = deeplake.load('hub://activeloop/plantvillage-with-augmentation')

-

This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/plantvillage-with-augmentation



\

hub://activeloop/plantvillage-with-augmentation loaded successfully.



  

In [None]:
from torchvision import transforms,models
tform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomRotation(20), # Image augmentation
    transforms.ToTensor(), # Must convert to pytorch tensor for subsequent operations to run
    transforms.Normalize([0.5], [0.5]),
])

In [None]:
batch_size = 64

# Since torchvision transforms expect PIL images, we use the 'pil' decode_method for the 'images' tensor. This is much faster than running ToPILImage inside the transform
train_loader = ds.pytorch(num_workers = 0, shuffle = True, transform = {'images': tform, 'labels': None}, batch_size = batch_size, decode_method = {'images': 'pil'})
val_loader = ds.pytorch(num_workers = 0, transform = {'images': tform, 'labels': None}, batch_size = batch_size, decode_method = {'images': 'pil'})
predict = ds.pytorch(num_workers = 0, transform = {'images': tform, 'labels': None}, batch_size = 1, decode_method = {'images': 'pil'})

In [None]:
print(next(iter(train_loader))["images"].shape)
print(next(iter(val_loader))["images"].shape)

torch.Size([64, 3, 256, 256])
torch.Size([64, 3, 256, 256])


In [None]:
import os
import torch 
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
import lightning.pytorch as pl
from torchvision.transforms import ToTensor
# define any number of nn.Modules (or use your current ones)
def get_model(num_classes = 38):
    model = models.resnet18(pretrained=True)

    # Convert model to grayscale
    model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=3, bias=False)

    # Update the fully connected layer based on the number of classes in the dataset
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    

    return model

# define the LightningModule
class Model(pl.LightningModule):
    def __init__(self, num_classes = 38):
        """
        Inputs:
            num_classes: Number of classes in the dataset and model
        """
        super().__init__()

        # Create the model
        self.model = get_model(num_classes)

        # Create loss module
        self.loss_module = torch.nn.CrossEntropyLoss()

    def forward(self, imgs):
        return self.model(imgs)

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.1)   

    def training_step(self, batch, batch_idx):
        images = batch['images']
        images = images.float()
        labels = torch.squeeze(batch['labels'])

        preds = self.model(images)
        loss = self.loss_module(preds, labels)
        
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log("train_acc", acc, on_step=True, on_epoch=True)
        self.log("train_loss", loss)
        
        return loss 

    def validation_step(self, batch, batch_idx):

        images = batch['images']
        labels = torch.squeeze(batch['labels'])
        preds = self.model(images.float()).argmax(dim=-1)
        acc = (labels == preds).float().mean()

        # Log the valdation accuracy to the progress bar at the end of each epoch
        self.log("val_acc", acc, on_epoch=True, prog_bar=True)

In [None]:
model = Model(39)
trainer = pl.Trainer(max_epochs = 10)#accelerator="tpu", devices=1)
trainer.fit(model=Model(39), train_dataloaders = train_loader, val_dataloaders = val_loader)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 204MB/s]
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name        | Type             | Params
-------------------------------------------------
0 | model       | ResNet           | 11.2 M
1 | loss_module | CrossEntropyLoss | 0 

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
!pip install onnx
!pip install torch
import torch
filepath = "model.onnx"
input_sample = torch.randn((1,3,256,256))
model.to_onnx(filepath, input_sample, export_params=True)

In [None]:
!pip install onnxruntime
import onnxruntime
import numpy as np
model.load_state_dict(torchmap_location=device))
ort_session = onnxruntime.InferenceSession(filepath)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.random((1,3,256,256)).astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)

In [None]:
len(ort_outs[0][0])

In [None]:
labels_dict = {
    1:"Apple Scab",
    2:"Apple Black Rot",
    3:"Apple Cedar Rust",
    4:"Apple Cedar Rust",
    5:"Blueberry healthy",
    6:"Cherry healthy",
    7:"Cherry Powdery Mildew",
    8:"Corn Gray Leaf Spot",
    9:"Corn Common Rust",
    10:"Corn healthy",
    11:"Corn Northern Leaf Blight",
    12:"Grape Black Rot",
    13:"Grape Black Measles",
    14:"Grape Healthy",
    15:"Grape Leaf Blight",
    16:"Orange Huanglongbing",
    17:"Peach Bacterial Spot",
    18:"Peach healthy",
    19:"Bell Pepper Bacterial Spot",
    20:"Bell Pepper healthy",
    21:"Potato Early Blight",
    22:"Potato healthy",
    23:"Potato Late Blight",
    24:"Raspberry healthy",
    25:"Soybean healthy",
    26:"Squash Powdery Mildew",
    27:"Strawberry Healthy",
    28:"Tomato Bacterial Spot",
    29:"Tomato Early Blight",
    30:"Tomato Late Blight",
    31:"Tomato Leaf Mold",
    32:"Tomato Septoria Leaf Spot",
    33:"Tomato Two Spotted Spider Mite",
    34:"Tomato Target Spot",
    35:"Tomato Mosaic Virus",
    36:"Tomato Yellow Leaf Curl Virus",
    37:"Tomato healthy"
}

In [None]:

lab = np.argmax(ort_outs[0][0])
labels_dict[lab]