In [393]:
import xml.etree.ElementTree as ET
import os
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from torchvision import models
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split

In [394]:
class_mapping = {
    'person': 0,
    'bird': 1,
    'cat': 2,
    'cow': 3,
    'dog': 4,
    'horse': 5,
    'sheep': 6,
    'aeroplane': 7,
    'plane': 7,
    'bike': 8,
    'bicycle':8,
    'boat': 9,
    'bus': 10,
    'car': 11,
    'motorbike': 12,
    'train': 13,
    'bottle': 14,
    'chair': 15,
    'diningtable': 16,
    'table':16,
    'pottedplant': 17,
    'plant': 17,
    'sofa': 18,
    'tvmonitor': 19,
    'monitor': 19,
    'background': 20,
    'void': 255
}
print(class_mapping['bicycle'])

8


In [395]:
def read_scribble_xml(xml_file):
    #  Parse XML into Element Tree
    tree = ET.parse(xml_file)
    root = tree.getroot()

    #  Read meta data
    filename = root.find('filename').text
    width = int(root.find('size/width').text)
    height = int(root.find('size/height').text)

    #  Read all points and assign them to the tensor
    tensor_categories = torch.zeros((22, 224, 224))
    
    
    polygons = root.findall('polygon')
    for polygon in polygons:
        tag = polygon.find('tag').text
        points = np.array([(min(int(int(point.find('X').text)/width*224), 223), min(int(int(point.find('Y').text)/height*224),223)) for point in polygon.findall('point')])
        tensor_categories[class_mapping[tag], points[:, 1], points[:, 0]] = 1 
    
    return filename, tensor_categories


In [396]:
# Define dataset
class ScribbleDataset(Dataset):
    def __init__(self, image_dir, xml_dir, xml_files, transform=None):
        self.image_dir = image_dir
        self.xml_files = xml_files
        self.transform = transform
        self.xml_dir = xml_dir

    def __len__(self):
        return len(self.xml_files)

    def __getitem__(self, idx):
        xml_file = self.xml_files[idx]
        xml_path = os.path.join(self.xml_dir, xml_file)
        image_path = os.path.join(self.image_dir, xml_file.replace(".xml", ".jpg"))

        filename, tensor_categories = read_scribble_xml(xml_path)
        image = Image.open(image_path)

        sample = {'image': image, 'tensor_category': tensor_categories}

        if self.transform:
            sample = self.transform(sample)

        return sample


In [397]:
class ToTensor(object):
    def __call__(self, sample):
        image, tensor_categories = sample['image'], sample['tensor_category']
        
        # Convert image to tensor
        image = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()])(image)
        
        return {'image': image, 'tensor_category': tensor_categories}

In [398]:
# Set up dataset and dataloader
xml_dir = "scribble"
image_dir = "train_JPEGImages"
xml_list = os.listdir(xml_dir)
transform = transforms.Compose([ToTensor()])
train_data, val_data = train_test_split(xml_list, test_size=0.1, random_state=1)
train_dataset = ScribbleDataset(image_dir=image_dir, xml_dir=xml_dir, xml_files=train_data, transform=transform)
val_dataset = ScribbleDataset(image_dir=image_dir, xml_dir=xml_dir, xml_files=val_data, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [399]:
# Define the U-Net model
class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()
        
        resnet18 = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet18.children())[:-2])

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            
            nn.ConvTranspose2d(64, num_classes, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [400]:
# Initialize U-Net model
in_channels = 3
out_channels = 22  # Number of classes
model = UNet(out_channels)
num_epochs = 10

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Set device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

UNet(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  

In [401]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        images, targets = batch['image'].to(device), batch['tensor_category'].to(device)
        
        # Forward pass
        outputs = model(images)
        # Compute your loss based on the weak annotations
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

# Save or use the trained model for inference
torch.save(model.state_dict(), 'seeds_weakly_supervised_segmentation_model.pth')

Epoch 1/10:   2%|▏         | 16/677 [00:23<16:15,  1.48s/it]


KeyboardInterrupt: 