In [2]:
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models

In [3]:
import numpy as np
import pandas as pd
import os


In [4]:

!pip install timm
import timm
from timm.loss import LabelSmoothingCrossEntropy

Defaulting to user installation because normal site-packages is not writeable


In [5]:
# remove warnings
import warnings
warnings.filterwarnings("ignore")

In [6]:
import matplotlib.pyplot as plt
%matplotlib inline

In [7]:
import sys
from tqdm import tqdm
import time
import copy

In [8]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [13]:
def get_data_loaders(data_dir, batch_size, train = False):
    if train:
        #train
        transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter()]), p=0.25),
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD), # imagenet means
            T.RandomErasing(p=0.1, value='random')
        ])
        train_data = datasets.ImageFolder(os.path.join(data_dir, "train"), transform = transform)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return train_loader, len(train_data)
    else:
        # val/test
        transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD),
        ])
        val_data = datasets.ImageFolder(os.path.join(data_dir, "valid"), transform=transform)
        test_data = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return val_loader, test_loader, len(val_data), len(test_data)

In [14]:
from sklearn.model_selection import train_test_split

In [19]:
dataset_path = ""

In [25]:
(train_loader, train_data_len) = get_data_loaders(dataset_path, 128, train=True)
(val_loader, test_loader, valid_data_len, test_data_len) = get_data_loaders(dataset_path, 32, train=False)


In [22]:
classes = get_classes("train")
print(classes, len(classes))

['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Sp

In [16]:
dataloaders = {
    "train": train_loader,
    "val": val_loader
}
dataset_sizes = {
    "train": train_data_len,
    "val": valid_data_len
}

In [17]:
print(len(train_loader), len(val_loader), len(test_loader))

616 616 309


In [18]:
print(train_data_len, valid_data_len, test_data_len)

78766 19698 9873


In [19]:
# now, for the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [20]:
HUB_URL = "SharanSMenon/swin-transformer-hub:main"
MODEL_NAME = "swin_tiny_patch4_window7_224"
# check hubconf for more models.
model = torch.hub.load(HUB_URL, MODEL_NAME, pretrained=True)

Downloading: "https://github.com/SharanSMenon/swin-transformer-hub/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth" to /root/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth
100%|██████████| 109M/109M [00:00<00:00, 210MB/s]


In [1]:
for param in model.parameters(): #freeze model
    param.requires_grad = False

n_inputs = model.head.in_features
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(classes))
)
model = model.to(device)
print(model.head)

NameError: name 'model' is not defined

In [22]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = optim.AdamW(model.head.parameters(), lr=0.001)

In [23]:
# lr scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

In [24]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=20):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-"*10)

        for phase in ['train', 'val']: # We do training and validation phase per epoch
            if phase == 'train':
                model.train() # model to training mode
            else:
                model.eval() # model to evaluate

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'): # no autograd makes validation go faster
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # used for accuracy
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step() # step at end of epoch

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc =  running_corrects.double() / dataset_sizes[phase]

            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict()) # keep the best validation accuracy model
        print()
    time_elapsed = time.time() - since 
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))

    model.load_state_dict(best_model_wts)
    return model

In [25]:
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=20) 


Epoch 0/19
----------


100%|██████████| 616/616 [1:38:35<00:00,  9.60s/it]


train Loss: 1.0908 Acc: 0.9037


100%|██████████| 616/616 [24:04<00:00,  2.34s/it]


val Loss: 0.8795 Acc: 0.9660

Epoch 1/19
----------


100%|██████████| 616/616 [07:08<00:00,  1.44it/s]


train Loss: 0.8926 Acc: 0.9621


100%|██████████| 616/616 [01:47<00:00,  5.71it/s]


val Loss: 0.8350 Acc: 0.9738

Epoch 2/19
----------


100%|██████████| 616/616 [07:09<00:00,  1.43it/s]


train Loss: 0.8654 Acc: 0.9690


100%|██████████| 616/616 [01:47<00:00,  5.74it/s]


val Loss: 0.8182 Acc: 0.9764

Epoch 3/19
----------


100%|██████████| 616/616 [07:08<00:00,  1.44it/s]


train Loss: 0.8507 Acc: 0.9727


100%|██████████| 616/616 [01:46<00:00,  5.79it/s]


val Loss: 0.8033 Acc: 0.9821

Epoch 4/19
----------


100%|██████████| 616/616 [07:07<00:00,  1.44it/s]


train Loss: 0.8409 Acc: 0.9765


100%|██████████| 616/616 [01:48<00:00,  5.67it/s]


val Loss: 0.7964 Acc: 0.9828

Epoch 5/19
----------


100%|██████████| 616/616 [07:09<00:00,  1.43it/s]


train Loss: 0.8349 Acc: 0.9778


100%|██████████| 616/616 [01:47<00:00,  5.75it/s]


val Loss: 0.7947 Acc: 0.9838

Epoch 6/19
----------


100%|██████████| 616/616 [07:10<00:00,  1.43it/s]


train Loss: 0.8295 Acc: 0.9789


100%|██████████| 616/616 [01:48<00:00,  5.70it/s]


val Loss: 0.7892 Acc: 0.9841

Epoch 7/19
----------


100%|██████████| 616/616 [07:12<00:00,  1.43it/s]


train Loss: 0.8257 Acc: 0.9801


100%|██████████| 616/616 [01:49<00:00,  5.64it/s]


val Loss: 0.7852 Acc: 0.9849

Epoch 8/19
----------


100%|██████████| 616/616 [07:09<00:00,  1.44it/s]


train Loss: 0.8222 Acc: 0.9809


100%|██████████| 616/616 [01:46<00:00,  5.79it/s]


val Loss: 0.7811 Acc: 0.9855

