In [37]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import segmentation_models_pytorch as smp

print("GPU: ", torch.cuda.is_available())
print(torch.cuda.empty_cache())
print(torch.__version__)
print(torch.cuda.is_available())

GPU:  True
None
2.5.1+rocm6.2
True


In [38]:
transform = transforms.Compose([
                    transforms.PILToTensor()
])

In [39]:
def encoded_pixels_to_masks(fname: str, df: pd.DataFrame):
    fname_df = df[df['ImageId'] == fname]
    masks = np.zeros((256 * 1600, 4), dtype=int) # float32 is V.Imp

    for i_row, row in fname_df.iterrows():
        cls_id = row['ClassId']
        encoded_pixels = row['EncodedPixels']
        if encoded_pixels is not np.nan:
            pixel_list = list(map(int, encoded_pixels.split(' ')))
            for i in range(0, len(pixel_list), 2):
                start_pixel = pixel_list[i] - 1
                num_pixel = pixel_list[i+1]
                masks[start_pixel:(start_pixel+num_pixel), cls_id-1] = 1
                
    masks = masks.reshape(256, 1600, 4, order='F')

    return masks

def masks_to_encoded_pixels(masks: np.ndarray):
    masks = masks.reshape(256*1600, 4, order='F')
    encoded_pixels_list = []
    for cls_id in range(4):
        cls_mask = masks[:, cls_id]
        cls_mask = cls_mask.reshape(256, 1600, order='F')
        cls_mask = cls_mask.T.flatten()
        prev_pixel = 0
        prev_pixel_val = 0
        encoded_pixels = []
        for i, pixel_val in enumerate(cls_mask):
            if pixel_val != prev_pixel_val:
                if pixel_val == 1:
                    start_pixel = i + 1
                    encoded_pixels.append(start_pixel - prev_pixel)
                else:
                    num_pixel = i - prev_pixel
                    encoded_pixels.append(num_pixel)
                prev_pixel = i
                prev_pixel_val = pixel_val
        encoded_pixels_list.append(encoded_pixels)
    return encoded_pixels_list # shape: 4x[]

### Solution

In [40]:
class SeverstalSteelDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True) 
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        fname = self.df.ImageId[idx]
        img_path = os.path.join(self.img_dir, fname)
        img = Image.open(img_path)
        img = np.array(Image.open(img_path).convert('RGB')) 
        masks = encoded_pixels_to_masks(img_path, self.df)
        
        img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
        masks = torch.tensor(masks, dtype=torch.float32).permute(2, 0, 1)
        # if self.transform:
        #     img = self.transform(img)
        
        return fname, img, masks
    
# collate function if needed
def collate_fn(batch_items):
    batched_fnames = [item[0] for item in batch_items]
    batched_imgs = torch.stack([item[1] for item in batch_items])
    batched_masks = torch.stack([item[2] for item in batch_items])
    return batched_fnames, batched_imgs, batched_masks
    

In [41]:
class SegModel(torch.nn.Module):
    def __init__(self):
        super(SegModel, self).__init__()
        self.model = smp.Unet(classes=4)
    def forward(self, x):
        return self.model(x)

# class SegModel(torch.nn.Module):
#     def __init__(self):
#         super(SegModel, self).__init__()
#         self.model = smp.DeepLabV3Plus(
#             encoder_name="mobilenet_v2",  # Лёгкий бэкбон
#             encoder_weights="imagenet",  # Предобученные веса
#             classes=4,                   # Количество классов
#             activation=None              # Без активации
#         )

#     def forward(self, x):
#         return self.model(x)

In [42]:
def dice_score(preds, targets, smooth=1e-6):
    preds = preds.reshape(-1)
    targets = targets.reshape(-1)
    
    intersection = (preds * targets).sum()
    dice = (2.0 * intersection + smooth) / (preds.sum() + targets.sum() + smooth)
    return dice.item()


In [43]:
def load_data(csv_path, img_folder_path, batch_size=4, test_split=0.1, val_split=0.2, frac=1):
    df = pd.read_csv(csv_path).sample(frac=frac, random_state=42).reset_index(drop=True)
    train_val_df, test_df = train_test_split(df, test_size=test_split, random_state=42)
    train_df, val_df = train_test_split(train_val_df, test_size=val_split, random_state=42)
    
    # Создаем датасеты
    train_dataset = SeverstalSteelDataset(train_df, img_folder_path, transform=transform)
    val_dataset = SeverstalSteelDataset(val_df, img_folder_path, transform=transform)
    test_dataset = SeverstalSteelDataset(test_df, img_folder_path, transform=transform)
    
    # Создаем загрузчики
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader, test_loader

