In [1]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import VOCDetection
from torchvision import models
import matplotlib.pyplot as plt
from PIL import Image
from torch.nn.functional import relu

In [2]:
class_mapping = {
    'person': 1,
    'bird': 2,
    'cat': 3,
    'cow': 4,
    'dog': 5,
    'horse': 6,
    'sheep': 7,
    'aeroplane': 8,
    'bicycle': 9,
    'boat': 10,
    'bus': 11,
    'car': 12,
    'motorbike': 13,
    'train': 14,
    'bottle': 15,
    'chair': 16,
    'diningtable': 17,
    'pottedplant': 18,
    'sofa': 19,
    'tvmonitor': 20,
    'background': 0
}

In [3]:
def to_target_tensor(num_classes, annotation_dict):
    # Extract image size information
    width = int(annotation_dict['annotation']['size']['width'])
    height = int(annotation_dict['annotation']['size']['height'])

    # Extract bounding box information
    tensor_categories = torch.zeros((1, 224, 224))
    for obj in annotation_dict['annotation']['object']:
        xmin = int((int(obj['bndbox']['xmin']) / width) * 224)
        ymin = int((int(obj['bndbox']['ymin']) / height) * 224)
        xmax = int((int(obj['bndbox']['xmax']) / width) * 224)
        ymax = int((int(obj['bndbox']['ymax']) / height) * 224)
        tensor_categories[0, ymin:ymax+1, xmin:xmax+1] = class_mapping[obj['name']]

    return tensor_categories

In [4]:
# Instantiate your weakly supervised dataset
# YourWeaklySupervisedDataset should provide images and their weak annotations
# You need to implement this dataset class
num_classes = 20
# Create a DataLoader for training
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
])
target_transform = transforms.Compose([
    transforms.Lambda(lambda x: to_target_tensor(num_classes, x))
])
train_dataset = VOCDetection(root='./data', year='2012', image_set='train', download=True, transform=transform, target_transform=target_transform)
print(train_dataset[0][1].shape)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data
torch.Size([1, 224, 224])


In [5]:
# Define the U-Net model
class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()

        resnet18 = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet18.children())[:-2])

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [6]:
def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, activation='sigmoid'):
    """
    Args:
        pr (torch.Tensor): A list of predicted elements
        gt (torch.Tensor):  A list of elements that are to be predicted
        eps (float): epsilon to avoid zero division
        threshold: threshold for outputs binarization
    Returns:
        float: IoU (Jaccard) score
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = torch.nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = torch.nn.Softmax2d()
    else:
        raise NotImplementedError(
            "Activation implemented for sigmoid and softmax2d"
        )

    pr = activation_fn(pr)

    if threshold is not None:
        pr = (pr > threshold).float()


    tp = torch.sum(gt * pr)
    fp = torch.sum(pr) - tp
    fn = torch.sum(gt) - tp

    score = ((1 + beta ** 2) * tp + eps) \
            / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps)

    return score


class DiceLoss(nn.Module):
    __name__ = 'dice_loss'

    def __init__(self, eps=1e-7, activation='sigmoid'):
        super().__init__()
        self.activation = activation
        self.eps = eps

    def forward(self, y_pr, y_gt):
        return 1 - f_score(y_pr, y_gt, beta=1.,
                           eps=self.eps, threshold=None,
                           activation=self.activation)


class BCEDiceLoss(DiceLoss):
    __name__ = 'bce_dice_loss'

    def __init__(self, eps=1e-7, activation='sigmoid', lambda_dice=1.0, lambda_bce=1.0):
        super().__init__(eps, activation)
        if activation == None:
            self.bce = nn.BCELoss(reduction='mean')
        else:
            self.bce = nn.BCEWithLogitsLoss(reduction='mean')
        self.lambda_dice=lambda_dice
        self.lambda_bce=lambda_bce

    def forward(self, y_pr, y_gt):
        dice = super().forward(y_pr, y_gt)
        bce = self.bce(y_pr, y_gt)
        return (self.lambda_dice*dice) + (self.lambda_bce* bce)

In [7]:
# Set your device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the model
model = UNet(num_classes=21).to(device)

# Define your loss function and optimizer
criterion = nn.CrossEntropyLoss()
criterion2 = BCEDiceLoss(eps=1.0, activation=None)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=2, cooldown=2)



In [8]:
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [9]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for images, annotations in train_loader:
        images, annotations = images.to(device), annotations.to(device)

        # Forward pass
        outputs = model(images)
        # Compute your loss based on the weak annotations
        loss = criterion2(outputs, annotations)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

# Save or use the trained model for inference
torch.save(model.state_dict(), 'weakly_supervised_segmentation_model.pth')

RuntimeError: ignored

In [None]:
from torchvision.transforms import ToPILImage
from PIL import Image

tensor = model(train_dataset[0][0].unsqueeze(0).to('cuda')).squeeze(0)
normalized_tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())).cpu().detach().numpy()

output = np.argmax(normalized_tensor, axis=0)
target = torch.argmax(train_dataset[0][1], dim=0).numpy()
print(normalized_tensor.shape)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(output)
axes[0].set_title('Image 1')

# Display the second image on the second subplot
axes[1].imshow(target)
axes[1].set_title('Image 2')


plt.colorbar()
plt.show()

In [None]:
from torchvision.transforms import ToPILImage
from PIL import Image

tensor = model(train_dataset[0][0].unsqueeze(0).cuda()).squeeze(0)
normalized_tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())).cpu().detach().numpy()

print(normalized_tensor.shape)

fig, axes = plt.subplots(20, 2, figsize=(10, 100))

# Plot each normalized grayscale image
for i in range(20):
    axes[i, 0].imshow(normalized_tensor[i], cmap='gray', aspect='auto')
    axes[i, 0].set_title(f'Image {i + 1}')
    axes[i, 0].axis('off')  # Turn off axis labels
    axes[i, 1].imshow(train_dataset[0][1][i], cmap='gray', aspect='auto')
    axes[i, 1].set_title(f'Image {i + 1}')
    axes[i, 1].axis('off')  # Turn off axis labels

plt.show()

In [None]:
# Load the model
model_path = "weakly_supervised_segmentation_model.pth"
checkpoint = torch.load(model_path)
# Load the model's state dictionary
model.load_state_dict(checkpoint)

# Ensure the model is in evaluation mode
model.eval()

# Optionally, you may also load other things such as optimizer state, epoch, etc.
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']

# Now, you can use the loaded model for inference
