In [40]:
from tensor import TensorSpec, Tensor
from utils.index import Index

In [41]:
fn si8_to_ui8(int_tensor: Tensor[DType.int8]) raises -> Tensor[DType.uint8]:
    let spec = TensorSpec(DType.uint8, int_tensor.shape())
    var unsigned_ints = Tensor[DType.uint8](spec)
    for i in range(int_tensor.num_elements()):
        let signed_int = int_tensor[i]
        let unsigned_int = signed_int.cast[DType.uint8]()
        unsigned_ints[i] = unsigned_int
    return unsigned_ints

In [42]:
fn uint8_to_int32(uint8_tensor: Tensor[DType.uint8]) raises -> SIMD[DType.int32, 1]:
    let result = (
            (uint8_tensor[0].to_int() << 24) |
            (uint8_tensor[1].to_int() << 16) |
            (uint8_tensor[2].to_int() << 8) |
            uint8_tensor[3].to_int()
        )
    return SIMD[DType.int32, 1](result)

In [43]:
fn get_slice[type: DType](tensor: Tensor[type], start_index: Int, end_index: Int) raises -> Tensor[type]:
    if end_index < start_index:
        raise 'End index more than'
    elif end_index == start_index:
        var output_tensor = Tensor[type](1)
        output_tensor[0] = tensor[start_index]
        
        return output_tensor
    else:
        var output_tensor = Tensor[type](end_index - start_index)
        for i in range(start_index, end_index):
            output_tensor[i - start_index] = tensor[i]
        return output_tensor

In [44]:
fn read_data_as_images(images_path: Path) raises -> Tensor[DType.uint8]:
    if images_path.exists():

        let image_file = si8_to_ui8(images_path.read_bytes())
        let num_images = uint8_to_int32(get_slice[DType.uint8](image_file, 4, 8))
        let width = uint8_to_int32(get_slice[DType.uint8](image_file, 8, 12))
        let height = uint8_to_int32(get_slice[DType.uint8](image_file, 12, 16))

        let image_data = get_slice(image_file, 16, image_file.num_elements())
        
        var images = Tensor[DType.uint8](num_images.to_int(), height.to_int(), width.to_int())

        var sample_num = 0
        var column_num = 0
        var row_num = 0
        for byte_num in range(image_data.num_elements()):
            images[Index(sample_num, column_num, row_num)] = image_data[byte_num]
            row_num += 1
            if row_num >= width.to_int():
                row_num = 0
                column_num += 1
            if column_num >= height.to_int():
                column_num = 0
                sample_num += 1
        return images
    raise 'The images directory does not exist.'

In [45]:
fn read_data_as_labels(labels_path: Path) raises -> Tensor[DType.uint8]:
    if labels_path.exists():
        let label_file = si8_to_ui8(labels_path.read_bytes())
        let num_labels = uint8_to_int32(get_slice[DType.uint8](label_file, 4, 8))

        let labels = get_slice(label_file, 8, num_labels.to_int() + 8)
        
        return labels
    raise 'The labels directory does not exist.'

In [51]:
fn tensor_print[type: DType](index: Int, tensor: Tensor[type]):
    var cur_line: String
    fn get_char_for_pixel(pixel_value: Int) -> String:
        if pixel_value == 0:
            return ' '
        elif pixel_value < 32:
            return "."
        elif pixel_value < 64:
            return ','
        elif pixel_value < 96:
            return ':'
        elif pixel_value < 128:
            return ';'
        elif pixel_value < 160:
            return 'o'
        elif pixel_value < 192:
            return 'O'
        elif pixel_value < 224:
            return 'X'
        else:
            return '#'
    for j in range(tensor.shape()[1]):
        cur_line = '' 
        for k in range(tensor.shape()[2]):
            cur_line += get_char_for_pixel(tensor[Index(index, j, k)].to_int()) + ' '
        print(cur_line)

In [52]:
let base_dir = '/Users/tprazak/Documents/seminary_work_nn/MNIST_in_mojo/'
let image_path_mnist = base_dir + 'mnist/train-images.idx3-ubyte'
let labels_path_mnist = base_dir + 'mnist/train-labels.idx1-ubyte'
let image_path_fashion = base_dir + 'fashion_mnist/train-images-idx3-ubyte'
let labels_path_fashion = base_dir + 'fashion_mnist/train-labels-idx1-ubyte'
let images = read_data_as_images(image_path_mnist)
let labels = read_data_as_labels(labels_path_mnist)
let fashion_images = read_data_as_images(image_path_fashion)
let fashion_labels = read_data_as_labels(labels_path_fashion)

In [61]:
tensor_print(5232, images)

                                                        
                                                        
                                                        
                                                        
                          : X # # X o                   
                        : # # O : X # .                 
                        X X ,     : # X                 
                        : o       : # O                 
                          .       : # ,                 
                                . X #                   
                                o # ,                   
                              : # O                     
                            . # # ,                     
                            O # :                       
                          o # o                         
                        . # X                           
                        O # ,                           
                      ; # o    

In [58]:
print(labels[3])

1
