## Install dependencies

In [None]:
!pip install lightning

## Import libraries

In [None]:
import os, random
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

## PyTorch & TorchVision

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.utils import make_grid

## Lightning

In [None]:
from lightning.pytorch import LightningModule, LightningDataModule, Trainer
import lightning.pytorch as L
print(L.__version__)

## Sklearn

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

## TensorFlow (unused in current logic, possibly for later)

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

## LOAD IMAGE PATHS & LABELS FROM DIRECTORY

In [None]:
dir0 = '/kaggle/input/mushroom1/merged_dataset'

classes = []
paths = []
for dirname, _, filenames in os.walk(dir0):
    for filename in filenames:
        classes.append(dirname.split('/')[-1])
        paths.append(os.path.join(dirname, filename))

## Create ImageFolder dataset to access class names

In [None]:
dataset0 = datasets.ImageFolder(root=dir0)
class_names = dataset0.classes
print(class_names)
print(f"Number of classes: {len(class_names)}")

N = list(range(len(classes)))
normal_mapping = dict(zip(class_names, N))
reverse_mapping = dict(zip(N, class_names))

## CREATE A DATAFRAME OF PATHS AND LABELS

In [None]:
data = pd.DataFrame({'path': paths, 'class': classes})
data['label'] = data['class'].map(normal_mapping)
print(f"Total images: {len(data)}")

## DEFINE IMAGE TRANSFORMATIONS

In [None]:
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

## CONVERT DATAFRAME TO LIST OF (PATH, LABEL) TUPLES

In [None]:
def create_path_label_list(df):
    """
    Converts a DataFrame with image paths and labels into a list of tuples.

    Args:
        df (pd.DataFrame): DataFrame containing 'path' and 'label' columns.

    Returns:
        List[Tuple[str, int]]: List of (image_path, label) tuples.
    """
    return [(row['path'], row['label']) for _, row in df.iterrows()]

path_label = create_path_label_list(data)
path_label = random.sample(path_label, 20000)
print(len(path_label))
print(path_label[0:3])

## DEFINE CUSTOM DATASET CLASS

In [None]:
class CustomDataset(Dataset):
    """
    Custom PyTorch Dataset to load images from file paths and return transformed images.

    Args:
        path_label (List[Tuple[str, int]]): List of (image_path, label) tuples.
        transform (callable, optional): Optional transform to be applied on a sample.

    Returns:
        Tuple[Tensor, int]: Transformed image tensor and its label.
    """
    def __init__(self, path_label, transform=None):
        self.path_label = path_label
        self.transform = transform

    def __len__(self):
        return len(self.path_label)

    def __getitem__(self, idx):
        path, label = self.path_label[idx]
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

## DEFINE LIGHTNING DATA MODULE

In [None]:
class DataModule(LightningDataModule):
    """
    PyTorch Lightning DataModule for loading and batching image data.

    Handles both custom datasets and torchvision ImageFolder datasets.

    Args:
        data_source (str): Either 'custom' or 'imagefolder'.
        path_label (List[Tuple[str, int]]): Data for custom loader.
        root_dir (str): Directory for ImageFolder loader.
        batch_size (int): Batch size for training/validation.
        train_split (float): Train/validation split ratio.
        custom_transform (callable, optional): Custom image transforms.
    """
    def __init__(self, data_source=None, path_label=None, root_dir=None,
                 batch_size=32, train_split=0.8, custom_transform=None):
        super().__init__()
        self.data_source = data_source or ('custom' if path_label else 'imagefolder')
        self.path_label = path_label
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.train_split = train_split
        self.transform = custom_transform or transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ])

    def setup(self, stage=None):
        """
        Prepares datasets for training and validation based on selected data source.
        """
        if self.data_source == 'custom':
            dataset = CustomDataset(self.path_label, self.transform)
            train_size = int(self.train_split * len(dataset))
            self.train_dataset = torch.utils.data.Subset(dataset, range(train_size))
            self.val_dataset = torch.utils.data.Subset(dataset, range(train_size, len(dataset)))
        elif self.data_source == 'imagefolder':
            dataset = datasets.ImageFolder(root=self.root_dir, transform=self.transform)
            train_size = int(self.train_split * len(dataset))
            val_size = len(dataset) - train_size
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    def test_dataloader(self):
        return self.val_dataloader()

    def get_num_classes(self):
        if self.data_source == 'imagefolder':
            return len(datasets.ImageFolder(root=self.root_dir).classes)
        elif self.data_source == 'custom':
            return len(set([label for _, label in self.path_label]))

## DEFINE MODEL USING TIMM AND LIGHTNINGMODULE

In [None]:
import timm

class ConvolutionalNetwork(LightningModule):
    """
    A PyTorch LightningModule using a pretrained ResNet152 from TIMM for image classification.

    Args:
        num_classes (int): Number of target classes for classification.
    """
    def __init__(self, num_classes):
        super().__init__()
        self.base_model = timm.create_model('resnet152', pretrained=True, num_classes=num_classes)

    def forward(self, x):
        """Forward pass through the network."""
        return self.base_model(x)

    def configure_optimizers(self):
        """Sets up Adam optimizer."""
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def test_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("test_loss", loss)
        self.log("test_acc", acc)

## TRAINING THE MODEL

In [None]:
if __name__ == '__main__':
    datamodule = DataModule(path_label=path_label)
    datamodule.setup()

    model = ConvolutionalNetwork(num_classes=len(class_names))
    trainer = Trainer(max_epochs=4, accelerator="cpu", devices=1)
    trainer.fit(model, datamodule)

    # TESTING
    datamodule.setup(stage='test')
    test_loader = datamodule.test_dataloader()
    trainer.test(model=model, dataloaders=test_loader)

## DISPLAY TEST IMAGE GRID

In [None]:
for images, labels in datamodule.test_dataloader():
    break

im = make_grid(images, nrow=8)
plt.figure(figsize=(12, 12))
plt.imshow(np.transpose(im.numpy(), (1, 2, 0)))

# Inverse transform for visualization
inv_normalize = transforms.Normalize(mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
                                     std=[1 / 0.229, 1 / 0.224, 1 / 0.225])
im = inv_normalize(im)

plt.figure(figsize=(12, 12))
plt.imshow(np.transpose(im.numpy(), (1, 2, 0)))

## EVALUATE MODEL WITH CLASSIFICATION REPORT

In [None]:
model.eval()
device = torch.device("cpu")
y_true, y_pred = [], []

with torch.no_grad():
    for images, labels in datamodule.test_dataloader():
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

print(classification_report(y_true, y_pred, target_names=class_names, digits=4))