# U-NET TRAINING

### Imports

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm 
from torchmetrics import Accuracy

import os

from PIL import Image
import numpy as np

### Dataset class and Dataloader

In [2]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))
        self.image_transform = image_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])

        image = Image.open(image_path).convert("RGB")  
        mask = Image.open(mask_path).convert("L")  

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask

image_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.875, 1.0)), 
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

mask_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.875, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

train_dataset = SegmentationDataset(
    image_dir="data/seg/train/images", mask_dir="data/seg/train/masks", image_transform=image_transform, mask_transform=mask_transform
)
val_dataset = SegmentationDataset(
    image_dir="data/seg/val/images", mask_dir="data/seg/val/masks", image_transform=image_transform, mask_transform=mask_transform
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

### U-NET class

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()
        
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.enc5 = self.conv_block(512, 1024)

        self.dec4 = self.conv_block(1024, 512)
        self.dec3 = self.conv_block(512, 256)
        self.dec2 = self.conv_block(256, 128)
        self.dec1 = self.conv_block(128, 64)
        
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

        self.pool = nn.MaxPool2d(2)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        enc5 = self.enc5(self.pool(enc4))

        dec4 = self.dec4(F.interpolate(enc5, scale_factor=2, mode='bilinear', align_corners=True))
        dec3 = self.dec3(F.interpolate(dec4, scale_factor=2, mode='bilinear', align_corners=True))
        dec2 = self.dec2(F.interpolate(dec3, scale_factor=2, mode='bilinear', align_corners=True))
        dec1 = self.dec1(F.interpolate(dec2, scale_factor=2, mode='bilinear', align_corners=True))

        out = self.final(dec1)

        return out

model = UNet(in_channels=3, out_channels=1)  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


### Load the weights (skip if you haven't train the model yet)

In [4]:
model.load_state_dict(torch.load("model/segmentation_weights.pth", map_location=device))
print("Pretrained weights loaded successfully!")

Pretrained weights loaded successfully!


### Training loop

In [5]:
criterion = nn.BCEWithLogitsLoss()  
optimizer = optim.Adam(model.parameters(), lr=1e-4)

save_path = "model/segmentation_weights.pth"
num_epochs = 54

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training", unit="batch") as tepoch:
        for inputs, masks in tepoch:
            inputs, masks = inputs.to(device), masks.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs.squeeze(1), masks.float().squeeze(1))

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            preds = torch.round(torch.sigmoid(outputs))  
            running_corrects += torch.sum(preds == masks).item()
            total += inputs.size(0) * inputs.size(2) * inputs.size(3)

            tepoch.set_postfix(loss=running_loss / (tepoch.n + 1), accuracy=running_corrects / total)

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = running_corrects / total
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")

    model.eval()
    val_loss = 0.0
    val_corrects = 0
    val_total = 0
    with torch.no_grad():
        with tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation", unit="batch") as vepoch:
            for inputs, masks in vepoch:
                inputs, masks = inputs.cuda(), masks.cuda()

                outputs = model(inputs)
                loss = criterion(outputs.squeeze(1), masks.float().squeeze(1))

                val_loss += loss.item()

                preds = torch.round(torch.sigmoid(outputs)) 
                val_corrects += torch.sum(preds == masks).item()
                val_total += inputs.size(0) * inputs.size(2) * inputs.size(3)

                vepoch.set_postfix(loss=val_loss / (vepoch.n + 1), accuracy=val_corrects / val_total)

    val_loss /= len(val_loader)
    val_accuracy = val_corrects / val_total
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

    if (epoch + 1) % 3 == 0:
        torch.save(model.state_dict(), save_path)
        print(f"Model weights saved at epoch {epoch+1}")

Epoch 1/54 - Training: 100%|██████████| 71/71 [20:34<00:00, 17.39s/batch, accuracy=0.887, loss=0.224]


