In [1]:
import torch
from torch import nn
import os
import numpy as np
import timm
from sklearn.metrics import confusion_matrix, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
os.chdir("/Users/czimbermark/Documents/Egyetem/Adatelemzes/Nagyhazi/FungiCLEF2024_ADC/")
print(os.getcwd())


  from .autonotebook import tqdm as notebook_tqdm


/Users/czimbermark/Documents/Egyetem/Adatelemzes/Nagyhazi/FungiCLEF2024_ADC


In [2]:
from src.dataset2 import FungiDataset
from src.dataset2 import fungi_collate_fn

In [None]:
# Configuration
config = {
    "image_dir": "/Users/czimbermark/Documents/Egyetem/Adatelemzes/Nagyhazi/FungiCLEF2024_ADC/data/x_train",
    "labels_path": "/Users/czimbermark/Documents/Egyetem/Adatelemzes/Nagyhazi/FungiCLEF2024_ADC/data/train_metadata_height.csv",
    "pre_load": True,
    "batch_size": 32,
    "crop_height": 16,
    "interpolate": "bilinear",
    "out_size": (224, 224)
}

# Define the classes to include (your list of unique class IDs)
class_ids_to_include = [4, 11, 16, 25, 30, 32, 37, 39, 43, 63, 100, 103, 128, 129, 131, 136, 142, 168, 180, 213, 214, 223, 252, 266, 309, 366, 389, 413, 473, 478, 487, 522, 555, 559, 591, 633, 637, 657, 671, 673, 689, 694, 724, 728, 738, 748, 764, 787, 812, 814, 830, 837, 845, 856, 884, 908, 909, 912, 967, 975, 989, 992, 1000, 1005, 1014, 1020, 1052, 1054, 1088, 1093, 1115, 1121, 1135, 1136, 1141, 1160, 1183, 1207, 1214, 1220, 1221, 1232, 1239, 1242, 1290, 1302, 1355, 1381, 1395, 1420, 1438, 1440, 1481, 1484, 1493, 1533, 1537, 1546, 1573, 1603]

# Initialize the full dataset
full_dataset = FungiDataset(
    image_dir=config["image_dir"],
    labels_path=config["labels_path"],
    pre_load=config["pre_load"],
    crop_height=config["crop_height"],
    interpolate=config["interpolate"],
    out_size=config["out_size"],
    transform=None,
    class_ids_to_include=class_ids_to_include  # Pass the list of class IDs
)

print(f"Number of samples in dataset: {len(full_dataset)}")

# Split into training and validation sets
train_indices, val_indices = train_test_split(
    list(range(len(full_dataset))), test_size=0.2, random_state=42
)

train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)

# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=1,
    prefetch_factor=0,  # Prefetch batches to improve performance
    pin_memory=True,  # Optimize for GPU
    collate_fn=fungi_collate_fn, 
    drop_last=True  # Drops the last incomplete batch
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=1,
    prefetch_factor=0,
    pin_memory=True,
    collate_fn=fungi_collate_fn,
    drop_last=True # also
)

print(f"Length of train_loader: {len(train_loader)} and val_loader: {len(val_loader)}")

Length of train_loader: 644 and val_loader: 161


In [4]:
import torch.optim as optim

# Set num_species_classes based on the dataset
num_species_classes = full_dataset.num_species_classes
print(f"Number of species classes: {num_species_classes}")  # Should print 100
num_toxicity_classes = 2   # Assuming binary classification (edible or poisonous)

# Create the CAFormer model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = timm.create_model(
    "caformer_s18.sail_in22k",
    pretrained=True,
    num_classes=num_species_classes  # For species classification
)
base_model.to(device)

class MultiTaskModel(nn.Module):
    def __init__(self, base_model, num_species_classes, num_toxicity_classes):
        super(MultiTaskModel, self).__init__()
        self.base_model = base_model

        # Replace the classifier for species classification
        self.base_model.reset_classifier(num_species_classes)

        num_features = base_model.num_features  # Number of features after pooling

        # Create a new classifier for toxicity
        self.fc_toxicity = nn.Linear(num_features, num_toxicity_classes)

        # Define a pooling layer if not present
        self.global_pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        # Extract features
        features = self.base_model.forward_features(x)  # Shape: [batch_size, channels, height, width]
        # print(f"Features shape: {features.shape}")

        # Species classification (using base model's head)
        species_logits = self.base_model.forward_head(features)
        # print(f"Species logits shape: {species_logits.shape}")

        # Apply global average pooling for toxicity classification
        pooled_features = self.global_pool(features).flatten(1)  # Shape: [batch_size, channels]
        # print(f"Pooled features shape: {pooled_features.shape}")

        # Toxicity classification
        toxicity_logits = self.fc_toxicity(pooled_features)
        # print(f"Toxicity logits shape: {toxicity_logits.shape}")

        return {'species': species_logits, 'toxicity': toxicity_logits}

