# Import

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.models.segmentation import deeplabv3_resnet50
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision.transforms import functional as TF
import os
from PIL import Image
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset
import torch
from torchvision.transforms import v2

from google.colab import drive

drive.mount('/content/drive')

BASE_PATH = 'Path to dataset'

# Dataset

## Tiles dataset for model training (reconstruction)

In [None]:


class RandomApplyTransforms:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, mask):

        for transform in self.transforms:
            if torch.rand(1) < 0.5:
                image = transform(image)
                mask = transform(mask)
        return image, mask

custom_transform = RandomApplyTransforms([
    TF.hflip,  # Horizontal flip
    TF.vflip,  # Vertical flip

])

class TailsDatasetReconstruct(Dataset):
    def __init__(self, raw_data_dir, image_transform=None):
        self.raw_data_dir = raw_data_dir
        self.image_transform = image_transform
        self.filenames = [f.split('.')[0] for f in os.listdir(raw_data_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # Load images
        raw_image_path = os.path.join(self.raw_data_dir, self.filenames[idx] + '.png')

        raw_image = Image.open(raw_image_path).convert('RGB')

        # Apply transformations
        if self.image_transform is not None:
            raw_image = self.image_transform(raw_image)

        raw_image, raw_image = custom_transform(raw_image, raw_image)

        return raw_image, raw_image

    def set_transform(self, raw_transform):

      self.image_transform = raw_transform


## Tiles dataset for model training (segmentation)

In [None]:
from PIL import Image
import os
from torch.utils.data import Dataset
from torchvision.transforms import functional as TF
import torch
from torchvision.transforms import v2

class RandomApplyTransforms:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, mask):
        # Apply each transformation to both image and mask
        for transform in self.transforms:
            if torch.rand(1) < 0.5:
                image = transform(image)
                mask = transform(mask)
        return image, mask

custom_transform = RandomApplyTransforms([
    TF.hflip,  # Horizontal flip
    TF.vflip,  # Vertical flip

])

class TailsDatasetMask(Dataset):
    def __init__(self, raw_data_dir, masks_dir, image_transform=None, mask_transform=None):
        self.raw_data_dir = raw_data_dir
        self.masks_dir = masks_dir
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.filenames = [f.split('.')[0] for f in os.listdir(masks_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # Load images
        raw_image_path = os.path.join(self.raw_data_dir, self.filenames[idx] + '.png')
        mask_path = os.path.join(self.masks_dir, self.filenames[idx] + '.png')

        raw_image = Image.open(raw_image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        # Apply transformations
        if self.image_transform is not None:
            raw_image = self.image_transform(raw_image)


        if self.mask_transform is not None:
            mask = self.mask_transform(mask)

        # mask = self.reverse_labels_in_mask(mask)
        raw_image, mask = custom_transform(raw_image, mask)

        return raw_image,  mask

    def set_transform(self, raw_transform, mask_transform):

      self.image_transform = raw_transform
      self.mask_transform = mask_transform

    def reverse_labels_in_mask(self, mask):
        mask = 1 - mask
        return mask


In [None]:
from torch.utils.data import DataLoader, random_split

raw_transform = v2.Compose([
    v2.ColorJitter(brightness=0.5, contrast = 0.5),
    v2.ToImage(), v2.ToDtype(torch.float32, scale=True)
])

test_raw_transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])

mask_transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])



# Initialize your dataset with the combined transform
dataset_m = TailsDatasetMask(raw_data_dir=BASE_PATH+'Tiles/raw',
                                  masks_dir=BASE_PATH+'Tiles/mask',
                                  image_transform=raw_transform, mask_transform=mask_transform)

# Split the dataset into training, validation, and test sets
train_size = int(0.9 * len(dataset_m))
val_size = int(0.1 * len(dataset_m))
test_size = len(dataset_m) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset_m, [train_size, val_size, test_size])


