In [1]:
import torch
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torch import nn
from torch.nn import functional
from PIL import Image
import os
from torchvision import datasets, transforms, models
from collections import defaultdict
import albumentations as A
import numpy as np
import cv2

### Finding different shapes and channels of images in the dataset

In [2]:
folder_path = r'D:\Suchit\Breast-Cancer-Detection\Dataset_BUSI_with_GT'
image_shapes, mask_shapes = set(), set()
channels = set()
images, masks = [], []
for folder in os.listdir(folder_path):
    for image in os.listdir(os.path.join(folder_path, folder)):
        with Image.open(os.path.join(folder_path, folder, image)) as img:
            if 'mask' in image:
                mask_shapes.add(img.size)
                masks.append(os.path.join(folder_path, folder, image))
            else:
                image_shapes.add(img.size)
                images.append(os.path.join(folder_path, folder, image))
            channels.add(img.mode)
print(
    f'image shape length - {len(image_shapes)} and mask shape length - {len(mask_shapes)}')
print(f'number of channels - {len(channels)}')
print(f'channels - {channels}')

image shape length - 639 and mask shape length - 639
number of channels - 3
channels - {'RGBA', 'RGB', '1'}


### Finding the average height and width of the images to resize them

In [3]:
folder_path = r'D:\Suchit\Breast-Cancer-Detection\Dataset_BUSI_with_GT'
height, width, num_samples = 0.0, 0.0, 0.0
for folder in os.listdir(folder_path):
    for image in os.listdir(os.path.join(folder_path, folder)):
        with Image.open(os.path.join(folder_path, folder, image)) as img:
            if 'mask' not in image:
                size = img.size
                height += size[1]
                width += size[0]
                num_samples += 1
height /= num_samples
width /= num_samples
print(f'average height = {height}')
print(f'average width = {width}')

average height = 501.4525641025641
average width = 615.6794871794872


Height and width is too large. Taking 256 * 256

### Finding the mean and standard deviation of the input images

In [4]:
def to_rgb_from_pil(x: np.ndarray) -> np.ndarray:
    if x.ndim == 2:
        return cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
    if x.ndim == 3:
        c = x.shape[2]
        if c == 1:
            return cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
        elif c == 3:
            return x  # already RGB
        elif c == 4:
            # RGBA -> RGB (drops alpha)
            return cv2.cvtColor(x, cv2.COLOR_RGBA2RGB)
    return x[..., :3]

In [5]:
from torchvision.transforms import InterpolationMode

