In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install segmentation-models-pytorch
!pip install torchmetrics
!pip install wandb

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.3.4-py3-none-any.whl.metadata (30 kB)
Collecting efficientnet-pytorch==0.7.1 (from segmentation-models-pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pretrainedmodels==0.7.4 (from segmentation-models-pytorch)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting timm==0.9.7 (from segmentation-models-pytorch)
  Downloading timm-0.9.7-py3-none-any.whl.metadata (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Collecting munch (from pretrainedmodels==0.7.4->segmentation-models-pytorch)
  Downloading munch-4.0.0-py2.py3-none-any.whl.metadata (5.9 kB)
Downloading segm

In [3]:
import numpy as np
from tqdm.notebook import tqdm
import scipy
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from skimage.transform import resize
import os
from torchmetrics import JaccardIndex
from torchmetrics.detection import IntersectionOverUnion
from segmentation_models_pytorch.losses import DiceLoss
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import os.path as osp
from torchvision import models


import wandb

!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [4]:
import os
import os.path as osp

def get_matching_files(data_dir, phase):
    # List all image and mask files
    img_dir = osp.join(data_dir, phase, 'img')
    mask_dir = osp.join(data_dir, phase, 'mask')

    # Get list of all files in the img and mask directories
    img_files = sorted(os.listdir(img_dir))
    mask_files = sorted(os.listdir(mask_dir))

    # Initialize lists to store matched image-mask pairs
    image_list = []
    mask_list = []

    # Iterate through image files and find corresponding mask
    for img_file in img_files:
        # Base name without extension
        base_name = img_file.split('.')[0]

        # Try to find a matching mask with the same base name
        matching_mask = next((m for m in mask_files if m.startswith(base_name)), None)

        if matching_mask:
            # Add the full path for both image and mask
            image_list.append(osp.join(img_dir, img_file))
            mask_list.append(osp.join(mask_dir, matching_mask))
        else:
            print(f"Warning: No matching mask found for image {img_file}")

    return image_list, mask_list


In [5]:
class OCTDataset(Dataset):
    def __init__(self, data_dir, phase, transforms):
        self.data_dir = data_dir
        self.phase = phase
        self.transforms = transforms
        self.image_list = None
        self.label_list = None
        self.read_lists()

    def __getitem__(self, index):
        image_path = self.image_list[index]
        mask_path = self.label_list[index]

        # Debugging: Ensure the pairing is correct
        # print(f"Image: {image_path}, Mask: {mask_path}")

        image = Image.open(image_path).convert('RGB')
        label = Image.open(mask_path)

        # Apply transformations
        data = list(self.transforms(image, label))
        image = data[0]
        label = data[1]

        return image, label.long()

    def __len__(self):
        return len(self.image_list)

    def read_lists(self):
        # Get matching image and mask files
        self.image_list, self.label_list = get_matching_files(self.data_dir, self.phase)
        print(f"Total number of {self.phase} images: {len(self.image_list)}")


In [6]:
import numpy as np
import torch
from PIL import Image

class Label_Transform(object):
    def __init__(self, label_pixel=(26, 51, 77, 102, 128, 153, 179, 204, 230, 255)):
        self.label_pixel = label_pixel

    def __call__(self, image, label, *args):
        label = np.array(label)
        for i in range(len(self.label_pixel)):
            label[label == self.label_pixel[i]] = i+1

        # Ensure label array is of type uint8 and then convert to tensor
        label = label.astype(np.uint8)
        return image, torch.tensor(label, dtype=torch.long)


class Normalize(object):
    """Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    """

    def __init__(self, mean, std):
        self.mean = torch.FloatTensor(mean)
        self.std = torch.FloatTensor(std)

    def __call__(self, image, label=None):
        for t, m, s in zip(image, self.mean, self.std):
            t.sub_(m).div_(s)
        if label is None:
            return image,
        else:
            return image, label

class ToTensor(object):
    """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __call__(self, pic, label=None):
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic)
        else:
            # handle PIL Image
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
            nchannel = len(pic.mode)
            img = img.view(pic.size[1], pic.size[0], nchannel)
            img = img.transpose(0, 1).transpose(0, 2).contiguous()
        img = img.float().div(255)
        if label is None:
            return img,
        else:
            return img, label

class Resize(object):
    def __init__(self, size):
        self.size = size  # Tuple (width, height) for the new size

    def __call__(self, image, label=None):
        image = image.resize(self.size, Image.BILINEAR)  # Resize image using bilinear interpolation
        if label is not None:
            label = label.resize(self.size, Image.NEAREST)  # Resize mask using nearest-neighbor to preserve labels
        if label is None:
            return image,
        else:
            return image, label



class Compose(object):
    """Composes several transforms together.
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, *args):
        for t in self.transforms:
            args = t(*args)
        return args


In [7]:
from torch.utils.data import Subset
transforms = Compose([
    Resize((512, 512)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    Label_Transform(),

])


train_dataset = OCTDataset(data_dir = 'drive/My Drive/oct_dataset', phase = 'train',transforms = transforms)
val_dataset = OCTDataset(data_dir = 'drive/My Drive/oct_dataset', phase = 'eval', transforms = transforms)
test_dataset = OCTDataset(data_dir = 'drive/My Drive/oct_dataset', phase = 'test', transforms = transforms)

#Choosing different sizes of dataset
train_size = len(train_dataset)


indices = list(range(train_size))
np.random.seed(42)  # Ensure reproducibility
# np.random.shuffle(indices)
split = int(np.floor(1 * train_size))  # Here is where u change the amount of data
train_indices = indices[:split]

# Use Subset to create a dataset with only the selected indices
train_dataset = Subset(train_dataset, train_indices)

print(len(train_dataset))


batch_size = 16

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

#Example to see what the data shape is
for images, labels in train_loader:
  print(images.shape)
  print(labels.shape)
  break



Total number of train images: 148
Total number of eval images: 48
Total number of test images: 48
148


  img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))


