In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import datasets, transforms

import torchvision
import torchvision.transforms as T

import cv2

import matplotlib.pyplot as plt
import numpy as np
from tqdm import *
import random
from PIL import Image
import traceback

In [None]:
batch_size = 16
cuda = torch.device('cuda:1')
transform_size = 128

In [None]:
# !unzip downloaded.zip -d BigData
import glob

files = glob.glob("BigData/downloaded/*/*.webp")

In [None]:
import pickle

with open("BigData/downloaded/metadata.pkl", "rb") as metadata:
    emojies = pickle.load(metadata)
for key, value in emojies.items():
    lst = []

    for _, value_im in value.items():
        lst.append(value_im["emoji"])

    emojies[key] = lst

In [None]:
unique_emojies = []
for key in emojies:
    unique_emojies.extend(emojies[key])
unique_emojies = list(set(unique_emojies))

emoji_id = {}
for i in range(len(unique_emojies)):
    emoji_id[unique_emojies[i]] = i

In [None]:
# pairs=[]
# for stickerpack,pics in emojies.items():
#   for link in files:
#     s = '/' + stickerpack + '/'
#     if s in link:
#       id_ = int(link[link.rfind('/') + 1:link.rfind('.')])
#       pairs.append((link,emoji_id[pics[id_]]))

all_stickerpacks = {}
link_emoji_stickerpack = []
for stickerpack, pics in emojies.items():

    all_stickerpacks[stickerpack] = []

    for link in files:
        s = '/' + stickerpack + '/'
        if s in link:
            id_ = int(link[link.rfind('/') + 1:link.rfind('.')])
            all_stickerpacks[stickerpack].append((link, emoji_id[pics[id_]]))
            link_emoji_stickerpack.append((link, emoji_id[pics[id_]], stickerpack))


In [None]:
def transform_img_cv2(link):  # РАБОАТЕТ преобразуем ссылки в картинки

    img = cv2.imread(link, cv2.IMREAD_UNCHANGED)  # открываем картинку

    # resize image
    resized = cv2.resize(img, (transform_size, transform_size), interpolation=cv2.INTER_AREA)

    img_array = np.array(resized, dtype="float32")  # переводим картинку в np
    img = torch.from_numpy(img_array)  # tensor
    if img.size()[2] != 4:
        mask = torch.zeros(transform_size, transform_size, 4)  # mask
        mask[:, :, :img.size()[2]] = img
        img = mask
    img = torch.permute(img, (2, 0, 1))  # переставляем канал на первую позицию

    return img / 255  # возвращаем тенсор размера в батч

In [None]:
class MyDataset(Dataset):
    def __init__(self, x, y):
        self.X = x
        self.Y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        link, label_encode, label_decode = self.X[index]
        return transform_img_cv2(link), label_encode, label_decode, transform_img_cv2(self.Y[index])

class NewDataset(Dataset):
    def __init__(self, x, y):
        self.X = x
        self.Y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.Y[index]


In [None]:
print("ATTENTION: write YY if you HAVE MORE 12 GB RAM to rewrite data or Y to download data or just skip")
input_ = input()
if "Y" in input_:
    if "YY" in input_:
        data_to_save_x = []
        data_to_save_y = []

        for picture, emoji_i, stickerpack in link_emoji_stickerpack:
            picture1, emoji_i1 = random.choice(all_stickerpacks[stickerpack])

            data_to_save_x.append((transform_img_cv2(picture), emoji_i, emoji_i1))
            data_to_save_y.append(transform_img_cv2(picture1))

        data = NewDataset(data_to_save_x, data_to_save_y)
        with open("BigData/data.pkl", "wb") as save_data:
            pickle.dump(data, save_data)

    with open("BigData/data.pkl", "rb") as data_load:
        data_load = pickle.load(data_load)
        data_loader = torch.utils.data.DataLoader(data_load, batch_size=batch_size, shuffle=True)
else:

    data_x = []
    data_y = []
    for picture, emoji_i, stickerpack in link_emoji_stickerpack:
        picture1, emoji_i1 = random.choice(all_stickerpacks[stickerpack])

        data_x.append((picture, emoji_i, emoji_i1))
        data_y.append(picture1)

    data = MyDataset(data_x, data_y)
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

ATTENTION: write YY if you HAVE MORE 12 GB RAM to rewrite data or Y to download data or just skip


 Y


#МОДЕЛь


