In [1]:
import os
import math
import cv2
import torch
import numpy as np
from PIL import Image
from io import BytesIO
import requests
import torch.nn.functional as F
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset
from datasets import Dataset as HFDataset

from transformers import Trainer, TrainingArguments, AutoImageProcessor
from datasets import load_dataset

rescale = 4
PATCH_SIZE = 256
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
CUSTOM_MEAN = [0.1778, 0.2696, 0.1686]
CUSTOM_STD = [0.0942, 0.0915, 0.0762]

INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.5 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations
  from .autonotebook import tqdm as notebook_tqdm
INFO:datasets:PyTorch version 2.6.0+cu124 available.


In [2]:
dataset = load_dataset(
    "my_superres_dataset.py", # путь к файлу с классом датасета
    data_dir="./data_overlapped", # корень, где лежат папки data/ и data_x4/
    # split=["train", "validation"], # какие сплиты грузить
    trust_remote_code=True,
)
train_ds = dataset["train"]
valid_ds = dataset["validation"]

Generating train split: 8658 examples [00:01, 7152.40 examples/s]
Generating validation split: 918 examples [00:00, 8015.26 examples/s]


In [3]:
train_ds

Dataset({
    features: ['hr_image', 'lr_image', 'filename'],
    num_rows: 8658
})

In [4]:
def compute_metrics(eval_pred):
    """
    Функция для вычисления метрик на валидации (PSNR, MSE и пр.).
    eval_pred – это (logits, labels),
    где logits и labels – тензоры [batch_size, C, H, W].
    """
    logits, labels = eval_pred
    mse = ((logits - labels) ** 2).mean()
    psnr = 10 * math.log10(1 / mse.item()) # при предположении нормализации [0,1]
    return {
    "mse": mse.item(),
    "psnr": psnr,
    }

# def data_collator(features):
#     """
#     Функция для формирования батча.
#     Приводит входные данные к нужному формату для Trainer.
#     """
#     lr_batch = []
#     hr_batch = []
#     for f in features:
#         lr_batch.append(f["lr"])
#         hr_batch.append(f["hr"])

#     lr_batch = torch.stack(lr_batch)
#     hr_batch = torch.stack(hr_batch)
    
#     return {
#         "pixel_values": lr_batch.float(),  # входные данные (x_LR)
#         "labels": hr_batch.float(),        # метка (x_HR)
#     }


In [5]:
def transforms(examples):
    transformA = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        
        # При необходимости можно добавить и другие аугментации, 
        # например, яркость/контраст, шум и т.д.
        
        # Простейшая нормализация, если нужно
        # A.Normalize(mean=CUSTOM_MEAN, std=CUSTOM_STD),  
        ToTensorV2()
    ],
    additional_targets={
    "lr_image": "image"  # Говорим, что "lr_image" — это тоже "image"
    },
    is_check_shapes=False,
    )
    transformed_images, transformed_labels = [], []
    for image, lr_image in zip(examples["hr_image"], examples["lr_image"]):
        image, lr_image = np.array(image), np.array(lr_image)
        transformed = transformA(image=image, lr_image=lr_image)
        transformed_labels.append(transformed["image"])
        transformed_images.append(transformed["lr_image"])
    examples_res = dict()
    examples_res["pixel_values"] = transformed_images
    examples_res["labels"] = transformed_labels
    return examples_res

def transforms_val(examples):
    transformA = A.Compose([
        # A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.5),
        # A.RandomRotate90(p=0.5),
        
        # При необходимости можно добавить и другие аугментации, 
        # например, яркость/контраст, шум и т.д.
        
        # Простейшая нормализация, если нужно
        # A.Normalize(mean=CUSTOM_MEAN, std=CUSTOM_STD),  
        A.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.]), 
        ToTensorV2()
    ],
    additional_targets={
    "lr_image": "image"  # Говорим, что "lr_image" — это тоже "image"
    },
    is_check_shapes=False,
    )
    transformed_images, transformed_labels = [], []
    for image, lr_image in zip(examples["hr_image"], examples["lr_image"]):
        image, lr_image = np.array(image), np.array(lr_image)
        transformed = transformA(image=image, lr_image=lr_image)
        transformed_labels.append(transformed["image"])
        transformed_images.append(transformed["lr_image"])
    examples_res = dict()
    examples_res["pixel_values"] = transformed_images
    examples_res["labels"] = transformed_labels
    return examples_res

from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
image_processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x4-64", image_mean=CUSTOM_MEAN, image_std=CUSTOM_STD)

def train_transforms(example_batch):
    example_batch = transforms(example_batch)
    # images = [jitter(x) for x in example_batch["lr_image"]]
    images = [x for x in example_batch["lr_image"]]
    labels = [x for x in example_batch["hr_image"]]
    # inputs = image_processor([images, labels], return_tensors="pt")
    return inputs


def val_transforms(example_batch):
    example_batch = transforms_val(example_batch)
    # images = [x for x in example_batch["lr_image"]]
    # labels = [x for x in example_batch["hr_image"]]
    # print(len(images))
    # print(example_batch["hr_image"])
    # inputs = image_processor(images + labels, return_tensors="pt")
    # example_batch.pop("labels")
    return example_batch