torch.Size([16, 3, 512, 512])
torch.Size([16, 512, 512])


In [8]:
import segmentation_models_pytorch as smp
# # Model
# class ResNetDeepLabV3(nn.Module):
#     def __init__(self, classes):
#         super(ResNetDeepLabV3, self).__init__()
#         self.classes = classes  # Store the number of classes
#         self.model = smp.DeepLabV3(
#             encoder_name="resnet50",
#             encoder_weights= "imagenet",
#             in_channels = 3,
#             classes=self.classes,  # Set the number of classes
#             activation=None
#         )

#     def forward(self, x):
#         return self.model(x)

# model = ResNetDeepLabV3(classes = 11)

In [9]:
import torch
import segmentation_models_pytorch as smp

def load_custom_weights_unet(model, custom_weights_path):
    # Load the custom weights
    custom_weights = torch.load(custom_weights_path)

    # Get the state dict of the encoder
    encoder_state_dict = model.encoder.state_dict()

    # Create a new state dict for the mapped weights
    new_state_dict = {}

    # Define a mapping between custom weight keys and encoder keys
    key_mapping = {
        '0': 'firstconv',
        '1': 'firstbn',
        '4.0': 'layer1.0',
        '4.1': 'layer1.1',
        '4.2': 'layer1.2',
        '5.0': 'layer2.0',
        '5.1': 'layer2.1',
        '5.2': 'layer2.2',
        '5.3': 'layer2.3',
        '6.0': 'layer3.0',
        '6.1': 'layer3.1',
        '6.2': 'layer3.2',
        '6.3': 'layer3.3',
        '6.4': 'layer3.4',
        '6.5': 'layer3.5',
        '7.0': 'layer4.0',
        '7.1': 'layer4.1',
        '7.2': 'layer4.2',
    }

    # Map the custom weights to the encoder structure
    for k, v in custom_weights['model'].items():
        for custom_key, encoder_key in key_mapping.items():
            if k.startswith(custom_key):
                new_key = k.replace(custom_key, encoder_key, 1)
                if new_key in encoder_state_dict:
                    new_state_dict[new_key] = v

    # Load the mapped weights into the encoder
    model.encoder.load_state_dict(new_state_dict, strict=False)

    return model