In [44]:
def init_model():
    model = SegModel() 
    return model, torch.nn.BCEWithLogitsLoss(), torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

def train(model, criterion, optimizer, loader, device, epochs=10):
    model.to(device)
    model.train()  # Устанавливаем модель в режим обучения
    for epoch in range(epochs):
        running_loss = 0.0
        batch_idx = 0
        for fname, imgs, masks in loader:
            batch_idx += 1
            imgs = imgs.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)

            loss = criterion(outputs, masks)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Печатаем статистику по ходу обучения
            if batch_idx % 10 == 0:  # Печатаем каждый 10-й батч
                print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(loader)}], Loss: {loss.item():.4f}')

        # Выводим средний loss за эпоху
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(loader):.4f}')

    return model

def validate(model, val_loader, device):
    model.eval()
    dice_scores = []
    with torch.no_grad():
        for fnames, imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            preds = torch.sigmoid(outputs) > 0.5 
            preds = preds.cpu().numpy()
            masks = batch_masks.cpu().numpy()

            for i in range(len(fnames)):
                for cls in range(4):
                    dice = dice_score(preds[i, cls_id], masks[i, cls_id])
                    dice_scores.append(dice)

    mean_dice = np.mean(dice_scores)
    print(f"Avg dice score: {mean_dice:.4f}")
    for cls_id in range(4):
        cls_dice = np.mean([score for i, score in enumerate(dice_scores) if i % 4 == cls_id])
        print(f"Dice Score for Class {cls_id + 1}: {cls_dice:.4f}")

def evaluate(model, loader, device): 
    model.eval()
    submission = []

    with torch.no_grad():
        for fnames, imgs, masks in test_loader:
            imgs = imgs.to(device)
            outputs = model(imgs)
            preds = torch.sigmoid(outputs) 

            preds = preds.cpu().numpy()
            for i, fname in enumerate(fnames):
                cur_submission = []
                for cls in range(4):
                    encoded_pixels = masks_to_encoded_pixels(preds[i, cls])
                    cur_submission.append((fname, cls + 1, encoded_pixels))

                submission.extend(cur_submission)
    return pd.DataFrame(submission, columns=['ImageId', 'ClassId', 'EncodedPixels'])

In [45]:
model, criterion, optimizer = init_model()
train_loader, val_loader, test_loader = load_data("../data/train.csv", "../data/train_images", frac=0.1)

In [None]:
device = torch.device("cpu")
model = train(model, criterion, optimizer, train_loader, device)
validate(model, test_loader, device)
submission_df = evaluate(model, test_loader, device)
submission_df.to_csv("my_submission.csv", index=False)

Epoch [1/10], Step [11/128], Loss: 0.6258
Epoch [1/10], Step [21/128], Loss: 0.5707
Epoch [1/10], Step [31/128], Loss: 0.5286
Epoch [1/10], Step [41/128], Loss: 0.4897
Epoch [1/10], Step [51/128], Loss: 0.4580
Epoch [1/10], Step [61/128], Loss: 0.4294
Epoch [1/10], Step [71/128], Loss: 0.4073
Epoch [1/10], Step [81/128], Loss: 0.3862
Epoch [1/10], Step [91/128], Loss: 0.3693
Epoch [1/10], Step [101/128], Loss: 0.3544
Epoch [1/10], Step [111/128], Loss: 0.3358
Epoch [1/10], Step [121/128], Loss: 0.3193
Epoch [1/10], Loss: 0.4439
Epoch [2/10], Step [11/128], Loss: 0.2948
Epoch [2/10], Step [21/128], Loss: 0.2822
Epoch [2/10], Step [31/128], Loss: 0.2685
Epoch [2/10], Step [41/128], Loss: 0.2563
Epoch [2/10], Step [51/128], Loss: 0.2450
Epoch [2/10], Step [61/128], Loss: 0.2341
Epoch [2/10], Step [71/128], Loss: 0.2228
Epoch [2/10], Step [81/128], Loss: 0.2137
Epoch [2/10], Step [91/128], Loss: 0.2039
Epoch [2/10], Step [101/128], Loss: 0.1956
Epoch [2/10], Step [111/128], Loss: 0.1867
Ep