In [None]:
!pip install --upgrade torch torchvision

In [None]:
#Importing Pytorch, OpenCV, Albemntations
import torch
import cv2
import albumentations as A
from glob import glob
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from albumentations.pytorch.transforms import ToTensorV2
from torchvision import models
import torch.nn as nn
import torch.optim as optim
from torchmetrics import (Accuracy, ConfusionMatrix,)

In [None]:
category = {
    'Apple': (
        0 , {
        'Apple_scab': 0,
        'Black_rot': 1,
        'Cedar_apple_rust': 2,
        'healthy': 3,
        }
    ),
    'Blueberry': (
        1, {
        'healthy': 4,
        }
    ),
    'Cherry': (
        2, {
            'Powdery_mildew': 5,
            'healthy': 6,
        }
    ),
    'Corn': (
        3, {
            'Cercospora_leaf_spot': 7,
            'Common_rust': 8,
            'Northern_Leaf_Blight': 9,
            'healthy': 10,
        }
    ),
    'Grape': (
        4, {
            'Black_rot': 11,
            'Esca': 12,
            'Leaf_blight': 13,
            'healthy': 14,
        }
    ),
    'Orange': (
        5, {
            'Haunglongbing': 15,
        }
    ),
    'Peach': (
        6, {
            'Bacterial_spot': 16,
            'healthy': 17,
        }
    ),
    'Pepper': (
        7, {
            'Bacterial_spot': 18,
            'healthy': 19
        }
    ),
    'Potato': (
        8, {
            'Early_blight': 20,
            'Late_blight': 21,
            'healthy': 22,
        }
    ),
    'Raspberry': (
        9, {
            'healthy': 23,
        }
    ),
    'Soybean': (
        10, {
            'healthy': 24,
        }
    ),
    'Squash': (
        11, {
            'Powdery_mildew': 25,
        }
    ),
    'Strawberry': (
        12, {
            'Leaf_scorch': 26,
            'healthy': 27,
        }
    ),
    'Tomato': (
        13, {
            'Bacterial_spot': 28,
            'Early_blight': 29,
            'Late_blight': 30,
            'Leaf_Mold': 31,
            'Septoria_leaf_spot': 32,
            'Spider_mites': 33,
            'Target_Spot': 34,
            'Yellow_Leaf_Curl_Virus': 35,
            'mosaic_virus': 36,
            'healthy': 37,
        }
    ),
}

In [None]:
class PlantDiseasesDataset(Dataset):
    def __init__(self, path, transform=None):
        self.files = glob(path, recursive=True)
        
        if transform is not None:
            self.transform = transform
        else:
            self.transform = A.Compose([
                A.Normalize(),
                ToTensorV2(),
            ])
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        item = self.files[idx]
        
        fruit = None
        label = None
        
        for fruit_i in category.keys():
            if fruit_i in item:
                fruit = category[fruit_i]
                break
                
        for dis in fruit[1].keys():
            if dis in item:
                label = fruit[1][dis]
                break

        img = cv2.imread(item)
        
        fruit_channel = torch.ones(1, img.shape[0], img.shape[1])
        fruit_channel *= fruit[0] / 13        
        
        img = self.transform(image=img)['image']
        
        img = torch.vstack([img, fruit_channel])
        
        if label is None or fruit is None:
            print(label, fruit)
        
        return img, fruit[0], label

In [None]:
class NeuralNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(NeuralNet, self).__init__()
        self.in_channels = in_channels
        self.input_layer = nn.Sequential(
#                 nn.BatchNorm2d(in_channels),
                nn.Conv2d(in_channels, 3, 3, 1, 1)
        )
        
        self.efficientnet = models.efficientnet_b7(pretrained=True)
        
        self.last = nn.Sequential(
            nn.Linear(1000, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes),
        )
        
    def forward(self, x,):
        if self.in_channels != 3:
            x = self.input_layer(x)
            
        x = self.efficientnet(x)
#         x = torch.cat([x, fruit])
        x = self.last(x)
        
        return x

In [None]:
def train(model, loader, crit, optimizer, device):
    loop = tqdm(loader)
    
    losses = []
    
    for img, _, label in loop:
        img = img.to(device)
        label = label.to(device).long()
        
        preds = model(img)
        loss = crit(preds, label.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loop.set_postfix({
            'loss': loss.item()
        })
        
        losses.append(loss.item())
        
    return losses
        
def valid(model, loader, crit, device):
    loop = tqdm(loader)
    
    acc, cmatrix = Accuracy(num_classes=38),  ConfusionMatrix(num_classes=38)
    losses = []
    
    for img, _, label in loop:
        img = img.to(device)
        label = label.to(device).long()
        
        preds = model(img)
        loss = crit(preds, label.view(-1))
        
#         preds
        preds = torch.softmax(preds, dim=1).cpu()
        label = label.cpu()
        acc.update(preds, label.view(-1))
#         avg_pre.update(preds, label.view(-1))
        cmatrix.update(preds, label.view(-1))
        
        loop.set_postfix({
            'loss': loss.item()
        })
        
        losses.append(loss.item())
        
    print(f"Accuracy: {acc.compute()}")
#     print(f"Avg. Precision: {avg_pre.compute()}")
    print(f"Confusion Matrix:")
    print(cmatrix.compute())
    
    return losses

In [None]:
NUM_EPOCHS = 5
LRN_RATE = 1e-5
BATCH_SIZE = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LOAD_MODEL = True

crit = nn.CrossEntropyLoss()
model = NeuralNet(4, 38).to(DEVICE)

if LOAD_MODEL:
    model.load_state_dict(torch.load('./model_checkpoint_1.pt')['model'])

optimizer = optim.AdamW(model.parameters(), LRN_RATE)
data = PlantDiseasesDataset('../input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/train/*/*.JPG')

train_len = len(data) - 5000
val_len = 5000

train_set, val_set = random_split(data, [train_len, val_len])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
train_losses = []
val_losses = []
for epoch in range(NUM_EPOCHS):
    print(f'Epoch #{epoch}')
    print('Training')
    train_losses += train(model, train_loader, crit, optimizer, DEVICE)
    print('Validation')
    val_losses += valid(model, val_loader, crit, DEVICE)
    
    checkpoint = {
        'model' : model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }
    
    torch.save(checkpoint, f'model_checkpoint_{epoch}.pt')

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses)
plt.plot(val_losses)

plt.show()

In [None]:
_ = valid(model, val_loader, crit, DEVICE)

In [None]:
def predict(image):
    """

    """

    model = NeuralNet(4, 38)
    model.load_state_dict(torch.load('./model_checkpoint_1.pt')['model'])
    
    pred = torch.argmax(torch.softmax(model(image)))

    prediction = None

    return prediction