# Create a U-Net model with a ResNet50 encoder
model = smp.DeepLabV3(
    encoder_name="resnet50",
    encoder_weights=None,  # We'll load custom weights
    in_channels=3,
    classes=11
)

# Load custom pretrained weights for the ResNet50 encoder
custom_weights_path = 'drive/My Drive/Colab Notebooks/SimCLR/No-Crop/best.pth'
model = load_custom_weights_unet(model, custom_weights_path)

# Optionally move the model to GPU if needed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print("Custom weights loaded successfully!")

  custom_weights = torch.load(custom_weights_path)


Custom weights loaded successfully!


In [10]:
from torchmetrics.classification import Dice, JaccardIndex
import torch.nn.functional as F
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-7, activation='softmax'):
        super().__init__()
        self.activation = activation
        self.eps = eps

    def forward(self, y_pr, y_gt):
        if self.activation == 'softmax':
            y_pr = F.softmax(y_pr, dim=1)
        elif self.activation == 'sigmoid':
            y_pr = torch.sigmoid(y_pr)

        num_classes = y_pr.shape[1]
        y_gt = F.one_hot(y_gt, num_classes=num_classes).permute(0, 3, 1, 2).float()

        intersection = torch.sum(y_pr * y_gt, dim=[0, 2, 3])
        union = torch.sum(y_pr, dim=[0, 2, 3]) + torch.sum(y_gt, dim=[0, 2, 3])

        dice = (2.0 * intersection + self.eps) / (union + self.eps)
        dice_loss = 1.0 - torch.mean(dice)

        return dice_loss


class CEDiceLoss(nn.Module):
    def __init__(self, eps=1e-7, activation='softmax', lambda_dice=1.0, lambda_ce=1.0):
        super().__init__()
        self.dice_loss = DiceLoss(eps, activation)
        self.ce_loss = nn.CrossEntropyLoss(reduction='mean')
        self.lambda_dice = lambda_dice
        self.lambda_ce = lambda_ce

    def forward(self, y_pr, y_gt):
        dice = self.dice_loss(y_pr, y_gt)
        ce = self.ce_loss(y_pr, y_gt)
        return self.lambda_dice * dice + self.lambda_ce * ce

# Metrics
dice_coef = Dice(average='micro')
iou = JaccardIndex(task="multiclass", num_classes=11)

