In [1]:
import multiprocessing as mp
from pathlib import Path

import pandas as pd
import torch
import torchvision
from IPython import display
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook

if torch.cuda.is_available():
    DEVICE = torch.device("cuda:0")
    print("Using GPU")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU")

Using GPU


In [2]:
root_path = Path("/kaggle/input/vegetable-image-dataset/Vegetable Images/")

data_train = root_path / "train"
data_test = root_path / "test"
data_val = root_path / "validation"

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((299, 299)),
        torchvision.transforms.ToTensor(),
    ]
)

ds_train = torchvision.datasets.ImageFolder(root=data_train, transform=transform)
ds_test = torchvision.datasets.ImageFolder(root=data_test, transform=transform)
ds_val = torchvision.datasets.ImageFolder(root=data_val, transform=transform)

In [3]:
model = torch.hub.load("pytorch/vision:v0.10.0", "inception_v3", weights="DEFAULT")

for param in model.parameters():
    param.requires_grad = False
    
classifier = torch.nn.Linear(model.fc.in_features, len(ds_train.classes))
model.fc = classifier

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

In [4]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
NUM_EPOCHS, BATCH_SIZE, NUM_WORKERS = 5, 64, mp.cpu_count()

dl_train = DataLoader(ds_train, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
dl_val = DataLoader(ds_val, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
dl_test = DataLoader(ds_test, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

In [5]:
class Solver:
    def __init__(self, model, criterion, optimizer, scheduler):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.stats = pd.DataFrame(columns=["accuracy_train", "accuracy_validation"])

    def _display(self):
        display.clear_output()

        display.display(
            display.HTML(
                self.stats.to_html(
                    formatters={
                        "accuracy_train": lambda x: f"{x*100:.2f}%",
                        "accuracy_validation": lambda x: f"{x*100:.2f}%",
                    }
                )
            )
        )

    def _append_stats(self, epoch, acc_train, acc_val):
        self.stats = pd.concat(
            (
                self.stats,
                pd.DataFrame.from_dict(
                    {f"epoch_{epoch}": [acc_train, acc_val]},
                    orient="index",
                    columns=self.stats.columns,
                ),
            )
        )

    def train(
        self,
        dl_train,
        dl_val,
        usage_percentage=1,
        num_epochs=NUM_EPOCHS,
        device=DEVICE,
    ):
        batches_to_train = int(round(len(dl_train) * usage_percentage))
        self.model.to(device)

        for epoch in range(num_epochs):
            self.model.train()
            batch_count, acc_train, acc_val = 0, 0, 0
            self._display()

            for x_batch, y_batch in tqdm_notebook(dl_train, desc="Train batches done"):
                if usage_percentage != 1 and batch_count > batches_to_train:
                    break

                batch_count += 1

                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                pred, _ = self.model(x_batch)
                loss = self.criterion(pred, y_batch)
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                is_correct = (torch.argmax(pred, dim=1) == y_batch).float()
                acc_train += is_correct.sum().cpu().numpy()

            self.model.eval()

            with torch.no_grad():
                for x_batch, y_batch in tqdm_notebook(dl_val, desc="Validation batches done"):
                    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                    pred = self.model(x_batch)
                    loss = self.criterion(pred, y_batch)

                    is_correct = (torch.argmax(pred, dim=1) == y_batch).float()
                    acc_val += is_correct.sum().cpu().numpy()

            self.scheduler.step()

            self._append_stats(
                epoch + 1,
                acc_train / len(dl_train.dataset),
                acc_val / len(dl_val.dataset),
            )

        self._display()

In [6]:
solver = Solver(model, criterion, optimizer, scheduler)
solver.train(dl_train, dl_val)

Unnamed: 0,accuracy_train,accuracy_validation
epoch_1,89.47%,98.00%
epoch_2,97.64%,99.03%
epoch_3,98.25%,99.40%
epoch_4,98.70%,99.50%
epoch_5,98.78%,99.57%


In [7]:
torch.onnx.export(
    model=solver.model,
    args=torch.randn(1, 3, 299, 299).to(DEVICE),
    f="/kaggle/working/model.onnx",
    input_names=["input"],
    output_names=["output"],
)