## **Imports**

In [1]:
import tensorflow as tf
import os
from sklearn.model_selection import train_test_split
import random
import matplotlib.pyplot as plt
import numpy as np

## **Load the data and random shuffle them**

In [2]:
image_folder = 'data/all'
label_mapping = {
    'NONE': -1,
    'BACTERIALBLIGHT': 0,
    'BACTERAILBLIGHT': 0, #there where some spelling mistake while labeling
    'BLAST': 1,
    'BROWNSPOT': 2,
    'TUNGRO': 3
}

image_files = [f for f in os.listdir(image_folder) if f.endswith(".jpg")]
random.shuffle(image_files)
image_labels = []

output_labels_file = 'data/shuffled_labels.txt'

with open(output_labels_file, 'w') as f:
    for image_filename in image_files:
        label = "NONE"
        label_int = -1
        for keyword, value in label_mapping.items():
            if keyword.lower() in image_filename.lower():
                if keyword == 'BACTERAILBLIGHT':
                    keyword = 'BACTERIALBLIGHT'
                label_int = value
                label = keyword
                image_labels.append(label_int)
                break

        f.write(image_filename + ',' + str(label) + ',' + str(label_int) + '\n')

print(f'>> Shuffled labels saved to {output_labels_file}')


>> Shuffled labels saved to data/shuffled_labels.txt


## **Train / Validation / Test Split**

In [3]:
train_image_paths, val_test_paths, train_labels, val_test_labels = train_test_split(
    image_files, image_labels, test_size=0.3, random_state=42)

val_image_paths, test_image_paths, val_labels, test_labels = train_test_split(
    val_test_paths, val_test_labels, test_size=0.00001, random_state=42)

## **Write TFRecords**

In [4]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def read_and_decode(filename):
    IMG_CHANNELS = 3
    img = tf.io.read_file(filename)
    img = tf.image.decode_jpeg(img, channels=IMG_CHANNELS)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img

def write_tfrecords(filename, image_paths, labels):
    with tf.io.TFRecordWriter(filename) as writer:
        for path, label in zip(image_paths, labels):
            image = read_and_decode("data/all/" + path)
            dims = image.shape

            example = tf.train.Example(features=tf.train.Features(feature={
                'image': _float_feature(image.numpy().flatten()), 
                'shape': _int64_feature([dims[0], dims[1], dims[2]]),
                'label': _int64_feature([label])
            }))

            writer.write(example.SerializeToString())

write_tfrecords("data/train_images.tfrecords", train_image_paths, train_labels)
write_tfrecords("data/val_images.tfrecords", val_image_paths, val_labels)
write_tfrecords("data/test_images.tfrecords", test_image_paths, test_labels)