In [11]:
# Initialize W&B
# Initialize W&B
wandb.init(
    project="Final_Research",
    config={
        "learning_rate": 1e-3,
        "data size" : 148
    },
    name="DeepLabv3 SimCLR 100"
    )


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33m2455744[0m ([33m2455744-university-of-witwatersrand[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
import wandb
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchmetrics import Accuracy


def train(model, train_loader, val_loader, epochs, loss_fn, optimizer, device):


    best_val_loss = float('inf')
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
    patience = 20 # Early stopping patience
    patience_counter = 0

    # Initialize metrics
    dice_metric = Dice(average='micro').to(device)
    iou_metric = JaccardIndex(task="multiclass", num_classes=11).to(device)
    accuracy_metric = Accuracy(task="multiclass", num_classes=11).to(device)

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0.0
        train_dice = 0.0
        train_iou = 0.0
        train_acc = 0.0

        for data, target in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            data, target = data.float().to(device), target.to(device)
            optimizer.zero_grad()
            # output = model(data)
            output = model(data)

            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            with torch.no_grad():
                pred_labels = torch.argmax(output, dim=1)
                train_dice += dice_metric(pred_labels, target)
                train_iou += iou_metric(pred_labels, target)
                train_acc += accuracy_metric(pred_labels, target)
        #Averages
        train_loss /= len(train_loader)
        train_dice /= len(train_loader)
        train_iou /= len(train_loader)
        train_acc /= len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0
        val_dice = 0.0
        val_iou = 0.0
        val_acc =0.0
        with torch.no_grad():
            for data, target, *_ in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Validation"):
                data, target = data.float().to(device), target.to(device)
                # output = model(data)
                output = model(data)

                loss = loss_fn(output, target)
                val_loss += loss.item()

                pred_labels = torch.argmax(output, dim=1)
                val_dice += dice_metric(pred_labels, target)
                val_iou += iou_metric(pred_labels, target)
                val_acc += accuracy_metric(pred_labels, target)

        val_loss /= len(val_loader)
        val_dice /= len(val_loader)
        val_iou /= len(val_loader)
        val_acc /= len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, Train IoU: {train_iou:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, Val IoU: {val_iou:.4f}")

        # Learning rate scheduling
        scheduler.step(val_loss)

        # Model checkpointing
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'drive/MyDrive/Colab Notebooks/DeepLabV3 SimCLR/Best_model_100.pth')
            patience_counter = 0
        else:
            patience_counter += 1

        # Early stopping
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch + 1} epochs")
            break

        # Log metrics to W&B
        wandb.log({
            "Epoch": epoch + 1,
            "Train Loss": train_loss,
            "Train Dice": train_dice,
            "Train IoU": train_iou,
            "Train Accuracy": train_acc,
            "Val Loss": val_loss,
            "Val Dice": val_dice,
            "Val IoU": val_iou,
            "Val Accuracy": val_acc,
            "Learning Rate": optimizer.param_groups[0]['lr']
        })

    # wandb.finish()
    return model

In [13]:
import torch
import numpy as np
import matplotlib.pyplot as plt

def visualize_results(model, data_loader, device, num_samples=5, class_colors=None):
    model.eval()
    images, targets, predictions = [], [], []

    with torch.no_grad():
        for data, target, *_ in data_loader:
            if len(images) >= num_samples:
                break
            data, target = data.float().to(device), target.long().to(device)
            output = model(data)
            pred = torch.argmax(output, dim=1)

            images.extend(data.cpu())
            targets.extend(target.cpu())
            predictions.extend(pred.cpu())

    images = images[:num_samples]
    targets = targets[:num_samples]
    predictions = predictions[:num_samples]

    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))

    for i in range(num_samples):
        # Original Image
        img = images[i].permute(1, 2, 0).numpy()
        if img.min() < 0 or img.max() > 1:  # Check if image needs normalization adjustment
            img = (img - img.min()) / (img.max() - img.min())  # Normalize to [0, 1]
        axs[i, 0].imshow(img)
        axs[i, 0].set_title('Original Image')
        axs[i, 0].axis('off')

        # Ground Truth
        gt_mask = targets[i].squeeze().numpy()
        if class_colors:
            gt_colored = np.zeros((gt_mask.shape[0], gt_mask.shape[1], 3), dtype=np.uint8)
            for class_idx, color in enumerate(class_colors):
                gt_colored[gt_mask == class_idx] = color
            axs[i, 1].imshow(gt_colored)
        else:
            axs[i, 1].imshow(gt_mask, cmap='gray')  # Use grayscale if no colors provided
        axs[i, 1].set_title('Ground Truth')
        axs[i, 1].axis('off')

        # Prediction
        pred_mask = predictions[i].squeeze().numpy()
        if class_colors:
            pred_colored = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8)
            for class_idx, color in enumerate(class_colors):
                pred_colored[pred_mask == class_idx] = color
            axs[i, 2].imshow(pred_colored)
        else:
            axs[i, 2].imshow(pred_mask, cmap='gray')  # Use grayscale if no colors provided
        axs[i, 2].set_title('Prediction')
        axs[i, 2].axis('off')

    plt.tight_layout()
    return fig