def calculate_rgb_stats(folder_path, resize_shape=(256, 256), batch_size=64, num_workers=0):
    print("Calculating RGB statistics...")

    transform = transforms.Compose([
        transforms.Resize(resize_shape, interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),  # -> [0,1], shape (C,H,W) with C=3 since we .convert('RGB')
    ])

    class ImageData(Dataset):
        def __init__(self, images, transform):
            self.images = images
            self.transform = transform

        def __len__(self):
            return len(self.images)

        def __getitem__(self, index):
            img_path = self.images[index]
            # Force RGB: handles grayscale (L) and RGBA (drops alpha) gracefully
            image = Image.open(img_path).convert("RGB")
            return self.transform(image)

    # Collect image paths (exclude anything with 'mask' in the filename)
    images = []
    for folder in os.listdir(folder_path):
        full = os.path.join(folder_path, folder)
        if not os.path.isdir(full):
            continue
        for fname in os.listdir(full):
            if 'mask' in fname.lower():
                continue
            images.append(os.path.join(full, fname))

    if len(images) == 0:
        raise RuntimeError("No images found (after excluding masks). Check your folder structure.")

    print(f"Found {len(images)} images for statistics calculation")

    dataset = ImageData(images, transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    # Numerically stable: use sums and squared sums
    sum_ = torch.zeros(3)
    sum_sq = torch.zeros(3)
    n_pixels = 0

    for data in loader:
        # data: (B, 3, H, W)
        b, c, h, w = data.shape
        n = b * h * w
        sum_ += data.sum(dim=[0, 2, 3])
        sum_sq += (data ** 2).sum(dim=[0, 2, 3])
        n_pixels += n

    mean = sum_ / n_pixels
    var = (sum_sq / n_pixels) - mean ** 2
    std = var.clamp(min=0).sqrt()

    return mean, std

# Example usage
rgb_mean, rgb_std = calculate_rgb_stats(folder_path, resize_shape=(256, 256))
print(f"RGB mean  = {rgb_mean.tolist()}")
print(f"RGB std   = {rgb_std.tolist()}")
print(f"Per-channel formatted -> "
      f"mean: R {rgb_mean[0]:.4f}, G {rgb_mean[1]:.4f}, B {rgb_mean[2]:.4f} | "
      f"std:  R {rgb_std[0]:.4f}, G {rgb_std[1]:.4f}, B {rgb_std[2]:.4f}")

Calculating RGB statistics...
Found 780 images for statistics calculation
RGB mean  = [0.32795798778533936, 0.3279508650302887, 0.32790568470954895]
RGB std   = [0.21965765953063965, 0.21965673565864563, 0.2196434587240219]
Per-channel formatted -> mean: R 0.3280, G 0.3280, B 0.3279 | std:  R 0.2197, G 0.2197, B 0.2196


### Managing the dataset

In [8]:
from albumentations.pytorch import ToTensorV2


class Data(Dataset):
    def __init__(self, folder_path, transforms):
        super().__init__()
        self.images, self.masks, self.category = [], defaultdict(list), []
        classes = {'benign': 0, 'malignant': 1, 'normal': 2}
        self.transforms = transforms

        for folder in os.listdir(folder_path):
            for image in os.listdir(os.path.join(folder_path, folder)):
                img_path = os.path.join(folder_path, folder, image)
                if 'mask' in image:
                    self.masks[image[: image.index('_mask')]].append(img_path)
                else:
                    self.images.append(img_path)
                    self.category.append(classes[folder])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        key = os.path.basename(image_path)[
            : os.path.basename(image_path).index('.png')]
        mask_paths = self.masks[key]

        # Load image and masks as numpy arrays
        image = np.array(Image.open(image_path).convert('RGB'))
        masks = [np.array(Image.open(x).convert('RGB')) for x in mask_paths]

        # Combine masks
        mask = np.sum(np.stack(masks), axis=0)
        mask = np.clip(mask, 0, 1)  # ensure binary

        # Apply albumentations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask, self.category[index]


transform = transforms.Compose([
        transforms.Resize((256, 256), interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),  # -> [0,1], shape (C,H,W) with C=3 since we .convert('RGB')
    ])

dataset = Data(folder_path, transform)

### Calculating positive weight to upweight the positive pixels in BCE loss

In [11]:
def estimate_pos_weight(dataset):
    loader = DataLoader(dataset= dataset, batch_size=16, shuffle= False)
    pos, neg = 0, 0
    with torch.no_grad():
        for _, mask, _ in loader:
            mask = mask.flatten()
            mask = (mask > 0)
            pos += mask.sum().item()
            neg += mask.numel() - mask.sum().item()
    if pos == 0:
        return torch.tensor(1.0)
    return torch.tensor(neg / pos)

# making counting transforms for counting the pos weights

count_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation= cv2.INTER_NEAREST),
    transforms.Normalize(mean = rgb_mean.tolist(), std= rgb_std.tolist()),
    transforms.ToTensor()
])
count_dataset = Data(folder_path= folder_path, transforms= count_transform)
pos_weight = estimate_pos_weight(count_dataset)
pos_weight = pos_weight.to('cuda')
print(f'Estimated positive weight = {pos_weight}')

TypeError: Compose.__call__() got an unexpected keyword argument 'image'

### Model

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels,
                      kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.conv(x)


