In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.models.segmentation import deeplabv3_resnet50
from PIL import Image
import os
import numpy as np
from tqdm import tqdm

In [None]:
class MedicalImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_dir = self.image_dir.replace('imgs', 'masks')
        mask_file=self.images[idx].split('.')[0]+'_mask.png'
        mask_path = os.path.join(mask_dir, mask_file)
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# Define the transform
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

# Create dataset and dataloaders
image_dataset = MedicalImageDataset('/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/train/imgs', transform=transform)

# Calculate lengths for the splits
train_size = int(0.95 * len(image_dataset))
val_size = len(image_dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(image_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
# Define Dice Loss
def dice_loss(inputs, targets, smooth=1):
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    intersection = (inputs * targets).sum()
    dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
    return 1 - dice

In [None]:
# Define the DeepLab v3 model
model = deeplabv3_resnet50(pretrained=True)
model.classifier[4] = nn.Conv2d(256, 1, kernel_size=1)  # Adjust the final layer for single-class segmentation

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
num_epochs = 100

In [None]:
# Define EarlyStopping class
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

early_stopping = EarlyStopping(patience=10, verbose=True, path='/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/weights/deeplabv3_best_model.pth')

In [None]:
for epoch in range(num_epochs):
    model.train()
    running_train_loss = 0.0
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for images, masks in train_loader_tqdm:
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)['out']
        loss = dice_loss(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()
        train_loader_tqdm.set_postfix({"Train Loss": running_train_loss / (train_loader_tqdm.n + 1)})

    avg_train_loss = running_train_loss / len(train_loader)
    print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}')

    model.eval()
    running_val_loss = 0.0
    val_loader_tqdm = tqdm(val_loader, desc=f"Validation {epoch+1}/{num_epochs}", unit="batch")

    with torch.no_grad():
        for images, masks in val_loader_tqdm:
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)['out']
            val_loss = dice_loss(outputs, masks)

            running_val_loss += val_loss.item()
            val_loader_tqdm.set_postfix({"Val Loss": running_val_loss / (val_loader_tqdm.n + 1)})

    avg_val_loss = running_val_loss / len(val_loader)
    print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}')

    scheduler.step(avg_val_loss)

    early_stopping(avg_val_loss, model)

    if early_stopping.early_stop:
        print("Early stopping")
        break

## Test

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.models.segmentation import deeplabv3_resnet50
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
def dice_loss(inputs, targets, smooth=1):
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    intersection = (inputs * targets).sum()
    dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
    return 1 - dice

In [None]:
class MedicalImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_dir = self.image_dir.replace('imgs', 'masks')
        mask_file=self.images[idx].split('.')[0]+'_mask.png'
        mask_path = os.path.join(mask_dir, mask_file)
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

# Define the transform
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

test_dataset = MedicalImageDataset('/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/test/imgs', transform=transform)

In [None]:
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
len(test_loader)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = deeplabv3_resnet50(pretrained=True)
model.classifier[4] = nn.Conv2d(256, 1, kernel_size=1) 
model = model.to(device)

In [None]:
modelPath = '/media/rohit/mirlproject2/fetal head circumference/Breast_Ultrasound/weights/deeplabv3_best_model.pth'
state_dict = torch.load(modelPath)
if "model_weight" in state_dict:
    model.load_state_dict(state_dict["model_weight"], strict=False)
else:
    model.load_state_dict(state_dict, strict=False)

model.to(device)

In [None]:
model.eval()
loss_fn = dice_loss
test_loss=[]

for images, masks in test_loader:
    images, masks = images.to(device), masks.to(device)

    outputs = model(images)
    val_loss = loss_fn(outputs, masks)

    test_loss.append(val_loss.item())

fig,axs=plt.subplots(1,3, figsize=(15,10))
axs[0].imshow(images[0][0].cpu().detach().numpy())
#axs[0].set_title(str(num))
axs[0].axis('off')

axs[1].imshow(masks[0][0].cpu().detach().numpy())
axs[1].set_title("Ground truth")
axs[1].axis('off')

axs[2].imshow(outputs[0][0].cpu().detach().numpy())
axs[2].set_title("Predicted")
axs[2].axis('off')

plt.show()

In [None]:
from statistics import mean

mean(test_loss)

## compare

In [None]:
i=0
for img,mask in test_loader:
    i+=1
    if i==8:
        break


images=img[0]
images=images.permute(1,2,0)
plt.imshow(images)

In [None]:
gt=mask[0]
gt=gt.permute(1,2,0)
plt.imshow(gt)

In [None]:
i=0
for images, masks in test_loader:
    i+=1
    if i==8:
        break
    else:
        continue

images, masks = images.to(device), masks.to(device)

outputs = model(images)
val_loss = loss_fn(outputs, masks)

print(val_loss.item())

img=images[0][0].cpu().detach().numpy()
gt=masks[0][0].cpu().detach().numpy()
pred=outputs[0][0].cpu().detach().numpy()


plt.imshow(img,cmap='gray')
plt.imshow(pred,alpha=0.4)

plt.title('U-net', fontsize=18)
plt.axis('off')

plt.show()