Epoch 1/54, Loss: 0.2237, Accuracy: 0.8872


Epoch 1/54 - Validation: 100%|██████████| 18/18 [04:59<00:00, 16.64s/batch, accuracy=0.891, loss=0.215]


Validation Loss: 0.2153, Validation Accuracy: 0.8909


Epoch 2/54 - Training: 100%|██████████| 71/71 [03:20<00:00,  2.82s/batch, accuracy=0.89, loss=0.215] 


Epoch 2/54, Loss: 0.2155, Accuracy: 0.8898


Epoch 2/54 - Validation: 100%|██████████| 18/18 [00:44<00:00,  2.46s/batch, accuracy=0.889, loss=0.219]


Validation Loss: 0.2194, Validation Accuracy: 0.8886


Epoch 3/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.77s/batch, accuracy=0.886, loss=0.225]


Epoch 3/54, Loss: 0.2250, Accuracy: 0.8861


Epoch 3/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.39s/batch, accuracy=0.885, loss=0.227]


Validation Loss: 0.2271, Validation Accuracy: 0.8854
Model weights saved at epoch 3


Epoch 4/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.77s/batch, accuracy=0.888, loss=0.22] 


Epoch 4/54, Loss: 0.2201, Accuracy: 0.8885


Epoch 4/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.37s/batch, accuracy=0.888, loss=0.218]


Validation Loss: 0.2182, Validation Accuracy: 0.8880


Epoch 5/54 - Training: 100%|██████████| 71/71 [03:21<00:00,  2.83s/batch, accuracy=0.891, loss=0.213]


Epoch 5/54, Loss: 0.2134, Accuracy: 0.8907


Epoch 5/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.42s/batch, accuracy=0.884, loss=0.225]


Validation Loss: 0.2250, Validation Accuracy: 0.8844


Epoch 6/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.79s/batch, accuracy=0.889, loss=0.216]


Epoch 6/54, Loss: 0.2163, Accuracy: 0.8892


Epoch 6/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.40s/batch, accuracy=0.89, loss=0.215] 


Validation Loss: 0.2153, Validation Accuracy: 0.8903
Model weights saved at epoch 6


Epoch 7/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.77s/batch, accuracy=0.892, loss=0.211]


Epoch 7/54, Loss: 0.2115, Accuracy: 0.8922


Epoch 7/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.37s/batch, accuracy=0.892, loss=0.219]


Validation Loss: 0.2188, Validation Accuracy: 0.8915


Epoch 8/54 - Training: 100%|██████████| 71/71 [03:15<00:00,  2.76s/batch, accuracy=0.892, loss=0.21] 


Epoch 8/54, Loss: 0.2104, Accuracy: 0.8923


Epoch 8/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.40s/batch, accuracy=0.891, loss=0.212]


Validation Loss: 0.2118, Validation Accuracy: 0.8909


Epoch 9/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.78s/batch, accuracy=0.891, loss=0.212]


Epoch 9/54, Loss: 0.2116, Accuracy: 0.8912


Epoch 9/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.38s/batch, accuracy=0.89, loss=0.214] 


Validation Loss: 0.2144, Validation Accuracy: 0.8903
Model weights saved at epoch 9


Epoch 10/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.77s/batch, accuracy=0.892, loss=0.211]


Epoch 10/54, Loss: 0.2108, Accuracy: 0.8922


Epoch 10/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.38s/batch, accuracy=0.885, loss=0.227]


Validation Loss: 0.2270, Validation Accuracy: 0.8845


Epoch 11/54 - Training: 100%|██████████| 71/71 [03:15<00:00,  2.76s/batch, accuracy=0.89, loss=0.215] 


Epoch 11/54, Loss: 0.2150, Accuracy: 0.8896


Epoch 11/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.37s/batch, accuracy=0.882, loss=0.229]


Validation Loss: 0.2288, Validation Accuracy: 0.8823


