In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score

In [2]:
class RetinalDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(self.image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        if self.transform:
            image = self.transform(image)
        mask = transforms.ToTensor()(mask)
        return image, mask  # Masks are returned as [1, H, W] if ToTensor is used correctly


In [3]:
# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create dataset instance
dataset = RetinalDataset("D:/junotbok/10th week/Data/train/image", "D:/junotbok/10th week/Data/train/mask", transform=transform)

# Create data loader
batch_size = 8
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)  # Adjust num_workers based on your system


In [4]:

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.enc1 = self.encoder_block(3, 64)
        self.enc2 = self.encoder_block(64, 128)
        self.enc3 = self.encoder_block(128, 256)
        self.enc4 = self.encoder_block(256, 512)
        self.bottleneck = self.encoder_block(512, 1024)
        self.upconv4 = self.upconv_block(1024, 512)
        self.upconv3 = self.upconv_block(512, 256)
        self.upconv2 = self.upconv_block(256, 128)
        self.upconv1 = self.upconv_block(128, 64)
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)  # 1 channel for binary segmentation

    def encoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        bottleneck = self.bottleneck(enc4)
        up4 = self.upconv4(bottleneck)
        up3 = self.upconv3(up4)
        up2 = self.upconv2(up3)
        up1 = self.upconv1(up2)
        out = self.final_conv(up1)
        out = torch.sigmoid(out)  # Apply sigmoid activation
        return out


In [5]:
# Create model instance
model = UNet()


In [6]:
# Set the number of CPU threads
torch.set_num_threads(4)


In [7]:
# Dataset loading with optimizations
data_loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)


In [8]:

torch.manual_seed(0)  # Ensure reproducibility before training


<torch._C.Generator at 0x1e0adf48670>

In [9]:
# Train the model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)


In [11]:
# Train the model
for epoch in range(10):
    model.train()
    total_loss = 0
    with tqdm(data_loader, unit="batch") as tepoch:
        for batch in tepoch:
            images, masks = batch
            images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
            
            # Check shapes
            print(f"Image shape: {images.shape}")
            print(f"Mask shape: {masks.shape}")
            
            # Forward pass: compute predicted outputs by passing inputs to the model
            outputs = model(images)
            # Optional: If you decide to upsample the model output to match the target masks:
            outputs = torch.nn.functional.interpolate(outputs, size=(512, 512), mode='bilinear', align_corners=False)
            
            # Ensure masks have the same shape as outputs
            #masks = masks.unsqueeze(1)  # Add channel dimension if not present
            
            # Check shapes
            print(f"Output shape: {outputs.shape}")
            print(f"Masked shape: {masks.shape}")
            
            # Calculate loss
            loss = criterion(outputs, masks)
            # Backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # Perform a single optimization step (parameter update)
            optimizer.step()
            # Zero the gradients for the next iteration
            optimizer.zero_grad()

            total_loss += loss.item()
            tepoch.set_postfix(loss=loss.item())
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(data_loader)}")


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▉                                                              | 1/10 [01:04<09:42, 64.68s/batch, loss=0.92]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [02:06<08:25, 63.23s/batch, loss=0.921]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [03:09<07:21, 63.01s/batch, loss=0.919]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [04:12<06:16, 62.79s/batch, loss=0.915]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [05:15<05:15, 63.11s/batch, loss=0.914]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [06:20<04:14, 63.51s/batch, loss=0.914]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [07:22<03:09, 63.31s/batch, loss=0.907]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [08:25<02:06, 63.14s/batch, loss=0.906]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [09:29<01:03, 63.46s/batch, loss=0.902]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [10:33<00:00, 63.34s/batch, loss=0.892]