class UNetWithClassifier(nn.Module):
    def __init__(self, in_channels=3, seg_out_channels=3, num_classes=3, p_drop=0.2):
        super().__init__()
        # Encoder
        self.down1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)

        # Decoder (with concatenations)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.dec4 = DoubleConv(1024, 512)   # 512 up + 512 skip

        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec3 = DoubleConv(512, 256)    # 256 up + 256 skip

        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = DoubleConv(256, 128)    # 128 up + 128 skip

        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = DoubleConv(128, 64)     # 64 up + 64 skip

        self.seg_head = nn.Conv2d(64, seg_out_channels, kernel_size=1)

        # Classification head on bottleneck features
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.cls_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024, 512, bias=True),
            nn.Dropout(p_drop),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        d1 = self.down1(x)
        p1 = self.pool1(d1)

        d2 = self.down2(p1)
        p2 = self.pool2(d2)

        d3 = self.down3(p2)
        p3 = self.pool3(d3)

        d4 = self.down4(p3)
        p4 = self.pool4(d4)

        bn = self.bottleneck(p4)

        u4 = self.up4(bn)
        u4 = torch.cat([u4, d4], dim=1)
        u4 = self.dec4(u4)

        u3 = self.up3(u4)
        u3 = torch.cat([u3, d3], dim=1)
        u3 = self.dec3(u3)

        u2 = self.up2(u3)
        u2 = torch.cat([u2, d2], dim=1)
        u2 = self.dec2(u2)

        u1 = self.up1(u2)
        u1 = torch.cat([u1, d1], dim=1)
        u1 = self.dec1(u1)

        seg_logits = self.seg_head(u1)
        cls_logits = self.cls_head(self.gap(bn))
        return seg_logits, cls_logits

### Loss function

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth = 1e-6):
        super().__init__()
        self.smooth = smooth
    def forward(self, output_mask, targets):
        targets = targets.float()
        probs = torch.sigmoid(output_mask)
        # no. of batches
        B = output_mask.shape[0]
        probs_flatten = probs.view(B, -1)
        targs_flatten = targets.view(B, -1)
        intersection = (probs_flatten * targs_flatten).sum(dim = 1)
        p_sum = probs_flatten.sum(dim = 1)
        t_sum = targs_flatten.sum(dim = 1)
        # identifying empty target images
        empty = (t_sum == 0)
        # standard dice for non empty
        dice_non_empty = (2 * intersection + self.smooth) / (p_sum + t_sum + self.smooth)
        # defining dice for empty target images
        # if predictions also empty -> 1, else -> (smooth) / (psum + smooth)
        dice_empty = torch.where((p_sum == 0), torch.ones_like(p_sum), (self.smooth) / (p_sum + self.smooth))
        dice = torch.where(empty, dice_empty, dice_non_empty)
        return 1 - dice.mean()
    
class MultiTaskLoss(nn.Module):
    def __init__(self, bce_pos_weight):
        super().__init__()
        self.seg_weight = 1.0
        self.cls_weight = 0.5
        self.bce_weight = 0.5
        self.dice_weight = 0.5
        self.bce = nn.BCEWithLogitsLoss(pos_weight= bce_pos_weight)
        self.dice = DiceLoss()
        self.crossentropy = nn.CrossEntropyLoss()
    def forward(self, output_mask, target_mask, output_class_logits, class_logits):
        loss_bce = self.bce(output_mask, target_mask)
        loss_dice = self.dice(output_mask, target_mask)
        
        loss_mask = self.bce_weight * loss_bce + self.dice_weight * loss_dice

        loss_classification = self.crossentropy(output_class_logits  , class_logits.long())

        total = self.seg_weight * loss_mask + self.cls_weight * loss_classification

        return total



### Initializations

In [None]:
# model
model = UNetWithClassifier()
model = model.to('cuda')

# loss function
criterion = MultiTaskLoss(pos_weight)

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr= 0.001, weight_decay= 0.0001)

# train test split
n_total = len(dataset)
n_val = int(0.2 * n_total)
n_train = n_total - n_val
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [n_train, n_val], generator= torch.Generator().manual_seed(42))

# weighted random sampler
labels = np.array([dataset.category[i] for i in train_dataset.indices])
class_counts = np.bincount(labels, minlength= 3)
class_weights = class_counts.sum() / (len(class_counts) * class_counts.clip(min= 1))
sample_weights = [class_weights[x] for x in labels]

sampler = WeightedRandomSampler(weights= torch.tensor(sample_weights), num_samples= len(sample_weights), replacement= True)

# DataLoaders
train_loader = DataLoader(dataset= train_dataset, batch_size= 16, sampler= sampler, pin_memory= True)
test_loader = DataLoader(dataset= test_dataset, batch_size= 16, shuffle= False, pin_memory= True)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer= optimizer, mode= 'min', factor = 0.1, patience= 10, cooldown= 4)

### Training

In [None]:
from tqdm import trange