# Create DataLoader objects for each dataset
train_loader_mask = DataLoader(train_dataset, batch_size=30, shuffle=True, num_workers =6)
val_loader_mask = DataLoader(val_dataset, batch_size=10, shuffle=False)
test_loader_mask = DataLoader(test_dataset, batch_size=15, shuffle=False)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import v2



# Function to calculate class weights
def calculate_class_weights(loader, num_classes=2):
    class_counts = torch.zeros(num_classes)

    for _, masks in loader:
        masks = masks.view(-1)
        class_counts[0] += (masks == 0).sum().item()
        class_counts[1] += (masks == 1).sum().item()

    # Calculate weights as the inverse of class frequency
    total_count = class_counts.sum()
    # print(class_counts )
    class_weights = class_weights = total_count / (num_classes * class_counts)
    return class_weights

# Calculate weights for the training data
class_weights = calculate_class_weights(train_loader_mask)
print(f'Class Weights: {class_weights}')

# Define the loss function with class weights
criterion_seg = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])


In [None]:
raw_transform_recon = v2.Compose([
    v2.ColorJitter(brightness=0.5, contrast = 0.5),
    v2.ToImage(), v2.ToDtype(torch.float32, scale=True)
])

# Initialize your dataset with the combined transform
dataset_rec = TailsDatasetReconstruct(raw_data_dir=BASE_PATH+'Tiles/Semi-supervised/raw',
                                  image_transform=raw_transform_recon)

# Split the dataset into training, validation, and test sets
train_size = int(0.9 * len(dataset_rec))
val_size = int(0.1 * len(dataset_rec))

train_dataset_rec, val_dataset_rec = random_split(dataset_rec, [train_size, val_size])

# Create DataLoader objects for each dataset
train_loader_rec = DataLoader(train_dataset_rec, batch_size=30, shuffle=True, num_workers =4)
val_loader_rec = DataLoader(val_dataset_rec, batch_size=10, shuffle=False)

# Models

## Original Unet

In [None]:
""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


In [None]:
""" Full assembly of the parts to form the complete network """

# from unet_parts import *


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

## Modified Unet model with 2 decoders

