In [1]:
import torch
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torch import nn
from torch.nn import functional
from PIL import Image
import os
from torchvision import datasets, transforms, models
from collections import defaultdict
import albumentations as A
import numpy as np
import cv2

### Finding different shapes and channels of images in the dataset

In [2]:
folder_path = r'D:\Suchit\Breast-Cancer-Detection\Dataset_BUSI_with_GT'
image_shapes, mask_shapes = set(), set()
channels = set()
images, masks = [], []
for folder in os.listdir(folder_path):
    for image in os.listdir(os.path.join(folder_path, folder)):
        with Image.open(os.path.join(folder_path, folder, image)) as img:
            if 'mask' in image:
                mask_shapes.add(img.size)
                masks.append(os.path.join(folder_path, folder, image))
            else:
                image_shapes.add(img.size)
                images.append(os.path.join(folder_path, folder, image))
            channels.add(img.mode)
print(
    f'image shape length - {len(image_shapes)} and mask shape length - {len(mask_shapes)}')
print(f'number of channels - {len(channels)}')
print(f'channels - {channels}')

image shape length - 639 and mask shape length - 639
number of channels - 3
channels - {'1', 'RGB', 'RGBA'}


### Finding the average height and width of the images to resize them

In [3]:
folder_path = r'D:\Suchit\Breast-Cancer-Detection\Dataset_BUSI_with_GT'
height, width, num_samples = 0.0, 0.0, 0.0
for folder in os.listdir(folder_path):
    for image in os.listdir(os.path.join(folder_path, folder)):
        with Image.open(os.path.join(folder_path, folder, image)) as img:
            if 'mask' not in image:
                size = img.size
                height += size[1]
                width += size[0]
                num_samples += 1
height /= num_samples
width /= num_samples
print(f'average height = {height}')
print(f'average width = {width}')

average height = 501.4525641025641
average width = 615.6794871794872


Height and width is too large. Taking 256 * 256

### Finding the mean and standard deviation of the input images

In [4]:
def calculate_grayscale_stats(folder_path, resize_shape=(256, 256)):
    """Calculate mean and std for grayscale images"""
    print("Calculating grayscale statistics...")

    # Simple transform for grayscale statistics calculation
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize(resize_shape),
        transforms.ToTensor()
    ])

    class GrayscaleImageData(Dataset):
        def __init__(self, images, transform):
            self.images = images
            self.transform = transform

        def __len__(self):
            return len(self.images)

        def __getitem__(self, index):
            img_path = self.images[index]
            image = Image.open(img_path)
            return self.transform(image)

    # Get image paths (exclude masks)
    images = []
    for folder in os.listdir(folder_path):
        for image in os.listdir(os.path.join(folder_path, folder)):
            if 'mask' not in image:
                images.append(os.path.join(folder_path, folder, image))

    print(f"Found {len(images)} images for statistics calculation")

    dataset = GrayscaleImageData(images, transform)
    loader = DataLoader(dataset, batch_size=64, shuffle=False)

    mean, std, num_samples = 0.0, 0.0, 0.0

    for data in loader:
        batch_size = data.size(0)
        data = data.view(batch_size, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        num_samples += batch_size

    mean /= num_samples
    std /= num_samples

    return mean.item(), std.item()


# Step 2: Calculate grayscale statistics for your dataset
grayscale_mean, grayscale_std = calculate_grayscale_stats(folder_path)
print(f'Grayscale mean = {grayscale_mean:.4f}, std = {grayscale_std:.4f}')

Calculating grayscale statistics...
Found 780 images for statistics calculation
Grayscale mean = 0.3279, std = 0.1998


### Managing the dataset

In [5]:
from albumentations.pytorch import ToTensorV2


class Data(Dataset):
    def __init__(self, folder_path, transforms):
        super().__init__()
        self.images, self.masks, self.category = [], defaultdict(list), []
        classes = {'benign': 0, 'malignant': 1, 'normal': 2}
        self.transforms = transforms

        for folder in os.listdir(folder_path):
            for image in os.listdir(os.path.join(folder_path, folder)):
                img_path = os.path.join(folder_path, folder, image)
                if 'mask' in image:
                    self.masks[image[: image.index('_mask')]].append(img_path)
                else:
                    self.images.append(img_path)
                    self.category.append(classes[folder])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        key = os.path.basename(image_path)[
            : os.path.basename(image_path).index('.png')]
        mask_paths = self.masks[key]

        # Load image and masks as numpy arrays
        image = np.array(Image.open(image_path).convert('L'))
        masks = [np.array(Image.open(x).convert('L')) for x in mask_paths]

        # Combine masks
        mask = np.sum(np.stack(masks), axis=0)
        mask = np.clip(mask, 0, 1)  # ensure binary

        # Apply albumentations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask, self.category[index]


transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=[grayscale_mean], std=[grayscale_std]),
    ToTensorV2()
])