train_loss, test_loss = [], []

num_epochs = 500
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in trange(num_epochs, desc="Epochs"):
    model.train()
    running_loss = 0.0
    for images, masks, labels in train_loader:
        images = images.to('cuda')
        masks = masks.to('cuda').view(-1, 3, 256, 256).float()
        labels = labels.to('cuda')

        optimizer.zero_grad()
        mask_pred, output_class_logits = model(images)
        loss = criterion(mask_pred, masks, output_class_logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, masks, labels in test_loader:
            images = images.to('cuda')
            masks = masks.to('cuda').view(-1, 3, 256, 256).float()
            labels = labels.to('cuda')
            mask_pred, output_class_logits = model(images)
            loss = criterion(mask_pred, masks, output_class_logits, labels)
            val_loss += loss.item()
            output_classes = output_class_logits.argmax(dim= 1)
            correct += (output_classes == labels).sum().item()
            total += labels.size(0)
    val_loss /= len(test_loader)

    scheduler.step(val_loss)

    val_acc = 100 * correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.6f} | Val Loss: {val_loss:.6f} | Val Acc: {val_acc:.2f}%")
    
    train_loss.append(avg_loss)
    test_loss.append(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'model.pt')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

Epochs:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch [1/500] Train Loss: 1.649261 | Val Loss: 1.540724 | Val Acc: 30.77%


Epochs:   0%|          | 2/500 [01:11<4:54:35, 35.49s/it]

Epoch [2/500] Train Loss: 1.478876 | Val Loss: 1.589297 | Val Acc: 49.36%


Epochs:   1%|          | 3/500 [01:46<4:51:28, 35.19s/it]

Epoch [3/500] Train Loss: 1.348680 | Val Loss: 3.359626 | Val Acc: 50.64%


Epochs:   1%|          | 4/500 [02:21<4:50:48, 35.18s/it]

Epoch [4/500] Train Loss: 1.336789 | Val Loss: 1.594668 | Val Acc: 43.59%
Epoch [5/500] Train Loss: 1.278811 | Val Loss: 1.482488 | Val Acc: 64.74%


Epochs:   1%|          | 5/500 [02:56<4:50:34, 35.22s/it]

Epoch [6/500] Train Loss: 1.232880 | Val Loss: 1.167382 | Val Acc: 73.08%


Epochs:   1%|▏         | 7/500 [04:06<4:48:14, 35.08s/it]

Epoch [7/500] Train Loss: 1.076600 | Val Loss: 2.375826 | Val Acc: 52.56%


Epochs:   2%|▏         | 8/500 [04:41<4:46:52, 34.98s/it]

Epoch [8/500] Train Loss: 1.043973 | Val Loss: 1.185386 | Val Acc: 67.31%


Epochs:   2%|▏         | 9/500 [05:15<4:44:36, 34.78s/it]

Epoch [9/500] Train Loss: 1.005608 | Val Loss: 1.462796 | Val Acc: 50.00%
Epoch [10/500] Train Loss: 1.043314 | Val Loss: 1.125653 | Val Acc: 66.67%


Epochs:   2%|▏         | 11/500 [06:22<4:38:48, 34.21s/it]

Epoch [11/500] Train Loss: 1.111352 | Val Loss: 1.624269 | Val Acc: 56.41%


Epochs:   2%|▏         | 12/500 [06:56<4:37:02, 34.06s/it]

Epoch [12/500] Train Loss: 1.006167 | Val Loss: 1.138562 | Val Acc: 65.38%
Epoch [13/500] Train Loss: 1.016001 | Val Loss: 1.091720 | Val Acc: 73.72%


Epochs:   3%|▎         | 13/500 [07:30<4:35:46, 33.98s/it]

Epoch [14/500] Train Loss: 0.936975 | Val Loss: 1.039198 | Val Acc: 76.28%


Epochs:   3%|▎         | 15/500 [08:37<4:32:52, 33.76s/it]

Epoch [15/500] Train Loss: 0.899358 | Val Loss: 1.049875 | Val Acc: 78.85%


Epochs:   3%|▎         | 16/500 [09:11<4:32:13, 33.75s/it]

Epoch [16/500] Train Loss: 0.935630 | Val Loss: 1.079635 | Val Acc: 69.23%


