In [1]:
import torch
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torch import nn
from PIL import Image
import os
from torchvision import datasets, transforms
from collections import defaultdict
import albumentations as A
import numpy as np

### Finding different shapes and channels of images in the dataset

In [2]:
folder_path = r'D:\Suchit\Breast-Cancer-Detection\Dataset_BUSI_with_GT'
image_shapes, mask_shapes = set(), set()
channels = set()
images, masks = [], []
for folder in os.listdir(folder_path):
    for image in os.listdir(os.path.join(folder_path, folder)):
        with Image.open(os.path.join(folder_path, folder, image)) as img:
            if 'mask' in image:
                mask_shapes.add(img.size)
                masks.append(os.path.join(folder_path, folder, image))
            else:
                image_shapes.add(img.size)
                images.append(os.path.join(folder_path, folder, image))
            channels.add(img.mode)
print(f'image shape length - {len(image_shapes)} and mask shape length - {len(mask_shapes)}')
print(f'number of channels - {len(channels)}')

image shape length - 639 and mask shape length - 639
number of channels - 3


### Finding the average height and width of the images to resize them

In [3]:
folder_path = r'D:\Suchit\Breast-Cancer-Detection\Dataset_BUSI_with_GT'
height, width, num_samples = 0.0, 0.0, 0.0
for folder in os.listdir(folder_path):
    for image in os.listdir(os.path.join(folder_path, folder)):
        with Image.open(os.path.join(folder_path, folder, image)) as img:
            if 'mask' not in image:
                size = img.size
                height += size[1]
                width += size[0]
                num_samples += 1
height /= num_samples
width /= num_samples
print(f'average height = {height}')
print(f'average width = {width}')

average height = 501.4525641025641
average width = 615.6794871794872


Height and width is too large. Taking 256 * 256

### Finding the mean and standard deviation of the input images

In [4]:
def calculate_grayscale_stats(folder_path, resize_shape=(256, 256)):
    """Calculate mean and std for grayscale images"""
    print("Calculating grayscale statistics...")
    
    # Simple transform for grayscale statistics calculation
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize(resize_shape),
        transforms.ToTensor()
    ])
    
    class GrayscaleImageData(Dataset):
        def __init__(self, images, transform):
            self.images = images
            self.transform = transform
        
        def __len__(self):
            return len(self.images)
        
        def __getitem__(self, index):
            img_path = self.images[index]
            image = Image.open(img_path)
            return self.transform(image)
    
    # Get image paths (exclude masks)
    images = []
    for folder in os.listdir(folder_path):
        for image in os.listdir(os.path.join(folder_path, folder)):
            if 'mask' not in image:
                images.append(os.path.join(folder_path, folder, image))
    
    print(f"Found {len(images)} images for statistics calculation")
    
    dataset = GrayscaleImageData(images, transform)
    loader = DataLoader(dataset, batch_size=64, shuffle=False)
    
    mean, std, num_samples = 0.0, 0.0, 0.0
    
    for data in loader:
        batch_size = data.size(0)
        data = data.view(batch_size, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        num_samples += batch_size
    
    mean /= num_samples
    std /= num_samples
    
    return mean.item(), std.item()

# Step 2: Calculate grayscale statistics for your dataset
grayscale_mean, grayscale_std = calculate_grayscale_stats(folder_path)
print(f'Grayscale mean = {grayscale_mean:.4f}, std = {grayscale_std:.4f}')

Calculating grayscale statistics...
Found 780 images for statistics calculation
Grayscale mean = 0.3279, std = 0.1998


### Managing the dataset

In [5]:
class Data(Dataset):
    def __init__(self, folder_path, transforms):
        super().__init__()
        self.images, self.masks, self.category = [], defaultdict(list), []
        classes = {'benign': 0, 'malignant': 1, 'normal': 2}
        self.transforms = transforms

        for folder in os.listdir(folder_path):
            for image in os.listdir(os.path.join(folder_path, folder)):
                img_path = os.path.join(folder_path, folder, image)
                if 'mask' in image:
                    self.masks[image[: image.index('_mask')]].append(img_path)
                else:
                    self.images.append(img_path)
                    self.category.append(classes[folder])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        key = os.path.basename(image_path)[: os.path.basename(image_path).index('.png')]
        mask_paths = self.masks[key]

        # Load image and masks as numpy arrays
        image = np.array(Image.open(image_path).convert('L'))
        masks = [np.array(Image.open(x).convert('L')) for x in mask_paths]

        # Combine masks
        mask = np.sum(np.stack(masks), axis=0)
        mask = np.clip(mask, 0, 1)  # ensure binary

        # Apply albumentations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask, self.category[index]

from albumentations.pytorch import ToTensorV2

transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=[grayscale_mean], std=[grayscale_std]),
    ToTensorV2()
])

dataset = Data(folder_path, transform)


### Model

In [6]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        d1 = self.down1(x)
        p1 = self.pool1(d1)
        d2 = self.down2(p1)
        p2 = self.pool2(d2)
        d3 = self.down3(p2)
        p3 = self.pool3(d3)
        d4 = self.down4(p3)
        p4 = self.pool4(d4)
        bn = self.bottleneck(p4)

        u4 = self.up4(bn)
        u4 = torch.cat([u4, d4], dim=1)
        u4 = self.conv4(u4)
        u3 = self.up3(u4)
        u3 = torch.cat([u3, d3], dim=1)
        u3 = self.conv3(u3)
        u2 = self.up2(u3)
        u2 = torch.cat([u2, d2], dim=1)
        u2 = self.conv2(u2)
        u1 = self.up1(u2)
        u1 = torch.cat([u1, d1], dim=1)
        u1 = self.conv1(u1)

        out = self.final_conv(u1)
        return out

class Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size= 3),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size= 3),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size= 3),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size= 3),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(256 * 14 * 14, 128),
            nn.ReLU(),
            nn.Linear(128, 3)
        )
    def forward(self, x):
        return self.conv(x)

class UNetPredictor(nn.Module):
    def __init__(self, unet, predictor):
        super().__init__()
        self.unet = unet
        for param in self.unet.parameters():
            param.requires_grad = False
        self.predictor = predictor

    def forward(self, x):
        seg_out = self.unet(x)
        pred_out = self.predictor(seg_out)
        return pred_out
    

unet = UNet()
unet = unet.to('cuda')
predictor = Predictor()
predictor = predictor.to('cuda')

### Weighted Random Sampling

In [7]:
targets = dataset.category
class_counts = np.bincount(targets)
class_weight = 1.0 / class_counts
sample_weight = [class_weight[label] for label in targets]

sampler = WeightedRandomSampler(
    weights= sample_weight,
    num_samples= len(sample_weight),
    replacement= True
)

### DataLoader, optimzer and loss function

In [None]:
trian_loader = DataLoader(dataset, batch_size= 16, sampler= sampler)
test_loader = DataLoader(dataset, batch_size=16, sampler= sampler)

criterion_unet = nn.BCEWithLogitsLoss()
criterion_prediction = nn.CrossEntropyLoss()

optimizer_unet = torch.optim.AdamW(unet.parameters(), lr= 0.001, weight_decay= 0.001)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_unet,
    mode='min',       # 'min' for loss, 'max' for accuracy/metric
    factor=0.1,       # multiply LR by this factor
    patience=10,       # wait for 5 epochs before reducing LR
)

In [None]:
from tqdm import trange

train_loss_unet, test_loss_unet = [], []

