In [5]:
import os

if not os.path.exists("../MNIST_dataset"):
    os.system("git clone https://github.com/DeepTrackAI/MNIST_dataset")

train_path = os.path.join("..","MNIST_dataset", "mnist", "train")
train_images_files = sorted(os.listdir(train_path))

print(len(train_images_files))

60000


In [6]:
import random
import numpy as np
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter

def number_to_string(number):
    num_to_str = {
        0: 'zero',
        1: 'one',
        2: 'two',
        3: 'three',
        4: 'four',
        5: 'five',
        6: 'six',
        7: 'seven',
        8: 'eight',
        9: 'nine',
        10: 'ten',
        30: 'thirty',
        45: 'forty-five',
        60: 'sixty',
        90: 'ninety',
        120: 'one hundred-twenty',
        135: 'one hundred-thirty-five',
        150: 'one hundred-fifty',
        180: 'ninety',
    }
    return num_to_str[number]


def number_representation(number):
    if random.choice([True, False]):
        return str(number)
    else:
        return number_to_string(number)


def apply_color_transform(image, color, bkg_color):
    # # Apply a slight blur to the mask to smooth the edges
    # image = image.filter(ImageFilter.GaussianBlur(1))
    # Create a new image with the background color
    bkg_image = Image.new("RGB", image.size, bkg_color)
    # Convert the original image to RGB
    image = image.convert("RGB")
    # Create a mask where the digit is
    mask = image.convert("L").point(lambda p: p > 128 and 255)
    # # Apply a slight blur to the mask to smooth the edges
    # mask = mask.filter(ImageFilter.GaussianBlur(1))
    # Apply the color to the digit
    colored_image = ImageChops.multiply(image, Image.new("RGB", image.size, color))
    # Combine the colored digit with the background
    final_image = Image.composite(colored_image, bkg_image, mask)
    return final_image