Epochs:   3%|▎         | 17/500 [09:44<4:30:07, 33.56s/it]

Epoch [17/500] Train Loss: 0.954444 | Val Loss: 1.179986 | Val Acc: 62.18%


Epochs:   4%|▎         | 18/500 [10:18<4:29:46, 33.58s/it]

Epoch [18/500] Train Loss: 0.900169 | Val Loss: 1.548650 | Val Acc: 58.97%


Epochs:   4%|▍         | 19/500 [10:51<4:29:13, 33.58s/it]

Epoch [19/500] Train Loss: 0.952169 | Val Loss: 3.620813 | Val Acc: 50.00%


Epochs:   4%|▍         | 20/500 [11:25<4:28:17, 33.54s/it]

Epoch [20/500] Train Loss: 0.896660 | Val Loss: 1.144645 | Val Acc: 74.36%


Epochs:   4%|▍         | 21/500 [11:58<4:27:45, 33.54s/it]

Epoch [21/500] Train Loss: 0.854285 | Val Loss: 1.057899 | Val Acc: 73.72%
Epoch [22/500] Train Loss: 0.868884 | Val Loss: 1.019099 | Val Acc: 71.79%


Epochs:   4%|▍         | 22/500 [12:32<4:27:44, 33.61s/it]

Epoch [23/500] Train Loss: 0.824946 | Val Loss: 1.001656 | Val Acc: 75.64%


Epochs:   5%|▍         | 24/500 [13:39<4:26:13, 33.56s/it]

Epoch [24/500] Train Loss: 0.809234 | Val Loss: 1.562675 | Val Acc: 63.46%


Epochs:   5%|▌         | 25/500 [14:12<4:25:09, 33.49s/it]

Epoch [25/500] Train Loss: 0.845743 | Val Loss: 1.151263 | Val Acc: 73.08%


Epochs:   5%|▌         | 26/500 [14:46<4:24:27, 33.47s/it]

Epoch [26/500] Train Loss: 0.818014 | Val Loss: 1.075154 | Val Acc: 75.64%


Epochs:   5%|▌         | 27/500 [15:19<4:23:51, 33.47s/it]

Epoch [27/500] Train Loss: 0.842924 | Val Loss: 1.208645 | Val Acc: 71.79%


Epochs:   6%|▌         | 28/500 [15:53<4:23:15, 33.47s/it]

Epoch [28/500] Train Loss: 0.840457 | Val Loss: 1.273202 | Val Acc: 67.31%
Epoch [29/500] Train Loss: 0.786504 | Val Loss: 0.999217 | Val Acc: 70.51%


Epochs:   6%|▌         | 29/500 [16:27<4:23:35, 33.58s/it]

Epoch [30/500] Train Loss: 0.751006 | Val Loss: 0.944211 | Val Acc: 78.21%


Epochs:   6%|▌         | 31/500 [17:34<4:22:22, 33.57s/it]

Epoch [31/500] Train Loss: 0.777555 | Val Loss: 1.384457 | Val Acc: 56.41%


Epochs:   6%|▋         | 32/500 [18:07<4:21:09, 33.48s/it]

Epoch [32/500] Train Loss: 0.720574 | Val Loss: 1.319775 | Val Acc: 72.44%


Epochs:   7%|▋         | 33/500 [18:40<4:20:33, 33.48s/it]

Epoch [33/500] Train Loss: 0.730199 | Val Loss: 0.963314 | Val Acc: 82.05%


Epochs:   7%|▋         | 34/500 [19:14<4:19:43, 33.44s/it]

Epoch [34/500] Train Loss: 0.754781 | Val Loss: 1.164888 | Val Acc: 80.77%


Epochs:   7%|▋         | 35/500 [19:47<4:19:16, 33.45s/it]

Epoch [35/500] Train Loss: 0.776586 | Val Loss: 1.030581 | Val Acc: 78.85%


Epochs:   7%|▋         | 36/500 [20:21<4:19:07, 33.51s/it]

Epoch [36/500] Train Loss: 0.724101 | Val Loss: 1.075791 | Val Acc: 67.31%


Epochs:   7%|▋         | 37/500 [20:54<4:18:22, 33.48s/it]