dataset = Data(folder_path, transform)

### Calculating positive weight to upweight the positive pixels in BCE loss

In [6]:
def estimate_pos_weight(dataset):
    loader = DataLoader(dataset= dataset, batch_size=16, shuffle= False)
    pos, neg = 0, 0
    with torch.no_grad():
        for _, mask, _ in loader:
            mask = mask.flatten()
            mask = (mask > 0)
            pos += mask.sum().item()
            neg += mask.numel() - mask.sum().item()
    if pos == 0:
        return torch.tensor(1.0)
    return torch.tensor(neg / pos)

# making counting transforms for counting the pos weights

count_transform = A.Compose([
    A.Resize(256, 256, interpolation= cv2.INTER_NEAREST, mask_interpolation= cv2.INTER_NEAREST),
    A.Normalize(mean = [grayscale_mean], std= [grayscale_std]),
    A.pytorch.ToTensorV2()
])
count_dataset = Data(folder_path= folder_path, transforms= count_transform)
pos_weight = estimate_pos_weight(count_dataset)
pos_weight = pos_weight.to('cuda')
print(f'Estimated positive weight = {pos_weight}')

Estimated positive weight = 11.776581764221191


### Model

In [7]:
def init_unet_kaiming(module):
    if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
        if module.weight is not None:
            nn.init.kaiming_normal_(module.weight, nonlinearity= 'relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Linear):
        if module.weight is not None:
            nn.init.kaiming_normal_(module.weight, nonlinearity= 'relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.BatchNorm2d):
        if getattr(module, 'weight', None) is not None:
            nn.init.ones_(module.weight)
        if getattr(module, 'bias', None) is not None:
            nn.init.zeros_(module.bias)

def zero_init_seg_head(head):
    if isinstance(head, nn.Conv2d):
        nn.init.zeros_(head.weight)
        if head.bias is not None:
            nn.init.zeros_(head.bias)

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels,
                      kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.conv(x)


class UNetWithClassifier(nn.Module):
    def __init__(self, in_channels=1, seg_out_channels=1, num_classes=3, p_drop=0.2):
        super().__init__()
        # Encoder
        self.down1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)

        # Decoder (with concatenations)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.dec4 = DoubleConv(1024, 512)   # 512 up + 512 skip

        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec3 = DoubleConv(512, 256)    # 256 up + 256 skip

        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = DoubleConv(256, 128)    # 128 up + 128 skip

        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = DoubleConv(128, 64)     # 64 up + 64 skip

        self.seg_head = nn.Conv2d(64, seg_out_channels, kernel_size=1)

        # Classification head on bottleneck features
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.cls_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024, 512, bias=True),
            nn.Dropout(p_drop),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )
        self.apply(init_unet_kaiming)
        zero_init_seg_head(self.seg_head)

    def forward(self, x):
        d1 = self.down1(x)
        p1 = self.pool1(d1)

        d2 = self.down2(p1)
        p2 = self.pool2(d2)

        d3 = self.down3(p2)
        p3 = self.pool3(d3)

        d4 = self.down4(p3)
        p4 = self.pool4(d4)

        bn = self.bottleneck(p4)

        u4 = self.up4(bn)
        u4 = torch.cat([u4, d4], dim=1)
        u4 = self.dec4(u4)

        u3 = self.up3(u4)
        u3 = torch.cat([u3, d3], dim=1)
        u3 = self.dec3(u3)

        u2 = self.up2(u3)
        u2 = torch.cat([u2, d2], dim=1)
        u2 = self.dec2(u2)

        u1 = self.up1(u2)
        u1 = torch.cat([u1, d1], dim=1)
        u1 = self.dec1(u1)

        seg_logits = self.seg_head(u1)
        cls_logits = self.cls_head(self.gap(bn))
        return seg_logits, cls_logits

