In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
!pip install torchinfo torch

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
!pip install monai

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Downloading monai-1.4.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.4.0


In [None]:
!pip install torch nibabel torchinfo



In [None]:
from typing import Tuple, Union

import torch
import torch.nn as nn

from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.nets import ViT


##Model: UNET-R

In [None]:
class UNETR_Reconstruction(nn.Module):
    """
    Modified UNETR for 3D medical image reconstruction tasks.
    """

    def __init__(
        self,
        in_channels: int,
        img_size: Tuple[int, int, int],
        feature_size: int = 16,
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_heads: int = 12,
        pos_embed: str = "perceptron",
        norm_name: Union[Tuple, str] = "instance",
        conv_block: bool = False,
        res_block: bool = True,
        dropout_rate: float = 0.0,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels (e.g., 1 for grayscale MRI).
            img_size: dimension of input image.
            feature_size: dimension of network feature size.
            hidden_size: dimension of hidden layer.
            mlp_dim: dimension of feedforward layer.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            norm_name: feature normalization type and arguments.
            conv_block: bool argument to determine if convolutional block is used.
            res_block: bool argument to determine if residual block is used.
            dropout_rate: fraction of the input units to drop.
        """
        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise AssertionError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise AssertionError("hidden size should be divisible by num_heads.")

        if pos_embed not in ["conv", "perceptron"]:
            raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")

        self.num_layers = 12
        self.patch_size = (16, 16, 16)  # UNETR typically uses 16x16x16 patches
        self.feat_size = (
            img_size[0] // self.patch_size[0],
            img_size[1] // self.patch_size[1],
            img_size[2] // self.patch_size[2],
        )
        self.hidden_size = hidden_size
        self.classification = False

        # Vision Transformer (ViT) backbone
        self.vit = ViT(
            in_channels=in_channels,
            img_size=img_size,
            patch_size=self.patch_size,
            hidden_size=hidden_size,
            mlp_dim=mlp_dim,
            num_layers=self.num_layers,
            num_heads=num_heads,
            classification=self.classification,
            dropout_rate=dropout_rate,
        )

        # Encoder blocks (extract features at different resolutions)
        self.encoder1 = UnetrBasicBlock(
            spatial_dims=3,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.encoder2 = UnetrPrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 2,
            num_layer=2,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder3 = UnetrPrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 4,
            num_layer=1,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )
        self.encoder4 = UnetrPrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            num_layer=0,
            kernel_size=3,
            stride=1,
            upsample_kernel_size=2,
            norm_name=norm_name,
            conv_block=conv_block,
            res_block=res_block,
        )

        # Decoder blocks (upsample and reconstruct the input)
        self.decoder5 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=hidden_size,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=res_block,
        )

        # Output block (reconstruction output, same size as input)
        self.out = nn.Conv3d(feature_size, in_channels, kernel_size=1)  # Reconstruction task requires in_channels output

        # Optional sigmoid activation to keep the output between 0 and 1
        self.sigmoid = nn.Sigmoid()

    def proj_feat(self, x, hidden_size, feat_size):
        x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

    def forward(self, x_in):
        x, hidden_states_out = self.vit(x_in)
        enc1 = self.encoder1(x_in)
        x2 = hidden_states_out[3]
        enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size))
        x3 = hidden_states_out[6]
        enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size))
        x4 = hidden_states_out[9]
        enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size))
        dec4 = self.proj_feat(x, self.hidden_size, self.feat_size)
        dec3 = self.decoder5(dec4, enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        out = self.decoder2(dec1, enc1)
        logits = self.out(out)
        return self.sigmoid(logits)  # Optionally apply sigmoid activation for normalized output


In [None]:
import torch
import torchinfo

model = UNETR_Reconstruction(in_channels=1, img_size=(16, 128, 128), feature_size=32, norm_name='instance')

# batch_size=2, channels=1, depth=16, height=128, width=128
torchinfo.summary(model, input_size=(2, 1, 16,  128,  128))


Layer (type:depth-idx)                        Output Shape              Param #
UNETR_Reconstruction                          [2, 1, 16, 128, 128]      --
├─ViT: 1-1                                    [2, 64, 768]              --
│    └─PatchEmbeddingBlock: 2-1               [2, 64, 768]              49,152
│    │    └─Conv3d: 3-1                       [2, 768, 1, 8, 8]         3,146,496
│    │    └─Dropout: 3-2                      [2, 64, 768]              --
│    └─ModuleList: 2-2                        --                        --
│    │    └─TransformerBlock: 3-3             [2, 64, 768]              9,447,168
│    │    └─TransformerBlock: 3-4             [2, 64, 768]              9,447,168
│    │    └─TransformerBlock: 3-5             [2, 64, 768]              9,447,168
│    │    └─TransformerBlock: 3-6             [2, 64, 768]              9,447,168
│    │    └─TransformerBlock: 3-7             [2, 64, 768]              9,447,168
│    │    └─TransformerBlock: 3-8             [2,

##Dataset : openBHB Dataset

In [None]:
import os
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset

In [None]:
import zipfile
import os

# Path to the zip file in your Google Drive
zip_file_path = '/content/drive/MyDrive/archive (9).zip'

# Destination path to extract files
destination_dir = '/content/dataset/'

# Unzipping the file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(destination_dir)

# List the extracted files (optional)
os.listdir(destination_dir)


In [None]:
import os
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset

class BrainMRIDataset(Dataset):
    def __init__(self, base_path, target_size=(88, 128, 128), slice_depth=16, transform=None):
        self.base_path = base_path
        self.target_size = target_size  # Target shape should be (128, 128, 128)
        self.slice_depth = slice_depth
        self.transform = transform

        self.subject_ids = self._get_subject_ids()[:500]

    def _get_subject_ids(self):
        files = os.listdir(self.base_path)

        subject_ids = []
        for file in files:
            if file.endswith('_preproc-quasiraw_T1w.npy'):
                subject_id_str = file.split('_')[0].replace('sub-', '')
                subject_ids.append(subject_id_str)

        return subject_ids

    def __len__(self):
        return len(self.subject_ids)

    def crop(self, volume, start_y=20, end_y=160, start_x=20, end_x=196, start_z=50, end_z=130):
      cropped_volume = volume[ start_z:end_z,start_y:end_y, start_x:end_x]
      return cropped_volume

    def get_slices(self, mri_volume):
        if len(mri_volume.shape) != 3:
            mri_volume = mri_volume.squeeze()

        patches = []
        num_slices = mri_volume.shape[0] // self.slice_depth

        for i in range(num_slices):
            start = i * self.slice_depth
            end = start + self.slice_depth
            patch = mri_volume[start:end, :, :]
            patches.append(patch)

        remainder = mri_volume.shape[0] % self.slice_depth
        if remainder > 0:
            patch = mri_volume[-self.slice_depth:, :, :]
            patches.append(patch)

        return patches

    def __getitem__(self, index):
        subject_id = self.subject_ids[index]

        # Construct the file path
        file_path = f"{self.base_path}/sub-{subject_id}_preproc-quasiraw_T1w.npy"

        # Load the .npy file
        mri_volume = np.load(file_path)  # Shape: (1, 1, 182, 218, 182)

        # Convert to tensor and remove redundant dimensions
        mri_volume = torch.tensor(mri_volume).float().squeeze().squeeze()  # Shape: [182, 218, 182]

        # (C, H, W, D ) - > (C, D, H, W)
        mri_volume = mri_volume.permute(2, 0, 1)

        #Crop to (80 , 140, 176)
        mri_crop = self.crop(mri_volume)

        # Resize to target shape (128x128x128)
        mri_resize = F.interpolate(
            mri_crop.unsqueeze(0).unsqueeze(0),
            size=(mri_crop.shape[0], 128, 128),
            mode='trilinear',
            align_corners=False
        ).squeeze()

        # Normalize the volume
        mri_volume = (mri_resize - mri_resize.min()) / (mri_resize.max() - mri_resize.min() + 1e-8)

        # Slice the volume along the depth axis
        mri_slices = self.get_slices(mri_volume)

        # Add channel dimension back to each slice
        mri_slices = [slice.unsqueeze(0) for slice in mri_slices]  # Shape: [1, 128, 128, 8]

        return mri_slices


In [None]:
# Initialize dataset
dataset = BrainMRIDataset(base_path='/content/dataset/val_quasiraw')

# Test the dataset
volume_slices = dataset[0]
print(f"Number of slices: {len(volume_slices)}, Slice shape: {volume_slices[0].shape}")

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split

class SliceDatasetFromList(Dataset):
    def __init__(self, patch_list):
        # Flatten the list of lists into a single list
        self.patch_list = [patch for sublist in patch_list for patch in sublist]

    def __len__(self):
        return len(self.patch_list)

    def __getitem__(self, index):
        patch = self.patch_list[index]
        patch_tensor = torch.tensor(patch).float()  # Convert each patch to tensor

        # For simplicity, we return the patch as both input and target (autoencoder-like task)
        return patch_tensor, patch_tensor


# Initialize the dataset
Slices = SliceDatasetFromList(dataset)

# Define train, validation, and test split ratios
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

# Calculate the sizes for each split
train_size = int(train_ratio * len(Slices))
val_size = int(val_ratio * len(Slices))
test_size = len(Slices) - train_size - val_size

# Perform the random split
train_dataset, val_dataset, test_dataset = random_split(Slices, [train_size, val_size, test_size])

# Initialize DataLoaders for each split
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Print the sizes of each dataset split
print("Train set size:", len(train_dataset))
print("Validation set size:", len(val_dataset))
print("Test set size:", len(test_dataset))


##Training

In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [None]:
criterion = torch.nn.MSELoss()  # Mean Squared Error for reconstruction
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print(device)

In [None]:
def save_checkpoint(state, is_best, checkpoint_dir="/content/drive/MyDrive/checkpoint_unetr_2", filename="checkpoint.pth"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, filename)
    torch.save(state, checkpoint_path)
    if is_best:
        best_model_path = os.path.join(checkpoint_dir, "best_model.pth")
        torch.save(state, best_model_path)
        print(f"Best model saved to {best_model_path}")


In [None]:
train_loss_values = []
val_loss_values = []

num_epochs = 200
best_val_loss = float('inf')
patience = 10
early_stop_counter = 0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    # Training loop
    for imgs, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} (Training)'):
        imgs, targets = imgs.to(device), targets.to(device)
        optimizer.zero_grad()  # Clear gradients
        outputs = model(imgs)  # Forward pass
        loss = criterion(outputs, targets)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        train_loss += loss.item()

    # Average training loss for this epoch
    train_loss /= len(train_loader)
    train_loss_values.append(train_loss)

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, targets in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} (Validation)'):
            imgs, targets = imgs.to(device), targets.to(device)
            outputs = model(imgs)  # Forward pass
            loss = criterion(outputs, targets)  # Compute validation loss
            val_loss += loss.item()  # Accumulate validation loss

    # Average validation loss for this epoch
    val_loss /= len(val_loader)
    val_loss_values.append(val_loss)

    # Print the losses for this epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    is_best = val_loss < best_val_loss
    if is_best:
        best_val_loss = val_loss
        early_stop_counter = 0

        # Save the best model checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss
        }
        save_checkpoint(checkpoint, is_best)
    else:
        early_stop_counter += 1

    # Early stopping condition
    if early_stop_counter >= patience:
        print(f"Early stopping at epoch {epoch+1}. No improvement in validation loss for {patience} consecutive epochs.")
        break

# Print the best accuracy:
print(f"Best Validation Loss: {best_val_loss:.4f}")

# Plotting the loss
clear_output(wait=True)
plt.plot(train_loss_values, label='Training Loss')
plt.plot(val_loss_values, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid()
plt.show()


  patch_tensor = torch.tensor(patch).float()
Epoch 1/50 (Training): 100%|██████████| 287/287 [03:38<00:00,  1.31it/s]
Epoch 1/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.37it/s]


Epoch [1/50], Train Loss: 0.5256, Val Loss: 0.4519
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 2/50 (Training): 100%|██████████| 287/287 [03:41<00:00,  1.30it/s]
Epoch 2/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.36it/s]


Epoch [2/50], Train Loss: 0.3557, Val Loss: 0.2615
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 3/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 3/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.39it/s]


Epoch [3/50], Train Loss: 0.1733, Val Loss: 0.1502
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 4/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 4/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.37it/s]


Epoch [4/50], Train Loss: 0.1142, Val Loss: 0.0865
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 5/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 5/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.33it/s]


Epoch [5/50], Train Loss: 0.1021, Val Loss: 0.1027


Epoch 6/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 6/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.44it/s]


Epoch [6/50], Train Loss: 0.0903, Val Loss: 0.0660
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 7/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 7/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.22it/s]


Epoch [7/50], Train Loss: 0.0725, Val Loss: 0.0540
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 8/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 8/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.42it/s]


Epoch [8/50], Train Loss: 0.0619, Val Loss: 0.0501
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 9/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 9/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.45it/s]


Epoch [9/50], Train Loss: 0.0530, Val Loss: 0.0387
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 10/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 10/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.31it/s]


Epoch [10/50], Train Loss: 0.0482, Val Loss: 0.0372
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 11/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 11/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.45it/s]


Epoch [11/50], Train Loss: 0.0642, Val Loss: 0.0435


Epoch 12/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 12/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.44it/s]


Epoch [12/50], Train Loss: 0.0490, Val Loss: 0.0342
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 13/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 13/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.39it/s]


Epoch [13/50], Train Loss: 0.0425, Val Loss: 0.0298
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 14/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 14/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.28it/s]


Epoch [14/50], Train Loss: 0.0403, Val Loss: 0.0293
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 15/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 15/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.37it/s]


Epoch [15/50], Train Loss: 0.0357, Val Loss: 0.0263
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 16/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 16/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.43it/s]


Epoch [16/50], Train Loss: 0.0325, Val Loss: 0.0238
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 17/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 17/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.43it/s]


Epoch [17/50], Train Loss: 0.0310, Val Loss: 0.0241


Epoch 18/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 18/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.30it/s]


Epoch [18/50], Train Loss: 0.0290, Val Loss: 0.0216
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 19/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 19/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.46it/s]


Epoch [19/50], Train Loss: 0.0268, Val Loss: 0.0211
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 20/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 20/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.44it/s]


Epoch [20/50], Train Loss: 0.0261, Val Loss: 0.0191
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 21/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 21/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.35it/s]


Epoch [21/50], Train Loss: 0.0266, Val Loss: 0.0176
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 22/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 22/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.40it/s]


Epoch [22/50], Train Loss: 0.0234, Val Loss: 0.0203


Epoch 23/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 23/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.43it/s]


Epoch [23/50], Train Loss: 0.0222, Val Loss: 0.0154
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 24/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 24/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.45it/s]


Epoch [24/50], Train Loss: 0.0227, Val Loss: 0.0167


Epoch 25/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 25/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.41it/s]


Epoch [25/50], Train Loss: 0.0200, Val Loss: 0.0149
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 26/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 26/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.25it/s]


Epoch [26/50], Train Loss: 0.0217, Val Loss: 0.0165


Epoch 27/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 27/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.42it/s]


Epoch [27/50], Train Loss: 0.0187, Val Loss: 0.0157


Epoch 28/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 28/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.35it/s]


Epoch [28/50], Train Loss: 0.0176, Val Loss: 0.0128
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 29/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 29/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.44it/s]


Epoch [29/50], Train Loss: 0.0262, Val Loss: 0.0158


Epoch 30/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 30/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.41it/s]


Epoch [30/50], Train Loss: 0.0183, Val Loss: 0.0127
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 31/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 31/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.38it/s]


Epoch [31/50], Train Loss: 0.0191, Val Loss: 0.0127


Epoch 32/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 32/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.41it/s]


Epoch [32/50], Train Loss: 0.0167, Val Loss: 0.0118
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 33/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 33/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.45it/s]


Epoch [33/50], Train Loss: 0.0228, Val Loss: 0.0159


Epoch 34/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 34/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.34it/s]


Epoch [34/50], Train Loss: 0.0182, Val Loss: 0.0118


Epoch 35/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 35/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.33it/s]


Epoch [35/50], Train Loss: 0.0170, Val Loss: 0.0112
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 36/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 36/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.39it/s]


Epoch [36/50], Train Loss: 0.0150, Val Loss: 0.0112
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 37/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 37/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.33it/s]


Epoch [37/50], Train Loss: 0.0145, Val Loss: 0.0105
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 38/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 38/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.31it/s]


Epoch [38/50], Train Loss: 0.0149, Val Loss: 0.0123


Epoch 39/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 39/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.42it/s]


Epoch [39/50], Train Loss: 0.0144, Val Loss: 0.0105


Epoch 40/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 40/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.42it/s]


Epoch [40/50], Train Loss: 0.0139, Val Loss: 0.0106


Epoch 41/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 41/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.44it/s]


Epoch [41/50], Train Loss: 0.0131, Val Loss: 0.0099
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 42/50 (Training): 100%|██████████| 287/287 [03:41<00:00,  1.30it/s]
Epoch 42/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.40it/s]


Epoch [42/50], Train Loss: 0.0134, Val Loss: 0.0098
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 43/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 43/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.44it/s]


Epoch [43/50], Train Loss: 0.0134, Val Loss: 0.0091
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 44/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 44/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.26it/s]


Epoch [44/50], Train Loss: 0.0151, Val Loss: 0.0114


Epoch 45/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 45/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.44it/s]


Epoch [45/50], Train Loss: 0.0130, Val Loss: 0.0098


Epoch 46/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 46/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.43it/s]


Epoch [46/50], Train Loss: 0.0130, Val Loss: 0.0084
Best model saved to /content/drive/MyDrive/checkpoint_UNET/best_model_MSE_SSIM.pth


Epoch 47/50 (Training): 100%|██████████| 287/287 [03:40<00:00,  1.30it/s]
Epoch 47/50 (Validation): 100%|██████████| 51/51 [00:04<00:00, 12.36it/s]


Epoch [47/50], Train Loss: 0.0169, Val Loss: 0.0125


Epoch 48/50 (Training):  87%|████████▋ | 249/287 [03:11<00:29,  1.30it/s]