In [None]:
!pip install -q timm

In [None]:
import numpy as np
import pandas as pd
import os

from sklearn import metrics

import torch
import torch.nn as nn
import torch.optim as optim
import albumentations
import torchvision
from torchvision import transforms
import timm


import PIL
import glob
from tqdm import tqdm

In [None]:
TRAIN_DIR='/kaggle/input/mango-dataset-bangladesh/MangoLeafBD_Without_Testset_Augmentation/Train/'
TEST_DIR= '/kaggle/input/mango-dataset-bangladesh/MangoLeafBD_Without_Testset_Augmentation/Test/'


## Calculate the Mean and Standard Deviation of Data

In [None]:
imgs=[]
common_size=(256,256)
for img in glob.glob(os.path.join(TRAIN_DIR,'**/*.jpg'),recursive=True):
    imgs.append(torch.tensor(np.array(PIL.Image.open(img).resize(common_size, PIL.Image.LANCZOS)),dtype=torch.float32))

In [None]:
imgs[0].shape

In [None]:
stacked_images= torch.stack(imgs,dim=0)
mean= torch.mean(stacked_images,dim=(0,1,2))
std= torch.std(stacked_images,dim=(0,1,2))

print(f"Mean of images: {mean}")
print(f"Std of images: {std}")


## Perform Augmentations and save the data

In [None]:
train_transform=transforms.Compose([transforms.RandomResizedCrop(256),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomRotation(30),
                                          transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[171.7605, 176.1683, 173.1120],std=[52.8552, 50.4523, 68.5839])])
test_transform= transforms.Compose([transforms.Resize(common_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[171.7605, 176.1683, 173.1120],std=[52.8552, 50.4523, 68.5839])])

In [None]:
train_dataset= torchvision.datasets.ImageFolder(TRAIN_DIR,transform=train_transform)
test_dataset= torchvision.datasets.ImageFolder(TEST_DIR,transform=test_transform)

In [None]:
train_loader= torch.utils.data.DataLoader(train_dataset,shuffle=True,batch_size=32)
test_loader= torch.utils.data.DataLoader(test_dataset,batch_size=32)

## Model Training and testing

In [None]:
class MangoClassifier(nn.Module):
    def __init__(self,num_classes):
        super(MangoClassifier,self).__init__()
        self.base_model= timm.create_model('vit_base_patch32_plus_256',pretrained=False)
        self.base_model.head= nn.Linear(self.base_model.head.in_features,num_classes)

    def forward(self,x):
        out= torch.softmax(self.base_model(x),dim=1)
        return out




In [None]:
model_fn= MangoClassifier(8)
rand_img=torch.randn(1,3,256,256)
model_fn(rand_img)

In [None]:
num_epochs=50
optimizer= optim.Adam(model_fn.parameters(),lr=1e-3)
loss_fn= nn.CrossEntropyLoss()
model_fn.train()
all_preds = []
all_labels = []
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_fn.to(device)
print(f"Device is {device}")
for epoch in range(num_epochs):
    loss_cum=0
    ind=0
    for img,label in train_loader:
        img,label=img.to(device),label.to(device)

        out= model_fn(img)
        loss=loss_fn(out,label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_cum+=loss.item()
        ind+=1

        _,predicted= torch.max(out,1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(label.cpu().numpy())

    print(f"For epoch {epoch} the loss is {loss_cum/ind} the F1 score is {metrics.f1_score(all_preds,all_labels,average='macro')}")




In [None]:
model_fn.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for img,labels in test_loader:
        img,label=img.to(device),label.to(device)
        out=model_fn(img)
        _,predicted= torch.max(out,1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
f1 = metrics.f1_score(all_labels, all_preds, average='weighted')