Epoch 12/54 - Training: 100%|██████████| 71/71 [03:14<00:00,  2.74s/batch, accuracy=0.89, loss=0.212] 


Epoch 12/54, Loss: 0.2124, Accuracy: 0.8904


Epoch 12/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.34s/batch, accuracy=0.886, loss=0.226]


Validation Loss: 0.2260, Validation Accuracy: 0.8858
Model weights saved at epoch 12


Epoch 13/54 - Training: 100%|██████████| 71/71 [03:14<00:00,  2.74s/batch, accuracy=0.893, loss=0.207]


Epoch 13/54, Loss: 0.2075, Accuracy: 0.8932


Epoch 13/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.36s/batch, accuracy=0.889, loss=0.217]


Validation Loss: 0.2168, Validation Accuracy: 0.8887


Epoch 14/54 - Training: 100%|██████████| 71/71 [03:14<00:00,  2.73s/batch, accuracy=0.89, loss=0.214] 


Epoch 14/54, Loss: 0.2139, Accuracy: 0.8897


Epoch 14/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.35s/batch, accuracy=0.89, loss=0.215] 


Validation Loss: 0.2150, Validation Accuracy: 0.8896


Epoch 15/54 - Training: 100%|██████████| 71/71 [03:14<00:00,  2.74s/batch, accuracy=0.892, loss=0.21] 


Epoch 15/54, Loss: 0.2101, Accuracy: 0.8923


Epoch 15/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.38s/batch, accuracy=0.889, loss=0.216]


Validation Loss: 0.2157, Validation Accuracy: 0.8886
Model weights saved at epoch 15


Epoch 16/54 - Training: 100%|██████████| 71/71 [03:15<00:00,  2.75s/batch, accuracy=0.89, loss=0.213] 


Epoch 16/54, Loss: 0.2133, Accuracy: 0.8902


Epoch 16/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.39s/batch, accuracy=0.887, loss=0.22] 


Validation Loss: 0.2200, Validation Accuracy: 0.8871


Epoch 17/54 - Training: 100%|██████████| 71/71 [03:15<00:00,  2.75s/batch, accuracy=0.893, loss=0.208]


Epoch 17/54, Loss: 0.2083, Accuracy: 0.8931


Epoch 17/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.36s/batch, accuracy=0.894, loss=0.205]


Validation Loss: 0.2049, Validation Accuracy: 0.8941


Epoch 18/54 - Training: 100%|██████████| 71/71 [03:15<00:00,  2.76s/batch, accuracy=0.891, loss=0.211]


Epoch 18/54, Loss: 0.2113, Accuracy: 0.8911


Epoch 18/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.39s/batch, accuracy=0.891, loss=0.214]


Validation Loss: 0.2139, Validation Accuracy: 0.8915
Model weights saved at epoch 18


Epoch 19/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.76s/batch, accuracy=0.892, loss=0.207]


Epoch 19/54, Loss: 0.2072, Accuracy: 0.8924


Epoch 19/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.37s/batch, accuracy=0.891, loss=0.213]


Validation Loss: 0.2128, Validation Accuracy: 0.8914


Epoch 20/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.76s/batch, accuracy=0.893, loss=0.209]


Epoch 20/54, Loss: 0.2087, Accuracy: 0.8926


Epoch 20/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.35s/batch, accuracy=0.891, loss=0.213]


Validation Loss: 0.2127, Validation Accuracy: 0.8910


Epoch 21/54 - Training: 100%|██████████| 71/71 [03:19<00:00,  2.80s/batch, accuracy=0.891, loss=0.21] 


Epoch 21/54, Loss: 0.2102, Accuracy: 0.8913


Epoch 21/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.36s/batch, accuracy=0.89, loss=0.213] 


Validation Loss: 0.2125, Validation Accuracy: 0.8896
Model weights saved at epoch 21


Epoch 22/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.79s/batch, accuracy=0.892, loss=0.21] 


