In [None]:
import PIL
from PIL import Image
import torch
import numpy as np
from tqdm import tqdm

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
im = Image.open("../data/external/sprites/items/berries/aguav-berry.png")

In [None]:
#torch.tensor(im)

In [None]:
from torch.utils.data import Dataset, DataLoader
import torchvision
import os
from skimage import io, transform


#ds = datasets.ImageFolder('../data/external/')
class PokemonDataset(Dataset):
    
    normal_sprites_sub_dir = "pokemon"
    female_sub_dir = "female"

    def __init__(self, sprites_path, transform=None):
        self.sprites_path = sprites_path
        self.transform = transform
        self.files = os.listdir(os.path.join(sprites_path, self.normal_sprites_sub_dir))
        print(self.files[:10])
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = Image.open(
            os.path.join(
                self.sprites_path, 
                self.normal_sprites_sub_dir,
                self.files[idx]),
        ).convert('RGB')
        #image  = image.astype(float)

        if self.transform:
            image = self.transform(image)

        sample = {
            'image': image
        }

        return sample
from torchvision import transforms
ds = PokemonDataset(
    '../data/external/sprites',
    transform=transforms.Compose([
        transforms.Resize((96,96)),
        transforms.ToTensor(),
    ]))

In [None]:
ds[0]['image'].size()

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

class Encoder(torch.nn.Module):
    def __init__(self, input_shape, latent_shape):
        super(Encoder, self).__init__()
        flattened_size = torch.prod(torch.tensor(input_shape), 0)
        self.dense1 = torch.nn.Linear(flattened_size, 256)
        self.dense2 = torch.nn.Linear(256, 128)
        self.dense3 = torch.nn.Linear(128, 64)
        self.dense4 = torch.nn.Linear(64, 32)
        self.dense5 = torch.nn.Linear(32, latent_shape)
        self.f = nn.Flatten()
        #self.dense = torch.nn.Linear()

    def forward(self, x):
        x = self.f(x)

        x = F.relu(self.dense1(x))
        x = F.relu(self.dense2(x))
        x = F.relu(self.dense3(x))
        x = F.relu(self.dense4(x))
        x = F.relu(self.dense5(x))

        return x

class Decoder(torch.nn.Module):
    def __init__(self, latent_shape, output_shape):
        super(Decoder, self).__init__()
        self.output_flatten_shape = torch.prod(torch.tensor(output_shape), 0).item()
        self.output_shape = output_shape
        self.dense = torch.nn.Linear(latent_shape, self.output_flatten_shape)

        
        self.dense1 = nn.Linear(latent_shape, 32)
        self.dense2 = nn.Linear(32, 64)
        self.dense3 = nn.Linear(64, 128)
        self.dense4 = nn.Linear(128, 256)
        self.dense5 = nn.Linear(256, self.output_flatten_shape)
        #self.dense = torch.nn.Linear()

    def forward(self, x):
        x = F.relu(self.dense1(x))
        x = F.relu(self.dense2(x))
        x = F.relu(self.dense3(x))
        x = F.relu(self.dense4(x))
        x = F.relu(self.dense5(x))

        x = torch.reshape(x, [-1, *self.output_shape])
        return x

class Auto(torch.nn.Module):
    def __init__(self, input_shape, latent_shape):
        super(Auto, self).__init__()
        self.encoder = Encoder(input_shape, latent_shape)
        self.decoder = Decoder(latent_shape, input_shape)
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
#Image.fromarray(ds[0]['image'])

In [None]:
dl = DataLoader(
    ds, 
    batch_size=4, 
    shuffle=True)

In [None]:
next(iter(dl))['image'].size()

In [None]:
encoder = Encoder((96,96,3), 8)
x = encoder(next(iter(dl))['image'])

In [None]:
decoder = Decoder(8, (3, 96, 96))
y = decoder(x)

In [None]:
y[0].size()

In [None]:
im = transforms.ToPILImage()(y[0].detach().cpu().data)
im.show()

In [None]:
y[0].size()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.5), (0.5)),
])
train = torchvision.datasets.FashionMNIST('/tmp', download=True, train=True, transform=transform)
trainloader = DataLoader(
    train, 
    batch_size=128, 
    shuffle=True, 
    drop_last=True, 
    num_workers=2,
    persistent_workers=True, # makes short epochs start faster
    pin_memory=True
)


In [None]:
def display(tensor):
    im = transforms.ToPILImage()(tensor)
    im.show()

In [None]:
latent_size = 4
ae = Auto((1, 28, 28), latent_size).to(device)

tensor = ae(train[0][0].to(device))

display(train[0][0])
print(tensor.size())
display(tensor[0])



In [None]:
rt = torch.rand(latent_size) # fix values
print(rt.size())
ae.decoder(rt.to(device)).cpu().data.size()
display(ae.decoder(rt.to(device))[0].cpu().data)

In [None]:

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(ae.parameters(), lr=0.001)
for epoch in range(100):
    running_loss = 0
    total = 0 # use total as drop_last=True
    ae.train()
    for image, label in tqdm(trainloader):
        optimizer.zero_grad()
        #print(data[0])
        image = image.to(device)
        y_pred = ae(image)

        loss = criterion(y_pred, image)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total += image.size(0)
    print(f"loss: {running_loss/total}")
    ae.eval()
    with torch.no_grad():
        for idx in [0, 100, 1000]:
            im = transforms.ToPILImage()(ae(train[idx][0].to(device))[0].cpu().data)
            im.show()
        rt = torch.rand(latent_size) # fix values
        
        display(ae.decoder(rt.to(device))[0].cpu().data)

In [None]:
ae.eval()
with torch.no_grad():
    im = transforms.ToPILImage()(ae(train[1][0].to(device))[0].cpu().data)
    im.show()

In [None]:
ae(train[1][0].to(device)).size()