In [1]:
import torch
from torch import nn
import torchvision

In [2]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from typing import List, Union
from pathlib import Path

In [3]:
NUM_CLASSES = 9

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [None]:
data_path = Path("./edible_or_not_mushrooms")
assert data_path.is_dir()

In [7]:
efficient_net_weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
efficient_net = torchvision.models.efficientnet_v2_s(weights=efficient_net_weights).to(DEVICE)
transforms = efficient_net_weights.transforms()

Downloading: "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth" to /home/roka/.var/app/com.visualstudio.code/cache/torch/hub/checkpoints/efficientnet_v2_s-dd5fe13b.pth


100.0%


In [8]:
from torchinfo import summary

summary(model=efficient_net, 
        input_size=(32, 3, 200, 200), 
        col_names=["input_size", "output_size", "trainable", "num_params"], 
        row_settings=["var_names"])

Layer (type (var_name))                                      Input Shape               Output Shape              Trainable                 Param #
EfficientNet (EfficientNet)                                  [32, 3, 200, 200]         [32, 1000]                True                      --
├─Sequential (features)                                      [32, 3, 200, 200]         [32, 1280, 7, 7]          True                      --
│    └─Conv2dNormActivation (0)                              [32, 3, 200, 200]         [32, 24, 100, 100]        True                      --
│    │    └─Conv2d (0)                                       [32, 3, 200, 200]         [32, 24, 100, 100]        True                      648
│    │    └─BatchNorm2d (1)                                  [32, 24, 100, 100]        [32, 24, 100, 100]        True                      48
│    │    └─SiLU (2)                                         [32, 24, 100, 100]        [32, 24, 100, 100]        --                        --


In [9]:
print(f"Original final dense layer: {efficient_net.classifier[1]}")

Original final dense layer: Linear(in_features=1280, out_features=1000, bias=True)


In [10]:
for param in efficient_net.features.parameters():
    param.requires_grad = False
    
efficient_net.classifier[1] = nn.Linear(in_features=1280, out_features=NUM_CLASSES, bias=True).to(DEVICE)

summary(model=efficient_net, 
        input_size=(32, 3, 200, 200), 
        col_names=["input_size", "output_size", "trainable", "num_params"], 
        row_settings=["var_names"])

Layer (type (var_name))                                      Input Shape               Output Shape              Trainable                 Param #
EfficientNet (EfficientNet)                                  [32, 3, 200, 200]         [32, 9]                   Partial                   --
├─Sequential (features)                                      [32, 3, 200, 200]         [32, 1280, 7, 7]          False                     --
│    └─Conv2dNormActivation (0)                              [32, 3, 200, 200]         [32, 24, 100, 100]        False                     --
│    │    └─Conv2d (0)                                       [32, 3, 200, 200]         [32, 24, 100, 100]        False                     (648)
│    │    └─BatchNorm2d (1)                                  [32, 24, 100, 100]        [32, 24, 100, 100]        False                     (48)
│    │    └─SiLU (2)                                         [32, 24, 100, 100]        [32, 24, 100, 100]        --                       

In [None]:
file_paths = []
path_labels = []
classes = []

for image_folder in sorted(os.listdir(data_path)):
    classes.append(os.path.basename(image_folder))
    for image_path in os.listdir(os.path.join(data_path, image_folder)):
        full_path = os.path.join(os.path.join(data_path, image_folder), image_path)
        
        try: # checking to see if the image is truncated/corrupted
            image = torchvision.io.read_image(full_path, mode=torchvision.io.ImageReadMode.RGB)
        except:
            continue
        
        # if Path(full_path).suffix in [".jpg", .jpeg", "png", "gif"]:
        file_paths.append(full_path)
        path_labels.append(os.path.basename(image_folder))

In [12]:
from sklearn.model_selection import train_test_split

X_train_paths, X_test_paths, y_train, y_test = train_test_split(file_paths, path_labels, test_size=0.2, random_state=42)

In [13]:
class MushroomDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, image_classes, transform=transforms, classes=classes):
        self.image_paths = image_paths
        self.image_classes = image_classes
        self.transform = transforms
        self.classes = classes
        indexes = [i for i in range(len(classes))]
        self.class_to_index = {k : v for k, v in zip(self.classes, indexes)}
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.class_to_index[self.image_classes[idx]]
        
        image = torchvision.io.read_image(image_path, mode=torchvision.io.ImageReadMode.RGB)
        image = self.transform(image)
        
        return image, label