num_epochs = 500
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in trange(num_epochs, desc="Epochs"):
    unet.train()
    running_loss = 0.0
    for images, masks, labels in trian_loader:
        images = images.to('cuda')
        masks = masks.to('cuda').view(-1, 1, 256, 256).float()
        # labels = labels.to('cuda')

        # --- UNet mask loss ---
        optimizer_unet.zero_grad()
        mask_pred = unet(images)
        loss_mask = criterion_unet(mask_pred, masks)
        loss_mask.backward()
        optimizer_unet.step()

        # --- UNetPredictor classification loss ---
        # optimizer_unet_predictor.zero_grad()
        # class_pred = model(images)
        # loss_class = criterion_prediction(class_pred, labels)
        # loss_class.backward()
        # optimizer_unet_predictor.step()

        running_loss += loss_mask.item()
        # running_loss += loss_mask.item() + loss_class.item()

    avg_loss = running_loss / len(trian_loader)

    # Validation
    unet.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, masks, labels in test_loader:
            images = images.to('cuda')
            masks = masks.to('cuda').view(-1, 1, 256, 256).float()
            # labels = labels.to('cuda')
            mask_pred = unet(images)
            # class_pred = model(images)
            loss_mask = criterion_unet(mask_pred, masks)
            # loss_class = criterion_prediction(class_pred, labels)
            val_loss += loss_mask.item()
            # val_loss += loss_mask.item() + loss_class.item()
            # _, predicted = torch.max(class_pred, 1)
            # correct += (predicted == labels).sum().item()
            total += labels.size(0)
    val_loss /= len(test_loader)

    scheduler.step(val_loss)

    # val_acc = 100 * correct / total

    # print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f}")
    train_loss_unet.append(avg_loss)
    test_loss_unet.append(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(unet.state_dict(), 'unet1.pt')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

import pandas as pd
train_loss_unet = pd.DataFrame(train_loss_unet)
train_loss_unet.to_csv('train_loss_unet1.csv', index= False, header= False)
test_loss_unet = pd.DataFrame(test_loss_unet)
test_loss_unet.to_csv('test_loss_unet1.csv', index= False, header= False)

Epochs:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch [1/500] Train Loss: 0.2007 | Val Loss: 0.2183


Epochs:   0%|          | 1/500 [00:59<8:17:45, 59.85s/it]

Epoch [2/500] Train Loss: 0.1837 | Val Loss: 0.1767


Epochs:   0%|          | 2/500 [01:59<8:13:48, 59.50s/it]

Epoch [3/500] Train Loss: 0.1627 | Val Loss: 0.1731


Epochs:   1%|          | 3/500 [02:58<8:10:35, 59.23s/it]

Epoch [4/500] Train Loss: 0.1450 | Val Loss: 0.1582


Epochs:   1%|          | 5/500 [04:54<8:04:28, 58.73s/it]

Epoch [5/500] Train Loss: 0.1493 | Val Loss: 0.4254


Epochs:   1%|          | 6/500 [05:52<8:01:22, 58.47s/it]

Epoch [6/500] Train Loss: 0.1411 | Val Loss: 0.1911
Epoch [7/500] Train Loss: 0.1266 | Val Loss: 0.1572


Epochs:   1%|▏         | 7/500 [06:51<8:01:24, 58.59s/it]

Epoch [8/500] Train Loss: 0.1347 | Val Loss: 0.1471


Epochs:   2%|▏         | 8/500 [07:51<8:02:51, 58.88s/it]

Epoch [9/500] Train Loss: 0.1385 | Val Loss: 0.1317


Epochs:   2%|▏         | 10/500 [09:49<8:02:16, 59.05s/it]

Epoch [10/500] Train Loss: 0.1178 | Val Loss: 0.1392


Epochs:   2%|▏         | 11/500 [10:48<8:01:02, 59.02s/it]

Epoch [11/500] Train Loss: 0.1301 | Val Loss: 0.1317


Epochs:   2%|▏         | 12/500 [11:47<7:59:58, 59.01s/it]

Epoch [12/500] Train Loss: 0.1209 | Val Loss: 0.1668


Epochs:   3%|▎         | 13/500 [12:46<7:59:11, 59.04s/it]

Epoch [13/500] Train Loss: 0.1141 | Val Loss: 0.1592
Epoch [14/500] Train Loss: 0.1161 | Val Loss: 0.1227


Epochs:   3%|▎         | 15/500 [14:44<7:57:27, 59.07s/it]

Epoch [15/500] Train Loss: 0.1228 | Val Loss: 0.2017


Epochs:   3%|▎         | 16/500 [15:43<7:56:08, 59.03s/it]

Epoch [16/500] Train Loss: 0.1264 | Val Loss: 0.1482
Epoch [17/500] Train Loss: 0.1284 | Val Loss: 0.1223


Epochs:   3%|▎         | 17/500 [16:42<7:55:18, 59.04s/it]

Epoch [18/500] Train Loss: 0.1190 | Val Loss: 0.1131


Epochs:   4%|▎         | 18/500 [17:43<7:58:07, 59.52s/it]

Epoch [19/500] Train Loss: 0.1111 | Val Loss: 0.1112


Epochs:   4%|▍         | 20/500 [19:42<7:55:23, 59.42s/it]

Epoch [20/500] Train Loss: 0.1077 | Val Loss: 0.1140
Epoch [21/500] Train Loss: 0.1081 | Val Loss: 0.1001


Epochs:   4%|▍         | 22/500 [21:39<7:50:29, 59.06s/it]

Epoch [22/500] Train Loss: 0.1111 | Val Loss: 0.1085


Epochs:   5%|▍         | 23/500 [22:38<7:49:13, 59.02s/it]

Epoch [23/500] Train Loss: 0.1034 | Val Loss: 0.1016


Epochs:   5%|▍         | 24/500 [23:37<7:47:52, 58.98s/it]

Epoch [24/500] Train Loss: 0.1104 | Val Loss: 0.1313
Epoch [25/500] Train Loss: 0.1101 | Val Loss: 0.0966


Epochs:   5%|▌         | 26/500 [25:35<7:46:03, 59.00s/it]

Epoch [26/500] Train Loss: 0.0987 | Val Loss: 0.1055


Epochs:   5%|▌         | 27/500 [26:34<7:44:52, 58.97s/it]

Epoch [27/500] Train Loss: 0.1002 | Val Loss: 0.1181
Epoch [28/500] Train Loss: 0.0973 | Val Loss: 0.0891


Epochs:   6%|▌         | 29/500 [28:33<7:44:06, 59.12s/it]

Epoch [29/500] Train Loss: 0.1016 | Val Loss: 0.1142
Epoch [30/500] Train Loss: 0.0989 | Val Loss: 0.0855


Epochs:   6%|▌         | 31/500 [30:30<7:39:31, 58.79s/it]

Epoch [31/500] Train Loss: 0.0959 | Val Loss: 0.0999


Epochs:   6%|▋         | 32/500 [31:28<7:36:51, 58.57s/it]

Epoch [32/500] Train Loss: 0.0927 | Val Loss: 0.0942


Epochs:   7%|▋         | 33/500 [32:27<7:35:48, 58.56s/it]

Epoch [33/500] Train Loss: 0.1015 | Val Loss: 0.0983


Epochs:   7%|▋         | 34/500 [33:25<7:35:01, 58.59s/it]

Epoch [34/500] Train Loss: 0.0882 | Val Loss: 0.1257


Epochs:   7%|▋         | 35/500 [34:24<7:34:18, 58.62s/it]

Epoch [35/500] Train Loss: 0.0981 | Val Loss: 0.0939


Epochs:   7%|▋         | 36/500 [35:22<7:33:09, 58.60s/it]

Epoch [36/500] Train Loss: 0.0846 | Val Loss: 0.1146


Epochs:   7%|▋         | 37/500 [36:22<7:33:28, 58.77s/it]

Epoch [37/500] Train Loss: 0.0858 | Val Loss: 0.0890
Epoch [38/500] Train Loss: 0.0878 | Val Loss: 0.0787


Epochs:   8%|▊         | 39/500 [38:20<7:33:19, 59.00s/it]

Epoch [39/500] Train Loss: 0.0801 | Val Loss: 0.0906
Epoch [40/500] Train Loss: 0.0889 | Val Loss: 0.0751


Epochs:   8%|▊         | 40/500 [39:20<7:33:35, 59.16s/it]

Epoch [41/500] Train Loss: 0.0718 | Val Loss: 0.0682


Epochs:   8%|▊         | 42/500 [41:18<7:31:49, 59.19s/it]

Epoch [42/500] Train Loss: 0.0751 | Val Loss: 0.0927


Epochs:   9%|▊         | 43/500 [42:17<7:30:22, 59.13s/it]

Epoch [43/500] Train Loss: 0.0810 | Val Loss: 0.0807


Epochs:   9%|▉         | 44/500 [43:16<7:29:11, 59.10s/it]

Epoch [44/500] Train Loss: 0.0764 | Val Loss: 0.0723


Epochs:   9%|▉         | 45/500 [44:15<7:28:28, 59.14s/it]

Epoch [45/500] Train Loss: 0.0730 | Val Loss: 0.0937
Epoch [46/500] Train Loss: 0.0674 | Val Loss: 0.0639


Epochs:   9%|▉         | 47/500 [46:13<7:24:40, 58.90s/it]

Epoch [47/500] Train Loss: 0.0719 | Val Loss: 0.0739


Epochs:  10%|▉         | 48/500 [47:11<7:22:03, 58.68s/it]

Epoch [48/500] Train Loss: 0.0728 | Val Loss: 0.0749


Epochs:  10%|▉         | 49/500 [48:09<7:19:36, 58.48s/it]

Epoch [49/500] Train Loss: 0.0689 | Val Loss: 0.0705
Epoch [50/500] Train Loss: 0.0622 | Val Loss: 0.0601


Epochs:  10%|█         | 50/500 [49:07<7:17:12, 58.30s/it]

Epoch [51/500] Train Loss: 0.0600 | Val Loss: 0.0527


Epochs:  10%|█         | 52/500 [51:01<7:10:38, 57.67s/it]

Epoch [52/500] Train Loss: 0.0656 | Val Loss: 0.0914


Epochs:  11%|█         | 53/500 [51:58<7:08:50, 57.56s/it]

Epoch [53/500] Train Loss: 0.0647 | Val Loss: 0.0667


Epochs:  11%|█         | 54/500 [52:55<7:06:39, 57.40s/it]

Epoch [54/500] Train Loss: 0.0622 | Val Loss: 0.0666


Epochs:  11%|█         | 55/500 [53:52<7:04:49, 57.28s/it]

Epoch [55/500] Train Loss: 0.0570 | Val Loss: 0.0738


Epochs:  11%|█         | 56/500 [54:50<7:03:45, 57.26s/it]

Epoch [56/500] Train Loss: 0.0548 | Val Loss: 0.0645
Epoch [57/500] Train Loss: 0.0578 | Val Loss: 0.0523


Epochs:  12%|█▏        | 58/500 [56:45<7:04:18, 57.60s/it]

Epoch [58/500] Train Loss: 0.0551 | Val Loss: 0.0695


Epochs:  12%|█▏        | 59/500 [57:43<7:04:28, 57.75s/it]

Epoch [59/500] Train Loss: 0.0507 | Val Loss: 0.0828
Epoch [60/500] Train Loss: 0.0523 | Val Loss: 0.0522


Epochs:  12%|█▏        | 60/500 [58:42<7:04:38, 57.91s/it]

Epoch [61/500] Train Loss: 0.0463 | Val Loss: 0.0468


Epochs:  12%|█▏        | 62/500 [1:00:39<7:05:40, 58.31s/it]

Epoch [62/500] Train Loss: 0.0515 | Val Loss: 0.0997


Epochs:  13%|█▎        | 63/500 [1:01:37<7:04:34, 58.29s/it]

Epoch [63/500] Train Loss: 0.0514 | Val Loss: 0.0488


Epochs:  13%|█▎        | 64/500 [1:02:36<7:03:35, 58.29s/it]

Epoch [64/500] Train Loss: 0.0516 | Val Loss: 0.0519
Epoch [65/500] Train Loss: 0.0506 | Val Loss: 0.0388


Epochs:  13%|█▎        | 66/500 [1:04:33<7:02:38, 58.43s/it]

Epoch [66/500] Train Loss: 0.0416 | Val Loss: 0.0415


Epochs:  13%|█▎        | 67/500 [1:05:31<7:01:36, 58.42s/it]

Epoch [67/500] Train Loss: 0.0477 | Val Loss: 0.0398


Epochs:  14%|█▎        | 68/500 [1:06:30<7:00:56, 58.47s/it]

Epoch [68/500] Train Loss: 0.0429 | Val Loss: 0.0532
Epoch [69/500] Train Loss: 0.0403 | Val Loss: 0.0381


Epochs:  14%|█▍        | 70/500 [1:08:28<7:00:39, 58.70s/it]

Epoch [70/500] Train Loss: 0.0428 | Val Loss: 0.0416


Epochs:  14%|█▍        | 71/500 [1:09:27<7:00:53, 58.87s/it]

Epoch [71/500] Train Loss: 0.0379 | Val Loss: 0.0410


Epochs:  14%|█▍        | 72/500 [1:10:25<6:58:56, 58.73s/it]

Epoch [72/500] Train Loss: 0.0446 | Val Loss: 0.0389
Epoch [73/500] Train Loss: 0.0299 | Val Loss: 0.0281


Epochs:  15%|█▍        | 74/500 [1:12:23<6:56:43, 58.69s/it]

Epoch [74/500] Train Loss: 0.0370 | Val Loss: 0.0322


Epochs:  15%|█▌        | 75/500 [1:13:21<6:55:19, 58.63s/it]

Epoch [75/500] Train Loss: 0.0417 | Val Loss: 0.0369


Epochs:  15%|█▌        | 76/500 [1:14:20<6:54:31, 58.66s/it]

Epoch [76/500] Train Loss: 0.0359 | Val Loss: 0.0322


Epochs:  15%|█▌        | 77/500 [1:15:18<6:53:26, 58.64s/it]

Epoch [77/500] Train Loss: 0.0333 | Val Loss: 0.0369


Epochs:  16%|█▌        | 78/500 [1:16:17<6:51:43, 58.54s/it]

Epoch [78/500] Train Loss: 0.0295 | Val Loss: 0.0317
Epoch [79/500] Train Loss: 0.0293 | Val Loss: 0.0278


Epochs:  16%|█▌        | 79/500 [1:17:15<6:50:57, 58.57s/it]

Epoch [80/500] Train Loss: 0.0312 | Val Loss: 0.0265


Epochs:  16%|█▌        | 81/500 [1:19:13<6:49:20, 58.62s/it]

Epoch [81/500] Train Loss: 0.0308 | Val Loss: 0.0540


Epochs:  16%|█▋        | 82/500 [1:20:11<6:47:03, 58.43s/it]

Epoch [82/500] Train Loss: 0.0341 | Val Loss: 0.0317


Epochs:  17%|█▋        | 83/500 [1:21:09<6:46:29, 58.49s/it]

Epoch [83/500] Train Loss: 0.0316 | Val Loss: 0.0301


Epochs:  17%|█▋        | 84/500 [1:22:08<6:44:52, 58.39s/it]

Epoch [84/500] Train Loss: 0.0341 | Val Loss: 0.0287


Epochs:  17%|█▋        | 85/500 [1:23:06<6:44:03, 58.42s/it]

Epoch [85/500] Train Loss: 0.0328 | Val Loss: 0.0335


Epochs:  17%|█▋        | 86/500 [1:24:04<6:42:52, 58.39s/it]

Epoch [86/500] Train Loss: 0.0316 | Val Loss: 0.0338
Epoch [87/500] Train Loss: 0.0294 | Val Loss: 0.0261


Epochs:  17%|█▋        | 87/500 [1:25:03<6:42:11, 58.43s/it]

Epoch [88/500] Train Loss: 0.0264 | Val Loss: 0.0239


Epochs:  18%|█▊        | 89/500 [1:27:00<6:40:33, 58.47s/it]

Epoch [89/500] Train Loss: 0.0256 | Val Loss: 0.0263
Epoch [90/500] Train Loss: 0.0250 | Val Loss: 0.0220


Epochs:  18%|█▊        | 90/500 [1:27:59<6:40:07, 58.56s/it]

Epoch [91/500] Train Loss: 0.0217 | Val Loss: 0.0198


Epochs:  18%|█▊        | 91/500 [1:28:58<6:40:22, 58.74s/it]

Epoch [92/500] Train Loss: 0.0212 | Val Loss: 0.0196


Epochs:  19%|█▊        | 93/500 [1:30:55<6:38:03, 58.68s/it]

Epoch [93/500] Train Loss: 0.0239 | Val Loss: 0.0218


Epochs:  19%|█▉        | 94/500 [1:31:53<6:35:38, 58.47s/it]

Epoch [94/500] Train Loss: 0.0234 | Val Loss: 0.0284


Epochs:  19%|█▉        | 95/500 [1:32:52<6:35:16, 58.56s/it]

Epoch [95/500] Train Loss: 0.0270 | Val Loss: 0.0263


Epochs:  19%|█▉        | 96/500 [1:33:50<6:33:38, 58.46s/it]

Epoch [96/500] Train Loss: 0.0293 | Val Loss: 0.0288


Epochs:  19%|█▉        | 97/500 [1:34:49<6:33:38, 58.61s/it]

Epoch [97/500] Train Loss: 0.0311 | Val Loss: 0.0273


Epochs:  20%|█▉        | 98/500 [1:35:48<6:32:59, 58.65s/it]

Epoch [98/500] Train Loss: 0.0271 | Val Loss: 0.0265


Epochs:  20%|█▉        | 99/500 [1:36:47<6:32:37, 58.75s/it]

Epoch [99/500] Train Loss: 0.0259 | Val Loss: 0.0245


Epochs:  20%|██        | 100/500 [1:37:45<6:31:20, 58.70s/it]

Epoch [100/500] Train Loss: 0.0313 | Val Loss: 0.0697


Epochs:  20%|██        | 101/500 [1:38:44<6:29:56, 58.64s/it]

Epoch [101/500] Train Loss: 0.0378 | Val Loss: 0.0291


Epochs:  20%|██        | 102/500 [1:39:43<6:29:44, 58.76s/it]

Epoch [102/500] Train Loss: 0.0266 | Val Loss: 0.0687


Epochs:  21%|██        | 103/500 [1:40:43<6:30:56, 59.09s/it]

Epoch [103/500] Train Loss: 0.0265 | Val Loss: 0.0204
Epoch [104/500] Train Loss: 0.0208 | Val Loss: 0.0185


Epochs:  21%|██        | 105/500 [1:42:41<6:29:14, 59.13s/it]

Epoch [105/500] Train Loss: 0.0194 | Val Loss: 0.0191
Epoch [106/500] Train Loss: 0.0215 | Val Loss: 0.0177


Epochs:  21%|██        | 106/500 [1:43:39<6:26:46, 58.90s/it]

Epoch [107/500] Train Loss: 0.0174 | Val Loss: 0.0173


Epochs:  22%|██▏       | 108/500 [1:45:36<6:22:43, 58.58s/it]

Epoch [108/500] Train Loss: 0.0172 | Val Loss: 0.0178


Epochs:  22%|██▏       | 109/500 [1:46:35<6:21:39, 58.57s/it]

Epoch [109/500] Train Loss: 0.0172 | Val Loss: 0.0180
Epoch [110/500] Train Loss: 0.0187 | Val Loss: 0.0170


Epochs:  22%|██▏       | 111/500 [1:49:00<7:09:49, 66.30s/it]

Epoch [111/500] Train Loss: 0.0170 | Val Loss: 0.0179
Epoch [112/500] Train Loss: 0.0184 | Val Loss: 0.0167


Epochs:  22%|██▏       | 112/500 [1:50:17<7:29:02, 69.44s/it]

Epoch [113/500] Train Loss: 0.0174 | Val Loss: 0.0156


Epochs:  23%|██▎       | 114/500 [1:52:50<7:50:53, 73.19s/it]

Epoch [114/500] Train Loss: 0.0166 | Val Loss: 0.0167


Epochs:  23%|██▎       | 115/500 [1:54:08<7:59:08, 74.67s/it]

Epoch [115/500] Train Loss: 0.0175 | Val Loss: 0.0166


Epochs:  23%|██▎       | 116/500 [1:55:25<8:01:39, 75.26s/it]

Epoch [116/500] Train Loss: 0.0186 | Val Loss: 0.0166


Epochs:  23%|██▎       | 117/500 [1:56:41<8:02:46, 75.63s/it]

Epoch [117/500] Train Loss: 0.0172 | Val Loss: 0.0168


Epochs:  24%|██▎       | 118/500 [1:57:58<8:03:21, 75.92s/it]

Epoch [118/500] Train Loss: 0.0163 | Val Loss: 0.0176
Epoch [119/500] Train Loss: 0.0161 | Val Loss: 0.0150


Epochs:  24%|██▍       | 120/500 [2:00:32<8:04:47, 76.55s/it]

Epoch [120/500] Train Loss: 0.0166 | Val Loss: 0.0160
Epoch [121/500] Train Loss: 0.0163 | Val Loss: 0.0147


Epochs:  24%|██▍       | 122/500 [2:03:07<8:05:08, 77.01s/it]

Epoch [122/500] Train Loss: 0.0152 | Val Loss: 0.0149


Epochs:  25%|██▍       | 123/500 [2:04:23<8:02:28, 76.79s/it]

Epoch [123/500] Train Loss: 0.0157 | Val Loss: 0.0152


Epochs:  25%|██▍       | 124/500 [2:05:38<7:57:27, 76.19s/it]

Epoch [124/500] Train Loss: 0.0160 | Val Loss: 0.0153


Epochs:  25%|██▌       | 125/500 [2:06:56<7:59:27, 76.71s/it]

Epoch [125/500] Train Loss: 0.0141 | Val Loss: 0.0150
Epoch [126/500] Train Loss: 0.0160 | Val Loss: 0.0145


Epochs:  25%|██▌       | 127/500 [2:09:29<7:57:30, 76.81s/it]

Epoch [127/500] Train Loss: 0.0162 | Val Loss: 0.0157


Epochs:  26%|██▌       | 128/500 [2:10:47<7:57:06, 76.95s/it]

Epoch [128/500] Train Loss: 0.0154 | Val Loss: 0.0146
Epoch [129/500] Train Loss: 0.0150 | Val Loss: 0.0133


Epochs:  26%|██▌       | 130/500 [2:13:19<7:51:20, 76.43s/it]

Epoch [130/500] Train Loss: 0.0153 | Val Loss: 0.0144


Epochs:  26%|██▌       | 131/500 [2:14:38<7:55:39, 77.34s/it]

Epoch [131/500] Train Loss: 0.0153 | Val Loss: 0.0150


Epochs:  26%|██▋       | 132/500 [2:15:54<7:51:12, 76.83s/it]

Epoch [132/500] Train Loss: 0.0153 | Val Loss: 0.0141


Epochs:  27%|██▋       | 133/500 [2:17:11<7:50:10, 76.87s/it]

Epoch [133/500] Train Loss: 0.0150 | Val Loss: 0.0138
Epoch [134/500] Train Loss: 0.0155 | Val Loss: 0.0133


Epochs:  27%|██▋       | 135/500 [2:19:18<7:03:14, 69.57s/it]

Epoch [135/500] Train Loss: 0.0142 | Val Loss: 0.0142


Epochs:  27%|██▋       | 136/500 [2:20:34<7:12:57, 71.37s/it]

Epoch [136/500] Train Loss: 0.0142 | Val Loss: 0.0136


Epochs:  27%|██▋       | 137/500 [2:21:50<7:20:32, 72.82s/it]

Epoch [137/500] Train Loss: 0.0157 | Val Loss: 0.0146


Epochs:  28%|██▊       | 138/500 [2:23:05<7:24:18, 73.64s/it]

Epoch [138/500] Train Loss: 0.0161 | Val Loss: 0.0139
Epoch [139/500] Train Loss: 0.0152 | Val Loss: 0.0132


Epochs:  28%|██▊       | 140/500 [2:25:36<7:27:37, 74.60s/it]

Epoch [140/500] Train Loss: 0.0133 | Val Loss: 0.0138


Epochs:  28%|██▊       | 141/500 [2:26:52<7:29:09, 75.07s/it]

Epoch [141/500] Train Loss: 0.0134 | Val Loss: 0.0135
Epoch [142/500] Train Loss: 0.0136 | Val Loss: 0.0131


Epochs:  29%|██▊       | 143/500 [2:29:24<7:29:04, 75.48s/it]

Epoch [143/500] Train Loss: 0.0139 | Val Loss: 0.0133


Epochs:  29%|██▉       | 144/500 [2:30:42<7:31:57, 76.17s/it]

Epoch [144/500] Train Loss: 0.0142 | Val Loss: 0.0134
Epoch [145/500] Train Loss: 0.0124 | Val Loss: 0.0128


Epochs:  29%|██▉       | 145/500 [2:31:58<7:29:24, 75.96s/it]

Epoch [146/500] Train Loss: 0.0135 | Val Loss: 0.0126


Epochs:  29%|██▉       | 147/500 [2:34:28<7:25:24, 75.71s/it]

Epoch [147/500] Train Loss: 0.0133 | Val Loss: 0.0131


Epochs:  30%|██▉       | 148/500 [2:35:44<7:23:38, 75.62s/it]

Epoch [148/500] Train Loss: 0.0144 | Val Loss: 0.0126


Epochs:  30%|██▉       | 149/500 [2:36:59<7:21:21, 75.45s/it]

Epoch [149/500] Train Loss: 0.0145 | Val Loss: 0.0130


Epochs:  30%|███       | 150/500 [2:38:14<7:19:38, 75.37s/it]

Epoch [150/500] Train Loss: 0.0139 | Val Loss: 0.0130


Epochs:  30%|███       | 151/500 [2:39:29<7:18:28, 75.38s/it]

Epoch [151/500] Train Loss: 0.0129 | Val Loss: 0.0128


Epochs:  30%|███       | 152/500 [2:40:45<7:18:07, 75.54s/it]

Epoch [152/500] Train Loss: 0.0130 | Val Loss: 0.0137
Epoch [153/500] Train Loss: 0.0135 | Val Loss: 0.0122


Epochs:  31%|███       | 154/500 [2:43:15<7:13:10, 75.12s/it]

Epoch [154/500] Train Loss: 0.0119 | Val Loss: 0.0129


Epochs:  31%|███       | 155/500 [2:44:30<7:10:58, 74.95s/it]

Epoch [155/500] Train Loss: 0.0131 | Val Loss: 0.0123
Epoch [156/500] Train Loss: 0.0125 | Val Loss: 0.0116


Epochs:  31%|███▏      | 157/500 [2:47:01<7:10:02, 75.22s/it]

Epoch [157/500] Train Loss: 0.0130 | Val Loss: 0.0120


Epochs:  32%|███▏      | 158/500 [2:48:19<7:14:30, 76.23s/it]

Epoch [158/500] Train Loss: 0.0124 | Val Loss: 0.0119


Epochs:  32%|███▏      | 159/500 [2:49:35<7:11:59, 76.01s/it]

Epoch [159/500] Train Loss: 0.0127 | Val Loss: 0.0117


Epochs:  32%|███▏      | 160/500 [2:50:50<7:09:55, 75.87s/it]

Epoch [160/500] Train Loss: 0.0130 | Val Loss: 0.0122


Epochs:  32%|███▏      | 161/500 [2:52:05<7:06:34, 75.50s/it]

Epoch [161/500] Train Loss: 0.0128 | Val Loss: 0.0116


Epochs:  32%|███▏      | 162/500 [2:53:19<7:03:26, 75.17s/it]

Epoch [162/500] Train Loss: 0.0124 | Val Loss: 0.0119


Epochs:  33%|███▎      | 163/500 [2:54:35<7:03:01, 75.32s/it]

Epoch [163/500] Train Loss: 0.0122 | Val Loss: 0.0122
Epoch [164/500] Train Loss: 0.0120 | Val Loss: 0.0113


Epochs:  33%|███▎      | 165/500 [2:57:07<7:02:40, 75.70s/it]

Epoch [165/500] Train Loss: 0.0130 | Val Loss: 0.0119


Epochs:  33%|███▎      | 166/500 [2:58:23<7:01:32, 75.73s/it]

Epoch [166/500] Train Loss: 0.0129 | Val Loss: 0.0124


Epochs:  33%|███▎      | 167/500 [2:59:38<6:59:45, 75.63s/it]

Epoch [167/500] Train Loss: 0.0130 | Val Loss: 0.0115


Epochs:  34%|███▎      | 168/500 [3:00:52<6:55:53, 75.16s/it]

Epoch [168/500] Train Loss: 0.0116 | Val Loss: 0.0113
Epoch [169/500] Train Loss: 0.0131 | Val Loss: 0.0113


Epochs:  34%|███▍      | 170/500 [3:03:23<6:53:49, 75.24s/it]

Epoch [170/500] Train Loss: 0.0125 | Val Loss: 0.0114


Epochs:  34%|███▍      | 171/500 [3:04:39<6:54:34, 75.61s/it]

Epoch [171/500] Train Loss: 0.0123 | Val Loss: 0.0119


Epochs:  34%|███▍      | 172/500 [3:05:52<6:48:27, 74.72s/it]

Epoch [172/500] Train Loss: 0.0122 | Val Loss: 0.0113


Epochs:  35%|███▍      | 173/500 [3:07:08<6:48:50, 75.02s/it]

Epoch [173/500] Train Loss: 0.0118 | Val Loss: 0.0116


Epochs:  35%|███▍      | 174/500 [3:08:28<6:55:32, 76.48s/it]

Epoch [174/500] Train Loss: 0.0118 | Val Loss: 0.0116


Epochs:  35%|███▌      | 175/500 [3:09:44<6:54:23, 76.50s/it]

Epoch [175/500] Train Loss: 0.0118 | Val Loss: 0.0116


Epochs:  35%|███▌      | 176/500 [3:11:00<6:51:58, 76.29s/it]

Epoch [176/500] Train Loss: 0.0130 | Val Loss: 0.0117
Epoch [177/500] Train Loss: 0.0123 | Val Loss: 0.0108


Epochs:  35%|███▌      | 177/500 [3:12:18<6:53:23, 76.79s/it]

Epoch [178/500] Train Loss: 0.0119 | Val Loss: 0.0104


Epochs:  36%|███▌      | 179/500 [3:14:54<6:53:22, 77.27s/it]

Epoch [179/500] Train Loss: 0.0121 | Val Loss: 0.0110


Epochs:  36%|███▌      | 180/500 [3:16:11<6:51:59, 77.25s/it]

Epoch [180/500] Train Loss: 0.0114 | Val Loss: 0.0111


Epochs:  36%|███▌      | 181/500 [3:17:27<6:48:43, 76.88s/it]

Epoch [181/500] Train Loss: 0.0124 | Val Loss: 0.0120


Epochs:  36%|███▋      | 182/500 [3:18:44<6:46:45, 76.75s/it]

Epoch [182/500] Train Loss: 0.0117 | Val Loss: 0.0110


Epochs:  37%|███▋      | 183/500 [3:19:58<6:41:20, 75.96s/it]

Epoch [183/500] Train Loss: 0.0111 | Val Loss: 0.0110


Epochs:  37%|███▋      | 184/500 [3:21:13<6:39:24, 75.84s/it]

Epoch [184/500] Train Loss: 0.0111 | Val Loss: 0.0105


Epochs:  37%|███▋      | 185/500 [3:22:30<6:38:43, 75.95s/it]

Epoch [185/500] Train Loss: 0.0112 | Val Loss: 0.0109
Epoch [186/500] Train Loss: 0.0114 | Val Loss: 0.0103


Epochs:  37%|███▋      | 187/500 [3:24:56<6:28:31, 74.48s/it]

Epoch [187/500] Train Loss: 0.0110 | Val Loss: 0.0115


Epochs:  38%|███▊      | 188/500 [3:26:12<6:30:15, 75.05s/it]

Epoch [188/500] Train Loss: 0.0108 | Val Loss: 0.0116


Epochs:  38%|███▊      | 189/500 [3:27:28<6:29:30, 75.14s/it]

Epoch [189/500] Train Loss: 0.0114 | Val Loss: 0.0105


Epochs:  38%|███▊      | 190/500 [3:28:44<6:29:35, 75.41s/it]

Epoch [190/500] Train Loss: 0.0115 | Val Loss: 0.0105
Epoch [191/500] Train Loss: 0.0112 | Val Loss: 0.0099


Epochs:  38%|███▊      | 192/500 [3:31:15<6:28:30, 75.69s/it]

Epoch [192/500] Train Loss: 0.0114 | Val Loss: 0.0106


Epochs:  39%|███▊      | 193/500 [3:32:31<6:27:06, 75.66s/it]

Epoch [193/500] Train Loss: 0.0104 | Val Loss: 0.0105


Epochs:  39%|███▉      | 194/500 [3:33:46<6:24:56, 75.48s/it]

Epoch [194/500] Train Loss: 0.0104 | Val Loss: 0.0107


Epochs:  39%|███▉      | 195/500 [3:35:01<6:22:39, 75.28s/it]

Epoch [195/500] Train Loss: 0.0105 | Val Loss: 0.0105
Epoch [196/500] Train Loss: 0.0111 | Val Loss: 0.0099


Epochs:  39%|███▉      | 197/500 [3:37:32<6:20:52, 75.42s/it]

Epoch [197/500] Train Loss: 0.0112 | Val Loss: 0.0104


Epochs:  40%|███▉      | 198/500 [3:38:48<6:19:24, 75.38s/it]

Epoch [198/500] Train Loss: 0.0111 | Val Loss: 0.0110


Epochs:  40%|███▉      | 199/500 [3:40:01<6:14:46, 74.71s/it]

Epoch [199/500] Train Loss: 0.0110 | Val Loss: 0.0103


Epochs:  40%|████      | 200/500 [3:41:16<6:14:05, 74.82s/it]

Epoch [200/500] Train Loss: 0.0099 | Val Loss: 0.0102
Epoch [201/500] Train Loss: 0.0107 | Val Loss: 0.0090


Epochs:  40%|████      | 202/500 [3:43:45<6:11:35, 74.82s/it]

Epoch [202/500] Train Loss: 0.0101 | Val Loss: 0.0101


Epochs:  41%|████      | 203/500 [3:45:01<6:10:51, 74.92s/it]

Epoch [203/500] Train Loss: 0.0103 | Val Loss: 0.0098


Epochs:  41%|████      | 204/500 [3:46:15<6:08:58, 74.79s/it]

Epoch [204/500] Train Loss: 0.0110 | Val Loss: 0.0100


Epochs:  41%|████      | 205/500 [3:47:30<6:08:08, 74.88s/it]

Epoch [205/500] Train Loss: 0.0111 | Val Loss: 0.0100


Epochs:  41%|████      | 206/500 [3:48:45<6:06:25, 74.78s/it]

Epoch [206/500] Train Loss: 0.0099 | Val Loss: 0.0100


Epochs:  41%|████▏     | 207/500 [3:50:00<6:05:56, 74.94s/it]

Epoch [207/500] Train Loss: 0.0105 | Val Loss: 0.0105


Epochs:  42%|████▏     | 208/500 [3:51:15<6:04:48, 74.96s/it]

Epoch [208/500] Train Loss: 0.0103 | Val Loss: 0.0099


Epochs:  42%|████▏     | 209/500 [3:52:31<6:04:16, 75.11s/it]

Epoch [209/500] Train Loss: 0.0098 | Val Loss: 0.0093


Epochs:  42%|████▏     | 210/500 [3:53:46<6:03:26, 75.20s/it]

Epoch [210/500] Train Loss: 0.0114 | Val Loss: 0.0098


Epochs:  42%|████▏     | 211/500 [3:55:02<6:02:48, 75.32s/it]

Epoch [211/500] Train Loss: 0.0098 | Val Loss: 0.0092


Epochs:  42%|████▏     | 212/500 [3:56:15<5:59:18, 74.86s/it]

Epoch [212/500] Train Loss: 0.0105 | Val Loss: 0.0097


Epochs:  43%|████▎     | 213/500 [3:57:31<5:58:43, 74.99s/it]

Epoch [213/500] Train Loss: 0.0105 | Val Loss: 0.0094


Epochs:  43%|████▎     | 214/500 [3:58:46<5:57:39, 75.03s/it]

Epoch [214/500] Train Loss: 0.0100 | Val Loss: 0.0095
Epoch [215/500] Train Loss: 0.0101 | Val Loss: 0.0089


Epochs:  43%|████▎     | 216/500 [4:01:16<5:55:10, 75.04s/it]

Epoch [216/500] Train Loss: 0.0094 | Val Loss: 0.0091
Epoch [217/500] Train Loss: 0.0097 | Val Loss: 0.0089


Epochs:  43%|████▎     | 217/500 [4:02:31<5:54:02, 75.06s/it]

Epoch [218/500] Train Loss: 0.0104 | Val Loss: 0.0087


Epochs:  44%|████▎     | 218/500 [4:03:47<5:54:51, 75.50s/it]

Epoch [219/500] Train Loss: 0.0094 | Val Loss: 0.0087


Epochs:  44%|████▍     | 220/500 [4:06:18<5:52:01, 75.43s/it]

Epoch [220/500] Train Loss: 0.0094 | Val Loss: 0.0090


Epochs:  44%|████▍     | 221/500 [4:07:34<5:51:44, 75.64s/it]

Epoch [221/500] Train Loss: 0.0089 | Val Loss: 0.0089
Epoch [222/500] Train Loss: 0.0099 | Val Loss: 0.0084


Epochs:  45%|████▍     | 223/500 [4:10:06<5:49:57, 75.80s/it]

Epoch [223/500] Train Loss: 0.0097 | Val Loss: 0.0088


Epochs:  45%|████▍     | 224/500 [4:11:20<5:46:20, 75.29s/it]

Epoch [224/500] Train Loss: 0.0104 | Val Loss: 0.0089


Epochs:  45%|████▌     | 225/500 [4:12:36<5:45:22, 75.35s/it]

Epoch [225/500] Train Loss: 0.0094 | Val Loss: 0.0088


Epochs:  45%|████▌     | 226/500 [4:13:50<5:42:49, 75.07s/it]

Epoch [226/500] Train Loss: 0.0098 | Val Loss: 0.0091


Epochs:  45%|████▌     | 227/500 [4:15:04<5:40:20, 74.80s/it]

Epoch [227/500] Train Loss: 0.0094 | Val Loss: 0.0084


Epochs:  46%|████▌     | 228/500 [4:16:19<5:39:18, 74.85s/it]

Epoch [228/500] Train Loss: 0.0095 | Val Loss: 0.0089


Epochs:  46%|████▌     | 229/500 [4:17:34<5:37:49, 74.79s/it]

Epoch [229/500] Train Loss: 0.0097 | Val Loss: 0.0088


Epochs:  46%|████▌     | 230/500 [4:18:50<5:38:40, 75.26s/it]

Epoch [230/500] Train Loss: 0.0100 | Val Loss: 0.0090


Epochs:  46%|████▌     | 231/500 [4:20:05<5:36:33, 75.07s/it]

Epoch [231/500] Train Loss: 0.0098 | Val Loss: 0.0086


Epochs:  46%|████▋     | 232/500 [4:21:21<5:36:13, 75.28s/it]

Epoch [232/500] Train Loss: 0.0096 | Val Loss: 0.0086


Epochs:  47%|████▋     | 233/500 [4:22:35<5:34:13, 75.11s/it]

Epoch [233/500] Train Loss: 0.0094 | Val Loss: 0.0090


Epochs:  47%|████▋     | 234/500 [4:23:50<5:32:50, 75.08s/it]

Epoch [234/500] Train Loss: 0.0095 | Val Loss: 0.0097
Epoch [235/500] Train Loss: 0.0101 | Val Loss: 0.0083


Epochs:  47%|████▋     | 236/500 [4:26:20<5:29:39, 74.92s/it]

Epoch [236/500] Train Loss: 0.0097 | Val Loss: 0.0089


Epochs:  47%|████▋     | 237/500 [4:27:35<5:28:35, 74.96s/it]

Epoch [237/500] Train Loss: 0.0098 | Val Loss: 0.0085


Epochs:  48%|████▊     | 238/500 [4:28:51<5:28:55, 75.33s/it]

Epoch [238/500] Train Loss: 0.0092 | Val Loss: 0.0089


Epochs:  48%|████▊     | 239/500 [4:30:07<5:29:09, 75.67s/it]

Epoch [239/500] Train Loss: 0.0095 | Val Loss: 0.0090


Epochs:  48%|████▊     | 240/500 [4:31:23<5:27:31, 75.58s/it]

Epoch [240/500] Train Loss: 0.0090 | Val Loss: 0.0086


Epochs:  48%|████▊     | 241/500 [4:32:37<5:24:20, 75.14s/it]

Epoch [241/500] Train Loss: 0.0091 | Val Loss: 0.0086


Epochs:  48%|████▊     | 242/500 [4:33:51<5:22:06, 74.91s/it]

Epoch [242/500] Train Loss: 0.0098 | Val Loss: 0.0085
Epoch [243/500] Train Loss: 0.0095 | Val Loss: 0.0082


Epochs:  49%|████▉     | 244/500 [4:36:22<5:20:17, 75.07s/it]

Epoch [244/500] Train Loss: 0.0091 | Val Loss: 0.0088


Epochs:  49%|████▉     | 245/500 [4:37:37<5:19:11, 75.10s/it]

Epoch [245/500] Train Loss: 0.0095 | Val Loss: 0.0086


Epochs:  49%|████▉     | 246/500 [4:38:52<5:17:58, 75.11s/it]

Epoch [246/500] Train Loss: 0.0096 | Val Loss: 0.0091


Epochs:  49%|████▉     | 247/500 [4:40:09<5:18:36, 75.56s/it]

Epoch [247/500] Train Loss: 0.0093 | Val Loss: 0.0085


Epochs:  50%|████▉     | 248/500 [4:41:23<5:16:12, 75.29s/it]

Epoch [248/500] Train Loss: 0.0094 | Val Loss: 0.0086


Epochs:  50%|████▉     | 249/500 [4:42:38<5:13:57, 75.05s/it]

Epoch [249/500] Train Loss: 0.0093 | Val Loss: 0.0085


Epochs:  50%|█████     | 250/500 [4:43:53<5:12:50, 75.08s/it]

Epoch [250/500] Train Loss: 0.0100 | Val Loss: 0.0084


Epochs:  50%|█████     | 251/500 [4:45:08<5:11:59, 75.18s/it]

Epoch [251/500] Train Loss: 0.0091 | Val Loss: 0.0090
Epoch [252/500] Train Loss: 0.0091 | Val Loss: 0.0082


Epochs:  51%|█████     | 253/500 [4:47:41<5:11:28, 75.66s/it]

Epoch [253/500] Train Loss: 0.0095 | Val Loss: 0.0086


Epochs:  51%|█████     | 254/500 [4:48:57<5:11:17, 75.93s/it]

Epoch [254/500] Train Loss: 0.0096 | Val Loss: 0.0088


Epochs:  51%|█████     | 255/500 [4:50:13<5:09:39, 75.84s/it]

Epoch [255/500] Train Loss: 0.0099 | Val Loss: 0.0085
Epoch [256/500] Train Loss: 0.0096 | Val Loss: 0.0079


Epochs:  51%|█████▏    | 257/500 [4:52:45<5:06:52, 75.77s/it]

Epoch [257/500] Train Loss: 0.0098 | Val Loss: 0.0084


Epochs:  52%|█████▏    | 258/500 [4:54:01<5:06:12, 75.92s/it]

Epoch [258/500] Train Loss: 0.0092 | Val Loss: 0.0091


Epochs:  52%|█████▏    | 259/500 [4:55:16<5:04:14, 75.74s/it]

Epoch [259/500] Train Loss: 0.0094 | Val Loss: 0.0088


Epochs:  52%|█████▏    | 260/500 [4:56:31<5:01:32, 75.39s/it]

Epoch [260/500] Train Loss: 0.0101 | Val Loss: 0.0086


Epochs:  52%|█████▏    | 261/500 [4:57:46<4:59:40, 75.23s/it]

Epoch [261/500] Train Loss: 0.0097 | Val Loss: 0.0087


Epochs:  52%|█████▏    | 262/500 [4:59:01<4:58:29, 75.25s/it]

Epoch [262/500] Train Loss: 0.0090 | Val Loss: 0.0087


Epochs:  53%|█████▎    | 263/500 [5:00:16<4:57:05, 75.21s/it]

Epoch [263/500] Train Loss: 0.0093 | Val Loss: 0.0086


Epochs:  53%|█████▎    | 264/500 [5:01:31<4:55:15, 75.06s/it]

Epoch [264/500] Train Loss: 0.0090 | Val Loss: 0.0086


Epochs:  53%|█████▎    | 265/500 [5:02:46<4:53:51, 75.03s/it]

Epoch [265/500] Train Loss: 0.0091 | Val Loss: 0.0081


Epochs:  53%|█████▎    | 266/500 [5:04:01<4:52:38, 75.04s/it]

Epoch [266/500] Train Loss: 0.0092 | Val Loss: 0.0088


Epochs:  53%|█████▎    | 267/500 [5:05:16<4:51:46, 75.13s/it]

Epoch [267/500] Train Loss: 0.0096 | Val Loss: 0.0084


Epochs:  54%|█████▎    | 268/500 [5:06:33<4:52:01, 75.52s/it]

Epoch [268/500] Train Loss: 0.0085 | Val Loss: 0.0088


Epochs:  54%|█████▍    | 269/500 [5:07:47<4:49:49, 75.28s/it]

Epoch [269/500] Train Loss: 0.0091 | Val Loss: 0.0089


Epochs:  54%|█████▍    | 270/500 [5:09:02<4:48:23, 75.23s/it]

Epoch [270/500] Train Loss: 0.0095 | Val Loss: 0.0087


Epochs:  54%|█████▍    | 271/500 [5:10:18<4:47:47, 75.40s/it]

Epoch [271/500] Train Loss: 0.0097 | Val Loss: 0.0083


Epochs:  54%|█████▍    | 272/500 [5:11:34<4:46:34, 75.42s/it]

Epoch [272/500] Train Loss: 0.0091 | Val Loss: 0.0090


Epochs:  55%|█████▍    | 273/500 [5:12:49<4:45:13, 75.39s/it]

Epoch [273/500] Train Loss: 0.0093 | Val Loss: 0.0088


Epochs:  55%|█████▍    | 274/500 [5:14:04<4:43:49, 75.35s/it]

Epoch [274/500] Train Loss: 0.0098 | Val Loss: 0.0086


Epochs:  55%|█████▌    | 275/500 [5:15:20<4:42:33, 75.35s/it]

Epoch [275/500] Train Loss: 0.0096 | Val Loss: 0.0084


Epochs:  55%|█████▌    | 275/500 [5:16:35<4:19:01, 69.07s/it]

Epoch [276/500] Train Loss: 0.0095 | Val Loss: 0.0085
Early stopping at epoch 276





### Initializing the entire model

In [15]:
model = UNetPredictor(unet= unet, predictor= predictor)
model = model.to('cuda')

### Defining the optimizer and scheduler

In [16]:
optimizer_unet_predictor = torch.optim.AdamW(model.parameters(), lr = 0.001, weight_decay= 0.001)

scheduler_predictor = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_unet,
    mode='min',       # 'min' for loss, 'max' for accuracy/metric
    factor=0.1,       # multiply LR by this factor
    patience=10,       # wait for 5 epochs before reducing LR
)

