### Hyperparameters

In [None]:
hparams = {
    'batch_size': 64,
    'lr': 0.005,
    'last_drop': 0.2,
    'max_epochs': 20,
    'patience': 5
}

### Import Statement

In [None]:
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning import Trainer
from torchvision import transforms
import os
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
from torchvision.models import mobilenet_v2
import pandas as pd
from typing import Optional
from sklearn.model_selection import train_test_split
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

### Project Properties

In [None]:
data_dir = 'data'

num_classes = len([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])

random_state = 42

num_workers = int(os.cpu_count() / 2)

val_split = 0.2

test_split = 0.1

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

### Data Module

In [None]:
class ImageDataset(Dataset):
    def __init__(self, data: pd.DataFrame, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path = self.data.iloc[idx]['image_path']
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = self.data.iloc[idx]['label']
        return image, label

class ImageDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=1, transform=None, val_split=0.2, test_split=0.2, random_state=42):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform
        self.val_split = val_split
        self.test_split = test_split
        self.random_state = random_state

    def setup(self, stage: Optional[str] = None):
        data = self.load_data()
        train_data, test_val_data = train_test_split(data, test_size=self.val_split + self.test_split, random_state=self.random_state, stratify=data['label'])
        val_data, test_data = train_test_split(test_val_data, test_size=self.test_split / (self.val_split + self.test_split), random_state=self.random_state, stratify=test_val_data['label'])

        if stage == 'fit' or stage is None:
            self.train_dataset = ImageDataset(train_data, transform=self.transform)
            self.val_dataset = ImageDataset(val_data, transform=self.transform)

        if stage == 'test' or stage is None:
            self.test_dataset = ImageDataset(test_data, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def load_data(self):
        image_paths = []
        labels = []
        class_dirs = [d for d in os.listdir(self.data_dir) if os.path.isdir(os.path.join(self.data_dir, d))]
        class_names = sorted(class_dirs)

        for idx, class_name in enumerate(class_names):
            class_dir = os.path.join(self.data_dir, class_name)
            for img_path in os.listdir(class_dir):
                if img_path == '.DS_Store':  # Skip .DS_Store files
                    continue
                image_path = os.path.join(class_dir, img_path)
                image_paths.append(image_path)
                labels.append(idx)

        data = pd.DataFrame({'image_path': image_paths, 'label': labels})
        return data


### Model  Architecture 

In [None]:
class MobileNetV2Lightning(pl.LightningModule):
    def __init__(self, num_classes, pretrained=True, lr=0.001, last_drop=0.2):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = lr
        self.model = mobilenet_v2(pretrained=pretrained)
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=last_drop, inplace=False),
            nn.Linear(self.model.last_channel, num_classes)
        )
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)

        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        
        self.log('test_loss', loss)
        self.log('test_acc', acc)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90)
        return [optimizer], [lr_scheduler]

### Main Initializer

In [None]:
batch_size = hparams['batch_size']
lr = hparams['lr']
last_drop = hparams['last_drop']
max_epochs = hparams['max_epochs']
patience = hparams['patience']

dataModule = ImageDataModule(data_dir=data_dir, batch_size=batch_size, num_workers=num_workers, transform=transform, val_split=val_split, test_split=test_split, random_state=random_state)
model = MobileNetV2Lightning(num_classes=num_classes, lr=lr, last_drop=last_drop)
early_stopping = EarlyStopping(monitor='val_loss', patience=patience)
lr_monitor = LearningRateMonitor(logging_interval='epoch')

trainer = Trainer(devices='auto', 
                    accelerator='auto', 
                    max_epochs=max_epochs,
                    logger=True, 
                    enable_checkpointing=True,
                    callbacks=[lr_monitor, early_stopping])


### Initiate Training/Validating

In [None]:
trainer.fit(model, datamodule=dataModule)