In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import segmentation_models_pytorch as smp

In [17]:
MODEL_PATH = "model_weights.pth"

In [2]:
sample_submission = False

In [3]:
transform = transforms.Compose([
                    transforms.PILToTensor()
])

In [4]:
def encoded_pixels_to_masks(fname: str, df: pd.DataFrame):
    fname_df = df[df['ImageId'] == fname]
    masks = np.zeros((256 * 1600, 4), dtype=int) # float32 is V.Imp

    for i_row, row in fname_df.iterrows():
        cls_id = row['ClassId']
        encoded_pixels = row['EncodedPixels']
        if encoded_pixels is not np.nan:
            pixel_list = list(map(int, encoded_pixels.split(' ')))
            for i in range(0, len(pixel_list), 2):
                start_pixel = pixel_list[i] - 1
                num_pixel = pixel_list[i+1]
                masks[start_pixel:(start_pixel+num_pixel), cls_id-1] = 1
               
    masks = masks.reshape(256, 1600, 4, order='F')

    return masks

def masks_to_encoded_pixels(masks: np.ndarray):
    masks = masks.reshape(256*1600, 4, order='F')
    encoded_pixels_list = []
    for cls_id in range(4):
        cls_mask = masks[:, cls_id]
        cls_mask = cls_mask.reshape(256, 1600, order='F')
        cls_mask = cls_mask.T.flatten()
        prev_pixel = 0
        prev_pixel_val = 0
        encoded_pixels = []
        for i, pixel_val in enumerate(cls_mask):
            if pixel_val != prev_pixel_val:
                if pixel_val == 1:
                    start_pixel = i + 1
                    encoded_pixels.append(start_pixel - prev_pixel)
                else:
                    num_pixel = i - prev_pixel
                    encoded_pixels.append(num_pixel)
                prev_pixel = i
                prev_pixel_val = pixel_val
        encoded_pixels_list.append(encoded_pixels)
    return encoded_pixels_list # shape: 4x[]

### Solution

In [5]:
class SeverstalSteelDataset(Dataset):
    def __init__(self, df, img_dir, transform):
        self.df = df.reset_index(drop=True) 
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        fname = self.df.ImageId[idx]
        img_path = os.path.join(self.img_dir, fname)
        img = Image.open(img_path)
        img = np.array(Image.open(img_path).convert('RGB')) 
        masks = encoded_pixels_to_masks(fname, self.df)
        img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
        masks = torch.tensor(masks, dtype=torch.float32).permute(2, 0, 1)
        return fname, img, masks
    
# collate function if needed
def collate_fn(batch_items):
    batched_fnames = [item[0] for item in batch_items]
    batched_imgs = torch.stack([item[1] for item in batch_items])
    batched_masks = torch.stack([item[2] for item in batch_items])
    return batched_fnames, batched_imgs, batched_masks
    

In [6]:
class SegModel(torch.nn.Module):
    def __init__(self, num_classes=4):
        super(SegModel, self).__init__()
        self.model = smp.Unet(encoder_name='resnet34', encoder_weights='imagenet', classes=num_classes, activation=None)
    def forward(self, x):
        return self.model(x)

In [7]:
def dice_score(preds, targets):
    intersection = np.sum(preds * targets)
    denominator = np.sum(preds) + np.sum(targets)
    if denominator == 0:
        return 1.0 
    return (2 * intersection) / denominator


In [8]:
def load_data(csv_path, img_folder_path, batch_size=4, val_split=0.2):
    if sample_submission:
        df = pd.read_csv(csv_path).sample(frac=0.01, random_state=10)
    else:
        df = pd.read_csv(csv_path)
    train_df, val_df = train_test_split(df, test_size=val_split, random_state=42)
    
    # Создаем датасеты
    train_dataset = SeverstalSteelDataset(train_df, img_folder_path, transform=transform)
    val_dataset = SeverstalSteelDataset(val_df, img_folder_path, transform=transform)
    
    # Создаем загрузчики
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, val_loader

