In [1]:
from zipfile import ZipFile
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import numpy as np
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import os
import cv2
import shutil

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling1D
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam

from tensorflow.keras import mixed_precision
from tensorflow.keras.layers import LayerNormalization, MultiHeadAttention, Add

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
train_dir = '/content/drive/MyDrive/APS360 Project/Data/train'
test_dir = '/content/drive/MyDrive/APS360 Project/Data/test'

In [4]:
batch_size = 32
img_size = (224, 224)

train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=img_size,
    batch_size=batch_size
)

val_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=img_size,
    batch_size=batch_size
)

test_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    test_dir,
    image_size=img_size,
    batch_size=batch_size
)

Found 28710 files belonging to 7 classes.
Using 22968 files for training.
Found 28710 files belonging to 7 classes.
Using 5742 files for validation.
Found 7178 files belonging to 7 classes.


In [5]:
# Define the CNN branch
def create_cnn_branch(input_shape=(224, 224, 3)):
    cnn_input = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, kernel_size=(3, 3), activation='relu')(cnn_input)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.Conv2D(64, kernel_size=(3, 3), activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.Conv2D(128, kernel_size=(3, 3), activation='relu')(x)
    x = layers.GlobalAveragePooling2D()(x)
    cnn_output = layers.Dense(128, activation='relu')(x)
    return Model(cnn_input, cnn_output, name="CNN_Branch")

# Define ViT branch
def create_vit_branch(input_shape=(224, 224, 3), patch_size=16, num_layers=4, num_heads=8, projection_dim=128, mlp_dim=256, num_classes=7):
    vit_input = layers.Input(shape=input_shape)

    #Extract patches
    patches = layers.Conv2D(filters=projection_dim, kernel_size=(patch_size, patch_size), strides=(patch_size, patch_size))(vit_input)
    patches = layers.Reshape((-1, projection_dim))(patches)

    # Positional encoding
    patch_positions = tf.range(start=0, limit=patches.shape[1], delta=1)
    positional_encoding = layers.Embedding(input_dim=patches.shape[1], output_dim=projection_dim)(patch_positions)
    positional_encoding = tf.expand_dims(positional_encoding, axis=0)
    patches += positional_encoding

    # Transformer layers
    for _ in range(num_layers):
        #Multi head Self Attention
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)(patches, patches)
        x = layers.Add()([attention_output, patches])
        x = layers.LayerNormalization()(x)

        #MLP
        x = layers.Dense(mlp_dim, activation='relu')(x)
        x = layers.Dense(projection_dim)(x)
        patches = layers.Add()([x, patches])
        patches = layers.LayerNormalization()(patches)

    vit_output = layers.GlobalAveragePooling1D()(patches)
    vit_output = layers.Dense(128, activation='relu')(vit_output)

    return Model(vit_input, vit_output, name="ViT_Branch")

# Combine CNN and ViT
def create_hybrid_model(input_shape=(224, 224, 3), num_classes=7):
    cnn_branch = create_cnn_branch(input_shape=input_shape)
    vit_branch = create_vit_branch(input_shape=input_shape)

    combined = layers.concatenate([cnn_branch.output, vit_branch.output])
    x = layers.Dense(256, activation='relu')(combined)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(128, activation='relu')(x)
    output = layers.Dense(num_classes, activation='softmax')(x)

    hybrid_model = Model(inputs=[cnn_branch.input, vit_branch.input], outputs=output, name="Hybrid_Model")
    return hybrid_model

#define model
model = create_hybrid_model()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

#model.summary()

In [None]:
def duplicate_input(image, label):
    return (image, image), label

train_dataset = train_dataset.map(duplicate_input)
val_dataset = val_dataset.map(duplicate_input)
test_dataset = test_dataset.map(duplicate_input)

epochs = 20
batch_size = 32

history = model.fit(
    train_dataset,
    epochs=20,
    validation_data=val_dataset
)

val_loss, val_accuracy = model.evaluate([val_dataset, val_dataset])
print(f'Validation Loss: {val_loss}')
print(f'Validation Accuracy: {val_accuracy}')

Epoch 1/20
[1m718/718[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1692s[0m 2s/step - accuracy: 0.2110 - loss: 2.2946 - val_accuracy: 0.2489 - val_loss: 1.8065
Epoch 2/20
[1m718/718[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1470s[0m 2s/step - accuracy: 0.2359 - loss: 1.8191 - val_accuracy: 0.2513 - val_loss: 1.7996
Epoch 3/20
[1m718/718[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1444s[0m 2s/step - accuracy: 0.2450 - loss: 1.8028 - val_accuracy: 0.2492 - val_loss: 1.7922
Epoch 4/20
[1m718/718[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1412s[0m 2s/step - accuracy: 0.2481 - loss: 1.7952 - val_accuracy: 0.2492 - val_loss: 1.7891
Epoch 5/20
[1m718/718[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1426s[0m 2s/step - accuracy: 0.2502 - loss: 1.7913 - val_accuracy: 0.2396 - val_loss: 1.8134
Epoch 6/20
[1m718/718[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1479s[0m 2s/step - accuracy: 0.2547 - loss: 1.7869 - val_accuracy: 0.2534 - val_loss: 1.7790
Epoch 7/20
[1m7