In [None]:
""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):

    """Modified Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True, use_skip=True):
        super().__init__()
        self.use_skip = use_skip
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) #check this part first
        if use_skip:
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.conv = DoubleConv(out_channels, out_channels, in_channels // 2)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # print(f'x1 up: {x1.shape}')
        if self.use_skip and x2 is not None:
            diffY = x2.size()[2] - x1.size()[2]
            diffX = x2.size()[3] - x1.size()[3]
            x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
            x = torch.cat([x2, x1], dim=1)
            # print('skip')
        else:
            x = x1  # Without skip connection
            # print(f'x1 up: {x.shape}')
        x = self.conv(x)
        # print(f'x up: {x.shape}')
        return x


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNetWithDualDecoders(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNetWithDualDecoders, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # Shared Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        # Decoder for Segmentation
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc_seg = OutConv(64, n_classes)

        # Decoder for Reconstruction
        self.recon_up1 = Up(1024, 512 // factor, bilinear, use_skip=False)
        # print(f'x_recon1: {self.recon_up1.shape}')
        self.recon_up2 = Up(512, 256 // factor, bilinear, use_skip=False)
        self.recon_up3 = Up(256, 128 // factor, bilinear, use_skip=False)
        self.recon_up4 = Up(128, 64, bilinear, use_skip=False)
        self.outc_recon = OutConv(64, n_channels)

    def forward(self, x, train_recon=False, freeze_encoder=False):
        # Encoder
        x1 = self.inc(x)
        # print(f'x1: {x1.shape}')
        x2 = self.down1(x1)
        # print(f'x2: {x2.shape}')
        x3 = self.down2(x2)
        # print(f'x3: {x3.shape}')
        x4 = self.down3(x3)
        # print(f'x4: {x4.shape}')
        x5 = self.down4(x4)
        # print(f'x5: {x5.shape}')


        if train_recon:
            # Decoder for Reconstruction
            x_recon = self.recon_up1(x5, None)  # No skipped connection
            # print(f'x_recon1: {x_recon.shape}')
            x_recon = self.recon_up2(x_recon, None)
            # print(f'x_recon2: {x_recon.shape}')
            x_recon = self.recon_up3(x_recon, None)
            # print(f'x_recon3: {x_recon.shape}')
            x_recon = self.recon_up4(x_recon, None)
            # print(f'x_recon4: {x_recon.shape}')
            recon = self.outc_recon(x_recon)
            recon = nn.Sigmoid()(recon)  # Use Sigmoid for activation to ensure non-negative output
            # print(f'recon: {recon.shape}')
            return recon
        else:
            # Decoder for Segmentation
            x_seg = self.up1(x5, x4)
            # print(f'x_seg1: {x_seg.shape}')
            x_seg = self.up2(x_seg, x3)
            # print(f'x_seg2: {x_seg.shape}')
            x_seg = self.up3(x_seg, x2)
            # print(f'x_seg3: {x_seg.shape}')
            x_seg = self.up4(x_seg, x1)
            # print(f'x_seg4: {x_seg.shape}')
            logits_seg = self.outc_seg(x_seg)
            #logits_seg = nn.Sigmoid()(logits_seg)
            # print(f'logits_seg: {logits_seg.shape}')
            return logits_seg


# Model setup

In [None]:
# print(class_weights)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, deeplabv3_mobilenet_v3_large, deeplabv3_resnet101, FCN_ResNet50_Weights, DeepLabV3_ResNet50_Weights
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision.transforms import functional as TF
import os
from PIL import Image
import numpy as np

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define loss functions
# Define the loss function with class weights
# criterion_seg = nn.BCEWithLogitsLoss()
criterion_seg = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
criterion_recon = nn.L1Loss()

# Assuming you have DataLoader instances: train_loader and test_loader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetWithDualDecoders(n_channels=3, n_classes=1).to(device)


# Assume optimizer is defined for your model parameters
optimizer = optim.Adam(model.parameters(), lr=0.0001)

## Train

In [None]:
from tqdm import tqdm
import torch
from tqdm.auto import tqdm

def train_combined(model, train_loader_mask, train_loader_rec, optimizer, device, lambda_recon=0.5):
    model.train()
    running_loss = 0.0

    # Create iterators for both DataLoaders
    iter_mask = iter(train_loader_mask)
    iter_rec = iter(train_loader_rec)

    # Calculate the total number of batches. Assuming train_loader_rec is the larger dataset.
    total_batches = len(train_loader_rec)

    pbar = tqdm(total=total_batches, leave=True)

    for _ in range(total_batches):
        optimizer.zero_grad()

        # Handle mask generation task if mask data is available
        try:
            data_mask, target_mask = next(iter_mask)
            data_mask, target_mask = data_mask.to(device), target_mask.to(device)
            output_mask = model(data_mask, train_recon=False, freeze_encoder=False)
            loss_mask = criterion_seg(output_mask, target_mask)
        except StopIteration:
            # No more mask data; mask dataset is smaller
            loss_mask = 0

        # Handle reconstruction task
        try:
            data_rec, _ = next(iter_rec)  # Assuming reconstruction targets are the inputs themselves
            data_rec = data_rec.to(device)
            output_rec = model(data_rec, train_recon=True, freeze_encoder=False)
            loss_rec = criterion_recon(output_rec, data_rec)
        except StopIteration:
            # Just in case reconstruction dataset finishes early, which shouldn't happen in this setup
            loss_rec = 0

        # Combine losses, with an option to weight the reconstruction loss
        loss = loss_mask + lambda_recon * loss_rec

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        pbar.update(1)
        pbar.set_description(f'Epoch Loss: {running_loss/(pbar.n+1):.4f}')

    pbar.close()
    print(f'Training Loss: {running_loss / total_batches:.4f}')
    return running_loss / total_batches

def train(model, dataloader, optimizer, device, train_recon=False, freeze_encoder=False):
    model.train()
    running_loss = 0.0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=False)
    for batch_idx, (data, target) in pbar:
        data = data.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        if train_recon:
            output = model(data, train_recon=True, freeze_encoder=freeze_encoder)
            loss = criterion_recon(output, target)
        else:
            output = model(data, train_recon=False, freeze_encoder=freeze_encoder)
            loss = criterion_seg(output, target)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        pbar.set_description(f'Epoch Loss: {running_loss/(batch_idx+1):.4f}')

    print(f'Training Loss: {running_loss / len(dataloader):.4f}')
    return running_loss / len(dataloader)


## Test

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def accuracy(output, target):
    preds = output.round()  # Round predictions to 0 or 1
    correct = (preds == target).float()  # Element-wise equality
    acc = correct.sum() / correct.numel()
    return acc

def test(model, dataloader, device, test_recon=False, num_visualizations=3):
    model.eval()
    running_loss = 0.0
    total_accuracy = 0.0 if not test_recon else None
    visualized = 0
    with torch.no_grad(), tqdm(enumerate(dataloader), total=len(dataloader), leave=False) as pbar:
        for batch_idx, (data, target) in pbar:
            data = data.to(device)
            if test_recon:
                target = data
            else:
                target = target.to(device)

            output = model(data, train_recon=test_recon, freeze_encoder=False)
            # output = model(data)['out']

            loss = criterion_recon(output, target) if test_recon else criterion_seg(output, target)
            running_loss += loss.item()

            if not test_recon:
                acc = accuracy(output, target)
                total_accuracy += acc.item()

            # Visualization for the first N samples
            if visualized < num_visualizations:
                if test_recon:
                    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
                    ax[0].imshow(data[0].cpu().permute(1, 2, 0))
                    ax[0].set_title("Input Image")
                    ax[1].imshow(output[0].cpu().permute(1, 2, 0))
                    ax[1].set_title("Reconstructed Image")
                    # ax[2].imshow(target[0].cpu().numpy().transpose(1, 2, 0))
                    # ax[2].set_title("Target Image")
                    plt.show()
                else:
                    preds = torch.sigmoid(output[0].cpu().squeeze()) > 0.9
                    # preds = output[0].cpu().squeeze()
                    # Assuming segmentation masks are single-channel
                    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
                    ax[0].imshow(data[0].cpu().numpy().transpose(1, 2, 0))
                    ax[0].set_title("Input Image")
                    ax[1].imshow(preds, cmap='gray')
                    ax[1].set_title("Predicted Mask")
                    # print(output[0].cpu().squeeze())
                    # ax[2].imshow(target[0].cpu(), cmap='gray')
                    # ax[2].set_title("True Mask")
                    # print(target[0].cpu().squeeze())
                    plt.show()
                visualized += 1


            pbar.set_description(f'Batch {batch_idx+1}, Loss: {loss.item():.4f}')

    epoch_loss = running_loss / len(dataloader)
    print(f'Test Loss: {epoch_loss:.4f}')
    if not test_recon:
        epoch_accuracy = total_accuracy / len(dataloader)
        print(f'Test Accuracy: {epoch_accuracy:.4f}')

    return epoch_loss



## Validate

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import torch
from skimage.morphology import skeletonize

def calculate_accuracy(output, target):
    # Assuming output is logits, threshold to get binary prediction
    preds = torch.sigmoid(output) > 0.9

    correct = (preds == target).float()  # Convert to float for division

    accuracy = correct.sum() / correct.numel()
    return accuracy

def calculate_precision_recall_f1(output, target, epsilon=1e-7):
    preds = torch.sigmoid(output) > 0.9
    true_positives = (preds * target).sum()
    predicted_positives = preds.sum()
    actual_positives = target.sum()

    precision = true_positives / (predicted_positives + epsilon)
    recall = true_positives / (actual_positives + epsilon)
    f1 = 2 * (precision * recall) / (precision + recall + epsilon)
    return precision, recall, f1

def calculate_iou(output, target, epsilon=1e-7):
    preds = torch.sigmoid(output) > 0.9
    # skeleton = skeletonize(preds)
    intersection = (preds * target).sum()
    union = preds.sum() + target.sum() - intersection
    iou = intersection / (union + epsilon)
    return iou

In [None]:
def validation(model, device, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    total_accuracy = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_f1 = 0.0
    total_iou = 0.0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)
            # output = model(data)['out']

            val_loss += criterion(output, target).item()

            accuracy = calculate_accuracy(output, target)
            precision, recall, f1 = calculate_precision_recall_f1(output, target)
            iou = calculate_iou(output, target)

            total_accuracy += accuracy.item()
            total_precision += precision.item()
            total_recall += recall.item()
            total_f1 += f1.item()
            total_iou += iou.item()

    # Calculate averages
    num_batches = len(val_loader)
    avg_loss = val_loss / num_batches
    avg_accuracy = total_accuracy / num_batches
    avg_precision = total_precision / num_batches
    avg_recall = total_recall / num_batches
    avg_f1 = total_f1 / num_batches
    avg_iou = total_iou / num_batches


    avg_metrics = {
    "avg_val_loss": avg_loss,
    "avg_val_accuracy": avg_accuracy,
    "avg_val_precision": avg_precision,
    "avg_val_recall": avg_recall,
    "avg_val_f1": avg_f1,
    "avg_val_iou": avg_iou,
    }

    # wandb.log(avg_metrics)
    print(f'Validation Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, F1: {avg_f1:.4f}, IoU: {avg_iou:.4f}')

    return avg_loss, avg_accuracy, avg_iou, avg_f1

In [None]:
len(train_loader_rec)

In [None]:
len(val_loader_rec)

# Train the model

In [None]:
import torch
import gc

def cleanup_gpu():
    gc.collect()
    torch.cuda.empty_cache()

cleanup_gpu()


In [None]:
from IPython.display import clear_output
# Function to evaluate and visualize predicted images
def visualize_reconstruction(model, data_loader, num_images=5):
    model.eval()
    with torch.no_grad():
        data_iter = iter(data_loader)
        images, _ = next(data_iter)
        images = images.to(device)
        outputs = model(images)

        images = images.cpu().numpy()
        outputs = outputs.cpu().numpy()

        # Rescale images
        images = images
        outputs = outputs

        plt.figure(figsize=(20, 8))
        for i in range(num_images):
            # Original images
            ax = plt.subplot(2, num_images, i + 1)
            plt.imshow(images[i].transpose(1, 2, 0).squeeze(), cmap='gray')
            plt.title("Original")
            plt.axis('off')

            # Reconstructed images
            ax = plt.subplot(2, num_images, i + 1 + num_images)
            plt.imshow(outputs[i].transpose(1, 2, 0).squeeze(), cmap='gray')
            plt.title("Reconstructed")
            plt.axis('off')

        plt.show()

# Function to plot the training and validation losses in real-time
def plot_losses(epoch_losses, val_losses):
    clear_output(wait=True)
    plt.figure(figsize=(10, 5))
    plt.plot(epoch_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss over Epochs')
    plt.legend()
    plt.show()

In [None]:
import torch


# Initialize early stopping parameters
patience = 5  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
patience_counter = 0
train_losess = []
val_losess = []

num_epochs = 100  # Adjust as needed
for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}/{num_epochs}")

    train(model, train_loader_rec, optimizer, device, train_recon=True, freeze_encoder=False)

    # train_loss=train(model, train_loader_mask, optimizer, device, train_recon=False, freeze_encoder=False)
    #train_loss = train_combined(model, train_loader_mask, train_loader_rec, optimizer, device, lambda_recon=1)

    val_loss, avg_accuracy, avg_iou, avg_f1 = validation(model, device, val_loader_mask, criterion_seg)

    #val_loss = test(model, val_loader_rec, device, test_recon=True)
    print(val_loss)
    test(model, val_loader_mask, device, test_recon=False)

    train_losess.append(train_loss)
    val_losess.append(val_loss)
    plot_losses(train_losess, val_losess)

    # Check for improvement
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0  # Reset patience counter if improvement is seen
        # Save the best model
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        patience_counter += 1  # Increment patience counter if no improvement
        print(patience_counter)

    print(f'Validation Loss: {val_loss:.4f}, Accuracy: {avg_accuracy:.4f}, IoU: {avg_iou:.4f}, F1: {avg_f1:.4f}')

    # Check if patience is exceeded
    if patience_counter >= patience:
        print(f'Early stopping triggered. No improvement in validation loss for {patience} consecutive epochs.')
        break

# Load the best model after training
model.load_state_dict(torch.load('best_model.pt'))


In [None]:
print(val_loss)

In [None]:
num_epochs = 50 # Adjust as needed
for epoch in range(1, num_epochs + 1):
    train(model, train_loader_mask, optimizer, device, train_recon=False, freeze_encoder=False)
    # train_combined(model, train_loader_mask, train_loader_rec, optimizer, device, lambda_recon=0.5)
    validation(model, device, val_loader_mask, criterion_seg)
    # test(model, val_loader_rec, device, test_recon=True)
    test(model, val_loader_mask, device, test_recon=False)
    # For training reconstruction
    print(epoch)

In [None]:


num_epochs = 100 # Adjust as needed
for epoch in range(1, num_epochs + 1):

    # For training reconstruction
    print(epoch)
    train(model, train_loader_rec, optimizer, device, train_recon=True, freeze_encoder=False)
    # # For testing reconstruction
    test(model, val_loader_rec, device, test_recon=True)

    # # For training segmentation
    # train(model, train_loader_mask, optimizer, device, train_recon=False, freeze_encoder=False)

    # # For testing segmentation
    # test(model, val_loader_mask, device, test_recon=False)

# Evaluate the model

In [None]:
# Load the best model after training
model.load_state_dict(torch.load('best_model.pt'))

In [None]:
validation(model, device, val_loader_mask, criterion_seg)

In [None]:
# For testing segmentation
test(model, val_loader_mask, device, test_recon=False)

In [None]:
# For testing recon
test(model, val_loader_rec, device, test_recon=True)

# Save the model

In [None]:
import torch

def save_checkpoint(model, optimizer, epoch, filename):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, BASE_PATH+"/models/"+filename)
    print(f"Checkpoint saved to {filename}")


def load_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(BASE_PATH+"/models/"+filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print(f"Model and optimizer loaded from checkpoint at epoch {epoch}")
    return epoch

# Predict the full image

In [None]:

import torch
import numpy as np
from PIL import Image
from torchvision.transforms import functional as TF
from skimage.morphology import skeletonize

import matplotlib.pyplot as plt

def predict_mask_full_image(model, image, tile_size=256, overlap=64, device='cuda'):
    model.eval()#

    # Calculate necessary padding to make the image divisible into tiles
    width, height = image.size
    pad_height = (tile_size - (height % (tile_size - overlap))) % (tile_size - overlap)
    pad_width = (tile_size - (width % (tile_size - overlap))) % (tile_size - overlap)
    print(f"Original image size: {width}x{height}, Padding: {pad_width}x{pad_height}")

    # Pad image
    padded_image = TF.pad(image, padding=(0, 0, pad_width, pad_height), padding_mode='reflect')
    print(f"Padded image size: {padded_image.width}x{padded_image.height}")

    full_mask = np.zeros((padded_image.height, padded_image.width))
    count_map = np.zeros((padded_image.height, padded_image.width))


    plt.imshow(padded_image)
    plt.title("Padded Image")
    plt.show()

    tile_count = 0

    # Generate and process tiles
    for y in range(0, padded_image.height - overlap, tile_size - overlap):
        for x in range(0, padded_image.width - overlap, tile_size - overlap):
            tile_count += 1

            tile = padded_image.crop((x, y, min(x + tile_size, padded_image.width), min(y + tile_size, padded_image.height)))
            tile_padded = TF.pad(tile, padding=(0, 0, tile_size - tile.width, tile_size - tile.height))  # Pad tile to ensure 256x256

            # Process tile
            tile_padded_tensor = TF.to_tensor(tile_padded).unsqueeze(0).to(device)  # Convert to tensor and add batch dimension

            # Predict mask for tile
            with torch.no_grad():
                tile_mask_tensor = model(tile_padded_tensor)["out"].squeeze().cpu()
                tile_mask = torch.sigmoid(tile_mask_tensor).numpy() > 0.9
                # tile_mask = skeletonize(tile_mask)
                tile_mask = 1 - tile_mask
            # Resize mask back to original tile size if padding was added
            tile_mask_resized = tile_mask[:tile.height, :tile.width]



            # Update full mask and count map
            full_mask[y:y + tile.height, x:x + tile.width] += tile_mask_resized
            count_map[y:y + tile.height, x:x + tile.width] += 1

    # Average the overlaps
    full_mask /= count_map

    # Crop out any extra padding added to the image
    final_mask = full_mask[:height, :width]


    plt.imshow(final_mask, cmap='gray')
    plt.title("Final Mask")
    plt.savefig("final.png")
    plt.show()

    print(f"Processed {tile_count} tiles.")
    return final_mask

raw_data_dir= BASE_PATH+'Raw_data'
filenames = [f.split('.')[0] for f in os.listdir(raw_data_dir) if f.endswith('.tif')]
raw_image_path = os.path.join(raw_data_dir, filenames[60] + '.tif')
raw_image = Image.open(raw_image_path).convert('RGB')

final_mask = predict_mask_full_image(model, raw_image, device='cuda')
Image.fromarray((final_mask * 255).astype(np.uint8)).save(raw_image.filename.split('.')[0] + '_mask.png')
print(final_mask.shape)

In [None]:
from PIL import Image, ImageEnhance, ImageChops
import numpy as np
import matplotlib.pyplot as plt

def overlay_mask_on_image(raw_image, mask):

    if not isinstance(raw_image, Image.Image):
        raise ValueError("raw_image must be a PIL.Image.Image object")
    if not isinstance(mask, Image.Image):
        raise ValueError("mask must be a PIL.Image.Image object")

    # Resize mask to match the raw image size if necessary
    if raw_image.size != mask.size:
        mask = mask.resize(raw_image.size, Image.BILINEAR)

    # Convert the mask to 'L' mode if it's not already
    single_channel_image = mask.convert('L')
    # mask = ImageEnhance.Contrast(mask).enhance(2.0)
    #high_contrast_mask = Image.fromarray(((np_mask)*255).astype(np.uint8)).convert('L')
    single_channel_image.save("high_contrast_mask.png")


    # Create an RGBA version of the single-channel image with some transparency
    alpha = 10 # Adjust the alpha value to control transparency
    single_channel_rgba = Image.merge('RGBA', (single_channel_image, single_channel_image, single_channel_image, Image.new('L', single_channel_image.size, alpha)))

    # Composite the single-channel image onto the RGB image
    composite_image = Image.alpha_composite(raw_image.convert('RGBA'), single_channel_rgba)


    return composite_image

Image.fromarray((final_mask * 255).astype(np.uint8)).save("final.png")
overlay_image = overlay_mask_on_image(raw_image, Image.fromarray((final_mask * 255).astype(np.uint8)))

# Display the result
plt.figure(figsize=(30, 30))
plt.imshow(overlay_image)
plt.axis('off')
plt.show()


In [None]:
import matplotlib.pyplot as plt
import cv2
import numpy as np
from skimage.morphology import skeletonize

# Load the uploaded image
image_path = 'high_contrast_mask-2.png'
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

# Apply skeletonization using skimage
binary_image = image > 200
skeleton = skeletonize(binary_image)

# Save the skeletonized image
skeletonized_image_path = 'skeletonized_image.png'
plt.imsave(skeletonized_image_path, skeleton, cmap='gray')

# Display the original and skeletonized images
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(image, cmap='gray')
ax[0].set_title('Original Image')
ax[0].axis('off')

ax[1].imshow(skeleton, cmap='gray')
ax[1].set_title('Skeletonized Image')
ax[1].axis('off')

plt.show()

skeletonized_image_path