In [14]:
train_dataset = MushroomDataset(X_train_paths, y_train)
test_dataset = MushroomDataset(X_test_paths, y_test)

In [15]:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                               batch_size=32,
                                               shuffle=True, 
                                               num_workers=os.cpu_count())
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=32,
                                              num_workers=os.cpu_count())

In [16]:
epochs = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(efficient_net.parameters(), lr=0.001)

In [17]:
try:
    import torchmetrics
    from torchmetrics import Accuracy
except:
    !pip install torchmetrics
    from torchmetrics import Accuracy
    
acc_fn = Accuracy(task="multiclass", num_classes=NUM_CLASSES).to(DEVICE)

Collecting torchmetrics
  Downloading torchmetrics-1.7.3-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Downloading torchmetrics-1.7.3-py3-none-any.whl (962 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m962.6/962.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)
Installing collected packages: lightning-utilities, torchmetrics
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [torchmetrics][0m [torchmetrics]
[1A[2KSuccessfully installed lightning-utilities-0.14.3 torchmetrics-1.7.3


In [18]:
max_test_acc = -1

In [19]:
# Set the random seeds
torch.manual_seed(42)
torch.cuda.manual_seed(42)

for epoch in range(epochs):
    efficient_net.train()
    train_loss = 0
    train_acc = 0
   
    for (X, y) in train_dataloader:
        # Forward pass
        X, y = X.to(DEVICE), y.to(DEVICE)
        logits = efficient_net(X)
       
        # Calculate loss
        loss = loss_fn(logits, y)
        train_loss += loss  # adding loss cumulatively
       
        # Calculate train accuracy
        train_acc += acc_fn(logits, y)
       
        # Optimizer zero grad
        optimizer.zero_grad()
       
        # Loss backward
        loss.backward()
       
        # Optimizer step
        optimizer.step()
       
    train_loss /= len(train_dataloader)  # divide by number of batches
    train_acc /= len(train_dataloader)  # divide by number of batches
   
   
    efficient_net.eval()
    with torch.inference_mode():
        test_loss = 0
        test_acc = 0
       
        for (X, y) in test_dataloader:
            # Forward pass
            X, y = X.to(DEVICE), y.to(DEVICE)
            test_logits = efficient_net(X)
           
            # Calculate test loss
            test_loss += loss_fn(test_logits, y)
           
            # Calculate test accuracy
            test_acc += acc_fn(test_logits, y)
           
           
        test_loss /= len(test_dataloader)
        test_acc /= len(test_dataloader)
       
    print(f"Train loss: {train_loss:.5f} | Train acc: {train_acc:.3f} | Test loss: {test_loss:.5f} | Test acc: {test_acc:.5f}")
   
    if test_acc > max_test_acc:
        max_test_acc = test_acc
       
        torch.save(efficient_net.state_dict(), "model.pt")
        with open("accuracy.txt", "w") as f:
            f.write(str(max_test_acc.cpu().numpy().item()))

Train loss: 1.45787 | Train acc: 0.536 | Test loss: 3.30161 | Test acc: 0.66794
Train loss: 1.03508 | Train acc: 0.666 | Test loss: 1.82583 | Test acc: 0.69844
Train loss: 0.90009 | Train acc: 0.704 | Test loss: 4.39971 | Test acc: 0.70442
Train loss: 0.84082 | Train acc: 0.722 | Test loss: 12.43953 | Test acc: 0.73123
Train loss: 0.79673 | Train acc: 0.738 | Test loss: 9.64139 | Test acc: 0.74311
Train loss: 0.76936 | Train acc: 0.741 | Test loss: 7.41293 | Test acc: 0.74093
Train loss: 0.74048 | Train acc: 0.745 | Test loss: 15.95483 | Test acc: 0.74165
Train loss: 0.71787 | Train acc: 0.754 | Test loss: 1.21613 | Test acc: 0.74088
Train loss: 0.71606 | Train acc: 0.751 | Test loss: 29.84927 | Test acc: 0.75432
Train loss: 0.68717 | Train acc: 0.767 | Test loss: 26.04865 | Test acc: 0.76697
Train loss: 0.67670 | Train acc: 0.773 | Test loss: 5.55097 | Test acc: 0.76620
Train loss: 0.65766 | Train acc: 0.771 | Test loss: 13.06879 | Test acc: 0.76692
Train loss: 0.66917 | Train acc: 0.

In [20]:
import json

with(open("class_to_index.json", "w")) as f:
    json.dump(train_dataset.class_to_index, f)