### Loss function

In [8]:
class DiceLoss(nn.Module):
    def __init__(self, smooth = 1e-6):
        super().__init__()
        self.smooth = smooth
    def forward(self, output_mask, targets):
        targets = targets.float()
        probs = torch.sigmoid(output_mask)
        # no. of batches
        B = output_mask.shape[0]
        probs_flatten = probs.view(B, -1)
        targs_flatten = targets.view(B, -1)
        intersection = (probs_flatten * targs_flatten).sum(dim = 1)
        p_sum = probs_flatten.sum(dim = 1)
        t_sum = targs_flatten.sum(dim = 1)
        # identifying empty target images
        empty = (t_sum == 0)
        # standard dice for non empty
        dice_non_empty = (2 * intersection + self.smooth) / (p_sum + t_sum + self.smooth)
        # defining dice for empty target images
        # if predictions also empty -> 1, else -> (smooth) / (psum + smooth)
        dice_empty = torch.where((p_sum == 0), torch.ones_like(p_sum), (self.smooth) / (p_sum + self.smooth))
        dice = torch.where(empty, dice_empty, dice_non_empty)
        return 1 - dice.mean()
    
class MultiTaskLoss(nn.Module):
    def __init__(self, bce_pos_weight):
        super().__init__()
        self.seg_weight = 1.0
        self.cls_weight = 0.5
        self.bce_weight = 0.5
        self.dice_weight = 0.5
        self.bce = nn.BCEWithLogitsLoss(pos_weight= bce_pos_weight)
        self.dice = DiceLoss()
        self.crossentropy = nn.CrossEntropyLoss()
    def forward(self, output_mask, target_mask, output_class_logits, class_logits):
        loss_bce = self.bce(output_mask, target_mask)
        loss_dice = self.dice(output_mask, target_mask)
        
        loss_mask = self.bce_weight * loss_bce + self.dice_weight * loss_dice

        loss_classification = self.crossentropy(output_class_logits  , class_logits.long())

        total = self.seg_weight * loss_mask + self.cls_weight * loss_classification

        return total



### Initializations

In [9]:
# model
model = UNetWithClassifier()
model = model.to('cuda')

# loss function
criterion = MultiTaskLoss(pos_weight)

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr= 0.001, weight_decay= 0.0001)

# train test split
n_total = len(dataset)
n_val = int(0.2 * n_total)
n_train = n_total - n_val
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [n_train, n_val], generator= torch.Generator().manual_seed(42))

# weighted random sampler
labels = np.array([dataset.category[i] for i in train_dataset.indices])
class_counts = np.bincount(labels, minlength= 3)
class_weights = class_counts.sum() / (len(class_counts) * class_counts.clip(min= 1))
sample_weights = [class_weights[x] for x in labels]

sampler = WeightedRandomSampler(weights= torch.tensor(sample_weights), num_samples= len(sample_weights), replacement= True)

# DataLoaders
train_loader = DataLoader(dataset= train_dataset, batch_size= 16, sampler= sampler, pin_memory= True)
test_loader = DataLoader(dataset= test_dataset, batch_size= 16, shuffle= False, pin_memory= True)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer= optimizer, mode= 'min', factor = 0.1, patience= 10, cooldown= 4)

### Training

In [10]:
from tqdm import trange

train_loss, test_loss = [], []