In [14]:
# Define your model, loss function, optimizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)
loss_fn = CEDiceLoss(lambda_dice=0, lambda_ce=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train the model
trained_model = train(model, train_loader, val_loader, epochs=100, loss_fn=loss_fn, optimizer=optimizer, device=device)

# Define class colors (adjust according to your classes)
class_colors = [
    [0, 0, 0],        # Class 0 (black)
    [255, 0, 0],      # Class 1 (red)
    [0, 255, 0],      # Class 2 (green)
    [0, 0, 255],      # Class 3 (blue)
    [255, 255, 0],    # Class 4 (yellow)
    [255, 0, 255],    # Class 5 (magenta)
    [0, 255, 255],    # Class 6 (cyan)
    [192, 192, 192],  # Class 7 (silver)
    [128, 0, 0],      # Class 8 (maroon)
    [128, 128, 0],    # Class 9 (olive)
    [0, 128, 0]       # Class 10 (dark green)
]

# Visualize results after training
# fig = visualize_results(trained_model, val_loader, device, num_samples=5, class_colors=class_colors)
# plt.show()

# # Log the visualization to W&B
# wandb.init(project="Final_Research", name="Post_Training_Visualisation_DeepLabv3_ImageNet_100")
# wandb.log({"Validation Predictions DeepLabv3 ImageNet 100": wandb.Image(fig)})


Epoch 1/100 - Training: 100%|██████████| 10/10 [04:49<00:00, 28.92s/it]
Epoch 1/100 - Validation: 100%|██████████| 3/3 [01:09<00:00, 23.10s/it]


Epoch 1: Train Loss: 1.3405, Train Dice: 0.7644, Train IoU: 0.1434, Val Loss: 0.9948, Val Dice: 0.8721, Val IoU: 0.0793


Epoch 2/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 2/100 - Validation: 100%|██████████| 3/3 [00:03<00:00,  1.01s/it]


Epoch 2: Train Loss: 0.4701, Train Dice: 0.9175, Train IoU: 0.2297, Val Loss: 0.6540, Val Dice: 0.8721, Val IoU: 0.0793


Epoch 3/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.02s/it]
Epoch 3/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 3: Train Loss: 0.2791, Train Dice: 0.9269, Train IoU: 0.2704, Val Loss: 0.4133, Val Dice: 0.8846, Val IoU: 0.1237


Epoch 4/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.02s/it]
Epoch 4/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 4: Train Loss: 0.2050, Train Dice: 0.9368, Train IoU: 0.3256, Val Loss: 0.2617, Val Dice: 0.9193, Val IoU: 0.2642


Epoch 5/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.02s/it]
Epoch 5/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.04it/s]


Epoch 5: Train Loss: 0.1726, Train Dice: 0.9433, Train IoU: 0.3941, Val Loss: 0.2276, Val Dice: 0.9234, Val IoU: 0.2919


Epoch 6/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.01s/it]
Epoch 6/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.03it/s]


Epoch 6: Train Loss: 0.1512, Train Dice: 0.9486, Train IoU: 0.4491, Val Loss: 0.1734, Val Dice: 0.9458, Val IoU: 0.4645


Epoch 7/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.01s/it]
Epoch 7/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 7: Train Loss: 0.1354, Train Dice: 0.9536, Train IoU: 0.4937, Val Loss: 0.1600, Val Dice: 0.9474, Val IoU: 0.4721


Epoch 8/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.02s/it]
Epoch 8/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 8: Train Loss: 0.1257, Train Dice: 0.9559, Train IoU: 0.5174, Val Loss: 0.1596, Val Dice: 0.9451, Val IoU: 0.4658


Epoch 9/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.01s/it]
Epoch 9/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 9: Train Loss: 0.1193, Train Dice: 0.9563, Train IoU: 0.5169, Val Loss: 0.1608, Val Dice: 0.9451, Val IoU: 0.4757


