In [1]:
import numpy as np
from PIL import Image

In [2]:
MAX_RGB_VALUE = 255
COLOR_CHANNELS_AMOUNT = 3

In [3]:
def image_to_blocks(image, b_h, b_w, overlap = 0):
    i_h, i_w = image.shape[:2]

    step_h = int(b_h * (1 - overlap))
    step_w = int(b_w * (1 - overlap))

    blocks = []

    for i in range(0, i_h - b_h + 1, step_h):
        for j in range(0, i_w - b_w + 1, step_w):
            block = image[i:i+b_h, j:j+b_w]                                  
            blocks.append(block)    
    
    if i_h % b_h != 0:
        for j in range(0, i_w - b_w + 1, step_w):
            block = image[i_h-b_h:i_h, j:j+b_w]
            blocks.append(block)    
    
    if i_w % b_w != 0:
        for i in range(0, i_h - b_h + 1, step_h):
            block = image[i:i+b_h, i_w-b_w:i_w]
            blocks.append(block)    
    
    if i_h % b_h != 0 and i_w % b_w != 0:
        block = image[i_h-b_h:i_h, i_w-b_w:i_w]
        blocks.append(block)
    
    return np.asarray(blocks)


def blocks_to_image(image_blocks, image_shape, b_h, b_w, overlap = 0):
    i_h, i_w = image_shape[:2]
    c = image_shape[2] if len(image_shape) == 3 else 1

    restored_image = np.zeros((i_h, i_w, c), dtype=np.float64)
    count_matrix = np.zeros((i_h, i_w), dtype=np.float64)
    
    step_h = int(b_h * (1 - overlap))
    step_w = int(b_w * (1 - overlap))
    
    block_index = 0
    
    for i in range(0, i_h - b_h + 1, step_h):
        for j in range(0, i_w - b_w + 1, step_w):
            block = image_blocks[block_index]            
            restored_image[i:i+b_h, j:j+b_w] += block
            count_matrix[i:i+b_h, j:j+b_w] += 1
            block_index += 1    
    
    if i_h % b_h != 0:
        for j in range(0, i_w - b_w + 1, step_w):
            block = image_blocks[block_index]
            restored_image[i_h-b_h:i_h, j:j+b_w] += block
            count_matrix[i_h-b_h:i_h, j:j+b_w] += 1
            block_index += 1    
    
    if i_w % b_w != 0:
        for i in range(0, i_h - b_h + 1, step_h):
            block = image_blocks[block_index]
            restored_image[i:i+b_h, i_w-b_w:i_w] += block
            count_matrix[i:i+b_h, i_w-b_w:i_w] += 1
            block_index += 1    
    
    if i_h % b_h != 0 and i_w % b_w != 0:
        block = image_blocks[block_index]
        restored_image[i_h-b_h:i_h, i_w-b_w:i_w] += block
        count_matrix[i_h-b_h:i_h, i_w-b_w:i_w] += 1    
    
    count_matrix[count_matrix == 0] = 1    
    restored_image = restored_image / count_matrix[..., np.newaxis]    
    
    return restored_image.astype(np.uint8)

In [4]:
def normalize_weights(weights):
    norms = np.linalg.norm(weights, axis=0)
    return weights / norms

# Функция активации
def linear_activation(x):
    return x

class LRNN:
    def __init__(self, input_dim, latent_dim, learning_rate=0.001):        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.learning_rate = learning_rate        
        
        self.W_enc = normalize_weights(np.random.randn(self.input_dim, self.latent_dim))
        self.W_dec = normalize_weights(np.random.randn(self.latent_dim, self.input_dim))
    
    def forward(self, x):
        z = linear_activation(x @ self.W_enc)
        x_reconstructed = linear_activation(z @ self.W_dec)
        return z, x_reconstructed
    
    def backward(self, x, x_reconstructed):
        error = x_reconstructed - x        
        
        dW_enc = (x.T @ error) @ self.W_dec.T
        dW_dec = (x @ self.W_enc).T @ error               
        
        self.W_dec -= self.learning_rate * dW_dec
        self.W_enc -= self.learning_rate * dW_enc        
        
        self.W_dec = normalize_weights(self.W_dec)
        self.W_enc = normalize_weights(self.W_enc)
    
    def train(self, data, epochs=1000):
        for epoch in range(epochs):
            total_loss = 0
            for x in data:                
                x = np.matrix(x)
                _, x_reconstructed = self.forward(x)
                self.backward(x, x_reconstructed)
                total_loss += np.sum(np.array(x - x_reconstructed) ** 2)            
            print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss}')