num_epochs = 500
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in trange(num_epochs, desc="Epochs"):
    model.train()
    running_loss = 0.0
    for images, masks, labels in train_loader:
        images = images.to('cuda')
        masks = masks.to('cuda').view(-1, 1, 256, 256).float()
        labels = labels.to('cuda')

        optimizer.zero_grad()
        mask_pred, output_class_logits = model(images)
        loss = criterion(mask_pred, masks, output_class_logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, masks, labels in test_loader:
            images = images.to('cuda')
            masks = masks.to('cuda').view(-1, 1, 256, 256).float()
            labels = labels.to('cuda')
            mask_pred, output_class_logits = model(images)
            loss = criterion(mask_pred, masks, output_class_logits, labels)
            val_loss += loss.item()
            output_classes = output_class_logits.argmax(dim= 1)
            correct += (output_classes == labels).sum().item()
            total += labels.size(0)
    val_loss /= len(test_loader)

    scheduler.step(val_loss)

    val_acc = 100 * correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.6f} | Val Loss: {val_loss:.6f} | Val Acc: {val_acc:.2f}%")
    
    train_loss.append(avg_loss)
    test_loss.append(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'model.pt')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

Epochs:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch [1/500] Train Loss: 1.700441 | Val Loss: 1.677556 | Val Acc: 47.44%


Epochs:   0%|          | 1/500 [00:36<5:00:13, 36.10s/it]

Epoch [2/500] Train Loss: 1.464693 | Val Loss: 1.584409 | Val Acc: 54.49%


Epochs:   0%|          | 2/500 [01:11<4:56:08, 35.68s/it]

Epoch [3/500] Train Loss: 1.400320 | Val Loss: 1.578504 | Val Acc: 37.18%


Epochs:   1%|          | 3/500 [01:46<4:52:25, 35.30s/it]

Epoch [4/500] Train Loss: 1.359732 | Val Loss: 1.498771 | Val Acc: 51.92%


Epochs:   1%|          | 4/500 [02:21<4:50:02, 35.08s/it]

Epoch [5/500] Train Loss: 1.234397 | Val Loss: 1.299418 | Val Acc: 68.59%


Epochs:   1%|          | 5/500 [02:55<4:47:41, 34.87s/it]

Epoch [6/500] Train Loss: 1.276256 | Val Loss: 1.198629 | Val Acc: 73.72%


Epochs:   1%|▏         | 7/500 [04:04<4:44:36, 34.64s/it]

Epoch [7/500] Train Loss: 1.211137 | Val Loss: 1.947131 | Val Acc: 58.97%


Epochs:   2%|▏         | 8/500 [04:38<4:43:10, 34.53s/it]

Epoch [8/500] Train Loss: 1.174480 | Val Loss: 1.267398 | Val Acc: 73.72%


Epochs:   2%|▏         | 9/500 [05:13<4:41:54, 34.45s/it]

Epoch [9/500] Train Loss: 1.195057 | Val Loss: 1.819574 | Val Acc: 44.23%
Epoch [10/500] Train Loss: 1.121153 | Val Loss: 1.175433 | Val Acc: 71.79%


Epochs:   2%|▏         | 10/500 [05:48<4:42:26, 34.59s/it]

Epoch [11/500] Train Loss: 1.160682 | Val Loss: 1.142873 | Val Acc: 76.92%


Epochs:   2%|▏         | 12/500 [06:56<4:40:27, 34.48s/it]

Epoch [12/500] Train Loss: 1.072226 | Val Loss: 1.681926 | Val Acc: 56.41%


Epochs:   3%|▎         | 13/500 [07:31<4:39:20, 34.42s/it]

Epoch [13/500] Train Loss: 1.085912 | Val Loss: 1.145668 | Val Acc: 69.23%
Epoch [14/500] Train Loss: 1.082605 | Val Loss: 1.092855 | Val Acc: 73.72%


Epochs:   3%|▎         | 15/500 [08:40<4:38:23, 34.44s/it]

Epoch [15/500] Train Loss: 1.018676 | Val Loss: 1.158118 | Val Acc: 75.64%


Epochs:   3%|▎         | 16/500 [09:14<4:37:28, 34.40s/it]

Epoch [16/500] Train Loss: 1.018000 | Val Loss: 1.151738 | Val Acc: 71.79%


Epochs:   3%|▎         | 17/500 [09:48<4:35:52, 34.27s/it]

Epoch [17/500] Train Loss: 1.121794 | Val Loss: 1.883450 | Val Acc: 51.28%
Epoch [18/500] Train Loss: 1.010612 | Val Loss: 1.029896 | Val Acc: 79.49%