Epoch 22/54, Loss: 0.2101, Accuracy: 0.8923


Epoch 22/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.39s/batch, accuracy=0.89, loss=0.216] 


Validation Loss: 0.2159, Validation Accuracy: 0.8902


Epoch 23/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.79s/batch, accuracy=0.893, loss=0.208]


Epoch 23/54, Loss: 0.2078, Accuracy: 0.8933


Epoch 23/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.38s/batch, accuracy=0.892, loss=0.21] 


Validation Loss: 0.2104, Validation Accuracy: 0.8917


Epoch 24/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.79s/batch, accuracy=0.894, loss=0.206]


Epoch 24/54, Loss: 0.2060, Accuracy: 0.8939


Epoch 24/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.43s/batch, accuracy=0.89, loss=0.213] 


Validation Loss: 0.2130, Validation Accuracy: 0.8903
Model weights saved at epoch 24


Epoch 25/54 - Training: 100%|██████████| 71/71 [03:19<00:00,  2.81s/batch, accuracy=0.893, loss=0.208]


Epoch 25/54, Loss: 0.2077, Accuracy: 0.8926


Epoch 25/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.42s/batch, accuracy=0.889, loss=0.215]


Validation Loss: 0.2152, Validation Accuracy: 0.8891


Epoch 26/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.80s/batch, accuracy=0.893, loss=0.21] 


Epoch 26/54, Loss: 0.2102, Accuracy: 0.8926


Epoch 26/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.43s/batch, accuracy=0.892, loss=0.211]


Validation Loss: 0.2108, Validation Accuracy: 0.8916


Epoch 27/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.77s/batch, accuracy=0.891, loss=0.21] 


Epoch 27/54, Loss: 0.2101, Accuracy: 0.8913


Epoch 27/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.40s/batch, accuracy=0.882, loss=0.227]


Validation Loss: 0.2272, Validation Accuracy: 0.8823
Model weights saved at epoch 27


Epoch 28/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.78s/batch, accuracy=0.892, loss=0.21] 


Epoch 28/54, Loss: 0.2102, Accuracy: 0.8921


Epoch 28/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.39s/batch, accuracy=0.889, loss=0.215]


Validation Loss: 0.2149, Validation Accuracy: 0.8885


Epoch 29/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.78s/batch, accuracy=0.893, loss=0.205]


Epoch 29/54, Loss: 0.2052, Accuracy: 0.8926


Epoch 29/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.41s/batch, accuracy=0.893, loss=0.206]


Validation Loss: 0.2059, Validation Accuracy: 0.8934


Epoch 30/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.78s/batch, accuracy=0.894, loss=0.204]


Epoch 30/54, Loss: 0.2042, Accuracy: 0.8944


Epoch 30/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.39s/batch, accuracy=0.887, loss=0.219]


Validation Loss: 0.2192, Validation Accuracy: 0.8875
Model weights saved at epoch 30


Epoch 31/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.79s/batch, accuracy=0.894, loss=0.203]


Epoch 31/54, Loss: 0.2033, Accuracy: 0.8942


Epoch 31/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.38s/batch, accuracy=0.895, loss=0.203]


Validation Loss: 0.2033, Validation Accuracy: 0.8950


Epoch 32/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.78s/batch, accuracy=0.893, loss=0.208]


Epoch 32/54, Loss: 0.2084, Accuracy: 0.8927


Epoch 32/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.40s/batch, accuracy=0.887, loss=0.22] 


Validation Loss: 0.2203, Validation Accuracy: 0.8872


Epoch 33/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.79s/batch, accuracy=0.892, loss=0.209]


Epoch 33/54, Loss: 0.2089, Accuracy: 0.8915


Epoch 33/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.38s/batch, accuracy=0.885, loss=0.224]


Validation Loss: 0.2245, Validation Accuracy: 0.8848
Model weights saved at epoch 33