Epoch [37/500] Train Loss: 0.767840 | Val Loss: 1.176928 | Val Acc: 70.51%
Epoch [38/500] Train Loss: 0.754192 | Val Loss: 0.920336 | Val Acc: 75.00%


Epochs:   8%|▊         | 39/500 [22:01<4:17:32, 33.52s/it]

Epoch [39/500] Train Loss: 0.676148 | Val Loss: 0.963385 | Val Acc: 82.05%


Epochs:   8%|▊         | 40/500 [22:35<4:16:24, 33.44s/it]

Epoch [40/500] Train Loss: 0.692632 | Val Loss: 0.964287 | Val Acc: 78.85%


Epochs:   8%|▊         | 41/500 [23:08<4:15:30, 33.40s/it]

Epoch [41/500] Train Loss: 0.682660 | Val Loss: 1.015257 | Val Acc: 76.28%


Epochs:   8%|▊         | 42/500 [23:41<4:15:01, 33.41s/it]

Epoch [42/500] Train Loss: 0.665139 | Val Loss: 1.053808 | Val Acc: 67.31%


Epochs:   9%|▊         | 43/500 [24:15<4:14:18, 33.39s/it]

Epoch [43/500] Train Loss: 0.625454 | Val Loss: 1.182174 | Val Acc: 74.36%


Epochs:   9%|▉         | 44/500 [24:48<4:13:54, 33.41s/it]

Epoch [44/500] Train Loss: 0.686373 | Val Loss: 1.292977 | Val Acc: 67.31%


Epochs:   9%|▉         | 45/500 [25:22<4:13:34, 33.44s/it]

Epoch [45/500] Train Loss: 0.642957 | Val Loss: 1.182071 | Val Acc: 76.28%


Epochs:   9%|▉         | 46/500 [25:55<4:12:32, 33.37s/it]

Epoch [46/500] Train Loss: 0.613309 | Val Loss: 1.127655 | Val Acc: 69.87%


Epochs:   9%|▉         | 47/500 [26:28<4:11:56, 33.37s/it]

Epoch [47/500] Train Loss: 0.653638 | Val Loss: 0.982970 | Val Acc: 81.41%


Epochs:  10%|▉         | 48/500 [27:02<4:11:43, 33.41s/it]

Epoch [48/500] Train Loss: 0.616873 | Val Loss: 1.012252 | Val Acc: 80.13%


Epochs:  10%|▉         | 49/500 [27:35<4:11:06, 33.41s/it]

Epoch [49/500] Train Loss: 0.606019 | Val Loss: 1.187934 | Val Acc: 74.36%


Epochs:  10%|█         | 50/500 [28:09<4:10:21, 33.38s/it]

Epoch [50/500] Train Loss: 0.592003 | Val Loss: 1.006489 | Val Acc: 75.64%


Epochs:  10%|█         | 51/500 [28:42<4:09:46, 33.38s/it]

Epoch [51/500] Train Loss: 0.553392 | Val Loss: 0.946103 | Val Acc: 78.21%
Epoch [52/500] Train Loss: 0.505081 | Val Loss: 0.907238 | Val Acc: 80.77%


Epochs:  11%|█         | 53/500 [29:49<4:09:13, 33.45s/it]

Epoch [53/500] Train Loss: 0.526239 | Val Loss: 0.980119 | Val Acc: 78.21%


Epochs:  11%|█         | 54/500 [30:22<4:08:23, 33.42s/it]

Epoch [54/500] Train Loss: 0.510860 | Val Loss: 0.947907 | Val Acc: 82.69%


Epochs:  11%|█         | 55/500 [30:56<4:07:54, 33.43s/it]

Epoch [55/500] Train Loss: 0.501557 | Val Loss: 0.974718 | Val Acc: 80.13%


Epochs:  11%|█         | 56/500 [31:29<4:07:20, 33.42s/it]

Epoch [56/500] Train Loss: 0.490652 | Val Loss: 0.929880 | Val Acc: 80.77%


Epochs:  11%|█▏        | 57/500 [32:02<4:06:25, 33.38s/it]

Epoch [57/500] Train Loss: 0.485907 | Val Loss: 0.950449 | Val Acc: 81.41%


Epochs:  12%|█▏        | 58/500 [32:36<4:06:03, 33.40s/it]

