##### Imports

In [None]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision.transforms import functional as F
from torch.utils.data import random_split
import random
from torchvision.transforms import ToPILImage

##### Dataset class (with preprocessed)

In [None]:
class PreprocessedMushroomDataset(Dataset):
    def __init__(self, csv_file, root_dir, has_labels=True):
        self.annotations = pd.read_csv(csv_file, dtype={0: str})
        self.root_dir = root_dir
        self.has_labels = has_labels

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.annotations.iloc[idx, 0] + '.pt')
        image = torch.load(img_name)  
        if self.has_labels:
            label = int(self.annotations.iloc[idx, 1])
        else:
            label = -1 
        return image, label

##### Define paths

In [None]:
root_path = os.path.dirname(os.getcwd())
models_path =  os.path.join(root_path, 'models')
dataset_path = os.path.join(root_path, 'dataset')

dataset_preprocessed_path = os.path.join(dataset_path, 'preprocessed')
preprocessed_train_path = os.path.join(dataset_preprocessed_path, 'train')
preprocessed_test_path = os.path.join(dataset_preprocessed_path, 'test')

csv_path = os.path.join(dataset_path, 'csv_mappings')
train_csv_path = os.path.join(csv_path, 'train.csv')
test_csv_path = os.path.join(csv_path, 'test.csv')


##### Load datasets

In [None]:
train_dataset = PreprocessedMushroomDataset(csv_file=train_csv_path, root_dir=preprocessed_train_path, has_labels=True)
test_dataset = PreprocessedMushroomDataset(csv_file=test_csv_path, root_dir=preprocessed_test_path, has_labels=False)

##### Split dataset into training and validation

In [None]:
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])


##### Define Dataloaders

In [None]:
train_dataloader = DataLoader(train_subset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_subset, batch_size=8, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

##### Load model

In [None]:
model = models.alexnet(pretrained=False)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, len(train_dataset.annotations['Mushroom'].unique()))

In [None]:
print(model)

##### Params

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

lr = 0.0001
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

num_epochs = 5

##### Train model

In [None]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(num_epochs):
    # Training
    model.train()
    running_train_loss = 0.0
    correct_train = 0
    total_train = 0

    for images, labels in train_dataloader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
    
    train_loss = running_train_loss / len(train_dataloader)
    train_accuracy = 100 * correct_train / total_train
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    
    # Validation
    model.eval()
    running_val_loss = 0.0
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for images, labels in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()
    
    val_loss = running_val_loss / len(val_dataloader)
    val_accuracy = 100 * correct_val / total_val
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

print('Training finished.')


##### Save trained model

In [None]:
model_save_path = os.path.join(models_path, 'alexnet_model.pth')
torch.save(model.state_dict(), model_save_path)
print(f'Model saved to {model_save_path}')

##### Evaluation

In [None]:
model.eval()
test_predictions = []

with torch.no_grad():
    for idx, (images, labels) in enumerate(test_dataloader):
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        for i in range(images.size(0)):
            img_name = test_dataset.annotations.iloc[idx * test_dataloader.batch_size + i, 0]
            test_predictions.append((img_name, predicted[i].item()))

for img_name, pred in test_predictions:
    print(f'Image: {img_name}, Predicted Label: {pred}')


In [None]:
def denormalize(image, mean, std):
    image = image.clone()
    for t, m, s in zip(image, mean, std):
        t.mul_(s).add_(m)
    return image


In [None]:
def show_sample_predictions(dataset, predictions, num_samples=10):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    indices = random.sample(range(len(predictions)), num_samples)
    sample_predictions = [predictions[i] for i in indices]

    fig, axes = plt.subplots(1, num_samples, figsize=(20, 4))
    for ax, (img_name, pred) in zip(axes, sample_predictions):
        img_path = os.path.join(preprocessed_test_path, img_name + '.pt')
        image = torch.load(img_path).cpu()
        image = denormalize(image, mean, std)
        image = ToPILImage()(image)
        
        ax.imshow(image)
        ax.set_title(f'Predicted: {pred}')
        ax.axis('off')
    plt.show()

show_sample_predictions(test_dataset, test_predictions, num_samples=10)


In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in val_dataloader: # val_dataloader # test_dataloader
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        # No eval for test since no comparison possible
        if labels[0] != -1:  
            labels = labels.to(device)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

if total > 0:
    accuracy = 100 * correct / total
    print(f'Accuracy on the test dataset: {accuracy:.2f}%')
else:
    print('No labels available for evaluation.')