In [9]:
def init_model(device):
    model = SegModel().to(device) 
    model.load_state_dict(torch.load(MODEL_PATH))
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3)
    return model, criterion, optimizer, scheduler

def train(model, loader, optimizer, criterion, device):
    
    model.train()  
    train_loss = 0.0
    for batch_idx, (fnames, imgs, masks) in enumerate(loader):
        imgs = imgs.to(device)
        optimizer.zero_grad()
        
        outputs = model(imgs).cpu()
        
        loss = criterion(outputs, masks)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        if(batch_idx % 10 == 0):
            print("Batch #{0} : [train_loss : {1}]".format(batch_idx, loss.item()))
    # возвращаем средний лосс
    return train_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    
    dice_scores = [[], [], [], []]
                                 
    with torch.no_grad():
        for batch_idx, (fnames, imgs, masks) in enumerate(val_loader):
            imgs = imgs.to(device)
            masks = masks.to(device)
            
            outputs = model(imgs)
            
            preds = (torch.sigmoid(outputs) > 0.5).cpu().numpy()
            
            loss = criterion(outputs, masks)
            
            masks = masks.cpu().numpy()
            
            val_loss += loss.item()
            
            if(batch_idx % 10 == 0):
                print("Batch #{0} : [val_loss : {1}]".format(batch_idx, loss.item()))
                
            for idx in range(len(fnames)):
                for cls_id in range(masks.shape[1]):
                    score = dice_score(preds[idx, cls_id], masks[idx, cls_id])
                    dice_scores[cls_id].append(score)
            
    avg_loss = val_loss / len(loader)
    avg_dice = []
    for class_dice in dice_scores:  
        avg_dice.append(sum(class_dice) / len(class_dice) if class_dice else 0.0)
    
    return avg_loss, avg_dice

def fit(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=10):
    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, optimizer, criterion, device)
        val_loss, score = validate(model, val_loader, criterion, device)
        print("Epoch #{0}: [val_loss : {1}, train_loss: {2}, dice_score: {3}]".format(epoch, val_loss, train_loss, score))
        scheduler.step(val_loss)


In [14]:
def evaluate(model, test_images_dir, device): 
    model.eval()

    # Список для хранения результатов
    results = []
    
    image_ids = [f for f in os.listdir(test_images_dir)]

    if sample_submission:
        image_ids = image_ids[:10]
    
    with torch.no_grad():
        for idx, image_id in enumerate(image_ids):
            # Чтение изображения
            image_path = os.path.join(test_images_dir, image_id)
            image = np.array(Image.open(image_path).convert('RGB')) 
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
            image = image.unsqueeze(0).to(device)

            outputs = model(image)
            
            outputs = (torch.sigmoid(outputs) > 0.5).squeeze(0).cpu().numpy() # Преобразуем к Numpy
            
            encoded_pixels = masks_to_encoded_pixels(outputs) # Преобразование в EncodedPixels
            
            for cls_id in range(4):
                
                enc_pixels = " ".join(str(x) for x in encoded_pixels[cls_id])

                if len(encoded_pixels) > 1: # Если маска непустая
                    results.append({
                    'ImageId': image_id,
                    'EncodedPixels': enc_pixels,
                    'ClassId': cls_id
                })
            if idx % 10 == 0:
                print(idx, " image saved")

    return pd.DataFrame(results, columns=['ImageId', 'EncodedPixels', 'ClassId'])

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Used device: {0}".format(device))
model, criterion, optimizer, scheduler = init_model(device)
train_loader, val_loader = load_data("../data/train.csv", "../data/train_images")

Used device: cuda


In [12]:
fit(model, train_loader, val_loader, criterion, optimizer, scheduler, device)