Epochs:   4%|▍         | 19/500 [10:56<4:34:26, 34.23s/it]

Epoch [19/500] Train Loss: 0.973401 | Val Loss: 1.113977 | Val Acc: 76.92%


Epochs:   4%|▍         | 20/500 [11:31<4:34:02, 34.26s/it]

Epoch [20/500] Train Loss: 1.016933 | Val Loss: 1.116111 | Val Acc: 73.08%


Epochs:   4%|▍         | 21/500 [12:05<4:33:20, 34.24s/it]

Epoch [21/500] Train Loss: 0.979032 | Val Loss: 1.430634 | Val Acc: 50.00%


Epochs:   4%|▍         | 22/500 [12:39<4:32:49, 34.25s/it]

Epoch [22/500] Train Loss: 0.918728 | Val Loss: 1.151081 | Val Acc: 73.08%


Epochs:   5%|▍         | 23/500 [13:14<4:32:45, 34.31s/it]

Epoch [23/500] Train Loss: 0.998915 | Val Loss: 1.129789 | Val Acc: 75.64%


Epochs:   5%|▍         | 24/500 [13:48<4:32:05, 34.30s/it]

Epoch [24/500] Train Loss: 0.926486 | Val Loss: 1.207250 | Val Acc: 67.31%


Epochs:   5%|▌         | 25/500 [14:22<4:31:46, 34.33s/it]

Epoch [25/500] Train Loss: 0.895685 | Val Loss: 1.494757 | Val Acc: 55.77%


Epochs:   5%|▌         | 26/500 [14:56<4:30:51, 34.28s/it]

Epoch [26/500] Train Loss: 0.940578 | Val Loss: 1.604716 | Val Acc: 62.82%


Epochs:   5%|▌         | 27/500 [15:31<4:30:37, 34.33s/it]

Epoch [27/500] Train Loss: 0.948557 | Val Loss: 1.103533 | Val Acc: 73.72%


Epochs:   6%|▌         | 28/500 [16:05<4:29:18, 34.23s/it]

Epoch [28/500] Train Loss: 0.898022 | Val Loss: 1.227427 | Val Acc: 64.74%
Epoch [29/500] Train Loss: 0.846633 | Val Loss: 0.996716 | Val Acc: 76.28%


Epochs:   6%|▌         | 30/500 [17:13<4:28:19, 34.25s/it]

Epoch [30/500] Train Loss: 0.900557 | Val Loss: 1.079801 | Val Acc: 76.92%
Epoch [31/500] Train Loss: 0.843697 | Val Loss: 0.987540 | Val Acc: 81.41%


Epochs:   6%|▋         | 32/500 [18:22<4:27:40, 34.32s/it]

Epoch [32/500] Train Loss: 0.873842 | Val Loss: 1.567691 | Val Acc: 70.51%


Epochs:   7%|▋         | 33/500 [18:57<4:27:55, 34.42s/it]

Epoch [33/500] Train Loss: 0.828916 | Val Loss: 1.081380 | Val Acc: 75.64%


Epochs:   7%|▋         | 34/500 [19:31<4:26:59, 34.38s/it]

Epoch [34/500] Train Loss: 0.838750 | Val Loss: 1.226299 | Val Acc: 62.18%


Epochs:   7%|▋         | 35/500 [20:06<4:26:37, 34.40s/it]

Epoch [35/500] Train Loss: 0.865840 | Val Loss: 1.075177 | Val Acc: 74.36%


Epochs:   7%|▋         | 36/500 [20:40<4:26:24, 34.45s/it]

Epoch [36/500] Train Loss: 0.827207 | Val Loss: 1.018391 | Val Acc: 78.21%
Epoch [37/500] Train Loss: 0.793512 | Val Loss: 0.968278 | Val Acc: 82.05%


Epochs:   8%|▊         | 38/500 [21:49<4:25:32, 34.49s/it]

Epoch [38/500] Train Loss: 0.743244 | Val Loss: 1.091124 | Val Acc: 77.56%


Epochs:   8%|▊         | 39/500 [22:24<4:24:35, 34.44s/it]

