In [None]:
import os
import io
import base64
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm.auto import tqdm

# отключаем ограничение PIL для больших изображений
Image.MAX_IMAGE_PIXELS = None

# фиксируем сиды для воспроизводимости
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# пути и базовые настройки
DATA_ROOT = ...
TEST_ROOT = ...
MODEL_DIR = ...
MODEL_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = (1024, 1024)
BATCH_SIZE = 64
NUM_EPOCHS = 5
LR = 1e-4
RUN_TRAINING = True



In [None]:
class SegmentationDataset(Dataset):
    def __init__(
        self,
        data_root,
        size = (1024, 1024),
        mask_subfolder = 'gt',
        image_subfolder = 'im',
        image_format = '.jpg',
        mask_format = '.png',
        num_mask_channels = 1
    ):
        self.size = size

        self.data_root = Path(data_root)
        if not self.data_root.exists():
            raise ValueError("Instance images root doesn't exists.")

        self.mask_subfolder = mask_subfolder
        self.image_subfolder = image_subfolder
        self.image_format = image_format
        self.mask_format = mask_format
        self.num_mask_channels = num_mask_channels

        if image_format is None and mask_format is None:
            self.data = os.listdir(str(self.data_root / self.mask_subfolder))
        else:
            self.data = [i.rsplit('.')[0] for i in os.listdir(str(self.data_root / self.mask_subfolder)) if i.endswith(self.mask_format)]
        
        self._length = len(self.data)
        
        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]), # 0 - 1 to -1 - 1
            ]
        )
        
        if self.num_mask_channels == 3:
            self.mask_transforms = transforms.Compose(
                [
                    transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ]
            )
        else:
            self.mask_transforms = transforms.Compose(
                [
                    transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                    transforms.ToTensor(),
                ]
            )
    

    def __len__(self):
        return self._length

    
    def __getitem__(self, index):
        obj = self.data[index]
        
        mask = Image.open(self.data_root / self.mask_subfolder / f'{obj}{self.mask_format if self.mask_format is not None else ""}')
        if self.num_mask_channels == 3:
            mask = mask.convert('RGB')
        img = Image.open(self.data_root / self.image_subfolder / f'{obj}{self.image_format if self.image_format is not None else ""}').convert('RGB')

        mask = self.mask_transforms(mask)
        img = self.image_transforms(img)
        example = {
            "mask": mask,
            "img": img
        }

        return example

class TestImageDataset(Dataset):
    def __init__(self, root: Path, size=(1024, 1024)):
        self.root = Path(root)
        self.images = sorted([p for p in self.root.iterdir() if p.is_file()])
        self.transform = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        path = self.images[idx]
        img = Image.open(path).convert("RGB")
        return {
            "path": path.name,
            "img": self.transform(img),
        }


def get_dataloaders(train_root: Path, size=(1024, 1024), batch_size=4):
    train_dataset = SegmentationDataset(
        data_root=train_root,
        size=size,
        mask_subfolder="gt",
        image_subfolder="im",
        image_format=".jpg",
        mask_format=".png",
        num_mask_channels=1,
    )
    val_size = max(1, int(0.1 * len(train_dataset)))
    train_size = len(train_dataset) - val_size
    train_ds, val_ds = torch.utils.data.random_split(train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader



In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, base_ch=32):
        super().__init__()
        self.enc1 = ConvBlock(in_channels, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch * 2)
        self.enc3 = ConvBlock(base_ch * 2, base_ch * 4)
        self.enc4 = ConvBlock(base_ch * 4, base_ch * 8)

        self.pool = nn.MaxPool2d(2)
        self.bottleneck = ConvBlock(base_ch * 8, base_ch * 16)

        self.up4 = nn.ConvTranspose2d(base_ch * 16, base_ch * 8, kernel_size=2, stride=2)
        self.dec4 = ConvBlock(base_ch * 16, base_ch * 8)
        self.up3 = nn.ConvTranspose2d(base_ch * 8, base_ch * 4, kernel_size=2, stride=2)
        self.dec3 = ConvBlock(base_ch * 8, base_ch * 4)
        self.up2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(base_ch * 4, base_ch * 2)
        self.up1 = nn.ConvTranspose2d(base_ch * 2, base_ch, kernel_size=2, stride=2)
        self.dec1 = ConvBlock(base_ch * 2, base_ch)

        self.head = nn.Conv2d(base_ch, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        return self.head(d1)



In [None]:
def mse_metric(logits, target):
    probs = torch.sigmoid(logits)
    return F.mse_loss(probs, target)


def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    total_mse = 0.0
    for batch in tqdm(loader, desc="train", leave=False):
        imgs = batch["img"].to(device)
        masks = batch["mask"].to(device)

        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        total_mse += mse_metric(logits, masks).item() * imgs.size(0)
    n = len(loader.dataset)
    return total_loss / n, total_mse / n


def eval_epoch(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    total_mse = 0.0
    with torch.no_grad():
        for batch in tqdm(loader, desc="val", leave=False):
            imgs = batch["img"].to(device)
            masks = batch["mask"].to(device)
            logits = model(imgs)
            loss = criterion(logits, masks)
            total_loss += loss.item() * imgs.size(0)
            total_mse += mse_metric(logits, masks).item() * imgs.size(0)
    n = len(loader.dataset)
    return total_loss / n, total_mse / n



In [None]:
train_loader, val_loader = get_dataloaders(DATA_ROOT, size=IMG_SIZE, batch_size=BATCH_SIZE)

model = UNet(in_channels=3, out_channels=1, base_ch=32).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

best_val_mse = float("inf")
best_path = MODEL_DIR / "unet_dis_best.pth"

if RUN_TRAINING:
    for epoch in range(1, NUM_EPOCHS + 1):
        train_loss, train_mse = train_epoch(model, train_loader, optimizer, criterion)
        val_loss, val_mse = eval_epoch(model, val_loader, criterion)
        print(f"Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f} val_mse={val_mse:.4f}")
        if val_mse < best_val_mse:
            best_val_mse = val_mse
            torch.save({"model_state": model.state_dict()}, best_path)
else:
    if best_path.exists():
        model.load_state_dict(torch.load(best_path)["model_state"])



In [None]:
model.eval()

test_dataset = TestImageDataset(TEST_ROOT, size=(1024, 1024))
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

rows = []
with torch.no_grad():
    for batch in tqdm(test_loader, desc="test", leave=False):
        imgs = batch["img"].to(device)
        names = batch["path"]
        logits = model(imgs)
        probs = torch.sigmoid(logits)
        mask = (probs[0, 0].cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
        print(mask.shape)
        pil_mask = Image.fromarray(mask, mode="L")
        buf = io.BytesIO()
        pil_mask.save(buf, format="PNG")
        image_utf = base64.b64encode(buf.getvalue()).decode("utf-8")
        rows.append({"filename": names[0].split(".")[0], "image_utf": image_utf})

submission = pd.DataFrame(rows)
submission_path = MODEL_DIR / "submission.csv"
submission.to_csv(submission_path, index=False)
print(f"Saved submission to {submission_path}")