Batch #0 : [train_loss : 0.7626004219055176]
Batch #10 : [train_loss : 0.6577332019805908]
Batch #20 : [train_loss : 0.604621410369873]
Batch #30 : [train_loss : 0.5698618292808533]
Batch #40 : [train_loss : 0.5290179252624512]
Batch #50 : [train_loss : 0.49889373779296875]
Batch #60 : [train_loss : 0.4816112518310547]
Batch #70 : [train_loss : 0.44677191972732544]
Batch #80 : [train_loss : 0.42583170533180237]
Batch #90 : [train_loss : 0.39926156401634216]
Batch #100 : [train_loss : 0.3806115686893463]
Batch #110 : [train_loss : 0.38451921939849854]
Batch #120 : [train_loss : 0.3717893958091736]
Batch #130 : [train_loss : 0.34182533621788025]
Batch #140 : [train_loss : 0.35363543033599854]
Batch #150 : [train_loss : 0.3330015540122986]
Batch #160 : [train_loss : 0.2956395447254181]
Batch #170 : [train_loss : 0.30956825613975525]
Batch #180 : [train_loss : 0.26471972465515137]
Batch #190 : [train_loss : 0.26607948541641235]
Batch #200 : [train_loss : 0.24101632833480835]
Batch #210 : [

Batch #300 : [val_loss : 0.04161177948117256]
Batch #310 : [val_loss : 0.037894394248723984]
Batch #320 : [val_loss : 0.029679762199521065]
Batch #330 : [val_loss : 0.029879961162805557]
Batch #340 : [val_loss : 0.028516573831439018]
Batch #350 : [val_loss : 0.027262873947620392]
Epoch #0: [val_loss : 0.03778276397833522, train_loss: 0.13459662556217705, dice_score: [0.8661028893587033, 0.9626497533474278, 0.5216366167253804, 0.9145209756694636]]
Batch #0 : [train_loss : 0.028890933841466904]
Batch #10 : [train_loss : 0.03144841641187668]
Batch #20 : [train_loss : 0.07448337972164154]
Batch #30 : [train_loss : 0.03285969793796539]
Batch #40 : [train_loss : 0.049263812601566315]
Batch #50 : [train_loss : 0.027666324749588966]
Batch #60 : [train_loss : 0.027183866128325462]
Batch #70 : [train_loss : 0.049479227513074875]
Batch #80 : [train_loss : 0.019576529040932655]
Batch #90 : [train_loss : 0.03004414029419422]
Batch #100 : [train_loss : 0.03105245530605316]
Batch #110 : [train_loss :

Batch #190 : [val_loss : 0.04201583191752434]
Batch #200 : [val_loss : 0.02466607093811035]
Batch #210 : [val_loss : 0.016889482736587524]
Batch #220 : [val_loss : 0.014693159610033035]
Batch #230 : [val_loss : 0.010061592794954777]
Batch #240 : [val_loss : 0.06341966241598129]
Batch #250 : [val_loss : 0.022374026477336884]
Batch #260 : [val_loss : 0.025150852277874947]
Batch #270 : [val_loss : 0.02395119145512581]
Batch #280 : [val_loss : 0.014148961752653122]
Batch #290 : [val_loss : 0.011687014251947403]
Batch #300 : [val_loss : 0.0324457511305809]
Batch #310 : [val_loss : 0.011063218116760254]
Batch #320 : [val_loss : 0.020572425797581673]
Batch #330 : [val_loss : 0.012857000343501568]
Batch #340 : [val_loss : 0.013867414556443691]
Batch #350 : [val_loss : 0.025312252342700958]
Epoch #1: [val_loss : 0.027710220514153932, train_loss: 0.03274214466859162, dice_score: [0.8661028893587033, 0.9626497533474278, 0.5570582850356204, 0.8944923164608112]]
Batch #0 : [train_loss : 0.016195895

