In [None]:
%pip install torch torchvision numpy matplotlib tensorboard kagglehub pillow scikit-learn pandas

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
import kagglehub
import os
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import pandas as pd
import time

In [None]:
EPOCHS = 15
BATCH_SIZE = 16
IMAGE_SIZE = 64
LEARNING_RATE = 0.002
MAX_IMAGES_PER_CLASS = 1000

In [None]:
def load_mushroom_data(max_images_per_class=1000):
    print("Loading dataset...")
    
    try:
        path = kagglehub.dataset_download("zlatan599/mushroom1")
        print(f"Dataset downloaded to: {path}")
    except Exception as e:
        print(f"Download error: {e}")
        return None, None, None, None
    
    train_csv = os.path.join(path, 'train.csv')
    val_csv = os.path.join(path, 'val.csv')
    
    train_df = pd.read_csv(train_csv)
    val_df = pd.read_csv(val_csv)
    all_df = pd.concat([train_df, val_df], ignore_index=True)
    
    # Get top 10 classes
    class_counts = all_df['label'].value_counts()
    top10_classes = class_counts.head(10)
    top10_class_names = list(top10_classes.index)
    
    print(f"Selected top 10 classes with {max_images_per_class} images each")
    
    class_to_idx = {class_name: i for i, class_name in enumerate(top10_class_names)}
    
    image_paths = []
    labels = []
    
    for class_name in top10_class_names:
        class_df = all_df[all_df['label'] == class_name].copy()
        class_df = class_df.sample(frac=1, random_state=42).reset_index(drop=True)
        limited_df = class_df.head(max_images_per_class)
        
        for _, row in limited_df.iterrows():
            img_path = row['image_path']
            if img_path.startswith('/kaggle/working/'):
                img_path = img_path.replace('/kaggle/working/', '')
            
            full_path = os.path.join(path, img_path)
            if os.path.exists(full_path):
                image_paths.append(full_path)
                labels.append(class_to_idx[class_name])
    
    print(f"Total images loaded: {len(image_paths):,}")
    
    return image_paths, labels, top10_class_names, path

In [None]:
class MushroomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx]).convert('RGB')
            label = self.labels[idx]
            
            if self.transform:
                image = self.transform(image)
            
            return image, label
        except Exception:
            dummy_image = torch.zeros(3, 64, 64)
            return dummy_image, 0

In [None]:
class MushroomCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(16, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
image_paths, labels, class_names, dataset_path = load_mushroom_data(MAX_IMAGES_PER_CLASS)

if image_paths is None:
    print("Failed to load data!")
    exit()

In [None]:
transform_train = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    T.RandomApply([T.ColorJitter(hue=0.1)], p=0.8),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(degrees=360),
    T.RandomApply([T.GaussianBlur(3, sigma=(0.1, 1.0))], p=0.3)
])

transform_test = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [None]:
X_train, X_temp, y_train, y_temp = train_test_split(
    image_paths, labels, test_size=0.3, random_state=42, stratify=labels
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

print(f"Data split - Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")

In [None]:
train_dataset = MushroomDataset(X_train, y_train, transform_train)
val_dataset = MushroomDataset(X_val, y_val, transform_test)
test_dataset = MushroomDataset(X_test, y_test, transform_test)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
model = MushroomCNN(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

In [None]:
def train(train_loader, model, criterion, optimizer, device):
    model.train()
    
    total_loss = 0.0
    total_samples = 0
    correct = 0
    
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        batch_size = y.size(0)
        
        y_pred = model(X)
        loss = criterion(y_pred, y)
        
        total_loss += loss.item() * batch_size
        total_samples += batch_size
        
        preds = y_pred.argmax(1)
        correct += (preds == y).sum().item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_loss = total_loss / total_samples
    accuracy = correct / total_samples
    
    return avg_loss, accuracy

In [None]:
def test(test_loader, model, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            loss = criterion(y_pred, y)
            
            total_loss += loss.item() * y.size(0)
            preds = y_pred.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    
    avg_loss = total_loss / total
    accuracy = correct / total
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4f}")
    
    return avg_loss, accuracy

In [None]:
writer = SummaryWriter()

print("Starting training...")
best_val_acc = 0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 20)
    
    train_loss, train_acc = train(train_dataloader, model, criterion, optimizer, device)
    val_loss, val_acc = test(val_dataloader, model, criterion, device)
    
    scheduler.step(val_loss)
    
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Accuracy/Train', train_acc, epoch)
    writer.add_scalar('Loss/Val', val_loss, epoch)
    writer.add_scalar('Accuracy/Val', val_acc, epoch)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_mushroom_model.pth')
        print(f"New best model saved! (Val Acc: {val_acc:.4f})")

writer.close()
print("\nTraining completed!")

Test Loss: 2.0431, Test Accuracy: 0.2353
New best model saved! (Val Acc: 0.2353)

Epoch 2/15
--------------------
Test Loss: 1.8711, Test Accuracy: 0.2633
New best model saved! (Val Acc: 0.2633)

Epoch 3/15
--------------------


In [None]:
model.load_state_dict(torch.load('best_mushroom_model.pth'))
print("Best model loaded for final evaluation")

final_loss, final_accuracy = test(test_dataloader, model, criterion, device)

In [None]:
def print_model_info(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nModel information:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print(f"  Model size: {total_params * 4 / 1024 / 1024:.2f} MB")

print_model_info(model)

print(f"\nFinal Results:")
print(f"  Best validation accuracy: {best_val_acc:.4f}")
print(f"  Final test accuracy: {final_accuracy:.4f}")

In [None]:
model.eval()

torch.save(model.state_dict(), 'mushroom_model_final.pth')
print("✅ PyTorch model saved!")

dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)

torch.onnx.export(
    model,
    dummy_input,
    'mushroom_model.onnx',
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

print("✅ ONNX model exported to mushroom_model.onnx")

try:
    import onnxruntime as ort
    
    ort_session = ort.InferenceSession('mushroom_model.onnx')
    ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.cpu().numpy()}
    ort_outputs = ort_session.run(None, ort_inputs)
    
    with torch.no_grad():
        pytorch_output = model(dummy_input)
    
    print(f"PyTorch prediction: {torch.argmax(pytorch_output, dim=1).item()}")
    print(f"ONNX prediction: {np.argmax(ort_outputs[0], axis=1)[0]}")
    print("✅ ONNX verification successful!")
    
except ImportError:
    print("💡 onnxruntime not installed - verification skipped")

In [None]:
def analyze_results(model, test_loader, class_names, device):
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    print("\nClassification Report:")
    print(classification_report(all_targets, all_preds, 
                              target_names=class_names, digits=4))

analyze_results(model, test_dataloader, class_names, device)