Epoch 1, Loss: 0.9109952569007873


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▊                                                             | 1/10 [01:02<09:25, 62.81s/batch, loss=0.851]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [02:06<08:26, 63.35s/batch, loss=0.697]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [03:27<08:21, 71.59s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [04:47<07:28, 74.72s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [05:51<05:53, 70.80s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [06:54<04:33, 68.36s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [08:01<03:23, 67.87s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [09:08<02:15, 67.64s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [10:14<01:07, 67.06s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [11:19<00:00, 67.95s/batch, loss=0.693]


Epoch 2, Loss: 0.7092311918735504


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▊                                                             | 1/10 [01:07<10:08, 67.66s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [02:14<08:59, 67.45s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [03:21<07:50, 67.19s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [04:25<06:35, 65.91s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [05:29<05:26, 65.29s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [06:37<04:23, 66.00s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [07:45<03:19, 66.60s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [08:53<02:14, 67.08s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [10:01<01:07, 67.28s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [11:09<00:00, 66.94s/batch, loss=0.693]


Epoch 3, Loss: 0.6931473016738892


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▊                                                             | 1/10 [01:04<09:37, 64.19s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [02:12<08:52, 66.53s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [03:20<07:49, 67.07s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [04:27<06:43, 67.21s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [05:35<05:37, 67.47s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [06:43<04:30, 67.62s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [07:51<03:23, 67.88s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [08:59<02:15, 67.99s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [10:07<01:07, 67.95s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [11:13<00:00, 67.35s/batch, loss=0.693]


Epoch 4, Loss: 0.6931473016738892


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▊                                                             | 1/10 [01:07<10:04, 67.20s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [02:14<08:58, 67.31s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [03:22<07:52, 67.52s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [04:30<06:45, 67.65s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [05:37<05:37, 67.47s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [06:44<04:29, 67.42s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [07:52<03:23, 67.67s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [09:00<02:15, 67.55s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [10:07<01:07, 67.59s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [11:15<00:00, 67.56s/batch, loss=0.693]


Epoch 5, Loss: 0.6931473016738892


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▊                                                             | 1/10 [01:07<10:05, 67.28s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [02:14<08:57, 67.13s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [03:21<07:51, 67.30s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [04:30<06:46, 67.76s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [05:37<05:38, 67.73s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [06:44<04:29, 67.34s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [07:52<03:22, 67.48s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [08:59<02:15, 67.53s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [10:07<01:07, 67.61s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [11:15<00:00, 67.55s/batch, loss=0.693]


Epoch 6, Loss: 0.6931473016738892


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▊                                                             | 1/10 [01:07<10:06, 67.34s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [02:12<08:48, 66.11s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [03:20<07:48, 66.89s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [04:27<06:41, 66.92s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [05:35<05:37, 67.42s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [06:42<04:29, 67.33s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [07:50<03:22, 67.42s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [08:58<02:15, 67.65s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [10:06<01:07, 67.73s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [11:13<00:00, 67.33s/batch, loss=0.693]


Epoch 7, Loss: 0.6931473016738892


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▊                                                             | 1/10 [01:06<09:55, 66.21s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [02:13<08:56, 67.12s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [03:21<07:50, 67.24s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [04:28<06:43, 67.28s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [05:36<05:37, 67.44s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [06:44<04:30, 67.50s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [07:51<03:22, 67.50s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [08:58<02:14, 67.38s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [10:04<01:06, 66.85s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [11:09<00:00, 66.93s/batch, loss=0.693]


Epoch 8, Loss: 0.6931473016738892


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▊                                                             | 1/10 [01:07<10:10, 67.88s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [02:15<09:01, 67.75s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [02:57<06:30, 55.82s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [03:38<05:00, 50.04s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [04:12<03:41, 44.25s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [04:45<02:41, 40.46s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [05:18<01:53, 37.93s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [05:49<01:11, 35.99s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [06:21<00:34, 34.72s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [06:53<00:00, 41.32s/batch, loss=0.693]


Epoch 9, Loss: 0.6931473016738892


  0%|                                                                                        | 0/10 [00:00<?, ?batch/s]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 10%|██████▊                                                             | 1/10 [00:30<04:31, 30.22s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 20%|█████████████▌                                                      | 2/10 [01:00<04:01, 30.20s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 30%|████████████████████▍                                               | 3/10 [01:31<03:35, 30.82s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 40%|███████████████████████████▏                                        | 4/10 [02:01<03:02, 30.43s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 50%|██████████████████████████████████                                  | 5/10 [02:32<02:31, 30.36s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 60%|████████████████████████████████████████▊                           | 6/10 [03:03<02:03, 30.89s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 70%|███████████████████████████████████████████████▌                    | 7/10 [03:34<01:32, 30.82s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 80%|██████████████████████████████████████████████████████▍             | 8/10 [04:04<01:00, 30.48s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


 90%|█████████████████████████████████████████████████████████████▏      | 9/10 [04:33<00:29, 29.97s/batch, loss=0.693]

Image shape: torch.Size([8, 3, 512, 512])
Mask shape: torch.Size([8, 1, 512, 512])
Output shape: torch.Size([8, 1, 512, 512])
Masked shape: torch.Size([8, 1, 512, 512])


100%|███████████████████████████████████████████████████████████████████| 10/10 [05:05<00:00, 30.53s/batch, loss=0.693]

Epoch 10, Loss: 0.6931473016738892





In [12]:
# Save the model
torch.save(model.state_dict(), 'Retena_model_Binary.pth')
print("Model saved successfully!")

Model saved successfully!


In [13]:

# Assuming your dataset is for binary segmentation
def calculate_metrics(preds, targets):
    # Convert to numpy arrays for metric calculations
    preds = preds.cpu().numpy().flatten()
    targets = targets.cpu().numpy().flatten()
    
    # Binarize predictions
    preds_binary = (preds > 0.5).astype(int)

    # Calculate metrics
    precision = precision_score(targets, preds_binary)
    recall = recall_score(targets, preds_binary)
    f1 = f1_score(targets, preds_binary)
    iou = jaccard_score(targets, preds_binary)
    
    return precision, recall, f1, iou

In [14]:
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_dataset = RetinalDataset("D:/junotbok/10th week/Data/test/image", "D:/junotbok/10th week/Data/test/mask", transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Load your best saved model
model.load_state_dict(torch.load('Retena_model_Binary.pth'))
model.eval()  # Set the model to evaluation mode

# Evaluate the model
test_loss = 0
all_preds = []
all_targets = []

with torch.no_grad():
    for test_batch in test_loader:
        test_images, test_masks = test_batch
        test_images, test_masks = test_images.to(device), test_masks.to(device)
        test_outputs = model(test_images)
        test_outputs = F.interpolate(test_outputs, size=(512, 512), mode='bilinear', align_corners=False)
        
        # Compute loss
        t_loss = criterion(test_outputs, test_masks)
        test_loss += t_loss.item()

        # Collect predictions and targets
        all_preds.append(test_outputs)
        all_targets.append(test_masks)

avg_test_loss = test_loss / len(test_loader)



In [15]:
# Concatenate all batches
all_preds = torch.cat(all_preds)
all_targets = torch.cat(all_targets)

# Calculate metrics
precision, recall, f1, iou = calculate_metrics(all_preds, all_targets)

print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"IoU: {iou:.4f}")

ValueError: Classification metrics can't handle a mix of continuous and binary targets