Пожалуйста, перед началом работы создайте файл `config.json` с Вашим API ключом от телеграм бота:
```json
{
    "telegram_token": "YOUR_API_KEY"
}

In [3]:
!pip install torch torchvision --quiet
!pip install python-telegram-bot --quiet
!apt install -y libjpeg-dev zlib1g-dev
!pip install gdown

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m708.7/708.7 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
libjpeg-dev is already the newest version (8c-2ubuntu10).
zlib1g-dev is already the newest version (1:1.2.11.dfsg-2ubuntu9.2).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.


In [4]:
# Imports
import gdown
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from io import BytesIO
import json
import os
from datetime import datetime
from telegram import Update, ReplyKeyboardMarkup
from telegram.ext import (
    ApplicationBuilder,
    CommandHandler,
    MessageHandler,
    ContextTypes,
    filters
)

In [5]:
url = 'https://drive.google.com/uc?id=1iCbjhPnJ2sJqGzcjU2L5ZJOZZt7JKXyi'
output = 'model_weights.zip'
gdown.download(url, output, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=1iCbjhPnJ2sJqGzcjU2L5ZJOZZt7JKXyi
From (redirected): https://drive.google.com/uc?id=1iCbjhPnJ2sJqGzcjU2L5ZJOZZt7JKXyi&confirm=t&uuid=d38f9e8c-d79e-4588-bab7-7ed98de3a181
To: /content/model_weights.zip
100%|██████████| 140M/140M [00:01<00:00, 94.7MB/s]


'model_weights.zip'

In [6]:
!unzip model_weights.zip

Archive:  model_weights.zip
   creating: model_weights/
  inflating: model_weights/decoder.pth  
  inflating: model_weights/decoder_van_gogh.pth  
  inflating: model_weights/decoder_picasso.pth  
  inflating: model_weights/decoder_monet.pth  
  inflating: model_weights/vgg_normalised.pth  
  inflating: model_weights/decoder_trained.pth  


In [7]:
url = 'https://drive.google.com/uc?id=1TdpdEmBiA267KclBplEvOczgX4HGRbFj'
output = 'test_images.zip'
gdown.download(url, output, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1TdpdEmBiA267KclBplEvOczgX4HGRbFj
To: /content/test_images.zip
100%|██████████| 1.92M/1.92M [00:00<00:00, 30.1MB/s]


'test_images.zip'

In [8]:
!unzip test_images.zip

Archive:  test_images.zip
   creating: test_images/
   creating: test_images/content/
  inflating: test_images/content/dancing.jpg  
  inflating: test_images/bot_interface.png  
   creating: test_images/style/
  inflating: test_images/style/picasso.jpg  
  inflating: test_images/style/van_gogh.jpg  
  inflating: test_images/style/monet.jpg  
  inflating: test_images/style/style.jpg  
  inflating: test_images/bot_transfer.png  
  inflating: test_images/bot_van_gogh.png  


In [9]:
# image io
transform = transforms.Compose([
    transforms.Resize(512),
    transforms.ToTensor()
])


def load_image(image_bytes):
    image = Image.open(BytesIO(image_bytes)).convert("RGB")
    return transform(image)


In [10]:
# adain_utils

def calc_mean_std(feat, eps=1e-5):
    """
    Calculate the channel-wise mean and standard deviation of a feature tensor.

    Args:
        feat (Tensor): Input tensor
        eps (float): Small constant to avoid division by zero.

    Returns:
        Mean and standard deviation tensors.
    """
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):
    """
    Apply Adaptive Instance Normalization to content features using style features.

    Args:
        content_feat (Tensor): Content features
        style_feat (Tensor): Style features

    Returns:
        Tensor: Stylized feature tensor of the same shape as content_feat.
    """
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


def _calc_feat_flatten_mean_std(feat):
    """
    Flatten 3D feature map and compute per-channel mean and std.
    """
    assert (feat.size()[0] == 3)
    assert (isinstance(feat, torch.FloatTensor))
    feat_flatten = feat.view(3, -1)
    mean = feat_flatten.mean(dim=-1, keepdim=True)
    std = feat_flatten.std(dim=-1, keepdim=True)
    return feat_flatten, mean, std


def _mat_sqrt(x):
    """
    Compute the matrix square root using SVD.
    """
    U, D, V = torch.svd(x)
    return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())


def coral(source, target):
    """
    Perform CORAL (Correlation Alignment) to match the color distribution of the source to the target.
    """

    source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
    source_f_norm = (source_f - source_f_mean.expand_as(
        source_f)) / source_f_std.expand_as(source_f)
    source_f_cov_eye = \
        torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)

    target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
    target_f_norm = (target_f - target_f_mean.expand_as(
        target_f)) / target_f_std.expand_as(target_f)
    target_f_cov_eye = \
        torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)

    source_f_norm_transfer = torch.mm(
        _mat_sqrt(target_f_cov_eye),
        torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
                 source_f_norm)
    )

    source_f_transfer = (source_f_norm_transfer *
                         target_f_std.expand_as(source_f_norm) +
                         target_f_mean.expand_as(source_f_norm))

    return source_f_transfer.view(source.size())


def style_transfer(vgg, decoder, content, style, alpha):
    """
    Perform neural style transfer using AdaIN.
    """
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)
    style_f = vgg(style)
    feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)


def process_images(net, content_bytes, style_bytes, alpha, preserve_colors=False):
    """Perform style transfer on image bytes"""
    content = load_image(content_bytes)
    style = load_image(style_bytes)

    # Apply color preservation if needed
    if preserve_colors:
        style = coral(style, content)

    # Move to device and add batch dimension
    device = next(net.parameters()).device
    content = content.to(device).unsqueeze(0)
    style = style.to(device).unsqueeze(0)

    # Perform style transfer
    with torch.no_grad():
        output = style_transfer(
            net.encode,
            net.decoder,
            content,
            style,
            alpha=alpha
        )

    # Convert to PIL image
    output = output.clamp(0, 1)
    return transforms.ToPILImage()(output.squeeze(0).cpu())




In [11]:
# adain_net


class Decoder(nn.Module):
    """
    Decoder network used to reconstruct an image from AdaIN features.
    """
    def __init__(self):
        super(Decoder, self).__init__()
        self.model = nn.Sequential(
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(512, 256, (3, 3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 256, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 256, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 256, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 128, (3, 3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(128, 128, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(128, 64, (3, 3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(64, 64, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(64, 3, (3, 3)),
        )

    def forward(self, x):
        return self.model(x)


class VGG(nn.Module):
    """
    Modified VGG-19 encoder used to extract content and style features.
    Includes convolutional layers up to relu4_1. Extra layers are present but not used.
    """
    def __init__(self):
        super(VGG, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 3, (1, 1)),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(3, 64, (3, 3)),
            nn.ReLU(),  # relu1-1
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(64, 64, (3, 3)),
            nn.ReLU(),  # relu1-2
            nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(64, 128, (3, 3)),
            nn.ReLU(),  # relu2-1
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(128, 128, (3, 3)),
            nn.ReLU(),  # relu2-2
            nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(128, 256, (3, 3)),
            nn.ReLU(),  # relu3-1
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 256, (3, 3)),
            nn.ReLU(),  # relu3-2
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 256, (3, 3)),
            nn.ReLU(),  # relu3-3
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 256, (3, 3)),
            nn.ReLU(),  # relu3-4
            nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 512, (3, 3)),
            nn.ReLU(),  # relu4-1, this is the last layer used
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(512, 512, (3, 3)),
            nn.ReLU(),  # relu4-2
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(512, 512, (3, 3)),
            nn.ReLU(),  # relu4-3
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(512, 512, (3, 3)),
            nn.ReLU(),  # relu4-4
            nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(512, 512, (3, 3)),
            nn.ReLU(),  # relu5-1
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(512, 512, (3, 3)),
            nn.ReLU(),  # relu5-2
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(512, 512, (3, 3)),
            nn.ReLU(),  # relu5-3
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(512, 512, (3, 3)),
            nn.ReLU()  # relu5-4
        )

    def forward(self, x):
        return self.model(x)


class Net(nn.Module):
    """
    Style transfer network combining a fixed VGG encoder and a trainable decoder.
    """
    def __init__(self, encoder, decoder):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
        self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
        self.decoder = decoder
        self.mse_loss = nn.MSELoss()

        # fix the encoder
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
    def encode_with_intermediate(self, input):
        """
        Extract intermediate features (relu1_1 to relu4_1) from the input image.

        Returns:
            List of feature maps at different VGG depths.
        """
        results = [input]
        for i in range(4):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    # extract relu4_1 from input image
    def encode(self, input):
        """
        Encode input image to relu4_1 feature map using the VGG encoder.

        Returns:
            Feature map after relu4_1.
        """
        for i in range(4):
            input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
        return input

    def calc_content_loss(self, input, target):
        """
        Compute content loss as MSE between generated and target feature maps.

        Returns:
            Scalar content loss.
        """
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        return self.mse_loss(input, target)

    def calc_style_loss(self, input, target):
        """
        Compute style loss as the sum of MSE between mean and std of input and target.

        Returns:
            Scalar style loss.
        """
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return (self.mse_loss(input_mean, target_mean) +
                self.mse_loss(input_std, target_std))

    def forward(self, content, style, alpha=1.0):
        """
        Perform forward pass of the style transfer network.

        Args:
            content: content image tensor
            style: style image tensor
            alpha: interpolation factor between content and style features (0 to 1)

        Returns:
            Content loss and total style loss
        """
        assert 0 <= alpha <= 1
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)
        t = adaptive_instance_normalization(content_feat, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feat

        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)

        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s


In [12]:
def init_model():
    """Initialize and load style transfer model"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using GPU:", torch.cuda.is_available())
    print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")
    decoder = Decoder()
    vgg = VGG()

    # Load weights
    decoder.model.load_state_dict(torch.load('model_weights/decoder.pth', map_location=device))
    vgg.model.load_state_dict(torch.load('model_weights/vgg_normalised.pth', map_location=device))

    # Configure models
    vgg = nn.Sequential(*list(vgg.model.children())[:31])
    vgg.to(device).eval()
    decoder.to(device).eval()

    decoder_picasso = Decoder()
    decoder_van_gogh = Decoder()
    decoder_monet = Decoder()

    # Load weights for fine-tuned models
    decoder_picasso.load_state_dict(torch.load('model_weights/decoder_picasso.pth', map_location=device))
    decoder_van_gogh.load_state_dict(torch.load('model_weights/decoder_van_gogh.pth', map_location=device))
    decoder_monet.load_state_dict(torch.load('model_weights/decoder_monet.pth', map_location=device))

    decoder_picasso.to(device).eval()
    decoder_van_gogh.to(device).eval()
    decoder_monet.to(device).eval()

    return (Net(vgg, decoder).to(device).eval(), Net(vgg, decoder_picasso).to(device).eval(),
            Net(vgg, decoder_van_gogh).to(device).eval(), Net(vgg, decoder_monet).to(device).eval())


In [None]:
USER_DATA_FILE = "user_data/user_preferences.json"
USER_DATA_DIR = 'user_data'


def load_user_data():
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(USER_DATA_FILE), exist_ok=True)
    
    # Create file with empty dict if it doesn't exist
    if not os.path.exists(USER_DATA_FILE):
        with open(USER_DATA_FILE, "w", encoding="utf-8") as f:
            json.dump({}, f)
    
    # Read and return the data
    with open(USER_DATA_FILE, "r", encoding="utf-8") as f:
        return json.load(f)


def save_user_data(data):
    with open(USER_DATA_FILE, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def get_user_settings(user_data, user_id):
    return user_data.get(str(user_id), {})


def update_user_settings(user_data, user_id, updates: dict):
    uid = str(user_id)
    if uid not in user_data:
        user_data[uid] = {}
    user_data[uid].update(updates)
    save_user_data(user_data)


def save_user_images(user_id: str, content: bytes, style: bytes, output: BytesIO):
    """
    Saves content, style, and output images in a user-specific timestamped folder.
    """
    # Create user directory and timestamped subfolder
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    user_dir = os.path.join(USER_DATA_DIR, user_id, f"result_{timestamp}")
    os.makedirs(user_dir, exist_ok=True)

    # Save content image
    with open(os.path.join(user_dir, "content.jpg"), "wb") as f:
        f.write(content)

    # Save style image
    with open(os.path.join(user_dir, "style.jpg"), "wb") as f:
        f.write(style)

    # Save output image
    output_path = os.path.join(user_dir, "output.jpg")
    with open(output_path, "wb") as f:
        f.write(output.getbuffer())

    print(f"Saved images for user {user_id} in {user_dir}")

In [14]:
MESSAGES = {
    "en": {
        "welcome": """
👋 Welcome to the Fast Style Transfer Bot!

✨ Choose one of the image generation options:

1️⃣  **Style Transfer** 🎨
Apply a style from any image to your content photo while keeping the original composition.

2️⃣  **Color-Preserving** 🌈
Keep your original photo's colors while applying the style's textures and patterns.

3️⃣  **Select a Style** 🖼️
Select a famous painting style (Van Gogh, Monet, Picasso) and apply it to your photo.

⚙️ You can also adjust the strength of the style transfer using the **"Set alpha"** command.
Choose a value between 0 and 1 — higher values mean stronger stylization.

🌐 To change the bot's language, use the **"Language"** command and select your preferred language.
""",
        "standard_instructions": """
📌 Please follow these steps:

1️⃣ Send the *content* image first
2️⃣ Then send the *style* image

💡 You can also send both images in one message!
""",
        "content_received": "✅ Content image received! Now send the style image.",
        "style_received": "✅ Style image received! Starting style transfer...",
        "processing": "🔄 Performing style transfer...",
        "success": "🎨 Style transfer complete!",
        "error": "⚠️ An error occurred during processing. Please try again.",
        "mode_not_selected": "❌ Please first select a style transfer mode from the menu.",
        "invalid_option": "❌ Please choose one of the available options.",
        "alpha_prompt": "🔧 Please enter a value for alpha (between 0 and 1):",
        "alpha_set": "✅ Alpha set to {alpha}.",
        "alpha_invalid": "❌ Invalid value. Please enter a number between 0 and 1.",
        "language_prompt": "🌍 Please choose your language:",
        "language_set": "✅ Language set to English.",
        "language_invalid": "❌ Invalid choice, please select a language from the keyboard.",
        "choose_style_prompt": "🖼️ Please select a style from the list below:",
        "style_selected": "🎨 Style {style} selected! Now please send the content image.",
        "choose_option": "📋 Please select an option from the menu."
    },

    "ru": {
        "welcome": """
👋 Добро пожаловать в бота для переноса стиля!

✨ Выберите один из вариантов:

1️⃣  **Перенос стиля** 🎨
Примените стиль из любого изображения к вашему фото, сохранив оригинальную композицию.

2️⃣  **Сохранение цветов** 🌈
Сохраните оригинальные цвета вашего фото, применяя только текстуры и паттерны стиля.

3️⃣  **Выбрать готовый стиль** 🖼️
Выберите стиль известного художника (Ван Гог, Моне, Пикассо) и примените его к своей фотографии.

⚙️ Вы также можете настроить силу переноса стиля с помощью команды **"Установить alpha"**.
Укажите значение от 0 до 1 — чем больше значение, тем сильнее эффект стилизации.

🌐 Чтобы изменить язык бота, используйте команду **"Язык"** и выберите предпочитаемый язык.
""",
        "standard_instructions": """
📌 Инструкция:

1️⃣ Сначала отправьте *контентное* изображение
2️⃣ Затем отправьте *стилевое* изображение

💡 Можно отправить оба изображения одним сообщением!
""",
        "content_received": "✅ Контентное изображение получено! Теперь отправьте стилевое.",
        "style_received": "✅ Стилевое изображение получено! Начинаю перенос стиля...",
        "processing": "🔄 Выполняю перенос стиля...",
        "success": "🎨 Готово! Перенос стиля выполнен.",
        "error": "⚠️ Произошла ошибка при обработке. Пожалуйста, попробуйте ещё раз.",
        "mode_not_selected": "❌ Сначала выберите режим переноса стиля.",
        "invalid_option": "❌ Пожалуйста, выберите один из доступных вариантов.",
        "alpha_prompt": "🔧 Пожалуйста, введите значение alpha (от 0 до 1):",
        "alpha_set": "✅ Значение alpha установлено на {alpha}.",
        "alpha_invalid": "❌ Неверное значение. Введите число от 0 до 1.",
        "language_prompt": "🌍 Пожалуйста, выберите язык:",
        "language_set": "✅ Язык установлен на русский.",
        "language_invalid": "❌ Неверный выбор, пожалуйста, выберите язык с клавиатуры.",
        "choose_style_prompt": "🖼️ Пожалуйста, выберите стиль из списка ниже:",
        "style_selected": "🎨 Стиль {style} выбран! Теперь отправьте фото для переноса стиля.",
        "choose_option": "📋 Выберите опцию из списка."
    }
}


def get_message(key, lang='en'):
    """Retrieve a localized message by key and language code."""
    return MESSAGES.get(lang, {}).get(key, MESSAGES['en'].get(key, ""))


In [15]:
app = None  # Global
# Preload user data and models
user_data_store = load_user_data()

# Constants
KEYBOARD_OPTIONS = {
    'en': [["Style Transfer", "Color-Preserving", "Select a Style"],
           ["Set alpha", "Language"]],
    'ru': [["Перенос стиля", "Сохранение цветов", "Выбрать готовый стиль"],
           ["Установить alpha", "Язык"]]
}

PRE_SAVED_STYLES = {
    "Van Gogh": "test_images/style/van_gogh.jpg",
    "Monet": "test_images/style/monet.jpg",
    "Picasso": "test_images/style/picasso.jpg"
}


# --- Keyboard Helpers ---

def get_language_keyboard():
    return ReplyKeyboardMarkup([["English", "Русский"]],
                               one_time_keyboard=True, resize_keyboard=True)


def get_styles_keyboard(lang='en'):
    styles = list(PRE_SAVED_STYLES.keys())
    keyboard = [styles[i:i+2] for i in range(0, len(styles), 2)]
    return ReplyKeyboardMarkup(keyboard, one_time_keyboard=True,
                               resize_keyboard=True)


def get_keyboard(lang='en'):
    return ReplyKeyboardMarkup(
        KEYBOARD_OPTIONS.get(lang, KEYBOARD_OPTIONS['en']),
        one_time_keyboard=True,
        resize_keyboard=True
    )

# --- Command Handlers ---


async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
    """Handles /start command"""
    user_id = str(update.effective_user.id)
    settings = get_user_settings(user_data_store, user_id)

    lang = settings.get('lang') if settings else (
        'ru' if update.effective_user.language_code == 'ru' else 'en'
    )

    if not settings:
        update_user_settings(user_data_store, user_id, {'lang': lang})

    context.user_data['lang'] = lang

    await update.message.reply_text(
        get_message("welcome", lang),
        reply_markup=get_keyboard(lang),
        parse_mode="Markdown"
    )


# --- Message Handlers ---

async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
    """Handles all non-command text messages"""
    user_id = str(update.effective_user.id)
    user_data = context.user_data
    lang = (user_data.get('lang') or
            get_user_settings(user_data_store, user_id).get('lang', 'en'))
    user_data['lang'] = lang
    text = update.message.text
    keyboard = KEYBOARD_OPTIONS.get(lang, KEYBOARD_OPTIONS['en'])

    # Alpha input mode
    if user_data.get("awaiting_alpha"):
        try:
            alpha = float(text)
            if not (0 <= alpha <= 1):
                raise ValueError
            user_data["awaiting_alpha"] = False
            update_user_settings(user_data_store, user_id, {"alpha": alpha})
            await update.message.reply_text(get_message("alpha_set", lang).format(alpha=alpha))
        except ValueError:
            await update.message.reply_text(get_message("alpha_invalid", lang))
        return

    # Language selection mode
    if user_data.get("awaiting_language"):
        lang_map = {
            "english": "en", "английский": "en", "en": "en",
            "русский": "ru", "russian": "ru", "ru": "ru"
        }
        selected = lang_map.get(text.lower())
        if selected:
            lang = selected
            user_data["awaiting_language"] = False
            user_data["lang"] = lang
            update_user_settings(user_data_store, user_id, {"lang": lang})
            await update.message.reply_text(get_message("language_set", lang), reply_markup=get_keyboard(lang))
        else:
            await update.message.reply_text(get_message("language_invalid", lang), reply_markup=get_language_keyboard())
        return

    # Main keyboard options
    if text == keyboard[0][0]:  # Style Transfer
        user_data['mode'] = 'standard'
        await update.message.reply_text(get_message("standard_instructions", lang), parse_mode="Markdown")
    elif text == keyboard[0][1]:  # Color-Preserving
        user_data['mode'] = 'color_preserving'
        await update.message.reply_text(get_message("standard_instructions", lang), parse_mode="Markdown")
    elif text == keyboard[0][2]:  # Select Style
        await update.message.reply_text(get_message("choose_style_prompt", lang),
                                        reply_markup=get_styles_keyboard(lang))
    elif text in PRE_SAVED_STYLES:  # Pre-saved style selected
        user_data['selected_style_path'] = PRE_SAVED_STYLES[text]
        user_data['mode'] = 'selected_style'
        await update.message.reply_text(get_message("style_selected", lang).format(style=text), parse_mode="Markdown")
    elif text == keyboard[1][0]:  # Set alpha
        user_data["awaiting_alpha"] = True
        await update.message.reply_text(get_message("alpha_prompt", lang))
    elif text == keyboard[1][1]:  # Change language
        user_data["awaiting_language"] = True
        await update.message.reply_text(get_message("language_prompt", lang), reply_markup=get_language_keyboard())
    else:
        await update.message.reply_text(get_message("invalid_option", lang))


async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE):
    """Handles incoming photo messages"""
    user_data = context.user_data
    user_id = str(update.effective_user.id)
    lang = user_data.get('lang') or get_user_settings(user_data_store, user_id).get('lang', 'en')
    user_data['lang'] = lang

    if 'mode' not in user_data:
        await update.message.reply_text(get_message("mode_not_selected", lang))
        return

    photo = await update.message.photo[-1].get_file()
    byte_img = await photo.download_as_bytearray()

    if 'content_image' not in user_data:
        user_data['content_image'] = byte_img
        user_data['media_group_id'] = update.message.media_group_id

        if user_data.get('mode') == 'selected_style':
            try:
                with open(user_data['selected_style_path'], 'rb') as f:
                    user_data['style_image'] = f.read()
                await update.message.reply_text(get_message("processing", lang))
                await perform_style_transfer(update, context)
                await update.message.reply_text(get_message("choose_option", lang), reply_markup=get_keyboard(lang))
            except Exception as e:
                print(f"Error reading style file: {e}")
                await update.message.reply_text(get_message("style_not_selected", lang))
        elif not update.message.media_group_id:
            await update.message.reply_text(get_message("content_received", lang))
    else:
        user_data['style_image'] = byte_img
        await update.message.reply_text(get_message("processing", lang))
        await perform_style_transfer(update, context)
        await update.message.reply_text(get_message("choose_option", lang), reply_markup=get_keyboard(lang))


# --- Style Transfer Core ---

async def perform_style_transfer(update: Update, context: ContextTypes.DEFAULT_TYPE):
    """Executes style transfer using the selected mode"""
    user_data = context.user_data
    user_id = str(update.effective_user.id)
    lang = user_data.get('lang') or get_user_settings(user_data_store, user_id).get('lang', 'en')
    user_data['lang'] = lang

    try:
        preserve_colors = (user_data.get('mode') == 'color_preserving')
        alpha = get_user_settings(user_data_store, user_id).get("alpha", 1.0)

        # Default to general style net
        style_net = context.bot_data['net']

        # If user selected a predefined style, use the corresponding model
        if user_data.get('mode') == 'selected_style':
            selected_path = user_data.get('selected_style_path', '').lower()
            if 'picasso' in selected_path:
                style_net = context.bot_data.get('net_picasso', style_net)
            elif 'van_gogh' in selected_path:
                style_net = context.bot_data.get('net_van_gogh', style_net)
            elif 'monet' in selected_path:
                style_net = context.bot_data.get('net_monet', style_net)

        result_image = process_images(
            net=style_net,
            content_bytes=user_data['content_image'],
            style_bytes=user_data['style_image'],
            preserve_colors=preserve_colors,
            alpha=alpha
        )

        img_bytes = BytesIO()
        result_image.save(img_bytes, format='JPEG')
        img_bytes.seek(0)
        # Save user images
        save_user_images(
            user_id=user_id,
            content=user_data['content_image'],
            style=user_data['style_image'],
            output=img_bytes
        )
        img_bytes.seek(0)
        await update.message.reply_photo(
            photo=img_bytes,
            caption=get_message("success", lang)
        )
    except Exception as e:
        print(f"Error: {e}")
        await update.message.reply_text(get_message("error", lang))
    finally:
        user_data.pop('content_image', None)
        user_data.pop('style_image', None)


# --- Entry Point ---

async def main():
    """Starts the Telegram bot"""
    global app
    print("Initializing style transfer models...")
    net, net_picasso, net_van_gogh, net_monet = init_model()

    with open('config.json') as f:
        config = json.load(f)

    app = ApplicationBuilder().token(config['telegram_token']).build()
    app.bot_data['net'] = net
    app.bot_data['net_picasso'] = net_picasso
    app.bot_data['net_van_gogh'] = net_van_gogh
    app.bot_data['net_monet'] = net_monet

    app.add_handler(CommandHandler("start", start))
    app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message))
    app.add_handler(MessageHandler(filters.PHOTO, handle_image))

    print("Bot is running...")

    await app.initialize()
    await app.start()
    await app.updater.start_polling()


In [16]:
# To start the bot
await main()

Initializing style transfer models...
Using GPU: False
GPU name: None
Bot is running...


In [17]:
# To stop the bot
# await app.updater.stop()
# await app.stop()
# await app.shutdown()