In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torchvision import transforms
from pathlib import Path
from dataclasses import dataclass
import torch.nn.functional as F
from torch import nn
from utils import print_trainable_parameters
import os
import keras

In [2]:
@dataclass
class Config:
    data_dir = Path('data')
    train_dir = data_dir / 'train'
    validation_dir = data_dir / 'validation'
    image_shape = (3, 224, 224)
    image_size = (224, 224)
    num_workers = os.cpu_count()
    batch_size = 32


config = Config()

In [3]:
#create dataset using ImageFolder
train_transforms = transforms.Compose([
                                 transforms.Resize(config.image_size),
                                 transforms.RandomHorizontalFlip(p=0.5),
                                 transforms.RandomVerticalFlip(p=0.5),
                                #  transforms.RandomRotation(degrees=45),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
                                 ])
validation_transforms = transforms.Compose([
                                    transforms.Resize(config.image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
                                    ])

In [4]:
train_dataset = ImageFolder(config.train_dir, transform=train_transforms)
validation_dataset = ImageFolder(config.validation_dir, transform=validation_transforms)
len(train_dataset), len(validation_dataset)

(2936, 734)

In [5]:
#create dataloaders

train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
    )

val_dataloader = torch.utils.data.DataLoader(
    validation_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True,
)

In [6]:
class Net(nn.Module):
    def __init__(self, num_classes):
        super(Net, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, padding='same'),
            nn.LeakyReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 64, kernel_size=3, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(256, 512, kernel_size=3, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(512, 512, kernel_size=3, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2, 2),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        # self.features.requires_grad_(False)

        # print_trainable_parameters(self)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
config.class_names = train_dataset.classes

In [None]:
class Net_1(nn.Module):
    def __init__(self, num_classes):
        super(Net, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, padding='same'),
            nn.LeakyReLU(),
            nn.Conv2d(32, 32, kernel_size=5, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size='same'),  # 1x1 convolution
            nn.LeakyReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=1),  # 1x1 convolution
            nn.LeakyReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(256, 512, kernel_size=3, padding='same'),
            nn.LeakyReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2, 2),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.classifier = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        # self.features.requires_grad_(False)

        # print_trainable_parameters(self)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
config.class_names = train_dataset.classes

In [7]:
class KerasModel(keras.Model):
    keras.backend.clear_session()
    def __init__(self):
        super().__init__()
        self.model = keras.layers.TorchModuleWrapper(Net(len(config.class_names)), name='torch_model')
    
    def call(self, x):
        x = self.model(x)
        return x
    
keras_model = KerasModel()

In [8]:
keras_model.summary()

In [11]:
keras_model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="adamW",
    metrics=["accuracy"]
)

In [None]:
keras_model.fit(
    train_dataloader,
    epochs=10,
    validation_data=val_dataloader,
)