In [22]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from skimage import color
import random


In [23]:

class ColorizationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((64, 64))  # Адаптивный пуллинг
        )
        
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x





In [40]:
class Trainer:
    def __init__(self, model, device):
        self.model = model.to(device)
        self.device = device
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.RandomResizedCrop(256, scale=(0.8, 1.0))
        ])

    def fit(self, train_loader, val_loader, learning_rate=1e-4, epochs=100):
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)

        for epoch in range(epochs):
            self.model.train()
            train_loss = 0.0
            for L, ab in train_loader:
                L = L.to(self.device)
                ab = ab.to(self.device)
                
                optimizer.zero_grad()
                pred_ab = self.model(L)
                loss = criterion(pred_ab, ab)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
            
            self.model.eval()
            val_loss = 0.0
            if val_loader is not None:
                with torch.no_grad():
                    for L, ab in val_loader:
                        L = L.to(self.device)
                        ab = ab.to(self.device)
                        pred_ab = self.model(L)
                        val_loss += criterion(pred_ab, ab).item()
            
            print(f'Epoch {epoch+1}/{epochs}')
            print(f'Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: 0')
        
        torch.save(self.model.state_dict(), 'colorization_model.pth')
    def predict(self, image_path):
        self.model.eval()
        
        # Загружаем изображение как RGB
        image = Image.open(image_path).convert("RGB")
        
        # Применяем детерминированные преобразования (не случайные!)
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor()
        ])
        
        image_tensor = transform(image)  # [3, H, W], range [0, 1]
        
        # Конвертируем в LAB
        img_np = image_tensor.permute(1, 2, 0).numpy()  # [H, W, 3]
        lab_img = color.rgb2lab(img_np)  # LAB in range L: [0, 100], a,b: [-128, 127]

        # Берем только L канал
        L = lab_img[:, :, 0]
        L_tensor = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(self.device)  # [1, 1, H, W]

        # Нормализуем L в диапазон [-1, 1] (как во время обучения)
        L_normalized = L_tensor / 50.0 - 1.0

        # Предсказываем ab
        with torch.no_grad():
            pred_ab = self.model(L_normalized).cpu()  # [1, 2, H, W]

        # Денормализуем ab
        pred_ab = pred_ab * 110.0  # ab range [-110, 110]

        # Восстанавливаем LAB
        L_restored = (L_normalized.cpu() + 1.0) * 50.0  # Back to [0, 100]
        lab_pred = torch.cat([L_restored, pred_ab], dim=1)  # [1, 3, H, W]
        lab_pred = lab_pred.squeeze().permute(1, 2, 0).numpy()

        # Обрезаем значения до допустимых диапазонов
        lab_pred[:, :, 0] = np.clip(lab_pred[:, :, 0], 0, 100)
        lab_pred[:, :, 1:] = np.clip(lab_pred[:, :, 1:], -128, 127)

        # Конвертируем обратно в RGB
        rgb_pred = color.lab2rgb(lab_pred)
        rgb_pred = (rgb_pred * 255).astype(np.uint8)

        return Image.fromarray(rgb_pred)

In [41]:

class ColorizationDataset(Dataset):
    def __init__(self, paths, output_size=(256, 256)):  # <-- Добавлен конструктор
        self.paths = paths
        self.output_size = output_size
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2)
        ])

    def __getitem__(self, idx):
        try:
            img = Image.open(self.paths[idx]).convert('RGB')
            
            # Адаптивное изменение размера
            if min(img.size) < 256:
                img = transforms.Resize(256)(img)
                
            # Случайное кадрирование
            i, j, h, w = transforms.RandomCrop.get_params(
                img, output_size=self.output_size
            )
            img = transforms.functional.crop(img, i, j, h, w)
            
            # Применение аугментаций
            img = self.transform(img)
            
            # Преобразование в LAB
            img_lab = color.rgb2lab(np.asarray(img))  # <-- Используем np.asarray
            img_lab = torch.from_numpy(img_lab).float()  # <-- Явное преобразование
            img_lab = img_lab.permute(2, 0, 1)  # (H, W, C) -> (C, H, W)
            
            L = img_lab[0:1, :, :] / 50.0 - 1.0
            ab = img_lab[1:3, :, :] / 110.0
            return L.float(), ab.float()
            
        except Exception as e:
            print(f"Ошибка загрузки {self.paths[idx]}: {e}")
            return None

    def __len__(self):
        return len(self.paths)

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используется устройство: {device}")

Используется устройство: cpu


In [43]:
data_root = "C:/Users/mr_bs/.cache/kagglehub/datasets/nelyg8002000/commercial-aircraft-dataset/versions/1/1_Liner TF"
train_paths = [os.path.join(data_root, fname) for fname in os.listdir(data_root)]

In [44]:
train_dataset = ColorizationDataset(train_paths[:100])
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [45]:
model = ColorizationModel()
trainer = Trainer(model, device)

In [46]:
trainer.fit(train_loader, val_loader=None, learning_rate=1e-4, epochs=3)  # Добавьте val_loader

Epoch 1/3
Train Loss: 0.0194 | Val Loss: 0
Epoch 2/3
Train Loss: 0.0129 | Val Loss: 0
Epoch 3/3
Train Loss: 0.0116 | Val Loss: 0


In [None]:
example = ColorizationDataset(["TF_PLANE_198.jpg"])

In [49]:
colorized = trainer.predict("TF_PLANE_100.jpg")
colorized.show()