Epoch 9/19
----------


100%|██████████| 616/616 [07:11<00:00,  1.43it/s]


train Loss: 0.8175 Acc: 0.9821


100%|██████████| 616/616 [01:46<00:00,  5.80it/s]


val Loss: 0.7797 Acc: 0.9874

Epoch 10/19
----------


100%|██████████| 616/616 [07:14<00:00,  1.42it/s]


train Loss: 0.8159 Acc: 0.9829


100%|██████████| 616/616 [01:48<00:00,  5.66it/s]


val Loss: 0.7757 Acc: 0.9873

Epoch 11/19
----------


100%|██████████| 616/616 [07:10<00:00,  1.43it/s]


train Loss: 0.8144 Acc: 0.9831


100%|██████████| 616/616 [01:46<00:00,  5.77it/s]


val Loss: 0.7752 Acc: 0.9874

Epoch 12/19
----------


100%|██████████| 616/616 [07:12<00:00,  1.42it/s]


train Loss: 0.8097 Acc: 0.9839


100%|██████████| 616/616 [01:46<00:00,  5.79it/s]


val Loss: 0.7716 Acc: 0.9889

Epoch 13/19
----------


100%|██████████| 616/616 [07:12<00:00,  1.42it/s]


train Loss: 0.8091 Acc: 0.9838


100%|██████████| 616/616 [01:48<00:00,  5.66it/s]


val Loss: 0.7773 Acc: 0.9866

Epoch 14/19
----------


100%|██████████| 616/616 [07:11<00:00,  1.43it/s]


train Loss: 0.8075 Acc: 0.9848


100%|██████████| 616/616 [01:49<00:00,  5.65it/s]


val Loss: 0.7741 Acc: 0.9883

Epoch 15/19
----------


100%|██████████| 616/616 [07:12<00:00,  1.42it/s]


train Loss: 0.8069 Acc: 0.9841


100%|██████████| 616/616 [01:48<00:00,  5.69it/s]


val Loss: 0.7687 Acc: 0.9885

Epoch 16/19
----------


100%|██████████| 616/616 [07:13<00:00,  1.42it/s]


train Loss: 0.8051 Acc: 0.9847


100%|██████████| 616/616 [01:49<00:00,  5.63it/s]


val Loss: 0.7658 Acc: 0.9890

Epoch 17/19
----------


100%|██████████| 616/616 [07:24<00:00,  1.39it/s]


train Loss: 0.8040 Acc: 0.9847


100%|██████████| 616/616 [01:51<00:00,  5.53it/s]


val Loss: 0.7674 Acc: 0.9886

Epoch 18/19
----------


100%|██████████| 616/616 [07:14<00:00,  1.42it/s]


train Loss: 0.8015 Acc: 0.9856


100%|██████████| 616/616 [01:48<00:00,  5.66it/s]


val Loss: 0.7677 Acc: 0.9890

Epoch 19/19
----------


100%|██████████| 616/616 [07:14<00:00,  1.42it/s]


train Loss: 0.7995 Acc: 0.9859


100%|██████████| 616/616 [01:48<00:00,  5.66it/s]


val Loss: 0.7651 Acc: 0.9894

Training complete in 293m 45s
Best Val Acc: 0.9894


In [32]:
torch.save(model_ft,"/content/Untitled Folder/model.pt")



In [33]:
test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
model_ft.eval()

for data, target in tqdm(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad(): # turn off autograd for faster testing
        output = model_ft(data)
        loss = criterion(output, target)
    test_loss = loss.item() * data.size(0)
    _, pred = torch.max(output, 1)
    correct_tensor = pred.eq(target.data.view_as(pred))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    if len(target) == 32:
        for i in range(32):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

test_loss = test_loss / test_data_len
print('Test Loss: {:.4f}'.format(test_loss))
for i in range(len(classes)):
    if class_total[i] > 0:
        print("Test Accuracy of %5s: %2d%% (%2d/%2d)" % (
            classes[i], 100*class_correct[i]/class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])
        ))
    else:
        print("Test accuracy of %5s: NA" % (classes[i]))
print("Test Accuracy of %2d%% (%2d/%2d)" % (
            100*np.sum(class_correct)/np.sum(class_total), np.sum(class_correct), np.sum(class_total)
        ))

100%|██████████| 309/309 [13:23<00:00,  2.60s/it]

Test Loss: 0.0012
Test Accuracy of Apple__Apple_scab: 99% (262/263)
Test Accuracy of Apple__Black_rot: 100% (258/258)
Test Accuracy of Apple__Cedar_apple_rust: 100% (226/226)
Test Accuracy of Apple__healthy: 100% (279/279)
Test Accuracy of Blueberry__healthy: 100% (251/251)
Test Accuracy of Cherry_(including_sour)___Powdery_mildew: 100% (232/232)
Test Accuracy of Cherry_(including_sour)___healthy: 99% (248/249)
Test Accuracy of Corn_(maize)__Common_rust_: 99% (260/262)
Test Accuracy of Corn_(maize)__Northern_Leaf_Blight: 98% (258/262)
Test Accuracy of Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot: 98% (213/217)
Test Accuracy of Corn_(maize)__healthy: 100% (255/255)
Test Accuracy of Grape__Black_rot: 99% (254/255)
Test Accuracy of Grape__Esca_(Black_Measles): 100% (264/264)
Test Accuracy of Grape___Leaf_blight_(Isariopsis_Leaf_Spot): 100% (238/238)
Test Accuracy of Grape__healthy: 100% (220/220)
Test Accuracy of Orange___Haunglongbing_(Citrus_greening): 100% (369/369)
Test Accuracy




In [34]:

example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model.cpu(), example)
traced_script_module.save("Apple_Disease_swin_transformer.pt")