def transform_image(image, label):
    # Convert the image to a PIL Image
    pil_image = Image.fromarray(np.uint8(image * 255))

    # Define possible transformations
    transformations = []
    description_parts = []

    # Colors
    colors = ['black', 'red', 'blue', 'green', 'white']#, 'orange', 'purple', 'yellow', 'gray']
    color = random.choice(colors)
    colors.remove(color)
    bkg_color = random.choice(colors)
    description_parts.append(f'in {color} on a {bkg_color} background')


    # lims = (4,5)
    # zoom = random.choice(['in', 'out', 'none'])
    # if zoom == 'in':
    #     transformations.append(lambda img: img.crop((3, 3, 25, 25)))
    #     description_parts.append(random.choice(['zoomed in', 'large']))
    # elif zoom == 'out':
    #     transformations.append(lambda img: ImageOps.expand(img, border=6, fill=0))
    #     description_parts.append(random.choice(['zoomed out', 'small']))
    #     lims=(5,10)

    ### add more angles and flips
    # Rotation
    rotate = random.choice(['mirror','flip','rotate','nothing',])
    #rotate = random.choice(['mirror','flip','nothing','nothing_','nothing__'])
    if rotate=='rotate':
        rot_angle = random.choice([-90, 90, 180])
        # rot_angle = random.choice([-150, -135, -120, -90, -60, -45, -30, 30, 45, 60, 90, 120, 135, 150, 180])
        transformations.append(lambda img: img.rotate(rot_angle))
        if rot_angle>0:
            description_parts.append(f'rotated by {number_representation(rot_angle)} degrees clockwise')
        elif rot_angle<0:
            description_parts.append(f'rotated by {number_representation(-rot_angle)} degrees anticlockwise')
    elif rotate=='mirror':
        transformations.append(lambda img: ImageOps.mirror(img))
        description_parts.append(random.choice(['flipped horizontally', 'mirrored', 'reflected horizontally']) )
    elif rotate=='flip':
        transformations.append(lambda img: ImageOps.flip(img))
        description_parts.append(random.choice(['flipped upside-down', 'reflected vertically']))

    # if zoom!='in':
    #     # Shifts
    #     shift_lr = random.choice([True, False])
    #     shift_ud = random.choice([True, False])
    
    #     if shift_lr:
    #         delta_shift_lr = random.randint(*lims)*random.choice([-1,1])
    #         transformations.append(lambda img: img.transform(img.size, Image.AFFINE, (1, 0, delta_shift_lr, 0, 1, 0)))
    #         if delta_shift_lr>0:
    #             description_parts.append(f'shifted left by {number_representation(delta_shift_lr)} pixels')
    #         else:   description_parts.append(f'shifted right by {number_representation(-delta_shift_lr)} pixels')
    #     else:   description_parts.append('centered horizontally')
    #     if shift_ud:
    #         delta_shift_ud = random.randint(*lims)*random.choice([-1,1])
    #         transformations.append(lambda img: img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, delta_shift_ud)))
    #         if delta_shift_ud>0:
    #             description_parts.append(f'shifted up by {number_representation(delta_shift_ud)} pixels')
    #         else:  description_parts.append(f'shifted down by {number_representation(-delta_shift_ud)} pixels')
    #     else:   description_parts.append('centered vertically')


    # Apply transformations
    trans_image = pil_image
    for transform in transformations:
        trans_image = transform(trans_image)

     # Apply color transformation
    trans_image = apply_color_transform(trans_image, color, bkg_color)
    
    # # Resize the image to 28x28
    trans_image = trans_image.resize((28, 28))
    # trans_image = trans_image.resize((28, 28), Image.Resampling.LANCZOS)

    # Normalize the image to range 0-255 using autocontrast
    trans_image = ImageOps.autocontrast(trans_image, cutoff=0)

    # # Convert the transformed image back to a numpy array
    # trans_image0 = np.array(trans_image)# / 255.0
    # print(trans_image0.min(),trans_image0.max())

    # Shuffle the description parts before creating the sentence
    # random.shuffle(description_parts)

    # Create the description sentence
    start_sent = random.choice(['Create an image of', 'Display the digit', 'Represent the handwritten digit', 'Generate an illustration of',
                'Show the number',  'Depict the handwritten numeral', 'Produce an image of',  'Illustrate the digit', 'Craft a visual of',
                'Render the number', 'Display the handwritten number', 'Create a depiction of', 'Form an image of'])
    digit_repr = number_representation(label)
    sentence = f"{start_sent} {digit_repr} {', '.join(description_parts)}."

    return trans_image, sentence


In [7]:
import json
import matplotlib.pyplot as plt

viz = False
MNIST_images, MNIST_labels, MNIST_sentences = [], [],[]


# Create a dictionary to store MNIST filenames and sentences
mnist_descriptions_dict = {}

for file in train_images_files:#random.sample(train_images_files,30):
    image = plt.imread(os.path.join(train_path, file))

    filename = os.path.basename(file)
    label = int(filename[0])

    trans_image, description = transform_image(image, label)
    
    mnist_descriptions_dict[filename] = description

    # image = Image.fromarray(np.uint8(rgb_array))
    # Save the image as a PNG file
    trans_image.save(os.path.join('../text2image_dataset/train/images',filename))

    if viz:
        fig, axes = plt.subplots(1, 2, figsize=(8, 4))
        fig.suptitle(description, fontsize=12)
        # Plot original image
        axes[0].imshow(image, cmap='gray')
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        # Plot transformed image
        axes[1].imshow(trans_image)
        axes[1].set_title('Transformed Image')
        axes[1].axis('off')
        
        plt.show()

# Save the dictionary as a JSON file
with open('../text2image_dataset/train/image_descriptions.json', 'w') as jsonfile:
    json.dump(mnist_descriptions_dict, jsonfile)

    # break
    # MNIST_images.append(trans(image))
    
    # filename = os.path.basename(file)
    # label = int(filename[0])
    # MNIST_labels.append(label)

    # sentence = mnist_sentences.get(filename, "")
    # MNIST_sentences.append(sentence)