Batch #70 : [val_loss : 0.011677665635943413]
Batch #80 : [val_loss : 0.01957242749631405]
Batch #90 : [val_loss : 0.08672097325325012]
Batch #100 : [val_loss : 0.03454786539077759]
Batch #110 : [val_loss : 0.014654017053544521]
Batch #120 : [val_loss : 0.025693945586681366]
Batch #130 : [val_loss : 0.019027888774871826]
Batch #140 : [val_loss : 0.01042219065129757]
Batch #150 : [val_loss : 0.011385872960090637]
Batch #160 : [val_loss : 0.016724467277526855]
Batch #170 : [val_loss : 0.01547744870185852]
Batch #180 : [val_loss : 0.009884670376777649]
Batch #190 : [val_loss : 0.04975428432226181]
Batch #200 : [val_loss : 0.017498785629868507]
Batch #210 : [val_loss : 0.015784990042448044]
Batch #220 : [val_loss : 0.016220424324274063]
Batch #230 : [val_loss : 0.006781738251447678]
Batch #240 : [val_loss : 0.04433970898389816]
Batch #250 : [val_loss : 0.032771166414022446]
Batch #260 : [val_loss : 0.022991737350821495]
Batch #270 : [val_loss : 0.02698502317070961]
Batch #280 : [val_loss :

Batch #1370 : [train_loss : 0.020183200016617775]
Batch #1380 : [train_loss : 0.01692214235663414]
Batch #1390 : [train_loss : 0.013176705688238144]
Batch #1400 : [train_loss : 0.008166421204805374]
Batch #1410 : [train_loss : 0.026831990107893944]
Batch #0 : [val_loss : 0.012083376757800579]
Batch #10 : [val_loss : 0.021646656095981598]
Batch #20 : [val_loss : 0.040362633764743805]
Batch #30 : [val_loss : 0.03983332961797714]
Batch #40 : [val_loss : 0.020985670387744904]
Batch #50 : [val_loss : 0.015129650942981243]
Batch #60 : [val_loss : 0.02437410317361355]
Batch #70 : [val_loss : 0.012634112499654293]
Batch #80 : [val_loss : 0.018606875091791153]
Batch #90 : [val_loss : 0.09323110431432724]
Batch #100 : [val_loss : 0.02970234304666519]
Batch #110 : [val_loss : 0.013072889298200607]
Batch #120 : [val_loss : 0.029631977900862694]
Batch #130 : [val_loss : 0.013017160817980766]
Batch #140 : [val_loss : 0.011133487336337566]
Batch #150 : [val_loss : 0.009568369947373867]
Batch #160 : [

Batch #1260 : [train_loss : 0.03826487436890602]
Batch #1270 : [train_loss : 0.01087859831750393]
Batch #1280 : [train_loss : 0.021746259182691574]
Batch #1290 : [train_loss : 0.009845632128417492]
Batch #1300 : [train_loss : 0.042508259415626526]
Batch #1310 : [train_loss : 0.010984268970787525]
Batch #1320 : [train_loss : 0.00872707273811102]
Batch #1330 : [train_loss : 0.008875473402440548]
Batch #1340 : [train_loss : 0.01180780865252018]
Batch #1350 : [train_loss : 0.008335242047905922]
Batch #1360 : [train_loss : 0.015868790447711945]
Batch #1370 : [train_loss : 0.0425594225525856]
Batch #1380 : [train_loss : 0.016439639031887054]
Batch #1390 : [train_loss : 0.020961813628673553]
Batch #1400 : [train_loss : 0.026019711047410965]
Batch #1410 : [train_loss : 0.011651448905467987]
Batch #0 : [val_loss : 0.008898228406906128]
Batch #10 : [val_loss : 0.019073486328125]
Batch #20 : [val_loss : 0.02452024444937706]
Batch #30 : [val_loss : 0.033770348876714706]
Batch #40 : [val_loss : 0.0