In [5]:
# Image compression/decompression pipeline
def compress_image(compression_weights, img_array, channels_amount: int,
                   block_height: int, block_width: int, overlap: float = 0):    
    normalized = (2.0 * img_array.astype(np.float32) / MAX_RGB_VALUE) - 1.0
    blocks = image_to_blocks(normalized, block_height, block_width, overlap)
    blocks = blocks.reshape((len(blocks), block_height * block_width, channels_amount))
    if channels_amount == 3:
        blocks = blocks.transpose(0, 2, 1)    
    blocks = np.einsum('ijk,kl->ijl', blocks, compression_weights)     
    return blocks
    

def decompress_image(decompression_weights, compressed_img, img_shape, channels_amount: int,
                     block_height: int, block_width: int, overlap: float = 0) -> Image.Image:
    compressed_img = np.einsum('ijk,kl->ijl', compressed_img, decompression_weights)
    compressed_img = MAX_RGB_VALUE * (compressed_img + 1.0) / 2.0
    if channels_amount == 3:
        compressed_img = compressed_img.transpose(0, 2, 1)
    compressed_img = compressed_img.reshape((len(compressed_img), block_height, block_width, channels_amount))    
    img_array = blocks_to_image(compressed_img, img_shape, block_height, block_width, overlap)    
    return Image.fromarray(img_array).convert('RGB' if channels_amount == 3 else 'L')

In [7]:
# Collecting everything

block_width = 10
block_height = 10

img = Image.open('test_cat.jpg')
img_array = np.asarray(img)
shape = img_array.shape
blocks = image_to_blocks(img_array, block_height, block_width, overlap=0)

one_color = blocks[:, : ,:, 0]
one_color = (2 * one_color / MAX_RGB_VALUE) - 1
one_color = one_color.reshape((len(blocks), 10 * 10))

train = np.matrix(one_color[np.random.choice(one_color.shape[0], int(one_color.shape[0] * 0.2))])
train.shape
train[0].shape

network = LRNN(100, 64, 0.001)
network.train(train, 150)

compressed = compress_image(network.W_enc, img_array, COLOR_CHANNELS_AMOUNT, 10, 10, 0)
dimg = decompress_image(network.W_dec, compressed, shape, COLOR_CHANNELS_AMOUNT, 10, 10, 0)
dimg_array = np.asarray(dimg)
dimg.save('compression-decompression_test.jpg')

Epoch 1/150, Loss: 3065.939340363546
Epoch 2/150, Loss: 1692.9219769108913
Epoch 3/150, Loss: 1422.8208767135925
Epoch 4/150, Loss: 1247.0073156453743
Epoch 5/150, Loss: 1122.6236225121174
Epoch 6/150, Loss: 1031.2792865459267
Epoch 7/150, Loss: 961.5725316438369
Epoch 8/150, Loss: 906.0676450035093
Epoch 9/150, Loss: 860.0709914762871
Epoch 10/150, Loss: 820.6942862532896
Epoch 11/150, Loss: 786.1590530818236
Epoch 12/150, Loss: 755.342457309445
Epoch 13/150, Loss: 727.5051744111995
Epoch 14/150, Loss: 702.1352583793844
Epoch 15/150, Loss: 678.859746659685
Epoch 16/150, Loss: 657.3942512346034
Epoch 17/150, Loss: 637.5135895583261
Epoch 18/150, Loss: 619.034126988028
Epoch 19/150, Loss: 601.802743559586
Epoch 20/150, Loss: 585.6896366828572
Epoch 21/150, Loss: 570.5834105074372
Epoch 22/150, Loss: 556.3875751192836
Epoch 23/150, Loss: 543.0179483582853
Epoch 24/150, Loss: 530.4006595056763
Epoch 25/150, Loss: 518.4705714097863
Epoch 26/150, Loss: 507.1700054979708
Epoch 27/150, Loss: 