Epoch [39/500] Train Loss: 0.733150 | Val Loss: 1.090485 | Val Acc: 76.28%


Epochs:   8%|▊         | 40/500 [22:58<4:23:11, 34.33s/it]

Epoch [40/500] Train Loss: 0.780034 | Val Loss: 1.024496 | Val Acc: 78.85%


Epochs:   8%|▊         | 41/500 [23:32<4:22:29, 34.31s/it]

Epoch [41/500] Train Loss: 0.775828 | Val Loss: 1.114881 | Val Acc: 76.28%


Epochs:   8%|▊         | 42/500 [24:06<4:22:10, 34.35s/it]

Epoch [42/500] Train Loss: 0.746993 | Val Loss: 1.394185 | Val Acc: 53.85%


Epochs:   9%|▊         | 43/500 [24:41<4:21:43, 34.36s/it]

Epoch [43/500] Train Loss: 0.737345 | Val Loss: 1.510013 | Val Acc: 65.38%


Epochs:   9%|▉         | 44/500 [25:15<4:20:44, 34.31s/it]

Epoch [44/500] Train Loss: 0.790830 | Val Loss: 1.124319 | Val Acc: 73.08%


Epochs:   9%|▉         | 45/500 [25:49<4:20:25, 34.34s/it]

Epoch [45/500] Train Loss: 0.710354 | Val Loss: 1.077346 | Val Acc: 82.69%


Epochs:   9%|▉         | 46/500 [26:24<4:19:27, 34.29s/it]

Epoch [46/500] Train Loss: 0.689717 | Val Loss: 1.076236 | Val Acc: 81.41%


Epochs:   9%|▉         | 47/500 [26:58<4:19:09, 34.33s/it]

Epoch [47/500] Train Loss: 0.729000 | Val Loss: 1.057510 | Val Acc: 77.56%


Epochs:  10%|▉         | 48/500 [27:32<4:17:56, 34.24s/it]

Epoch [48/500] Train Loss: 0.737779 | Val Loss: 1.012191 | Val Acc: 74.36%
Epoch [49/500] Train Loss: 0.693318 | Val Loss: 0.924773 | Val Acc: 79.49%


Epochs:  10%|▉         | 49/500 [28:07<4:18:10, 34.35s/it]

Epoch [50/500] Train Loss: 0.663049 | Val Loss: 0.914933 | Val Acc: 82.05%


Epochs:  10%|█         | 50/500 [28:41<4:18:20, 34.45s/it]

Epoch [51/500] Train Loss: 0.635165 | Val Loss: 0.913651 | Val Acc: 83.33%


Epochs:  10%|█         | 52/500 [29:50<4:16:32, 34.36s/it]

Epoch [52/500] Train Loss: 0.636922 | Val Loss: 0.914712 | Val Acc: 82.69%


Epochs:  11%|█         | 53/500 [30:24<4:16:07, 34.38s/it]

Epoch [53/500] Train Loss: 0.620401 | Val Loss: 0.943523 | Val Acc: 82.69%


Epochs:  11%|█         | 54/500 [30:59<4:15:38, 34.39s/it]

Epoch [54/500] Train Loss: 0.610525 | Val Loss: 0.924923 | Val Acc: 82.69%
Epoch [55/500] Train Loss: 0.607270 | Val Loss: 0.903259 | Val Acc: 83.33%


Epochs:  11%|█         | 56/500 [32:07<4:14:31, 34.40s/it]

Epoch [56/500] Train Loss: 0.603466 | Val Loss: 0.946217 | Val Acc: 83.33%


Epochs:  11%|█▏        | 57/500 [32:42<4:13:21, 34.31s/it]

Epoch [57/500] Train Loss: 0.617593 | Val Loss: 0.927084 | Val Acc: 82.05%


Epochs:  12%|█▏        | 58/500 [33:16<4:12:29, 34.28s/it]

Epoch [58/500] Train Loss: 0.604987 | Val Loss: 0.910957 | Val Acc: 82.69%


Epochs:  12%|█▏        | 59/500 [33:50<4:11:56, 34.28s/it]