Epoch 34/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.78s/batch, accuracy=0.894, loss=0.205]


Epoch 34/54, Loss: 0.2045, Accuracy: 0.8937


Epoch 34/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.42s/batch, accuracy=0.887, loss=0.222]


Validation Loss: 0.2217, Validation Accuracy: 0.8868


Epoch 35/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.76s/batch, accuracy=0.892, loss=0.208]


Epoch 35/54, Loss: 0.2085, Accuracy: 0.8917


Epoch 35/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.40s/batch, accuracy=0.892, loss=0.207]


Validation Loss: 0.2073, Validation Accuracy: 0.8917


Epoch 36/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.80s/batch, accuracy=0.894, loss=0.203]


Epoch 36/54, Loss: 0.2028, Accuracy: 0.8936


Epoch 36/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.37s/batch, accuracy=0.89, loss=0.212] 


Validation Loss: 0.2121, Validation Accuracy: 0.8903
Model weights saved at epoch 36


Epoch 37/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.77s/batch, accuracy=0.894, loss=0.203]


Epoch 37/54, Loss: 0.2031, Accuracy: 0.8942


Epoch 37/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.39s/batch, accuracy=0.896, loss=0.2]  


Validation Loss: 0.2003, Validation Accuracy: 0.8964


Epoch 38/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.78s/batch, accuracy=0.891, loss=0.209]


Epoch 38/54, Loss: 0.2094, Accuracy: 0.8912


Epoch 38/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.37s/batch, accuracy=0.887, loss=0.218]


Validation Loss: 0.2184, Validation Accuracy: 0.8870


Epoch 39/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.77s/batch, accuracy=0.895, loss=0.201]


Epoch 39/54, Loss: 0.2009, Accuracy: 0.8946


Epoch 39/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.37s/batch, accuracy=0.892, loss=0.213]


Validation Loss: 0.2126, Validation Accuracy: 0.8916
Model weights saved at epoch 39


Epoch 40/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.77s/batch, accuracy=0.893, loss=0.204]


Epoch 40/54, Loss: 0.2041, Accuracy: 0.8935


Epoch 40/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.41s/batch, accuracy=0.887, loss=0.22] 


Validation Loss: 0.2205, Validation Accuracy: 0.8867


Epoch 41/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.79s/batch, accuracy=0.895, loss=0.202]


Epoch 41/54, Loss: 0.2016, Accuracy: 0.8953


Epoch 41/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.42s/batch, accuracy=0.89, loss=0.212] 


Validation Loss: 0.2122, Validation Accuracy: 0.8904


Epoch 42/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.80s/batch, accuracy=0.895, loss=0.202]


Epoch 42/54, Loss: 0.2017, Accuracy: 0.8950


Epoch 42/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.38s/batch, accuracy=0.891, loss=0.209]


Validation Loss: 0.2088, Validation Accuracy: 0.8907
Model weights saved at epoch 42


Epoch 43/54 - Training: 100%|██████████| 71/71 [03:16<00:00,  2.77s/batch, accuracy=0.893, loss=0.204]


Epoch 43/54, Loss: 0.2045, Accuracy: 0.8930


Epoch 43/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.43s/batch, accuracy=0.891, loss=0.21] 


Validation Loss: 0.2105, Validation Accuracy: 0.8914


Epoch 44/54 - Training: 100%|██████████| 71/71 [03:19<00:00,  2.80s/batch, accuracy=0.894, loss=0.204]


Epoch 44/54, Loss: 0.2043, Accuracy: 0.8938


Epoch 44/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.40s/batch, accuracy=0.89, loss=0.212] 


Validation Loss: 0.2117, Validation Accuracy: 0.8904


Epoch 45/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.79s/batch, accuracy=0.895, loss=0.2]  


Epoch 45/54, Loss: 0.2004, Accuracy: 0.8954


Epoch 45/54 - Validation: 100%|██████████| 18/18 [00:42<00:00,  2.38s/batch, accuracy=0.894, loss=0.198]