Epoch 10/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 10/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 10: Train Loss: 0.1141, Train Dice: 0.9586, Train IoU: 0.5509, Val Loss: 0.1492, Val Dice: 0.9499, Val IoU: 0.5147


Epoch 11/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.01s/it]
Epoch 11/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.03it/s]


Epoch 11: Train Loss: 0.1054, Train Dice: 0.9616, Train IoU: 0.5683, Val Loss: 0.1528, Val Dice: 0.9466, Val IoU: 0.4799


Epoch 12/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 12/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 12: Train Loss: 0.1005, Train Dice: 0.9630, Train IoU: 0.5752, Val Loss: 0.1413, Val Dice: 0.9528, Val IoU: 0.5447


Epoch 13/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.01s/it]
Epoch 13/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 13: Train Loss: 0.0969, Train Dice: 0.9642, Train IoU: 0.5872, Val Loss: 0.1525, Val Dice: 0.9485, Val IoU: 0.5110


Epoch 14/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 14/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 14: Train Loss: 0.0970, Train Dice: 0.9635, Train IoU: 0.5897, Val Loss: 0.1491, Val Dice: 0.9481, Val IoU: 0.4756


Epoch 15/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 15/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 15: Train Loss: 0.0918, Train Dice: 0.9652, Train IoU: 0.5995, Val Loss: 0.1548, Val Dice: 0.9467, Val IoU: 0.5046


Epoch 16/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 16/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.03it/s]


Epoch 16: Train Loss: 0.0889, Train Dice: 0.9661, Train IoU: 0.6045, Val Loss: 0.1550, Val Dice: 0.9463, Val IoU: 0.4900


Epoch 17/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 17/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.03it/s]


Epoch 17: Train Loss: 0.0851, Train Dice: 0.9677, Train IoU: 0.6226, Val Loss: 0.1582, Val Dice: 0.9421, Val IoU: 0.4361


Epoch 18/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.98s/it]
Epoch 18/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 18: Train Loss: 0.0821, Train Dice: 0.9688, Train IoU: 0.6298, Val Loss: 0.1379, Val Dice: 0.9534, Val IoU: 0.5407


Epoch 19/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.02s/it]
Epoch 19/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 19: Train Loss: 0.0796, Train Dice: 0.9694, Train IoU: 0.6307, Val Loss: 0.1399, Val Dice: 0.9521, Val IoU: 0.5416


Epoch 20/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 20/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 20: Train Loss: 0.0804, Train Dice: 0.9683, Train IoU: 0.6225, Val Loss: 0.1493, Val Dice: 0.9483, Val IoU: 0.4942


Epoch 21/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 21/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 21: Train Loss: 0.0789, Train Dice: 0.9688, Train IoU: 0.6235, Val Loss: 0.1564, Val Dice: 0.9502, Val IoU: 0.5180


Epoch 22/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 22/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 22: Train Loss: 0.0737, Train Dice: 0.9713, Train IoU: 0.6444, Val Loss: 0.1426, Val Dice: 0.9511, Val IoU: 0.5177


Epoch 23/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 23/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 23: Train Loss: 0.0738, Train Dice: 0.9707, Train IoU: 0.6387, Val Loss: 0.1433, Val Dice: 0.9542, Val IoU: 0.5560


Epoch 24/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 24/100 - Validation: 100%|██████████| 3/3 [00:03<00:00,  1.01s/it]


Epoch 24: Train Loss: 0.0699, Train Dice: 0.9726, Train IoU: 0.6623, Val Loss: 0.1471, Val Dice: 0.9525, Val IoU: 0.5591


Epoch 25/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 25/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 25: Train Loss: 0.0642, Train Dice: 0.9755, Train IoU: 0.6903, Val Loss: 0.1309, Val Dice: 0.9588, Val IoU: 0.5981


Epoch 26/100 - Training: 100%|██████████| 10/10 [00:20<00:00,  2.02s/it]
Epoch 26/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.00it/s]