Batch #1150 : [train_loss : 0.03393902629613876]
Batch #1160 : [train_loss : 0.010685627348721027]
Batch #1170 : [train_loss : 0.020422449335455894]
Batch #1180 : [train_loss : 0.006502125412225723]
Batch #1190 : [train_loss : 0.019199008122086525]
Batch #1200 : [train_loss : 0.017160629853606224]
Batch #1210 : [train_loss : 0.028772158548235893]
Batch #1220 : [train_loss : 0.02810378558933735]
Batch #1230 : [train_loss : 0.0037524986546486616]
Batch #1240 : [train_loss : 0.01931942068040371]
Batch #1250 : [train_loss : 0.014380257576704025]
Batch #1260 : [train_loss : 0.030291475355625153]
Batch #1270 : [train_loss : 0.014339889399707317]
Batch #1280 : [train_loss : 0.019069544970989227]
Batch #1290 : [train_loss : 0.018326206132769585]
Batch #1300 : [train_loss : 0.020561328157782555]
Batch #1310 : [train_loss : 0.01863209530711174]
Batch #1320 : [train_loss : 0.021614762023091316]
Batch #1330 : [train_loss : 0.01480154786258936]
Batch #1340 : [train_loss : 0.010750816203653812]
Batc

Batch #1040 : [train_loss : 0.006651661824434996]
Batch #1050 : [train_loss : 0.007827481254935265]
Batch #1060 : [train_loss : 0.01288979034870863]
Batch #1070 : [train_loss : 0.020734870806336403]
Batch #1080 : [train_loss : 0.024761946871876717]
Batch #1090 : [train_loss : 0.0052855354733765125]
Batch #1100 : [train_loss : 0.008844631724059582]
Batch #1110 : [train_loss : 0.019330281764268875]
Batch #1120 : [train_loss : 0.027378344908356667]
Batch #1130 : [train_loss : 0.016946015879511833]
Batch #1140 : [train_loss : 0.012909449636936188]
Batch #1150 : [train_loss : 0.01310032606124878]
Batch #1160 : [train_loss : 0.010933343321084976]
Batch #1170 : [train_loss : 0.013892863877117634]
Batch #1180 : [train_loss : 0.018970148637890816]
Batch #1190 : [train_loss : 0.013681484386324883]
Batch #1200 : [train_loss : 0.015267877839505672]
Batch #1210 : [train_loss : 0.02529708296060562]
Batch #1220 : [train_loss : 0.0056982822716236115]
Batch #1230 : [train_loss : 0.008817337453365326]
B

Batch #920 : [train_loss : 0.018374701961874962]
Batch #930 : [train_loss : 0.01819668337702751]
Batch #940 : [train_loss : 0.009809855371713638]
Batch #950 : [train_loss : 0.005521182902157307]
Batch #960 : [train_loss : 0.010868269018828869]
Batch #970 : [train_loss : 0.013059007935225964]
Batch #980 : [train_loss : 0.004579611588269472]
Batch #990 : [train_loss : 0.003536611096933484]
Batch #1000 : [train_loss : 0.021525759249925613]
Batch #1010 : [train_loss : 0.010263088159263134]
Batch #1020 : [train_loss : 0.011089924722909927]
Batch #1030 : [train_loss : 0.01160399243235588]
Batch #1040 : [train_loss : 0.02175576612353325]
Batch #1050 : [train_loss : 0.00877447985112667]
Batch #1060 : [train_loss : 0.009869960136711597]
Batch #1070 : [train_loss : 0.012307289987802505]
Batch #1080 : [train_loss : 0.009233545511960983]
Batch #1090 : [train_loss : 0.0161836426705122]
Batch #1100 : [train_loss : 0.015932615846395493]
Batch #1110 : [train_loss : 0.020704740658402443]
Batch #1120 : 

