In [1]:
import os
import pandas as pd
import torch
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchmetrics import Accuracy, Precision, Recall, F1Score, ConfusionMatrix
from torchvision import models, transforms
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR



In [2]:
# Dataset Class
class PreprocessedMushroomDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, root_dir, has_labels=True):
        self.annotations = pd.read_csv(csv_file, dtype={0: str})
        self.root_dir = root_dir
        self.has_labels = has_labels

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.annotations.iloc[idx, 0] + '.pt')
        image = torch.load(img_name) 
        label = int(self.annotations.iloc[idx, 1]) if self.has_labels else -1
        return image, label

In [3]:
class MushroomClassifier(pl.LightningModule):
    def __init__(self, num_classes, lr=1e-3, weight_decay=0.01):
        super().__init__()
        self.save_hyperparameters()
        
        # Define model
        self.model = models.alexnet(pretrained=False)
        num_ftrs = self.model.classifier[6].in_features
        self.model.classifier[6] = nn.Linear(num_ftrs, num_classes)

        # Define loss and metrics
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr
        self.weight_decay = weight_decay
        self.validation_outputs = []


    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        preds = torch.argmax(outputs, dim=1)

        # Log validation loss
        self.log("val_loss", loss, prog_bar=True)
        self.validation_outputs.append((preds, labels))  

    def on_validation_epoch_end(self):
        preds, labels = zip(*self.validation_outputs)
        preds = torch.cat(preds)
        labels = torch.cat(labels)

        # Compute accuracy manually
        accuracy = (preds == labels).float().mean()

        # Get the actual unique classes from the labels
        unique_classes = torch.unique(labels)
        class_names = [f"Class {i}" for i in unique_classes.tolist()] 

        # Compute metrics
        metrics = classification_report(
            labels.cpu(), preds.cpu(), target_names=class_names, output_dict=True
        )
        cm = confusion_matrix(labels.cpu(), preds.cpu())

        # Log other metrics
        self.log_dict({
            "val_accuracy": accuracy,  # Log the computed accuracy
            "val_macro_precision": metrics["macro avg"]["precision"],
            "val_macro_recall": metrics["macro avg"]["recall"],
            "val_macro_f1": metrics["macro avg"]["f1-score"],
        })

        # Print confusion matrix and metrics
        print("Confusion Matrix:")
        print(cm)
        print("Classification Report:")
        print(classification_report(labels.cpu(), preds.cpu(), target_names=class_names))

        # Clear validation outputs
        self.validation_outputs.clear()

    def configure_optimizers(self):
        optimizer = AdamW(
            self.parameters(), 
            lr=self.lr,  # Learning rate
            weight_decay=0.01  # Weight decay
        )
        
        # Warm-up and scheduling with OneCycleLR
        scheduler = {
            'scheduler': OneCycleLR(
                optimizer, 
                max_lr=self.lr,  # Peak learning rate
                total_steps=self.trainer.estimated_stepping_batches,
                pct_start=0.1,  # Warm-up phase percentage
                anneal_strategy='cos',  # Cosine annealing after warm-up
                div_factor=25.0  # Initial LR = max_lr/div_factor
            ),
            'interval': 'step',  # Apply the scheduler at every step
            'frequency': 1  # Scheduler is updated every step
        }
        
        return [optimizer], [scheduler]



In [4]:
# Paths
root_path = os.path.dirname(os.getcwd())
models_path = os.path.join(root_path, 'models')
dataset_path = os.path.join(root_path, 'dataset')

In [5]:
# Dataset
train_csv_path = os.path.join(dataset_path, 'csv_mappings/train.csv')
test_csv_path = os.path.join(dataset_path, 'csv_mappings/test.csv')
preprocessed_train_path = os.path.join(dataset_path, 'preprocessed/train')
preprocessed_test_path = os.path.join(dataset_path, 'preprocessed/test')

In [6]:
train_dataset = PreprocessedMushroomDataset(csv_file=train_csv_path, root_dir=preprocessed_train_path, has_labels=True)
test_dataset = PreprocessedMushroomDataset(csv_file=test_csv_path, root_dir=preprocessed_test_path, has_labels=False)

In [7]:
# Split Dataset
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

In [8]:
# Dataloaders
train_dataloader = DataLoader(train_subset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_subset, batch_size=8, shuffle=False)

In [9]:
# Training parameters
num_classes = len(train_dataset.annotations.iloc[:, 1].unique())
learning_rate = 0.001
weight_decay = 0.01

# Model
model = MushroomClassifier(num_classes=num_classes, lr=learning_rate, weight_decay=weight_decay)



In [10]:
# Callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=models_path, 
    filename='mushroom_model_{epoch:02d}', 
    save_top_k=1, 
    monitor='val_accuracy',  # Monitor val_accuracy instead of val_acc
    mode='max'
)

In [11]:
early_stopping_callback = EarlyStopping(
    monitor='val_accuracy',  # Monitor val_accuracy instead of val_acc
    patience=5, 
    mode='max'
)

In [12]:
trainer = Trainer(
    max_epochs=1, 
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    callbacks=[checkpoint_callback, early_stopping_callback],
    deterministic=True  
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [13]:
# Train
trainer.fit(model, train_dataloader, val_dataloader)

c:\Users\ilian\Documents\Projects\git_projects\university\mushroom_classification\.venv\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory C:\Users\ilian\Documents\Projects\git_projects\university\mushroom_classification\models exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
c:\Users\ilian\Documents\Projects\git_projects\university\mushroom_classification\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=21` in the `DataLoader` to improve performance.

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | AlexNet          | 57.0 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
57.0 M  

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\ilian\Documents\Projects\git_projects\university\mushroom_classification\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=21` in the `DataLoader` to improve performance.
  image = torch.load(img_name)


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  4.40it/s]Confusion Matrix:
[[0 0 0 0 0 0 0 1]
 [0 0 0 0 0 0 0 2]
 [0 0 0 0 0 0 0 2]
 [0 0 0 0 0 0 0 3]
 [0 0 0 0 0 0 0 1]
 [0 0 0 0 0 0 0 3]
 [0 0 0 0 0 0 0 2]
 [0 0 0 0 0 0 0 2]]
