In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage

In [None]:
class UNetEncoder(nn.Module):
    """
    Энкодер для автокодировщика на основе U-Net.
    """
    def __init__(self, input_channels, features=64):
        super(UNetEncoder, self).__init__()
        self.enc1 = self.contract_block(input_channels, features)
        self.enc2 = self.contract_block(features, features * 2)
        self.enc3 = self.contract_block(features * 2, features * 4)
        self.enc4 = self.contract_block(features * 4, features * 8)

    def contract_block(self, in_channels, out_channels, kernel_size=3):
        contract = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        return contract

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3)
        return x4

In [None]:
class UNetDecoder(nn.Module):
    """
    Декодер для автокодировщика на основе U-Net.
    """
    def __init__(self, output_channels, features=64):
        super(UNetDecoder, self).__init__()
        self.dec1 = self.expand_block(features * 8, features * 4)
        self.dec2 = self.expand_block(features * 4, features * 2)
        self.dec3 = self.expand_block(features * 2, features)
        self.final = nn.Conv2d(features, output_channels, kernel_size=1)

    def expand_block(self, in_channels, out_channels, kernel_size=3):
        expand = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2)
        )
        return expand

    def forward(self, x):
        x = self.dec1(x)
        x = self.dec2(x)
        x = self.dec3(x)
        x = self.final(x)
        return x

In [None]:
class Autoencoder(nn.Module):
    """
    Автокодировщик, сочетающий энкодер и декодер U-Net.
    """
    def __init__(self, input_channels, output_channels):
        super(Autoencoder, self).__init__()
        self.encoder = UNetEncoder(input_channels)
        self.decoder = UNetDecoder(output_channels)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
# Предобработка данных
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((640, 640)),
])

# Загрузка данных
train_data = datasets.ImageFolder(root='train', transform=transform)
val_data = datasets.ImageFolder(root='val', transform=transform)
test_data = datasets.ImageFolder(root='test', transform=transform)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

# Инициализация модели, функции потерь и оптимизатора
model = Autoencoder(input_channels=3, output_channels=3)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Обучение модели
def train_model(model, criterion, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for data in train_loader:
            inputs, _ = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, inputs)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)
        
        train_loss = train_loss / len(train_loader.dataset)

        # Валидация
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for data in val_loader:
                inputs, _ = data
                outputs = model(inputs)
                loss = criterion(outputs, inputs)
                val_loss += loss.item() * inputs.size(0)

        val_loss = val_loss / len(val_loader.dataset)

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

train_model(model, criterion, optimizer, num_epochs=25)

In [None]:
# Функция для инференса и сохранения изображений
def infer_and_save(input_dir, output_dir, model):
    # Создаем директорию для вывода, если она не существует
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    model.eval()  # Переключаем модель в режим инференса

    for image_name in os.listdir(input_dir):
        # Загружаем и обрабатываем изображение
        image_path = os.path.join(input_dir, image_name)
        image = Image.open(image_path)
        input_image = ToTensor()(image).unsqueeze(0)  # Добавляем размерность batch

        # Применяем модель
        with torch.no_grad():
            output_image = model(input_image)

        # Преобразуем результат в PIL Image и сохраняем
        output_image = ToPILImage()(output_image.squeeze(0))
        output_image.save(os.path.join(output_dir, image_name))

# Загрузка обученной модели (предполагается, что модель уже сохранена)
model = Autoencoder(input_channels=3, output_channels=3)
model.load_state_dict(torch.load('model.pth'))

# Инференс и сохранение результатов
infer_and_save('input_images', 'output_images', model)