Epoch [58/500] Train Loss: 0.465906 | Val Loss: 0.935986 | Val Acc: 80.77%


Epochs:  12%|█▏        | 59/500 [33:09<4:05:33, 33.41s/it]

Epoch [59/500] Train Loss: 0.480837 | Val Loss: 0.948123 | Val Acc: 82.05%


Epochs:  12%|█▏        | 60/500 [33:43<4:04:35, 33.35s/it]

Epoch [60/500] Train Loss: 0.472331 | Val Loss: 0.996262 | Val Acc: 79.49%


Epochs:  12%|█▏        | 61/500 [34:16<4:04:52, 33.47s/it]

Epoch [61/500] Train Loss: 0.448774 | Val Loss: 0.944911 | Val Acc: 82.05%


Epochs:  12%|█▏        | 62/500 [34:50<4:04:32, 33.50s/it]

Epoch [62/500] Train Loss: 0.447713 | Val Loss: 1.000380 | Val Acc: 80.77%


Epochs:  13%|█▎        | 63/500 [35:23<4:03:17, 33.40s/it]

Epoch [63/500] Train Loss: 0.470841 | Val Loss: 0.960368 | Val Acc: 80.77%


Epochs:  13%|█▎        | 64/500 [35:57<4:02:54, 33.43s/it]

Epoch [64/500] Train Loss: 0.462609 | Val Loss: 1.015559 | Val Acc: 82.05%


Epochs:  13%|█▎        | 65/500 [36:30<4:02:37, 33.47s/it]

Epoch [65/500] Train Loss: 0.465878 | Val Loss: 0.977501 | Val Acc: 80.77%


Epochs:  13%|█▎        | 66/500 [37:03<4:01:36, 33.40s/it]

Epoch [66/500] Train Loss: 0.444459 | Val Loss: 0.972981 | Val Acc: 79.49%


Epochs:  13%|█▎        | 67/500 [37:37<4:00:42, 33.35s/it]

Epoch [67/500] Train Loss: 0.459915 | Val Loss: 0.997578 | Val Acc: 80.77%


Epochs:  14%|█▎        | 68/500 [38:10<4:00:09, 33.35s/it]

Epoch [68/500] Train Loss: 0.446898 | Val Loss: 0.939826 | Val Acc: 84.62%


Epochs:  14%|█▍        | 69/500 [38:43<3:59:22, 33.32s/it]

Epoch [69/500] Train Loss: 0.435767 | Val Loss: 0.951759 | Val Acc: 80.13%


Epochs:  14%|█▍        | 70/500 [39:17<3:58:48, 33.32s/it]

Epoch [70/500] Train Loss: 0.449346 | Val Loss: 0.970975 | Val Acc: 82.69%


Epochs:  14%|█▍        | 71/500 [39:50<3:58:46, 33.40s/it]

Epoch [71/500] Train Loss: 0.459629 | Val Loss: 0.985116 | Val Acc: 81.41%


Epochs:  14%|█▍        | 71/500 [40:23<4:04:05, 34.14s/it]

Epoch [72/500] Train Loss: 0.431892 | Val Loss: 0.996667 | Val Acc: 78.85%
Early stopping at epoch 72





### Saving loss lists

In [None]:
import pandas as pd
train_pd = pd.DataFrame(train_loss)
train_pd.to_csv('train_loss.csv', index= False, header= False)
test_pd = pd.DataFrame(test_loss)
test_pd.to_csv('test_loss.csv', index= False, header= False)

In [None]:
from torchmetrics.classification import MulticlassPrecision

model.eval()
preds, targets = [], []

with torch.no_grad():
    for images, masks, labels in test_loader:
        images = images.to('cuda')
        masks = masks.to('cuda').view(-1, 3, 256, 256).float()
        labels = labels.to('cuda')

        seg_logits, cls_logits = model(images)
        pred_cls = cls_logits.argmax(dim=1)  # (B,)

        preds.append(pred_cls.cpu())
        targets.append(labels.cpu())

preds = torch.cat(preds)      # shape (N,)
targets = torch.cat(targets)  # shape (N,)

metric = MulticlassPrecision(num_classes=3, average="macro")
precision_macro = metric(preds, targets).item()
print("Macro Precision:", precision_macro)