train_ds.set_transform(train_transforms)
valid_ds.set_transform(val_transforms)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [6]:
from sr.TTST.model_archs.TTST_arc import TTST

In [7]:
from transformers.modeling_outputs import ImageSuperResolutionOutput
from typing import Optional, Union, Tuple

class ModelWrapper(torch.nn.Module):
    def __init__(self, model, patch_size=PATCH_SIZE, accepts_loss_kwargs=False):
        super().__init__()
        self.model = model
        self.accepts_loss_kwargs = accepts_loss_kwargs
        self.patch_size = patch_size
        self.loss = torch.nn.L1Loss()

    # def forward(self, *args, **kwargs):
    #     outputs = self.main(*args, **kwargs)
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, ImageSuperResolutionOutput]:
        outputs = self.model(pixel_values)
        loss, logits_val = self.loss(outputs, labels), outputs
        return ImageSuperResolutionOutput(loss=loss, reconstruction=logits_val)

In [None]:
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_WATCH"] = "all"
os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_TOKEN"] = "<your-token>"
# os.environ["WANDB_RESUME"] = "must"
# os.environ["WANDB_RUN_ID"] = "h4h4s471"

In [9]:

gpus_list = [0]
model_dir = "sr/TTST/saved_models/ttst_4x.pth"
model = TTST()
model = torch.nn.DataParallel(model, device_ids=gpus_list)
model = model.cuda()
model.load_state_dict(torch.load(model_dir))

modelwr = ModelWrapper(model).cuda()

print(f"num_model_parameters:{sum(par.numel() for par in modelwr.parameters())}")

output_dir = "outputs_sr/TTST"
    
# Настраиваем Trainer
training_args = TrainingArguments(
    output_dir=os.path.join(output_dir, "checkpoints"),
    learning_rate=2e-4,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=1000,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    save_total_limit=2,
    fp16=True,  # При наличии поддерживаемого GPU
    remove_unused_columns=False,
)

trainer = Trainer(
    model=modelwr,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    compute_metrics=compute_metrics,
)

# Запуск обучения
# trainer.train()
    
    # # Сохраняем итоговую модель
    # trainer.save_model("./models/swinir_sr_x4_final")
    # print("Модель успешно обучена и сохранена!")
    
    # # Пример инференса на каком-нибудь LR-изображении
    # # Для Swin2SR можно использовать Swin2SRImageProcessor,
    # # но ниже пример универсальный с использованием уже 
    # # готового processor из super_image.
    # print("Делаем пример инференса на тестовом изображении...")
    
    # image_processor = Swin2SRImageProcessor(scale=4)
    
    # url = "https://huggingface.co/datasets/super_image/test_images/resolve/main/butterfly.png"
    # response = requests.get(url)
    # lr_image_pil = Image.open(BytesIO(response.content)).convert("RGB")
    
    # # Преобразуем LR-картинку в модельный формат
    # input_tensor = image_processor(lr_image_pil, return_tensors="pt").pixel_values
    
    # # Получаем супер-разрешённую картинку
    # with torch.no_grad():
    #     preds = model.generate(input_tensor)
    
    # sr_images = image_processor.postprocess(preds)  # Список PIL.Image
    # sr_image = sr_images[0]
    # sr_image.save("sr_result.png")
    
    # print("Супер-разрешённое изображение сохранено как sr_result.png")


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


num_model_parameters:18366555


In [10]:
train_ds.reset_format()
train_ds.set_transform(val_transforms)
eval_dataset = train_ds

eval_folder = "outputs_sr/data_overlapped_sr_TTST/train_prep"
pred_res = trainer.predict(eval_dataset)

In [11]:
pred_res.predictions.shape

(8658, 3, 256, 256)

In [12]:

pred_labels = pred_res.predictions  # [N, H, W]
# upvalue
pred_labels = (pred_labels * 255).astype(np.uint8)
# Создаём поддиректорию в зависимости от шага
# eval_folder = os.path.join(output_dir, "train_predict")
os.makedirs(eval_folder, exist_ok=True)

# Сохраняем каждую маску как png
try:
    eval_dataset.reset_format()
    for i, label_map in tqdm(enumerate(pred_labels)):
        # img = np.array(eval_dataset[i]["lr_image"])
        # torch_img = torch.from_numpy(img.astype(float)).moveaxis((0, 1, 2), (1, 2, 0)).unsqueeze(0)
        # aug_img = (F.interpolate(torch_img, size=(img.shape[0] * rescale, img.shape[1] * rescale), mode="bilinear", align_corners=False)).numpy().astype(np.uint8).squeeze(0)
        # aug_img = np.moveaxis(aug_img, (0, 1, 2), (2, 0, 1))
        label_img = Image.fromarray(np.moveaxis(label_map, (0, 1, 2), (2, 0, 1)))
        filename = eval_dataset[i]['filename']
        img_name = filename.rsplit("_", 2)[0]
        label_img.save(os.path.join(eval_folder, img_name, "img", filename))
except RuntimeError as e:
    eval_dataset.set_transform(val_transforms)
    raise e
eval_dataset.set_transform(val_transforms)

print(f"Saved evaluation predictions to: {eval_folder}")

8658it [05:40, 25.44it/s]

Saved evaluation predictions to: outputs_sr/data_overlapped_sr_TTST/train_prep



