In [1]:
import numpy as np
import torch
import torch.nn as nn
from os import listdir
from random import shuffle
import pygame

w, h = 64, 64

pygame 2.5.1 (SDL 2.28.2, Python 3.11.4)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [12]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),  # 8, 32, 32
            nn.ReLU(),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),  # 16, 16, 16
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 32, 8, 8
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 64, 4, 4
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 128, 2, 2
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),  # 256, 1, 1
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [83]:
model = Autoencoder()
criterion = nn.MSELoss()

In [84]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [6]:
def img_to_data(file):
    img = pygame.image.load(f'data/{file}')
    r, g, b = pygame.surfarray.array_red(img), pygame.surfarray.array_green(img), pygame.surfarray.array_blue(
        img)
    img = np.array([r, g, b], dtype=np.float32) / 255
    return img


def data_to_img(img):
    res = np.array(img * 255, dtype=np.int32)
    r, g, b = res[0].reshape((3, 64, 64, 1))
    res = np.concatenate([r, g, b], axis=2)
    res = pygame.surfarray.make_surface(res)
    return res

In [85]:
def get():
    files = listdir('data')
    shuffle(files)
    tests = []
    for _ in range(len(files) // 3200):
        imgs = []
        for _ in range(32):
            file = files.pop()
            img = pygame.image.load(f'data/{file}')
            r, g, b = pygame.surfarray.array_red(img), pygame.surfarray.array_green(img), pygame.surfarray.array_blue(
                img)
            img = np.array([r, g, b])
            imgs.append(img / 255)
        imgs = torch.tensor(np.array(imgs, dtype=np.float32))
        tests.append(imgs)
    return tests

In [97]:
num_epochs = 400
for epoch in range(num_epochs):
    losses = []
    for img in get():
        output = model(img)
        loss = criterion(output, img)
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch:{epoch + 1}, Loss:{sum(losses) / len(losses)}')
zxc = model.encoder(img)
print(zxc.max(), zxc.min())

Epoch:1, Loss:0.014477180590962662
Epoch:2, Loss:0.014556988690267591
Epoch:3, Loss:0.014512542361284004
Epoch:4, Loss:0.014745919307803406
Epoch:5, Loss:0.014562962412395897
Epoch:6, Loss:0.014369944767916904
Epoch:7, Loss:0.014398236594655934
Epoch:8, Loss:0.015047254424323054
Epoch:9, Loss:0.014499389182995348
Epoch:10, Loss:0.01469624469823697
Epoch:11, Loss:0.01458678427426254
Epoch:12, Loss:0.014598760863437373
Epoch:13, Loss:0.014730195460074088
Epoch:14, Loss:0.01473677930805613
Epoch:15, Loss:0.015093446270946194
Epoch:16, Loss:0.014383772552451667
Epoch:17, Loss:0.014654928708777708
Epoch:18, Loss:0.014386252061847378
Epoch:19, Loss:0.01447010286809767
Epoch:20, Loss:0.014687464572489262
Epoch:21, Loss:0.01457569227718255
Epoch:22, Loss:0.014437719796072035
Epoch:23, Loss:0.014604549909777501
Epoch:24, Loss:0.015097227057113367
Epoch:25, Loss:0.01518965994610506
Epoch:26, Loss:0.014894939959049225
Epoch:27, Loss:0.014435376369339578
Epoch:28, Loss:0.014624281567247474
Epoch:2

In [155]:
torch.save(model.state_dict(), 'model1.pt')

In [101]:
files = listdir('data')
shuffle(files)
n = 8
sz = (64 * (n + 2), 64)
w, h = 1, 8
SURF = pygame.Surface((sz[0] * w, sz[1] * h))
for x in range(w):
    for y in range(h):
        file = [files.pop() for _ in range(n)]

        encoded = [model.encoder(torch.tensor(img_to_data(fl).reshape((1, 3, 64, 64)))) for fl in file]
        new = torch.tensor(encoded[0])
        for i in encoded[1:]: new += i
        encoded.append(new / n)

        decoded = [data_to_img(model.decoder(i).detach().numpy()) for i in encoded]
        surf = pygame.Surface(sz)
        for i in range(n + 1):
            surf.blit(decoded[i], (i * 64, 0))
        SURF.blit(surf, (x * sz[0], y * sz[1]))
pygame.image.save(SURF, 'shuffle.jpg')


  new = torch.tensor(encoded[0])


In [98]:
files = listdir('data')

shuffle(files)
n = 0
W, H = 20, 5
w, h = 64, 64
dis = pygame.display.set_mode((W * w, 2 * H * h))
for x in range(W):
    for y in range(H):
        img = pygame.image.load(f'data/{files[n]}')
        dis.blit(img, (x * w, 2 * y * h))
        r, g, b = pygame.surfarray.array_red(img), pygame.surfarray.array_green(img), pygame.surfarray.array_blue(img)
        img = pygame.surfarray.array3d(img)
        res = np.array([r, g, b], dtype=np.float32) / 255
        res = torch.tensor(res.reshape((1, 3, w, h)))
        res = model(res).detach().numpy()
        res = np.array(res * 255, dtype=np.int32)
        r, g, b = res[0].reshape((3, w, h, 1))
        res = np.concatenate([r, g, b], axis=2)
        res = pygame.surfarray.make_surface(res)
        dis.blit(res, (x * w, 2 * y * h + h))
        n += 1
while True:
    for ev in pygame.event.get():
        if ev.type == pygame.QUIT:
            pygame.image.save(dis, 'res.jpg')
            pygame.display.quit()
    pygame.display.update()

error: video system not initialized

In [52]:
files = listdir('data')
shuffle(files)
n = 0
W, H = 16, 8
w, h = 64, 64
mx, const = 30, 1000_000
dis = pygame.display.set_mode((W * w, H * h))
for x in range(W):
    for y in range(H):
        res = torch.randint(-const * mx, const * mx, (1, 256, 1, 1)) / const
        #res = torch.randn((1, 256, 1, 1))
        res = model.decoder(res).detach().numpy()
        res = np.array(res * 255, dtype=np.int32)
        r, g, b = res[0].reshape((3, w, h, 1))
        res = np.concatenate([r, g, b], axis=2)
        res = pygame.surfarray.make_surface(res)
        dis.blit(res, (x * w, y * h))
        n += 1
while True:
    for ev in pygame.event.get():
        if ev.type == pygame.QUIT:
            pygame.display.quit()
    pygame.display.update()

error: video system not initialized