Validation Loss: 0.1981, Validation Accuracy: 0.8944
Model weights saved at epoch 45


Epoch 46/54 - Training: 100%|██████████| 71/71 [03:21<00:00,  2.84s/batch, accuracy=0.894, loss=0.203]


Epoch 46/54, Loss: 0.2028, Accuracy: 0.8945


Epoch 46/54 - Validation: 100%|██████████| 18/18 [00:47<00:00,  2.62s/batch, accuracy=0.893, loss=0.207]


Validation Loss: 0.2065, Validation Accuracy: 0.8933


Epoch 47/54 - Training: 100%|██████████| 71/71 [03:27<00:00,  2.92s/batch, accuracy=0.897, loss=0.198]


Epoch 47/54, Loss: 0.1975, Accuracy: 0.8970


Epoch 47/54 - Validation: 100%|██████████| 18/18 [00:46<00:00,  2.59s/batch, accuracy=0.893, loss=0.206]


Validation Loss: 0.2060, Validation Accuracy: 0.8929


Epoch 48/54 - Training: 100%|██████████| 71/71 [03:25<00:00,  2.90s/batch, accuracy=0.893, loss=0.203]


Epoch 48/54, Loss: 0.2034, Accuracy: 0.8934


Epoch 48/54 - Validation: 100%|██████████| 18/18 [00:44<00:00,  2.46s/batch, accuracy=0.89, loss=0.212] 


Validation Loss: 0.2115, Validation Accuracy: 0.8898
Model weights saved at epoch 48


Epoch 49/54 - Training: 100%|██████████| 71/71 [03:20<00:00,  2.82s/batch, accuracy=0.895, loss=0.204]


Epoch 49/54, Loss: 0.2042, Accuracy: 0.8952


Epoch 49/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.44s/batch, accuracy=0.892, loss=0.216]


Validation Loss: 0.2165, Validation Accuracy: 0.8918


Epoch 50/54 - Training: 100%|██████████| 71/71 [03:20<00:00,  2.82s/batch, accuracy=0.895, loss=0.202]


Epoch 50/54, Loss: 0.2025, Accuracy: 0.8947


Epoch 50/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.41s/batch, accuracy=0.887, loss=0.217]


Validation Loss: 0.2175, Validation Accuracy: 0.8870


Epoch 51/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.79s/batch, accuracy=0.896, loss=0.198]


Epoch 51/54, Loss: 0.1977, Accuracy: 0.8961


Epoch 51/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.40s/batch, accuracy=0.895, loss=0.202]


Validation Loss: 0.2018, Validation Accuracy: 0.8946
Model weights saved at epoch 51


Epoch 52/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.79s/batch, accuracy=0.893, loss=0.204]


Epoch 52/54, Loss: 0.2039, Accuracy: 0.8930


Epoch 52/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.40s/batch, accuracy=0.893, loss=0.207]


Validation Loss: 0.2074, Validation Accuracy: 0.8926


Epoch 53/54 - Training: 100%|██████████| 71/71 [03:18<00:00,  2.80s/batch, accuracy=0.895, loss=0.199]


Epoch 53/54, Loss: 0.1991, Accuracy: 0.8949


Epoch 53/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.41s/batch, accuracy=0.887, loss=0.22] 


Validation Loss: 0.2198, Validation Accuracy: 0.8869


Epoch 54/54 - Training: 100%|██████████| 71/71 [03:17<00:00,  2.78s/batch, accuracy=0.897, loss=0.196]


Epoch 54/54, Loss: 0.1961, Accuracy: 0.8968


Epoch 54/54 - Validation: 100%|██████████| 18/18 [00:43<00:00,  2.43s/batch, accuracy=0.891, loss=0.213]


Validation Loss: 0.2131, Validation Accuracy: 0.8907
Model weights saved at epoch 54
