In [1]:
import tensorflow as tf
import tensorflow_hub as hub
print("TF",tf.__version__)

import time
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
import pathlib

import os

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam


TF 2.9.1


In [2]:
# Load model into KerasLayer
model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
module = hub.KerasLayer(model_url)

In [3]:
data_dir = 'dataset'
train_dir = os.path.join(data_dir,"data")
train_fnames = os.listdir(train_dir)
NUM_CLASSES = len(train_fnames)
print(NUM_CLASSES)

127


In [4]:
train_datagen = ImageDataGenerator(rescale=1./255.,
                                  validation_split=0.2)

train_data = train_datagen.flow_from_directory(train_dir,
                                              target_size=(224,224),
                                              batch_size=32,
                                              class_mode="categorical",
                                              shuffle = True,
                                              subset='training')
validation_data = train_datagen.flow_from_directory(train_dir,
                                              target_size=(224,224),
                                              batch_size=32,
                                              class_mode="categorical",
                                              shuffle = True,
                                              subset='validation')

Found 20909 images belonging to 127 classes.
Found 5172 images belonging to 127 classes.


In [5]:
# BiT model 구성하기

class MyBiTModel(tf.keras.Model):
  """BiT with a new head."""

  def __init__(self, num_classes, module):
    super().__init__()

    self.num_classes = num_classes
    self.head = tf.keras.layers.Dense(num_classes, kernel_initializer='zeros')
    self.bit_model = module
  
  def call(self, images):
    # No need to cut head off since we are using feature extractor model
    bit_embedding = self.bit_model(images)
    return self.head(bit_embedding)

model = MyBiTModel(num_classes=NUM_CLASSES, module=module)

In [9]:
model.summary()

Model: "my_bi_t_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               multiple                  260223    
                                                                 
 keras_layer (KerasLayer)    multiple                  23500352  
                                                                 
Total params: 23,760,575
Trainable params: 260,223
Non-trainable params: 23,500,352
_________________________________________________________________


In [15]:
print(train_data)

<keras.preprocessing.image.DirectoryIterator object at 0x000001A7DE562CA0>


In [16]:
#loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer = Adam(1e-3),
             loss = "categorical_crossentropy",
             metrics=['accuracy'])

In [17]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())


[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 951525658347962674
xla_global_id: -1
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 5738397696
locality {
  bus_id: 1
  links {
  }
}
incarnation: 13974684223565143559
physical_device_desc: "device: 0, name: NVIDIA GeForce RTX 3070, pci bus id: 0000:2b:00.0, compute capability: 8.6"
xla_global_id: 416903419
]


In [None]:
history = model.fit(train_data,
                    epochs = 50,
                    batch_size=32,
                    validation_data = validation_data)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50

In [None]:
his_dict = history.history
loss = his_dict['loss']
val_loss = his_dict['val_loss'] 

epochs = range(1, len(loss) + 1)
fig = plt.figure(figsize = (10, 5))

# 훈련 및 검증 손실 그리기
ax1 = fig.add_subplot(1, 2, 1)
ax1.plot(epochs, loss, color = 'blue', label = 'train_loss')
ax1.plot(epochs, val_loss, color = 'orange', label = 'val_loss')
ax1.set_title('train and val loss')
ax1.set_xlabel('epochs')
ax1.set_ylabel('loss')
ax1.legend()

acc = his_dict['accuracy']
val_acc = his_dict['val_accuracy']

# 훈련 및 검증 정확도 그리기
ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(epochs, acc, color = 'blue', label = 'train_acc')
ax2.plot(epochs, val_acc, color = 'orange', label = 'val_acc')
ax2.set_title('train and val acc')
ax2.set_xlabel('epochs')
ax2.set_ylabel('acc')
ax2.legend()

plt.show()