In [31]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import VOCDetection
from torchvision import models
import matplotlib.pyplot as plt
from PIL import Image

In [74]:
class_mapping = {
    'person': 0,
    'bird': 1,
    'cat': 2,
    'cow': 3,
    'dog': 4,
    'horse': 5,
    'sheep': 6,
    'aeroplane': 7,
    'bicycle': 8,
    'boat': 9,
    'bus': 10,
    'car': 11,
    'motorbike': 12,
    'train': 13,
    'bottle': 14,
    'chair': 15,
    'diningtable': 16,
    'pottedplant': 17,
    'sofa': 18,
    'tvmonitor': 19
}

In [86]:
def to_target_tensor(num_classes, annotation_dict):
    # Extract image size information
    width = int(annotation_dict['annotation']['size']['width'])
    height = int(annotation_dict['annotation']['size']['height'])

    # Extract bounding box information
    tensor_categories = torch.zeros((20, 224, 224))
    for obj in annotation_dict['annotation']['object']:
        xmin = int((int(obj['bndbox']['xmin']) / width) * 224)
        ymin = int((int(obj['bndbox']['ymin']) / height) * 224)
        xmax = int((int(obj['bndbox']['xmax']) / width) * 224)
        ymax = int((int(obj['bndbox']['ymax']) / height) * 224)
        tensor_categories[class_mapping[obj['name']], xmin:xmax+1, ymin:ymax+1] = 1
        
    return tensor_categories

In [87]:
# Instantiate your weakly supervised dataset
# YourWeaklySupervisedDataset should provide images and their weak annotations
# You need to implement this dataset class
num_classes = 20
# Create a DataLoader for training
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
target_transform = transforms.Compose([
    transforms.Lambda(lambda x: to_target_tensor(num_classes, x))
])
train_dataset = VOCDetection(root='./data', year='2012', image_set='train', download=False, transform=transform, target_transform=target_transform)
print(train_dataset[0][1].shape)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

torch.Size([20, 224, 224])


In [88]:
# Define the U-Net model
class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()
        
        resnet18 = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet18.children())[:-2])
        #self.encoder = nn.Sequential(
            #nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            #nn.MaxPool2d(kernel_size=2, stride=2),
            #models.resnet18(pretrained=True),
            #nn.ReLU(inplace=True),
        #)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            nn.ConvTranspose2d(64, num_classes, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [89]:
# Set your device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the model
model = UNet(num_classes=20).to(device)

# Define your loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [90]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for images, annotations in train_loader:
        images, annotations = images.to(device), annotations.to(device)

        # Forward pass
        outputs = model(images)
        # Compute your loss based on the weak annotations
        loss = criterion(outputs, annotations)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

# Save or use the trained model for inference
torch.save(model.state_dict(), 'weakly_supervised_segmentation_model.pth')

Epoch [1/10], Loss: 1.0638
Epoch [2/10], Loss: 1.7075
Epoch [3/10], Loss: 1.0358
Epoch [4/10], Loss: 1.0432
Epoch [5/10], Loss: 2.1106
Epoch [6/10], Loss: 0.9387
Epoch [7/10], Loss: 1.6952
Epoch [8/10], Loss: 1.1847
Epoch [9/10], Loss: 2.0589
Epoch [10/10], Loss: 1.5336


In [95]:
from torchvision.transforms import ToPILImage
from PIL import Image

tensor = model(train_dataset[0][0].unsqueeze(0)).squeeze(0)
normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())

print(normalized_tensor.shape)
# Convert to a list of PIL Images
pil_images = [ToPILImage()(normalized_tensor[i]) for i in range(20)]

# Merge the list of images into a single grayscale image
grayscale_image = Image.merge("L", pil_images)

# Display or save the resulting grayscale image
grayscale_image.show()

torch.Size([1, 20, 224, 224])


ValueError: pic should not have > 4 channels. Got 20 channels.