Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ data/
*.jpg
node_modules/
.env
.DS_Store
.DS_Store
models/
15 changes: 15 additions & 0 deletions ImageTransformer.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import Jimp from 'jimp'

/**
* Converts pixel arrays to real images which can be saved to disk
*/
export class ImageTransformer {
saveImage(img, width = 28, height = 28, path) {
new Jimp({ width, height, data: Buffer.from(img) }, (_, img) => img.write(path))
}

toImages(data, filePrefix = 'processed', width = 28, height = 28) {
const imgs = data.map(img => img.flatMap(val => [val * 255, val * 255, val * 255, 255]))
imgs.forEach((img, i) => this.saveImage(img, width, height, `output/${filePrefix}_${i}.png`))
}
}
44 changes: 44 additions & 0 deletions arbitraryImageDataSource.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import fs from 'fs'
import sharp from 'sharp'
import tf from '@tensorflow/tfjs-node'

/**
* Loads images from the images directory and processes them to fit the model
*/
export class ArbitraryImageDataSource {
constructor(countTraining = 1000, countTest = 10) {
const files = fs.readdirSync('images')
.filter(f => f.endsWith('.jpg') || f.endsWith('.jpeg') || f.endsWith('.png'))
.map(f => `images/${f}`)

this.trainingFiles = files.shuffle().slice(0, countTraining)
this.testFiles = files.shuffle().slice(0, countTest)
}

async getTrainingData() {
const data = await Promise.all(this.trainingFiles.map(f => this._processImageFile(f)))
return tf.tensor(data).div(255)
}

async getTestData() {
const data = await Promise.all(this.testFiles.map(f => this._processImageFile(f)))
return tf.tensor(data).div(255)
}

_processImageFile(filename) {
return sharp(filename)
.resize(28, 28, {
fit: 'cover'
})
.gamma()
.greyscale()
.raw()
.toBuffer()
}
}

Array.prototype.shuffle = function () {
return this.map((value) => ({ value, sort: Math.random() }))
.sort((a, b) => a.sort - b.sort)
.map(({ value }) => value)
}
144 changes: 31 additions & 113 deletions index.js
Original file line number Diff line number Diff line change
@@ -1,119 +1,37 @@
console.log("Hello Autoencoder 🚂");
import { Model } from './model.js'
import { MnistDataSource } from './mnistDataSource.js'
import { ImageTransformer } from './ImageTransformer.js'
import { RandomDataSource } from './randomDataSource.js'
import { ArbitraryImageDataSource } from './arbitraryImageDataSource.js'

import * as tf from "@tensorflow/tfjs-node";
// import canvas from "canvas";
// const { loadImage } = canvas;
import Jimp from "jimp";
import numeral from "numeral";

main();
main()

async function main() {
// Build the model
const autoencoder = buildModel();
// load all image data
const images = await loadImages(550);

// train the model
const x_train = tf.tensor2d(images.slice(0, 500));
await trainModel(autoencoder, x_train, 250);

// test the model
const x_test = tf.tensor2d(images.slice(500));
await generateTests(autoencoder, x_test);
}