Batch #800 : [train_loss : 0.010967868380248547]
Batch #810 : [train_loss : 0.016305720433592796]
Batch #820 : [train_loss : 0.01508375070989132]
Batch #830 : [train_loss : 0.015281030908226967]
Batch #840 : [train_loss : 0.011941354721784592]
Batch #850 : [train_loss : 0.018716227263212204]
Batch #860 : [train_loss : 0.008529397659003735]
Batch #870 : [train_loss : 0.0034706429578363895]
Batch #880 : [train_loss : 0.01587763987481594]
Batch #890 : [train_loss : 0.02137073315680027]
Batch #900 : [train_loss : 0.031760323792696]
Batch #910 : [train_loss : 0.011264965869486332]
Batch #920 : [train_loss : 0.018937556073069572]
Batch #930 : [train_loss : 0.017437128350138664]
Batch #940 : [train_loss : 0.016140326857566833]
Batch #950 : [train_loss : 0.0293872132897377]
Batch #960 : [train_loss : 0.009617308154702187]
Batch #970 : [train_loss : 0.01749020256102085]
Batch #980 : [train_loss : 0.00616403529420495]
Batch #990 : [train_loss : 0.014635932631790638]
Batch #1000 : [train_loss : 0

Batch #680 : [train_loss : 0.012073525227606297]
Batch #690 : [train_loss : 0.021978436037898064]
Batch #700 : [train_loss : 0.009431997314095497]
Batch #710 : [train_loss : 0.021817907691001892]
Batch #720 : [train_loss : 0.00966790970414877]
Batch #730 : [train_loss : 0.013509259559214115]
Batch #740 : [train_loss : 0.010914984159171581]
Batch #750 : [train_loss : 0.011276495642960072]
Batch #760 : [train_loss : 0.005025416612625122]
Batch #770 : [train_loss : 0.008777888491749763]
Batch #780 : [train_loss : 0.007618836127221584]
Batch #790 : [train_loss : 0.010167330503463745]
Batch #800 : [train_loss : 0.009150411933660507]
Batch #810 : [train_loss : 0.010707489214837551]
Batch #820 : [train_loss : 0.012420717626810074]
Batch #830 : [train_loss : 0.006156684830784798]
Batch #840 : [train_loss : 0.012803714722394943]
Batch #850 : [train_loss : 0.012558000162243843]
Batch #860 : [train_loss : 0.008732673712074757]
Batch #870 : [train_loss : 0.011761240661144257]
Batch #880 : [train_l

In [15]:
submission_df = evaluate(model, "../data/test_images", device)
submission_df.to_csv("my_submission.csv", index=False)
# submission_df

0  image saved
10  image saved
20  image saved
30  image saved
40  image saved
50  image saved
60  image saved
70  image saved
80  image saved
90  image saved
100  image saved
110  image saved
120  image saved
130  image saved
140  image saved
150  image saved
160  image saved
170  image saved
180  image saved
190  image saved
200  image saved
210  image saved
220  image saved
230  image saved
240  image saved
250  image saved
260  image saved
270  image saved
280  image saved
290  image saved
300  image saved
310  image saved
320  image saved
330  image saved
340  image saved
350  image saved
360  image saved
370  image saved
380  image saved
390  image saved
400  image saved
410  image saved
420  image saved
430  image saved
440  image saved
450  image saved
460  image saved
470  image saved
480  image saved
490  image saved
500  image saved
510  image saved
520  image saved
530  image saved
540  image saved
550  image saved
560  image saved
570  image saved
580  image saved
590  ima

4620  image saved
4630  image saved
4640  image saved
4650  image saved
4660  image saved
4670  image saved
4680  image saved
4690  image saved
4700  image saved
4710  image saved
4720  image saved
4730  image saved
4740  image saved
4750  image saved
4760  image saved
4770  image saved
4780  image saved
4790  image saved
4800  image saved
4810  image saved
4820  image saved
4830  image saved
4840  image saved
4850  image saved
4860  image saved
4870  image saved
4880  image saved
4890  image saved
4900  image saved
4910  image saved
4920  image saved
4930  image saved
4940  image saved
4950  image saved
4960  image saved
4970  image saved
4980  image saved
4990  image saved
5000  image saved
5010  image saved
5020  image saved
5030  image saved
5040  image saved
5050  image saved
5060  image saved
5070  image saved
5080  image saved
5090  image saved
5100  image saved
5110  image saved
5120  image saved
5130  image saved
5140  image saved
5150  image saved
5160  image saved
5170  imag

In [18]:
torch.save(model.state_dict(), MODEL_PATH)