Epoch 26: Train Loss: 0.0629, Train Dice: 0.9761, Train IoU: 0.6975, Val Loss: 0.1317, Val Dice: 0.9588, Val IoU: 0.5962


Epoch 27/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 27/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 27: Train Loss: 0.0607, Train Dice: 0.9771, Train IoU: 0.6994, Val Loss: 0.1284, Val Dice: 0.9593, Val IoU: 0.6028


Epoch 28/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.98s/it]
Epoch 28/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 28: Train Loss: 0.0594, Train Dice: 0.9779, Train IoU: 0.7095, Val Loss: 0.1306, Val Dice: 0.9590, Val IoU: 0.6006


Epoch 29/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 29/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.03it/s]


Epoch 29: Train Loss: 0.0587, Train Dice: 0.9781, Train IoU: 0.7122, Val Loss: 0.1311, Val Dice: 0.9591, Val IoU: 0.6002


Epoch 30/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 30/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 30: Train Loss: 0.0580, Train Dice: 0.9783, Train IoU: 0.7083, Val Loss: 0.1308, Val Dice: 0.9593, Val IoU: 0.6041


Epoch 31/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 31/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 31: Train Loss: 0.0570, Train Dice: 0.9789, Train IoU: 0.7160, Val Loss: 0.1331, Val Dice: 0.9591, Val IoU: 0.6009


Epoch 32/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 32/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.00it/s]


Epoch 32: Train Loss: 0.0562, Train Dice: 0.9791, Train IoU: 0.7163, Val Loss: 0.1317, Val Dice: 0.9595, Val IoU: 0.6062


Epoch 33/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 33/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 33: Train Loss: 0.0558, Train Dice: 0.9794, Train IoU: 0.7205, Val Loss: 0.1329, Val Dice: 0.9592, Val IoU: 0.6020


Epoch 34/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 34/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 34: Train Loss: 0.0556, Train Dice: 0.9794, Train IoU: 0.7192, Val Loss: 0.1333, Val Dice: 0.9593, Val IoU: 0.6034


Epoch 35/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 35/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.03it/s]


Epoch 35: Train Loss: 0.0551, Train Dice: 0.9796, Train IoU: 0.7213, Val Loss: 0.1333, Val Dice: 0.9594, Val IoU: 0.6045


Epoch 36/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 36/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 36: Train Loss: 0.0555, Train Dice: 0.9795, Train IoU: 0.7206, Val Loss: 0.1336, Val Dice: 0.9595, Val IoU: 0.6052


Epoch 37/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 37/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 37: Train Loss: 0.0558, Train Dice: 0.9793, Train IoU: 0.7200, Val Loss: 0.1334, Val Dice: 0.9594, Val IoU: 0.6052


Epoch 38/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 38/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 38: Train Loss: 0.0554, Train Dice: 0.9795, Train IoU: 0.7225, Val Loss: 0.1335, Val Dice: 0.9594, Val IoU: 0.6047


Epoch 39/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]
Epoch 39/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 39: Train Loss: 0.0557, Train Dice: 0.9792, Train IoU: 0.7178, Val Loss: 0.1337, Val Dice: 0.9593, Val IoU: 0.6050


Epoch 40/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 40/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 40: Train Loss: 0.0553, Train Dice: 0.9796, Train IoU: 0.7234, Val Loss: 0.1337, Val Dice: 0.9594, Val IoU: 0.6049


Epoch 41/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 41/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]


Epoch 41: Train Loss: 0.0557, Train Dice: 0.9793, Train IoU: 0.7160, Val Loss: 0.1340, Val Dice: 0.9594, Val IoU: 0.6050


Epoch 42/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 42/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 42: Train Loss: 0.0552, Train Dice: 0.9795, Train IoU: 0.7206, Val Loss: 0.1340, Val Dice: 0.9594, Val IoU: 0.6051


Epoch 43/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.95s/it]
Epoch 43/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 43: Train Loss: 0.0552, Train Dice: 0.9797, Train IoU: 0.7248, Val Loss: 0.1338, Val Dice: 0.9594, Val IoU: 0.6055