Classification Report:
              precision    recall  f1-score   support

     Class 0       0.00      0.00      0.00         1
     Class 1       0.00      0.00      0.00         2
     Class 2       0.00      0.00      0.00         2
     Class 4       0.00      0.00      0.00         3
     Class 6       0.00      0.00      0.00         1
     Class 7       0.00      0.00      0.00         3
     Class 8       0.00      0.00      0.00         2
     Class 9       0.12      1.00      0.22         2

    accuracy                           0.12        16
   macro avg       0.02      0.12      0.03        16
weighted avg       0.02      0.12      0.03        16

                                                                           

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 0:   0%|          | 0/237 [00:00<?, ?it/s] 

  image = torch.load(img_name)


Epoch 0: 100%|██████████| 237/237 [01:17<00:00,  3.05it/s, v_num=9]Confusion Matrix:
[[ 0  0 30  0  0 14  0  0  0  1]
 [ 0  0 28  0  0 13  0  0  0  0]
 [ 0  0 33  0  0  9  0  0  0  1]
 [ 0  0 30  0  0 15  0  0  0  0]
 [ 0  0 38  0  0 18  0  0  0  1]
 [ 0  0 34  0  0 14  0  0  0  1]
 [ 0  0 33  0  0 15  0  0  0  1]
 [ 0  0 32  0  0 12  0  0  0  1]
 [ 0  0 31  0  0 22  0  0  0  0]
 [ 0  0 33  0  0 13  0  0  0  0]]
Classification Report:
              precision    recall  f1-score   support

     Class 0       0.00      0.00      0.00        45
     Class 1       0.00      0.00      0.00        41
     Class 2       0.10      0.77      0.18        43
     Class 3       0.00      0.00      0.00        45
     Class 4       0.00      0.00      0.00        57
     Class 5       0.10      0.29      0.14        49
     Class 6       0.00      0.00      0.00        49
     Class 7       0.00      0.00      0.00        45
     Class 8       0.00      0.00      0.00        53
     Class 9       0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 237/237 [01:27<00:00,  2.71it/s, v_num=9, val_loss=2.300]


In [14]:
# Global Evaluation
all_preds = []
all_labels = []
model.eval()
with torch.no_grad():
    for images, labels in DataLoader(val_subset, batch_size=8):
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

class_names = train_dataset.annotations.iloc[:, 1].unique().tolist()
report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
print("Classification Report:\n", classification_report(all_labels, all_preds, target_names=class_names))


  image = torch.load(img_name)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


TypeError: object of type 'int' has no len()

In [None]:
import matplotlib.pyplot as plt

optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = OneCycleLR(
    optimizer,
    max_lr=0.001,
    total_steps=1000,  # Simulate 1000 training steps
    pct_start=0.1,
    anneal_strategy='cos',
    div_factor=25.0
)

lrs = []
for step in range(1000):
    scheduler.step()
    lrs.append(optimizer.param_groups[0]['lr'])

plt.plot(lrs)
plt.title("Learning Rate Schedule with Warm-Up")
plt.xlabel("Training Step")
plt.ylabel("Learning Rate")
plt.show()


In [None]:
# Per-Class Metrics
print("Per-Class Metrics:")
print(f"{'Class':<15}{'Precision':<10}{'Recall':<10}{'F1-Score':<10}{'Support':<10}")
print("-" * 55)
for class_name, metrics in report.items():
    if class_name in class_names:
        print(f"{class_name:<15}{metrics['precision']:<10.2f}{metrics['recall']:<10.2f}{metrics['f1-score']:<10.2f}{metrics['support']:<10}")


In [None]:
# After collecting predictions and labels from validation or test phase
all_preds = torch.cat(all_preds)  
all_labels = torch.cat(all_labels) 

# Compute global confusion matrix
cm = confusion_matrix(all_labels.cpu(), all_preds.cpu())

# Function to plot confusion matrix per class
def plot_per_class_confusion_matrix(cm, class_names):
    num_classes = cm.shape[0]
    
    # Iterate over each class
    for i in range(num_classes):
        # Extract the row for class `i`
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        tn = cm.sum() - (tp + fp + fn)

        # Print per-class confusion matrix details
        print(f"Confusion Matrix for Class {i} ({class_names[i]}):")
        print(f"True Positive (TP): {tp}")
        print(f"False Positive (FP): {fp}")
        print(f"False Negative (FN): {fn}")
        print(f"True Negative (TN): {tn}")
        print("-" * 50)

        # Plot confusion matrix for class `i`
        per_class_cm = np.array([[tp, fp], [fn, tn]])

        plt.figure(figsize=(5, 5))
        sns.heatmap(per_class_cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Pred Class", "Other Class"], yticklabels=["True Class", "Other Class"])
        plt.title(f"Confusion Matrix for Class {i} ({class_names[i]})")
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.show()

# Assuming `class_names` is a list of the class names
class_names = [f"Class {i}" for i in range(len(torch.unique(all_labels)))]

# Plot per-class confusion matrices
plot_per_class_confusion_matrix(cm, class_names)