Epoch [59/500] Train Loss: 0.586503 | Val Loss: 0.938828 | Val Acc: 83.97%


Epochs:  12%|█▏        | 60/500 [34:24<4:10:58, 34.22s/it]

Epoch [60/500] Train Loss: 0.588000 | Val Loss: 0.952171 | Val Acc: 83.97%


Epochs:  12%|█▏        | 61/500 [34:58<4:10:23, 34.22s/it]

Epoch [61/500] Train Loss: 0.573666 | Val Loss: 0.940851 | Val Acc: 85.26%


Epochs:  12%|█▏        | 62/500 [35:33<4:09:59, 34.25s/it]

Epoch [62/500] Train Loss: 0.580069 | Val Loss: 0.954661 | Val Acc: 83.33%


Epochs:  13%|█▎        | 63/500 [36:07<4:09:11, 34.21s/it]

Epoch [63/500] Train Loss: 0.565297 | Val Loss: 0.934039 | Val Acc: 83.33%


Epochs:  13%|█▎        | 64/500 [36:41<4:08:34, 34.21s/it]

Epoch [64/500] Train Loss: 0.571838 | Val Loss: 0.972566 | Val Acc: 84.62%


Epochs:  13%|█▎        | 65/500 [37:15<4:08:01, 34.21s/it]

Epoch [65/500] Train Loss: 0.569971 | Val Loss: 1.007251 | Val Acc: 84.62%


Epochs:  13%|█▎        | 66/500 [37:49<4:07:19, 34.19s/it]

Epoch [66/500] Train Loss: 0.549311 | Val Loss: 0.981265 | Val Acc: 83.97%


Epochs:  13%|█▎        | 67/500 [38:24<4:06:50, 34.20s/it]

Epoch [67/500] Train Loss: 0.559916 | Val Loss: 0.973305 | Val Acc: 84.62%


Epochs:  14%|█▎        | 68/500 [38:58<4:06:13, 34.20s/it]

Epoch [68/500] Train Loss: 0.545910 | Val Loss: 0.980693 | Val Acc: 85.26%


Epochs:  14%|█▍        | 69/500 [39:32<4:05:56, 34.24s/it]

Epoch [69/500] Train Loss: 0.569734 | Val Loss: 0.970382 | Val Acc: 83.97%


Epochs:  14%|█▍        | 70/500 [40:07<4:06:01, 34.33s/it]

Epoch [70/500] Train Loss: 0.593091 | Val Loss: 0.966506 | Val Acc: 82.69%


Epochs:  14%|█▍        | 71/500 [40:41<4:06:05, 34.42s/it]

Epoch [71/500] Train Loss: 0.536049 | Val Loss: 0.977694 | Val Acc: 83.97%


Epochs:  14%|█▍        | 72/500 [41:16<4:05:55, 34.48s/it]

Epoch [72/500] Train Loss: 0.544787 | Val Loss: 0.965431 | Val Acc: 84.62%


Epochs:  15%|█▍        | 73/500 [41:51<4:05:39, 34.52s/it]

Epoch [73/500] Train Loss: 0.563124 | Val Loss: 0.986946 | Val Acc: 83.97%


Epochs:  15%|█▍        | 74/500 [42:25<4:05:16, 34.55s/it]

Epoch [74/500] Train Loss: 0.576185 | Val Loss: 0.972137 | Val Acc: 83.97%


Epochs:  15%|█▍        | 74/500 [43:00<4:07:32, 34.87s/it]

Epoch [75/500] Train Loss: 0.567214 | Val Loss: 0.986857 | Val Acc: 85.26%
Early stopping at epoch 75





### Saving loss lists

In [11]:
import pandas as pd
train_pd = pd.DataFrame(train_loss)
train_pd.to_csv('train_loss.csv', index= False, header= False)
test_pd = pd.DataFrame(test_loss)
test_pd.to_csv('test_loss.csv', index= False, header= False)

In [12]:
from torchmetrics.classification import MulticlassPrecision

model.eval()
preds, targets = [], []