Epoch 44/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 44/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.02it/s]


Epoch 44: Train Loss: 0.0553, Train Dice: 0.9796, Train IoU: 0.7243, Val Loss: 0.1337, Val Dice: 0.9594, Val IoU: 0.6052


Epoch 45/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 45/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.03it/s]


Epoch 45: Train Loss: 0.0549, Train Dice: 0.9797, Train IoU: 0.7227, Val Loss: 0.1337, Val Dice: 0.9594, Val IoU: 0.6050


Epoch 46/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 46/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.03it/s]


Epoch 46: Train Loss: 0.0556, Train Dice: 0.9795, Train IoU: 0.7214, Val Loss: 0.1338, Val Dice: 0.9594, Val IoU: 0.6050


Epoch 47/100 - Training: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]
Epoch 47/100 - Validation: 100%|██████████| 3/3 [00:02<00:00,  1.00it/s]

Epoch 47: Train Loss: 0.0555, Train Dice: 0.9795, Train IoU: 0.7208, Val Loss: 0.1338, Val Dice: 0.9594, Val IoU: 0.6047
Early stopping triggered after 47 epochs





In [None]:
from tqdm import tqdm

def test(model, test_loader, loss_fn, device):
    model.eval()  # Set model to evaluation mode
    test_loss = 0.0
    test_dice = 0.0
    test_iou = 0.0
    test_acc =0.0

    dice_metric = Dice(average='micro').to(device)
    iou_metric = JaccardIndex(task="multiclass", num_classes=11).to(device)  # Adjust num_classes as needed
    accuracy_metric = Accuracy(task="multiclass", num_classes=11).to(device)
    with torch.no_grad():  # Disable gradient computation
        for data, target, *_ in tqdm(test_loader, desc="Testing"):
            data, target = data.float().to(device), target.to(device)

            # Forward pass
            output = model(data)
            loss = loss_fn(output, target)
            test_loss += loss.item()

            # Get predictions
            pred_labels = torch.argmax(output, dim=1)

            # Calculate metrics
            test_dice += dice_metric(pred_labels, target)
            test_iou += iou_metric(pred_labels, target)
            test_acc += accuracy_metric(pred_labels, target)

    # Average metrics over all batches
    test_loss /= len(test_loader)
    test_dice /= len(test_loader)
    test_iou /= len(test_loader)
    test_acc /= len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Test Dice: {test_dice:.4f}, Test IoU: {test_iou:.4f}")

    return test_loss, test_dice, test_iou, test_acc

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = DeepLabV3(classes = 11).to(device)
model.load_state_dict(torch.load('drive/MyDrive/Colab Notebooks/DeepLabV3 SimCLR/Best_model_100.pth'))
loss_fn = CEDiceLoss(lambda_dice=0, lambda_ce=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Run the test function
test_loss, test_dice, test_iou, test_acc = test(model, test_loader, loss_fn, device)

# Optionally, log test results to W&B
wandb.init(project="Final_Research", name="Test_Results_DeepLabv3_SimCLR_100")
wandb.log({
    "Test Loss": test_loss,
    "Test Dice": test_dice,
    "Test IoU": test_iou,
    "Test Accuracy": test_acc
})
wandb.finish()

# fig = visualize_results(trained_model, test_loader, device, num_samples=5, class_colors=class_colors)
# plt.show()

# # Log the visualization to W&B
# wandb.init(project="Final_Research", name="Test_Visualization_DeepLabv3_ImageNet_10")
# wandb.log({"Final Predictions: DeepLabv3 ImageNet 100": wandb.Image(fig)})
# wandb.finish()

  model.load_state_dict(torch.load('drive/MyDrive/Colab Notebooks/DeepLabV3 SimCLR/Best_model_100.pth'))
Testing:   0%|          | 0/3 [00:00<?, ?it/s]