# Initialize the multi-task model
model = MultiTaskModel(base_model, num_species_classes, num_toxicity_classes)
model.to(device)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Adjust learning rate as needed

# Define loss functions
criterion_species = nn.CrossEntropyLoss()
criterion_toxicity = nn.CrossEntropyLoss()

print(model)

Number of species classes: 100
MultiTaskModel(
  (base_model): MetaFormer(
    (stem): Stem(
      (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
      (norm): LayerNorm2dNoBias((64,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): MetaFormerStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): MetaFormerBlock(
            (norm1): LayerNorm2dNoBias((64,), eps=1e-06, elementwise_affine=True)
            (token_mixer): SepConv(
              (pwconv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (act1): StarReLU(
                (relu): ReLU()
              )
              (dwconv): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128, bias=False)
              (act2): Identity()
              (pwconv2): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            )
            (drop_path1): Identity()
            (layer_scale1

In [5]:
# Get a batch of data
images, (class_ids, toxicities), _ = next(iter(train_loader))

# Check the range of class_ids
print(f"Class IDs in batch: {class_ids}")
print(f"Min class ID: {class_ids.min()}, Max class ID: {class_ids.max()}")

# Ensure class IDs are in the range [0, num_species_classes - 1]
assert class_ids.min() >= 0 and class_ids.max() < num_species_classes

Class IDs in batch: tensor([17, 47, 69,  6, 81, 23, 98, 91, 93, 50, 89, 85, 56, 91, 34, 78, 12,  5,
        41, 88, 80, 47, 12, 33, 41, 44, 53, 86, 63, 71, 39, 40])
Min class ID: 5, Max class ID: 98


In [None]:
def conf_matrix(targets, predictions, task_name, matrix_type='species', colors=None):
    """
    Compute and plot a confusion matrix with customizable display and colors.

    Args:
        targets: True labels.
        predictions: Predicted labels.
        task_name: Name of the task (e.g., 'Species', 'Toxicity').
        matrix_type: Type of matrix ('species' or 'toxicity') for specific display logic.
        colors: Custom color palette.
    """
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    import numpy as np

    # Compute confusion matrix
    cm = confusion_matrix(targets, predictions)

    if matrix_type == 'species':
        # For Species Misclassification Matrix
        plt.figure(figsize=(6, 1))  # Wide and short, single row
        sns.set(font_scale=1.4)

        # Define colors: Red for misclassified, Green for recognized
        cmap = sns.color_palette(['green', 'red']) if not colors else sns.color_palette(colors)

        sns.heatmap(
            cm[:1],  # Only show the first row
            annot=True,
            fmt='d',
            cmap=cmap,
            cbar=False,
            xticklabels=['Missclassified', 'Recognized'],  # Bottom labels
            yticklabels=[]  # Remove true labels from the side
        )
        plt.xlabel('Predicted')
        plt.title(f'{task_name} Misclassification Matrix')
        plt.show()

    if matrix_type == 'toxicity':
        # For Toxicity Matrix
        plt.figure(figsize=(6, 5))
        sns.set(font_scale=1.4)

        # Assign unique indices to each cell type
        cell_type_indices = np.array([[0, 1], [2, 3]])  # TN, FP, FN, TP

        # Define custom colors for each cell type
        if colors is None:
            colors = ['black', 'yellow', 'purple', 'black']  # TN, FP, FN, TP

        # Create a custom colormap from the specified colors
        custom_cmap = ListedColormap(colors)

        # Plot the heatmap using the cell_type_indices to map colors
        sns.heatmap(
            cell_type_indices,
            annot=cm,
            fmt='d',
            cmap=custom_cmap,
            cbar=False,
            xticklabels=['Negative', 'Positive'],
            yticklabels=['Negative', 'Positive']
        )
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'{task_name} Classification Confusion Matrix')
        plt.show()

    return cm

In [None]:
def train_model(model, train_loader, val_loader, optimizer, num_epochs=10, alpha = 1.0, beta = 0.6):
    train_losses = []
    val_losses = []
    train_accuracies_species = []
    train_accuracies_toxicity = []
    val_accuracies_species = []
    val_accuracies_toxicity = []
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
    
        # Training Phase
        model.train()
        running_loss = 0.0
        correct_species = 0
        correct_toxicity = 0
        total_samples = 0
    
        # Iterate over batches
        for batch_idx, (images, (class_ids, toxicities), _) in enumerate(train_loader):
            images = images.to(device)
            class_ids = class_ids.to(device, dtype=torch.long)
            toxicities = toxicities.to(device, dtype=torch.long)
    
            optimizer.zero_grad()
            outputs = model(images)
    
            loss_species = criterion_species(outputs['species'], class_ids)
            loss_toxicity = criterion_toxicity(outputs['toxicity'], toxicities)
            loss = alpha * loss_species + beta * loss_toxicity  # weighting
    
            loss.backward()
            optimizer.step()
    
            # Update running loss
            batch_loss = loss.item() * images.size(0)
            running_loss += batch_loss
    
            # Calculate accuracies
            _, preds_species = torch.max(outputs['species'], 1)
            _, preds_toxicity = torch.max(outputs['toxicity'], 1)
            batch_correct_species = (preds_species == class_ids).sum().item()
            batch_correct_toxicity = (preds_toxicity == toxicities).sum().item()
            batch_samples = images.size(0)
            correct_species += batch_correct_species
            correct_toxicity += batch_correct_toxicity
            total_samples += batch_samples
    
            # Calculate batch accuracies
            batch_accuracy_species = batch_correct_species / batch_samples * 100
            batch_accuracy_toxicity = batch_correct_toxicity / batch_samples * 100
    
            # Print per-batch metrics
            print(f"Train Batch {batch_idx+1}/{len(train_loader)}: Loss = {loss.item():.4f}, Species Acc = {batch_accuracy_species:.2f}%, Toxicity Acc = {batch_accuracy_toxicity:.2f}%")
    
        avg_train_loss = running_loss / total_samples
        train_losses.append(avg_train_loss)
        train_accuracy_species = correct_species / total_samples * 100
        train_accuracy_toxicity = correct_toxicity / total_samples * 100
        train_accuracies_species.append(train_accuracy_species)
        train_accuracies_toxicity.append(train_accuracy_toxicity)
    
        # Validation Phase
        model.eval()
        val_running_loss = 0.0
        val_correct_species = 0
        val_correct_toxicity = 0
        val_total_samples = 0
        val_predictions_species = []
        val_targets_species = []
        val_predictions_toxicity = []
        val_targets_toxicity = []
    
        with torch.no_grad():
            for batch_idx, (images, (class_ids, toxicities), _) in enumerate(val_loader):
                images = images.to(device)
                class_ids = class_ids.to(device, dtype=torch.long)
                toxicities = toxicities.to(device, dtype=torch.long)
    
                outputs = model(images)
    
                loss_species = criterion_species(outputs['species'], class_ids)
                loss_toxicity = criterion_toxicity(outputs['toxicity'], toxicities)
                loss = alpha * loss_species + beta * loss_toxicity  # Adjust weighting
    
                # Update running loss
                batch_loss = loss.item() * images.size(0)
                val_running_loss += batch_loss
    
                # Calculate accuracies
                _, preds_species = torch.max(outputs['species'], 1)
                _, preds_toxicity = torch.max(outputs['toxicity'], 1)
                batch_correct_species = (preds_species == class_ids).sum().item()
                batch_correct_toxicity = (preds_toxicity == toxicities).sum().item()
                batch_samples = images.size(0)
                val_correct_species += batch_correct_species
                val_correct_toxicity += batch_correct_toxicity
                val_total_samples += batch_samples
    
                # Store predictions and targets for confusion matrices
                val_predictions_species.extend(preds_species.cpu().numpy())
                val_targets_species.extend(class_ids.cpu().numpy())
                val_predictions_toxicity.extend(preds_toxicity.cpu().numpy())
                val_targets_toxicity.extend(toxicities.cpu().numpy())
    
                # Calculate batch accuracies
                batch_accuracy_species = batch_correct_species / batch_samples * 100
                batch_accuracy_toxicity = batch_correct_toxicity / batch_samples * 100
    
                # Print per-batch metrics
                print(f"Val Batch {batch_idx+1}/{len(val_loader)}: Loss = {loss.item():.4f}, Species Acc = {batch_accuracy_species:.2f}%, Toxicity Acc = {batch_accuracy_toxicity:.2f}%")
    
        avg_val_loss = val_running_loss / val_total_samples
        val_losses.append(avg_val_loss)
        val_accuracy_species = val_correct_species / val_total_samples * 100
        val_accuracy_toxicity = val_correct_toxicity / val_total_samples * 100
        val_accuracies_species.append(val_accuracy_species)
        val_accuracies_toxicity.append(val_accuracy_toxicity)
    
        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{num_epochs} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f}, Species Acc: {train_accuracy_species:.2f}%, Toxicity Acc: {train_accuracy_toxicity:.2f}%")
        print(f"  Val Loss:   {avg_val_loss:.4f}, Species Acc: {val_accuracy_species:.2f}%, Toxicity Acc: {val_accuracy_toxicity:.2f}%")
    
        # Confusion Matrices
        # For Species Classification
        species_correct = np.array(val_predictions_species) == np.array(val_targets_species)
        species_binary_predictions = species_correct.astype(int)  # 1 if correct, 0 if incorrect
        species_binary_targets = np.zeros_like(species_binary_predictions)  # All zeros (1 row for misclassification matrix)

        # Plot Confusion Matrix for Species Classification
        conf_matrix(
            species_binary_targets, 
            species_binary_predictions, 
            task_name='Species', 
            matrix_type='species', 
            colors=['red', 'green']  # Red for misclassified, Green for recognized
        )

        # For Toxicity Classification
        toxicity_predictions = np.array(val_predictions_toxicity)
        toxicity_targets = np.array(val_targets_toxicity)

        # Plot Confusion Matrix for Toxicity Classification
        conf_matrix(
            toxicity_targets, 
            toxicity_predictions, 
            task_name='Toxicity', 
            matrix_type='toxicity', 
            colors=['orange', 'yellow', 'purple', 'black']  # TN, FP, FN, TP
        )
        
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies_species': train_accuracies_species,
        'train_accuracies_toxicity': train_accuracies_toxicity,
        'val_accuracies_species': val_accuracies_species,
        'val_accuracies_toxicity': val_accuracies_toxicity
    }

In [9]:
# Set the number of epochs
num_epochs = 10

In [None]:
# Start training
training_stats = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    num_epochs=num_epochs
)

# Save the model and statistics
torch.save(model.state_dict(), "multi_fungi_model.pth")
print("Training completed!")


Epoch 1/10
Train Batch 1/644: Loss = 4.9902, Species Acc = 0.00%, Toxicity Acc = 84.38%
Train Batch 2/644: Loss = 5.2633, Species Acc = 0.00%, Toxicity Acc = 78.12%
Train Batch 3/644: Loss = 4.6822, Species Acc = 0.00%, Toxicity Acc = 93.75%
Train Batch 4/644: Loss = 5.0829, Species Acc = 3.12%, Toxicity Acc = 78.12%
Train Batch 5/644: Loss = 4.8650, Species Acc = 0.00%, Toxicity Acc = 87.50%
Train Batch 6/644: Loss = 4.8097, Species Acc = 0.00%, Toxicity Acc = 93.75%
Train Batch 7/644: Loss = 4.6622, Species Acc = 3.12%, Toxicity Acc = 93.75%
Train Batch 8/644: Loss = 4.8120, Species Acc = 0.00%, Toxicity Acc = 96.88%
Train Batch 9/644: Loss = 4.9224, Species Acc = 0.00%, Toxicity Acc = 90.62%
Train Batch 10/644: Loss = 5.1301, Species Acc = 0.00%, Toxicity Acc = 84.38%
Train Batch 11/644: Loss = 4.9194, Species Acc = 0.00%, Toxicity Acc = 96.88%
Train Batch 12/644: Loss = 5.1352, Species Acc = 0.00%, Toxicity Acc = 93.75%
Train Batch 13/644: Loss = 5.0117, Species Acc = 3.12%, Toxic

KeyboardInterrupt: 

In [None]:

# Plot Training and Validation Loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), training_stats['train_losses'], label="Train Loss")
plt.plot(range(1, num_epochs + 1), training_stats['val_losses'], label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.show()

# Plot Species Classification Accuracy
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), training_stats['train_accuracies_species'], label="Train Species Accuracy")
plt.plot(range(1, num_epochs + 1), training_stats['val_accuracies_species'], label="Validation Species Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.title("Species Classification Accuracy")
plt.legend()
plt.show()

# Plot Toxicity Classification Accuracy
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), training_stats['train_accuracies_toxicity'], label="Train Toxicity Accuracy")
plt.plot(range(1, num_epochs + 1), training_stats['val_accuracies_toxicity'], label="Validation Toxicity Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.title("Toxicity Classification Accuracy")
plt.legend()
plt.show()

In [None]:
# Plot User-Focused Loss (if included)
if 'val_user_focused_losses' in training_stats:
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), training_stats['val_user_focused_losses'], label="Validation User-Focused Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Validation User-Focused Loss")
    plt.legend()
    plt.show()