In [1]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class DAGMDataset(Dataset):
    def __init__(self, image_dir, label_dir, mapping_file, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

        self.samples = []

        with open(mapping_file, 'r') as file:
            for line in file:
                parts = line.strip().split()
                if len(parts) < 5:
                    continue
                img_file = parts[2]
                label_flag = int(parts[1])
                label_file = parts[4] if label_flag == 1 else None
                self.samples.append((img_file, label_file))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_file, label_file = self.samples[idx]

        img_path = os.path.join(self.image_dir, img_file)
        image = Image.open(img_path).convert('L').resize((256, 256))
        image = np.array(image, dtype=np.float32) / 255.0
        image = torch.tensor(image).unsqueeze(0)  # Shape: [1, H, W]

        if label_file and label_file != "0":
            label_path = os.path.join(self.label_dir, label_file)
            mask = Image.open(label_path).convert('L').resize((256, 256))
            mask = (np.array(mask, dtype=np.uint8) > 127).astype(np.float32)
        else:
            mask = np.zeros((256, 256), dtype=np.float32)

        mask = torch.tensor(mask).unsqueeze(0)  # Shape: [1, H, W]

        return image, mask


In [2]:
from torch.utils.data import DataLoader

image_dir = "../static/DAGM_KaggleUpload/Class1/Train"
label_dir = "../static/DAGM_KaggleUpload/Class1/Train/Label"
mapping_file = "../static/DAGM_KaggleUpload/Class1/Train/Label/Labels.txt"

dataset = DAGMDataset(image_dir, label_dir, mapping_file)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Preview one sample
img, mask = dataset[0]
print("Image shape:", img.shape)
print("Mask shape :", mask.shape)


Image shape: torch.Size([1, 256, 256])
Mask shape : torch.Size([1, 256, 256])


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.down1 = conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)

        self.down2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.down3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.down4 = conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = conv_block(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = conv_block(1024, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = conv_block(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = conv_block(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = conv_block(128, 64)

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        d1 = self.down1(x)
        p1 = self.pool1(d1)

        d2 = self.down2(p1)
        p2 = self.pool2(d2)

        d3 = self.down3(p2)
        p3 = self.pool3(d3)

        d4 = self.down4(p3)
        p4 = self.pool4(d4)

        bn = self.bottleneck(p4)

        up4 = self.up4(bn)
        merge4 = torch.cat([up4, d4], dim=1)
        dec4 = self.dec4(merge4)

        up3 = self.up3(dec4)
        merge3 = torch.cat([up3, d3], dim=1)
        dec3 = self.dec3(merge3)

        up2 = self.up2(dec3)
        merge2 = torch.cat([up2, d2], dim=1)
        dec2 = self.dec2(merge2)

        up1 = self.up1(dec2)
        merge1 = torch.cat([up1, d1], dim=1)
        dec1 = self.dec1(merge1)

        return torch.sigmoid(self.final(dec1))


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(in_channels=1, out_channels=1).to(device)
print("Model initialized on:", device)


Model initialized on: cuda


In [5]:
import torch.optim as optim

criterion = nn.BCELoss()  # For binary segmentation
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [6]:
from torch.utils.data import random_split, DataLoader
# 1. Prepare Train & Validation Loaders

# Split dataset 80% train, 20% val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)


In [7]:
num_epochs = 50
#2. Training + Validation Loop

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        preds = model(images)
        loss = criterion(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    avg_train_loss = train_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)

            preds = model(images)
            loss = criterion(preds, masks)
            val_loss += loss.item() * images.size(0)

    avg_val_loss = val_loss / len(val_loader.dataset)

    print(f"[{epoch+1}/{num_epochs}] 🏋️ Train Loss: {avg_train_loss:.4f} | 🧪 Val Loss: {avg_val_loss:.4f}")


[1/50] 🏋️ Train Loss: 0.3119 | 🧪 Val Loss: 0.0375
[2/50] 🏋️ Train Loss: 0.0410 | 🧪 Val Loss: 0.0328
[3/50] 🏋️ Train Loss: 0.0358 | 🧪 Val Loss: 0.0288
[4/50] 🏋️ Train Loss: 0.0346 | 🧪 Val Loss: 0.0273
[5/50] 🏋️ Train Loss: 0.0320 | 🧪 Val Loss: 0.0278
[6/50] 🏋️ Train Loss: 0.0327 | 🧪 Val Loss: 0.0264
[7/50] 🏋️ Train Loss: 0.0308 | 🧪 Val Loss: 0.0260
[8/50] 🏋️ Train Loss: 0.0309 | 🧪 Val Loss: 0.0344
[9/50] 🏋️ Train Loss: 0.0308 | 🧪 Val Loss: 0.0302
[10/50] 🏋️ Train Loss: 0.0292 | 🧪 Val Loss: 0.0253
[11/50] 🏋️ Train Loss: 0.0298 | 🧪 Val Loss: 0.0260
[12/50] 🏋️ Train Loss: 0.0302 | 🧪 Val Loss: 0.0266
[13/50] 🏋️ Train Loss: 0.0293 | 🧪 Val Loss: 0.0249
[14/50] 🏋️ Train Loss: 0.0284 | 🧪 Val Loss: 0.0288
[15/50] 🏋️ Train Loss: 0.0286 | 🧪 Val Loss: 0.0242
[16/50] 🏋️ Train Loss: 0.0285 | 🧪 Val Loss: 0.0239
[17/50] 🏋️ Train Loss: 0.0274 | 🧪 Val Loss: 0.0241
[18/50] 🏋️ Train Loss: 0.0275 | 🧪 Val Loss: 0.0244
[19/50] 🏋️ Train Loss: 0.0278 | 🧪 Val Loss: 0.0257
[20/50] 🏋️ Train Loss: 0.0272 | 🧪 Val Lo

In [8]:
torch.save(model.state_dict(), "unet_dagm_class1.pth")

In [22]:
import os
print("Saved at:", os.getcwd())

Saved at: /home/emre/Documents/GitHub/Digital-Image-Fault-Detection/src