async function generateTests(autoencoder, x_test) {
const output = autoencoder.predict(x_test);
// output.print();

const newImages = await output.array();
for (let i = 0; i < newImages.length; i++) {
const img = newImages[i];
const buffer = [];
for (let n = 0; n < img.length; n++) {
const val = Math.floor(img[n] * 255);
buffer[n * 4 + 0] = val;
buffer[n * 4 + 1] = val;
buffer[n * 4 + 2] = val;
buffer[n * 4 + 3] = 255;
// Instantiate the model
const model = new Model()

// Instantiate a data source
const dataSource = new MnistDataSource()
// const dataSource = new RandomDataSource()
// const dataSource = new ArbitraryImageDataSource()

// Instatiate the Image transformer
const transformer = new ImageTransformer()

// Check if there is a pretrained model. If it exists load it, or train the model
if (model.pretrainedModelExists()) {
await model.load()
} else {
// Create the layers
model.configure()
// and train
await model.train(await dataSource.getTrainingData(), 200)
}
const image = new Jimp(
{
data: Buffer.from(buffer),
width: 28,
height: 28,
},
(err, image) => {
const num = numeral(i).format("000");
image.write(`output/square${num}.png`);
}
);
}
}

function buildModel() {
const autoencoder = tf.sequential();
// Build the model
autoencoder.add(
tf.layers.dense({
units: 256,
inputShape: [784],
activation: "relu",
})
);
autoencoder.add(
tf.layers.dense({
units: 128,
activation: "relu",
})
);
// Test the model with testing data from the data source
const testData = await dataSource.getTestData()
const autoEncodedImages = model.autoencode(testData)

autoencoder.add(
tf.layers.dense({
units: 256,
activation: "sigmoid",
})
);

autoencoder.add(
tf.layers.dense({
units: 784,
activation: "sigmoid",
})
);
autoencoder.compile({
optimizer: "adam",
loss: "binaryCrossentropy",
metrics: ["accuracy"],
});
return autoencoder;
}

async function trainModel(autoencoder, x_train, epochs) {
await autoencoder.fit(x_train, x_train, {
epochs: epochs,
batch_size: 32,
shuffle: true,
verbose: true,
});
}

async function loadImages(total) {
const allImages = [];
for (let i = 0; i < total; i++) {
const num = numeral(i).format("000");
const img = await Jimp.read(`data/square${num}.png`);

let rawData = [];
for (let n = 0; n < 28 * 28; n++) {
let index = n * 4;
let r = img.bitmap.data[index + 0];
// let g = img.bitmap.data[n + 1];
// let b = img.bitmap.data[n + 2];
rawData[n] = r / 255.0;
}
allImages[i] = rawData;
}
return allImages;
// save the images to disk
transformer.toImages(testData.arraySync(), 'org')
transformer.toImages(autoEncodedImages.arraySync())
}
20 changes: 20 additions & 0 deletions mnistDataSource.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import mnist from 'mnist'
import tf from '@tensorflow/tfjs-node'

/**
* Load data from the mnist data set
*/
export class MnistDataSource {
constructor(countTraining = 1000, countTest = 10) {
const { training, test } = mnist.set(countTraining, countTest)
this.training = training.map(x => x.input)
this.test = test.map(x => x.input)
}

getTrainingData() {
return tf.tensor(this.training)
}
getTestData() {
return tf.tensor(this.test)
}
}
67 changes: 67 additions & 0 deletions model.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import fs from 'fs'
import tf from '@tensorflow/tfjs-node'

/**
* Abstraction for tfjs
*/
export class Model {
pretrainedModelExists() {
return fs.existsSync('models/autoencoder/model.json') &&
fs.existsSync('models/encoder/model.json') &&
fs.existsSync('models/decoder/model.json')
}

async load() {
this.autoencoder = await tf.loadLayersModel('file://models/autoencoder/model.json')
this.encoder = await tf.loadLayersModel('file://models/encoder/model.json')
this.decoder = await tf.loadLayersModel('file://models/decoder/model.json')
}

configure() {
const encoded = [
tf.layers.dense({ units: 128, inputShape: [784], activation: "relu" }),
tf.layers.dense({ units: 64, activation: "relu" }),
tf.layers.dense({ units: 32, activation: "relu" }),
]
const decoded = [
tf.layers.dense({ units: 64, activation: "relu" }),
tf.layers.dense({ units: 128, activation: "relu" }),
tf.layers.dense({ units: 784, activation: "sigmoid" }),
]

this.autoencoder = tf.sequential({ layers: [...encoded, ...decoded] })
this.encoder = tf.sequential({ layers: encoded })

const encoded_input = tf.layers.inputLayer({ inputShape: [32] })
this.decoder = tf.sequential({ layers: [encoded_input, ...decoded] })

this.autoencoder.compile({
optimizer: 'adam',
loss: 'binaryCrossentropy',
})
}

async train(x_train, epochs = 100) {
await this.autoencoder.fit(x_train, x_train, {
epochs,
batchSize: 32,
shuffle: true,
})
fs.mkdirSync('models')
await this.autoencoder.save('file://models/autoencoder')
await this.encoder.save('file://models/encoder')
await this.decoder.save('file://models/decoder')
}

autoencode(data) {
return this.decode(this.encode(data))
}

encode (data) {
return this.encoder.predict(data)
}

decode(encoded) {
return this.decoder.predict(encoded)
}
}