# Unzip defungi.zip

In [4]:
%%capture
import os
if not os.path.isdir('data'): 
    !mkdir data
    !tar -xvzf defungi.zip -C data

# Necessary Imports

In [7]:
import os
import torch
from torch.utils.data import random_split

# Our code
from src.dataset import FungiDataset

# Constants

In [9]:
CLASSES = {
    'H1': 0,
    'H2': 1,
    'H3': 2,
    'H5': 3,
    'H6': 4
}

SEED = torch.Generator().manual_seed(42)


# Get all valid images

In [10]:
def load_images_from_folder(folder):
    images = []
    for root, _, files in os.walk(folder):
        if not str(root).startswith("data/H"):
            continue
        for file in files:
            c = file.split('_')[0]
            images.append((root+'/'+file, CLASSES[c]))
    return images

images = load_images_from_folder('data/')

assert(len(images) == 9114)

files, labels = zip(*images)


# Instantiate Torch Dataset

In [16]:
dataset = FungiDataset(files=files, labels=labels)

train, test = random_split(dataset=dataset, lengths=[0.7, 0.3], generator=SEED)

# train_loader = DataLoader(train, batch_size=4)
# test_loader = DataLoader(test, batch_size=4)

assert(len(train) == 6380)
assert(len(test) == 2734)

# Save to machine for use in other notebooks

In [None]:
torch.save(train, './train.pt')
torch.save(test, './test.pt')

# Sanity check, reload those files and compare side by side

In [11]:
train_reload = torch.load('./train.pt')
test_reload = torch.load('./test.pt')

assert(len(train_reload) == 6380)
assert(len(test_reload) == 2734)

  train_reload = torch.load('./train.pt')
  test_reload = torch.load('./test.pt')


In [13]:
import torchvision.transforms as T

transform = T.ToPILImage()

sample_image_name = train_reload[0]['file']
print(sample_image_name)
sample_image = transform(train_reload[0]['image'])
sample_image.show()

data/H3/H3_5c_9.jpg


In [14]:
from PIL import Image

Image.open(sample_image_name).show()

In [26]:
from torch.utils.data import DataLoader

batch_size = 16
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)


In [3]:
import torch.nn as nn
import torch.nn.functional as F

class FungiClassifier(nn.Module):
    def __init__(self, num_classes=5):
        super(FungiClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        
        # Dynamically determine the input size for fc1
        dummy_input = torch.zeros(1, 3, 500, 500)  # Example input size
        with torch.no_grad():
            out = self.pool(F.relu(self.conv1(dummy_input)))
            out = self.pool(F.relu(self.conv2(out)))
        flattened_size = out.numel()  # Compute number of features
        self.fc1 = nn.Linear(flattened_size, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [25]:
# Load a specific checkpoint
def load_checkpoint(checkpoint_path, model, optimizer=None):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Loaded checkpoint from {checkpoint_path}, starting at epoch {start_epoch}, loss: {loss:.4f}")
    return start_epoch

In [None]:
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FungiClassifier(num_classes=len(CLASSES)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

from tqdm import tqdm
import os

# Directory to save checkpoints
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
# Evaluation function
def evaluate_model(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            inputs = batch['image'].to(device, dtype=torch.float32) / 255.0
            labels = batch['label'].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

checkpoint_path = "checkpoints/checkpoint_epoch_2.pth"  # Replace with the desired checkpoint file
start_epoch = load_checkpoint(checkpoint_path, model, optimizer)

# Resume training from the loaded epoch
num_epochs = 10
for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss = 0.0

    # Wrap the DataLoader in tqdm
    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as tepoch:
        for batch in tepoch:
            inputs = batch['image'].to(device, dtype=torch.float32) / 255.0
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Update progress bar with running loss
            tepoch.set_postfix(loss=loss.item())

    # Calculate average loss for the epoch
    avg_loss = running_loss / len(train_loader)

    # Evaluate on test set for accuracy
    accuracy = evaluate_model(model, test_loader)

    # Print epoch summary
    print(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {avg_loss:.4f}, Accuracy: {accuracy*100:.2f}%")

    # Save checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")


  checkpoint = torch.load(checkpoint_path)


Loaded checkpoint from checkpoints/checkpoint_epoch_2.pth, starting at epoch 2, loss: 1.0655


Epoch 3/10: 100%|██████████| 399/399 [09:23<00:00,  1.41s/batch, loss=0.773]


Epoch 3/10 completed. Loss: 1.0073, Accuracy: 60.86%
Checkpoint saved: checkpoints\checkpoint_epoch_3.pth


Epoch 4/10: 100%|██████████| 399/399 [09:13<00:00,  1.39s/batch, loss=0.846]


Epoch 4/10 completed. Loss: 0.9400, Accuracy: 64.16%
Checkpoint saved: checkpoints\checkpoint_epoch_4.pth


Epoch 5/10: 100%|██████████| 399/399 [09:11<00:00,  1.38s/batch, loss=1.92] 


Epoch 5/10 completed. Loss: 0.9200, Accuracy: 62.69%
Checkpoint saved: checkpoints\checkpoint_epoch_5.pth


Epoch 6/10: 100%|██████████| 399/399 [09:09<00:00,  1.38s/batch, loss=0.642]


Epoch 6/10 completed. Loss: 0.8615, Accuracy: 63.61%
Checkpoint saved: checkpoints\checkpoint_epoch_6.pth


Epoch 7/10: 100%|██████████| 399/399 [09:04<00:00,  1.36s/batch, loss=0.84] 


Epoch 7/10 completed. Loss: 0.8122, Accuracy: 61.70%
Checkpoint saved: checkpoints\checkpoint_epoch_7.pth


Epoch 8/10: 100%|██████████| 399/399 [09:11<00:00,  1.38s/batch, loss=1.01] 


Epoch 8/10 completed. Loss: 0.7425, Accuracy: 62.84%
Checkpoint saved: checkpoints\checkpoint_epoch_8.pth


Epoch 9/10: 100%|██████████| 399/399 [09:09<00:00,  1.38s/batch, loss=0.599]


Epoch 9/10 completed. Loss: 0.6690, Accuracy: 60.10%
Checkpoint saved: checkpoints\checkpoint_epoch_9.pth


Epoch 10/10: 100%|██████████| 399/399 [09:04<00:00,  1.36s/batch, loss=0.771]


Epoch 10/10 completed. Loss: 0.6229, Accuracy: 62.40%
Checkpoint saved: checkpoints\checkpoint_epoch_10.pth


In [27]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        inputs = batch['image'].to(device, dtype=torch.float32) / 255.0
        labels = batch['label'].to(device)

        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")

Accuracy: 62.40%


# SAVE MODEL:

In [21]:
torch.save(model.state_dict(), 'fungi_classifier.pth')

# Reload Model:

In [23]:
model = FungiClassifier(num_classes=len(CLASSES))  # Initialize the model
model.load_state_dict(torch.load("fungi_classifier.pth"))  # Load the state dictionary
model = model.to(device)  # Move to the same device used for training
model.eval()  # Set to evaluation mode for inference

  model.load_state_dict(torch.load("fungi_classifier.pth"))  # Load the state dictionary


FungiClassifier(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=500000, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=5, bias=True)
)