In [17]:
from tqdm import trange

train_loss_unetpredictor, test_loss_unetpredictor = [], []

num_epochs = 500
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in trange(num_epochs, desc="Epochs"):
    model.train()
    running_loss = 0.0
    for images, masks, labels in trian_loader:
        images = images.to('cuda')
        masks = masks.to('cuda').view(-1, 1, 256, 256).float()
        labels = labels.to('cuda')

        # --- UNet mask loss ---
        # optimizer_unet.zero_grad()
        # mask_pred = unet(images)
        # loss_mask = criterion_unet(mask_pred, masks)
        # loss_mask.backward()
        # optimizer_unet.step()

        # --- UNetPredictor classification loss ---
        optimizer_unet_predictor.zero_grad()
        class_pred = model(images)
        loss_class = criterion_prediction(class_pred, labels)
        loss_class.backward()
        optimizer_unet_predictor.step()

        # running_loss += loss_mask.item()
        running_loss += loss_class.item()

    avg_loss = running_loss / len(trian_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, masks, labels in test_loader:
            images = images.to('cuda')
            masks = masks.to('cuda').view(-1, 1, 256, 256).float()
            labels = labels.to('cuda')
            mask_pred = unet(images)
            class_pred = model(images)
            # loss_mask = criterion_unet(mask_pred, masks)
            loss_class = criterion_prediction(class_pred, labels)
            # val_loss += loss_mask.item()
            val_loss += loss_class.item()
            _, predicted = torch.max(class_pred, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    val_loss /= len(test_loader)

    scheduler_predictor.step(val_loss)

    val_acc = 100 * correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    # print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f}")
    train_loss_unetpredictor.append(avg_loss)
    test_loss_unetpredictor.append(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'model1.pt')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

import pandas as pd
train_loss_unetpredictor = pd.DataFrame(train_loss_unetpredictor)
train_loss_unetpredictor.to_csv('train_loss_predictor1.csv', index= False, header= False)
test_loss_unetpredictor = pd.DataFrame(test_loss_unetpredictor)
test_loss_unetpredictor.to_csv('test_loss_predictor1.csv', index= False, header= False)

Epochs:   0%|          | 0/500 [00:00<?, ?it/s]

Epoch [1/500] Train Loss: 18.3328 | Val Loss: 1.0652 | Val Acc: 70.90%


Epochs:   0%|          | 1/500 [01:18<10:49:38, 78.11s/it]

Epoch [2/500] Train Loss: 0.7551 | Val Loss: 0.6705 | Val Acc: 74.87%


Epochs:   0%|          | 2/500 [02:22<9:42:50, 70.22s/it] 

Epoch [3/500] Train Loss: 0.7059 | Val Loss: 0.6665 | Val Acc: 72.05%


Epochs:   1%|          | 3/500 [03:17<8:43:00, 63.14s/it]

Epoch [4/500] Train Loss: 0.6603 | Val Loss: 0.6617 | Val Acc: 69.74%


Epochs:   1%|          | 4/500 [04:10<8:08:03, 59.04s/it]

Epoch [5/500] Train Loss: 0.6019 | Val Loss: 0.6386 | Val Acc: 69.23%


Epochs:   1%|          | 5/500 [05:02<7:47:46, 56.70s/it]

Epoch [6/500] Train Loss: 0.5751 | Val Loss: 0.5322 | Val Acc: 71.92%


Epochs:   1%|          | 6/500 [05:55<7:35:28, 55.32s/it]

Epoch [7/500] Train Loss: 0.5259 | Val Loss: 0.5255 | Val Acc: 74.74%


Epochs:   2%|▏         | 8/500 [07:39<7:18:38, 53.49s/it]

Epoch [8/500] Train Loss: 0.6658 | Val Loss: 0.5328 | Val Acc: 76.67%


Epochs:   2%|▏         | 9/500 [08:31<7:13:41, 53.00s/it]

Epoch [9/500] Train Loss: 0.5953 | Val Loss: 0.5651 | Val Acc: 70.51%


Epochs:   2%|▏         | 10/500 [09:23<7:10:26, 52.71s/it]

Epoch [10/500] Train Loss: 0.5017 | Val Loss: 0.5632 | Val Acc: 75.26%
Epoch [11/500] Train Loss: 0.4760 | Val Loss: 0.4699 | Val Acc: 77.82%


Epochs:   2%|▏         | 12/500 [11:07<7:05:22, 52.30s/it]

Epoch [12/500] Train Loss: 0.5104 | Val Loss: 0.4918 | Val Acc: 77.44%


Epochs:   3%|▎         | 13/500 [11:57<7:00:30, 51.81s/it]

Epoch [13/500] Train Loss: 7.8164 | Val Loss: 7.2822 | Val Acc: 36.54%


Epochs:   3%|▎         | 14/500 [12:48<6:57:05, 51.49s/it]

Epoch [14/500] Train Loss: 1.3490 | Val Loss: 1.0853 | Val Acc: 37.05%


Epochs:   3%|▎         | 15/500 [13:39<6:54:34, 51.29s/it]

Epoch [15/500] Train Loss: 1.0727 | Val Loss: 1.0600 | Val Acc: 39.49%


Epochs:   3%|▎         | 16/500 [14:29<6:51:24, 51.00s/it]

Epoch [16/500] Train Loss: 0.9768 | Val Loss: 1.1261 | Val Acc: 30.38%


Epochs:   3%|▎         | 17/500 [15:20<6:49:40, 50.89s/it]

Epoch [17/500] Train Loss: 1.0975 | Val Loss: 1.1038 | Val Acc: 31.28%


Epochs:   4%|▎         | 18/500 [16:11<6:48:57, 50.91s/it]

Epoch [18/500] Train Loss: 1.0997 | Val Loss: 1.0972 | Val Acc: 35.00%


Epochs:   4%|▍         | 19/500 [17:01<6:47:00, 50.77s/it]

Epoch [19/500] Train Loss: 1.1006 | Val Loss: 1.1003 | Val Acc: 28.72%


Epochs:   4%|▍         | 20/500 [17:52<6:45:52, 50.73s/it]

Epoch [20/500] Train Loss: 1.0999 | Val Loss: 1.1021 | Val Acc: 33.72%


Epochs:   4%|▍         | 21/500 [18:42<6:44:21, 50.65s/it]

Epoch [21/500] Train Loss: 1.1008 | Val Loss: 1.1004 | Val Acc: 33.08%


Epochs:   4%|▍         | 22/500 [19:33<6:43:24, 50.64s/it]

Epoch [22/500] Train Loss: 1.1015 | Val Loss: 1.0995 | Val Acc: 33.21%


Epochs:   5%|▍         | 23/500 [20:24<6:42:40, 50.65s/it]

Epoch [23/500] Train Loss: 1.1014 | Val Loss: 1.0992 | Val Acc: 33.33%


Epochs:   5%|▍         | 24/500 [21:14<6:41:28, 50.61s/it]

Epoch [24/500] Train Loss: 1.0966 | Val Loss: 1.0967 | Val Acc: 36.92%


Epochs:   5%|▌         | 25/500 [22:05<6:40:16, 50.56s/it]

Epoch [25/500] Train Loss: 1.0987 | Val Loss: 1.1009 | Val Acc: 31.67%


Epochs:   5%|▌         | 26/500 [22:55<6:38:45, 50.48s/it]

Epoch [26/500] Train Loss: 1.0981 | Val Loss: 1.0992 | Val Acc: 31.79%


Epochs:   5%|▌         | 27/500 [23:45<6:37:05, 50.37s/it]

Epoch [27/500] Train Loss: 1.0986 | Val Loss: 1.0992 | Val Acc: 33.72%


Epochs:   6%|▌         | 28/500 [24:35<6:35:57, 50.33s/it]

Epoch [28/500] Train Loss: 1.0963 | Val Loss: 1.0996 | Val Acc: 33.46%


Epochs:   6%|▌         | 29/500 [25:26<6:34:42, 50.28s/it]

Epoch [29/500] Train Loss: 1.1006 | Val Loss: 1.0987 | Val Acc: 34.74%


Epochs:   6%|▌         | 30/500 [26:15<6:33:04, 50.18s/it]

Epoch [30/500] Train Loss: 1.0992 | Val Loss: 1.0988 | Val Acc: 34.23%


Epochs:   6%|▌         | 30/500 [27:06<7:04:34, 54.20s/it]

Epoch [31/500] Train Loss: 1.1006 | Val Loss: 1.0992 | Val Acc: 31.28%
Early stopping at epoch 31





In [None]:
from tqdm import trange

for param in model.layer.parameters():
    param.requires_grad = True

train_loss_unetpredictor, test_loss_unetpredictor = [], []

num_epochs = 500
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in trange(num_epochs, desc="Epochs"):
    model.train()
    running_loss = 0.0
    for images, masks, labels in trian_loader:
        images = images.to('cuda')
        masks = masks.to('cuda').view(-1, 1, 256, 256).float()
        labels = labels.to('cuda')

        # --- UNetPredictor classification loss ---
        optimizer_unet_predictor.zero_grad()
        class_pred = model(images)
        loss_class = criterion_prediction(class_pred, labels)
        loss_class.backward()
        optimizer_unet_predictor.step()

        # running_loss += loss_mask.item()
        running_loss += loss_class.item()

    avg_loss = running_loss / len(trian_loader)

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, masks, labels in test_loader:
            images = images.to('cuda')
            masks = masks.to('cuda').view(-1, 1, 256, 256).float()
            labels = labels.to('cuda')
            mask_pred = unet(images)
            class_pred = model(images)
            # loss_mask = criterion_unet(mask_pred, masks)
            loss_class = criterion_prediction(class_pred, labels)
            # val_loss += loss_mask.item()
            val_loss += loss_class.item()
            _, predicted = torch.max(class_pred, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    val_loss /= len(test_loader)

    scheduler_predictor.step(val_loss)

    val_acc = 100 * correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    # print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f}")
    train_loss_unetpredictor.append(avg_loss)
    test_loss_unetpredictor.append(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'model1.pt')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

import pandas as pd
train_loss_unetpredictor = pd.DataFrame(train_loss_unetpredictor)
train_loss_unetpredictor.to_csv('train_loss_predictor1.csv', index= False, header= False)
test_loss_unetpredictor = pd.DataFrame(test_loss_unetpredictor)
test_loss_unetpredictor.to_csv('test_loss_predictor1.csv', index= False, header= False)