with torch.no_grad():
    for images, masks, labels in test_loader:
        images = images.to('cuda')
        masks = masks.to('cuda').view(-1, 1, 256, 256).float()
        labels = labels.to('cuda')

        seg_logits, cls_logits = model(images)
        pred_cls = cls_logits.argmax(dim=1)  # (B,)

        preds.append(pred_cls.cpu())
        targets.append(labels.cpu())

preds = torch.cat(preds)      # shape (N,)
targets = torch.cat(targets)  # shape (N,)

metric = MulticlassPrecision(num_classes=3, average="macro")
precision_macro = metric(preds, targets).item()
print("Macro Precision:", precision_macro)

Macro Precision: 0.8583203554153442


In [13]:
from torchmetrics.classification import MulticlassPrecision

prec_macro = MulticlassPrecision(num_classes=3, average="macro")(preds, targets).item()
prec_micro = MulticlassPrecision(num_classes=3, average="micro")(preds, targets).item()
prec_weighted = MulticlassPrecision(num_classes=3, average="weighted")(preds, targets).item()
print(f"macro={prec_macro:.4f}  micro={prec_micro:.4f}  weighted={prec_weighted:.4f}")


macro=0.8583  micro=0.8462  weighted=0.8479


In [14]:
statedict = torch.load('model.pt')
model = UNetWithClassifier()
model = model.to('cuda')
model.load_state_dict(statedict)
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
)
from tabulate import tabulate

def evaluate_metrics(model, dataloader, device="cuda", num_classes=3):
    model.eval()
    preds, targets = [], []

    with torch.no_grad():
        for images, masks, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            _, cls_logits = model(images)
            pred_cls = cls_logits.argmax(dim=1)

            preds.append(pred_cls.cpu())
            targets.append(labels.cpu())

    preds = torch.cat(preds)
    targets = torch.cat(targets)

    # Macro / Micro / Weighted
    metrics = {
        "Accuracy": MulticlassAccuracy(num_classes=num_classes, average="micro")(preds, targets).item(),
        "Precision (macro)": MulticlassPrecision(num_classes=num_classes, average="macro")(preds, targets).item(),
        "Recall (macro)": MulticlassRecall(num_classes=num_classes, average="macro")(preds, targets).item(),
        "F1 (macro)": MulticlassF1Score(num_classes=num_classes, average="macro")(preds, targets).item(),
    }

    # Per-class metrics
    per_class_prec = MulticlassPrecision(num_classes=num_classes, average=None)(preds, targets)
    per_class_rec  = MulticlassRecall(num_classes=num_classes, average=None)(preds, targets)
    per_class_f1   = MulticlassF1Score(num_classes=num_classes, average=None)(preds, targets)

    # Print summary table
    summary_table = [[k, f"{v:.4f}"] for k, v in metrics.items()]
    print(tabulate(summary_table, headers=["Metric", "Score"], tablefmt="fancy_grid"))

    # Print per-class table
    class_table = []
    for i in range(num_classes):
        class_table.append([i, f"{per_class_prec[i]:.4f}", f"{per_class_rec[i]:.4f}", f"{per_class_f1[i]:.4f}"])
    print(tabulate(class_table, headers=["Class", "Precision", "Recall", "F1"], tablefmt="fancy_grid"))
evaluate_metrics(model, test_loader)

╒═══════════════════╤═════════╕
│ Metric            │   Score │
╞═══════════════════╪═════════╡
│ Accuracy          │  0.8269 │
├───────────────────┼─────────┤
│ Precision (macro) │  0.8335 │
├───────────────────┼─────────┤
│ Recall (macro)    │  0.83   │
├───────────────────┼─────────┤
│ F1 (macro)        │  0.8316 │
╘═══════════════════╧═════════╛
╒═════════╤═════════════╤══════════╤════════╕
│   Class │   Precision │   Recall │     F1 │
╞═════════╪═════════════╪══════════╪════════╡
│       0 │      0.8333 │   0.8333 │ 0.8333 │
├─────────┼─────────────┼──────────┼────────┤
│       1 │      0.7609 │   0.7778 │ 0.7692 │
├─────────┼─────────────┼──────────┼────────┤
│       2 │      0.9062 │   0.8788 │ 0.8923 │
╘═════════╧═════════════╧══════════╧════════╛