In [None]:
class Autoencoder_Working(nn.Module):
    def __init__(self):
        super().__init__()
        self.embeddings = nn.Embedding(len(unique_emojies), 64)
        self.encoder = nn.Sequential(
            nn.Conv2d(68, 16, 3, stride=2, padding=1),  # 64
            nn.InstanceNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 64, 3, stride=2, padding=1),  # 32
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 256, 3, stride=2, padding=1),  # 16
            nn.InstanceNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 1024, 3, stride=2, padding=1),  # 8
            nn.InstanceNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(1024, 4096, 3, stride=2, padding=1),  # 4
            nn.InstanceNorm2d(4096),
            nn.Conv2d(4096, 4096, 3, stride=2, padding=1),  # 2
            nn.InstanceNorm2d(4096)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2048 + 64, 4096, 4, stride=2, padding=1),  # 4
            nn.InstanceNorm2d(4096),
            nn.ReLU(),
            nn.ConvTranspose2d(4096, 1024, 4, stride=2, padding=1),  # 8
            nn.InstanceNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 256, 4, stride=2, padding=1),  # 16
            nn.InstanceNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 32, 4, stride=2, padding=1),  # 32
            nn.InstanceNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 8, 4, stride=2, padding=1),  # 64
            nn.InstanceNorm2d(8),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 4, 4, stride=2, padding=1),  # (n,4,128,128)
            nn.InstanceNorm2d(4),
            nn.Sigmoid()
        )

    def Encoder_func(self, x, label_encode):
        label_encode = label_encode.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, x.shape[2], x.shape[3])
        x = torch.cat([x, label_encode], 1)

        encoded = self.encoder(x)  # свертка

        return encoded

    def Decoder_func(self, x, label_decode):
        label_decode = label_decode.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, x.shape[2], x.shape[3])
        x = torch.cat([x, label_decode], 1)
        decoded = self.decoder(x)

        return decoded

    def _sample_latent(self, h_enc):
        mu = h_enc[:, :2048]
        log_sigma = h_enc[:, 2048:]
        sigma = torch.exp(log_sigma)

        return mu + sigma * torch.randn_like(sigma), mu, sigma  # Reparameterization trick

    def latent_loss(self, mu, sigma):
        mean_sq = mu ** 2
        stddev_sq = sigma ** 2
        return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

    def forward(self, tensor, label_encode, label_decode):
        label_encode, label_decode = torch.tensor(label_encode, dtype=torch.long), torch.tensor(label_decode,
                                                                                                dtype=torch.long)

        # КОДИРОВНИЕ
        encoded = self.Encoder_func(tensor, self.embeddings(label_encode))
        # print(encoded.size())

        # LOSS
        encoded, mu, sigma = self._sample_latent(encoded)
        loss = self.latent_loss(mu, sigma)

        # ДЕКОДИРОВАНИЕ
        decoded = self.Decoder_func(encoded, self.embeddings(label_decode))
        # print(decoded.size())
        return decoded, loss

In [None]:
print("ATTENTION: write Y if you want to rewrite model else just skip or any symbol")
input_ = input()
if input_ == "Y":
    model = Autoencoder_Working().to(cuda)
else:
    with open("BigData/model.pkl", "rb") as md:
        model = pickle.load(md)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

ATTENTION: write Y if you want to rewrite model else just skip or any symbol


 Y


In [None]:
num_epochs = 5
outputs = []
try:
    for epoch in range(num_epochs):
        for tensor, label_encode, label_decode, y in tqdm(data_loader):
            tensor, label_encode, label_decode, y = tensor.to(cuda), label_encode.to(cuda), label_decode.to(cuda), y.to(
                cuda)
            recon, loss_kld = model(tensor, label_encode, label_decode)
            loss_rec = criterion(recon, y)
            loss = loss_rec + loss_kld

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch:{epoch + 1}, Loss:{loss.item():.4f}')
        outputs.append(loss.item())

        # сохр МОдели
        with open("BigData/model.pkl", "wb") as md:
            pickle.dump(model, md)

        # ---- картинки ----
        _, axarr = plt.subplots(1, 2)

        resized = cv2.resize(recon[-1].permute(1, 2, 0).cpu().detach().numpy(), (512, 512),
                             interpolation=cv2.INTER_AREA)
        axarr[0].imshow(resized)
        axarr[1].imshow(y[-1].permute(1, 2, 0).cpu().detach().numpy())

except Exception as e:
print('FATAL ERROR:\n', traceback.format_exc())
print("FATAL ERROR")

# dataset_shuffle
plt.figure(figsize=(12, 4))
plt.plot(range(len(outputs)), outputs)
plt.tight_layout()
plt.show()

SyntaxError: ignored

In [None]:
with open("BigData/model.pkl", "wb") as md:
    pickle.dump(model, md)

In [None]:
def make_stickers(x, y, z):
    with open("BigData/model.pkl", "wb") as md:
        model = pickle.load(md)
    img_tensor = model(x, y, z)
    resized = cv2.resize(img_tensor.permute(1, 2, 0).cpu().detach().numpy(), (512, 512), interpolation=cv2.INTER_AREA)
    print(dtype(resized))
    return 0