In [1]:
import os
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torchinfo
import modules

In [2]:
num_workers = os.cpu_count()
print(num_workers)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

24
cuda


In [3]:
train_dir = "modules/data/train_organized/"
test_dir = "modules/data/test_organized/"


In [4]:
weights = torchvision.models.EfficientNet_V2_L_Weights.DEFAULT
weights

EfficientNet_V2_L_Weights.IMAGENET1K_V1

In [5]:
auto_transforms = weights.transforms()
auto_transforms

ImageClassification(
    crop_size=[480]
    resize_size=[480]
    mean=[0.5, 0.5, 0.5]
    std=[0.5, 0.5, 0.5]
    interpolation=InterpolationMode.BICUBIC
)

In [7]:
from modules import data_setup
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir, test_dir=test_dir, transform=auto_transforms, batch_size=32, num_workers=num_workers)

FileNotFoundError: [WinError 3] The system cannot find the path specified: '/Users/raghavmehta/Downloads/Fruit-Freshness-Classification/modules/data/train_organized'

In [None]:
model = torchvision.models.efficientnet_v2_l(weights=weights).to(device)
model

In [None]:
from torchinfo import summary
summary(
    model=model,
    input_size=(32,3,480,480),
    col_names=["input_size","output_size","num_params","trainable"],
    col_width=20,
    row_settings=["var_names"]
)

In [None]:
for param in model.features.parameters():
    param.requires_grad = False

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

output_shape = len(class_names)

model.classifier = torch.nn.Sequential(
    torch.nn.Dropout(p=0.2, inplace=True),
    torch.nn.Linear(in_features=1280, out_features=output_shape, bias=True)
).to(device)

In [None]:
from torchinfo import summary
summary(
    model=model,
    input_size=(32,3,480,480),
    col_names=["input_size","output_size","num_params","trainable"],
    col_width=20,
    row_settings=["var_names"]
)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
from modules import engine

results = engine.train(
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=5,
    device=device
)


In [None]:
def plot_loss_curves(results):
    loss = results["train_loss"]
    test_loss = results["test_loss"]

    accuracy = results["train_acc"]
    test_accuracy = results["test_acc"]

    epochs = range(len(results["train_loss"]))

    plt.figure(figsize=(15, 7))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, label="train_loss")
    plt.plot(epochs, test_loss, label="test_loss")
    plt.title("Loss")
    plt.xlabel("Epochs")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracy, label="train_accuracy")
    plt.plot(epochs, test_accuracy, label="test_accuracy")
    plt.title("Accuracy")
    plt.xlabel("Epochs")
    plt.legend()

In [None]:
plot_loss_curves(results)

In [None]:
from typing import List, Tuple
from PIL import Image
def pred_and_plot_image(model: torch.nn.Module,
                        image_path: str, 
                        class_names: List[str],
                        image_size: Tuple[int, int] = (224, 224),
                        transform: torchvision.transforms = None,
                        device: torch.device=device):
    
    img = Image.open(image_path)

    if transform is not None:
        image_transform = transform
    else:
        image_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    model.to(device)

    model.eval()
    with torch.inference_mode():
      transformed_image = image_transform(img).unsqueeze(dim=0)

      target_image_pred = model(transformed_image.to(device))

    target_image_pred_probs = torch.softmax(target_image_pred, dim=1)

    target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)

    plt.figure()
    plt.imshow(img)
    plt.title(f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}")
    plt.axis(False);

In [None]:
import random
from pathlib import Path
num_images_to_plot = 3
test_image_path_list = list(Path(test_dir).glob("*/*.jpg")) 
test_image_path_sample = random.sample(population=test_image_path_list, 
                                       k=num_images_to_plot) 

for image_path in test_image_path_sample:
    pred_and_plot_image(model=model, 
                        image_path=image_path,
                        class_names=class_names,
                        transform=weights.transforms(), 
                        image_size=(480, 480))

In [None]:
custom_image_path = "banana.jpg"

pred_and_plot_image(model=model, image_path=custom_image_path, class_names=class_names)

In [None]:
from pathlib import Path
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "fruit_model"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

torch.save(obj=model.state_dict(), f=MODEL_SAVE_PATH)
