My UNET returns a blank image after traning 😔

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    assert preds.shape == x.shape

if __name__ == "__main__":
    test()

In [3]:
import cv2
from torchvision.datasets import VOCSegmentation
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)


VOC_CLASSES = [
    "background",
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "potted plant",
    "sheep",
    "sofa",
    "train",
    "tv/monitor",
]


VOC_COLORMAP = [
    [0, 0, 0],
    [128, 0, 0],
    [0, 128, 0],
    [128, 128, 0],
    [0, 0, 128],
    [128, 0, 128],
    [0, 128, 128],
    [128, 128, 128],
    [64, 0, 0],
    [192, 0, 0],
    [64, 128, 0],
    [192, 128, 0],
    [64, 0, 128],
    [192, 0, 128],
    [64, 128, 128],
    [192, 128, 128],
    [0, 64, 0],
    [128, 64, 0],
    [0, 192, 0],
    [128, 192, 0],
    [0, 64, 128],
]


class PascalVOCSearchDataset(VOCSegmentation):
    def __init__(self, image_set, root="~/data/pascal_voc", download=True, transform=None):
        super().__init__(root=root, image_set=image_set, download=download, transform=transform)

    @staticmethod
    def _convert_to_segmentation_mask(mask):
        # This function converts a mask from the Pascal VOC format to the format required by AutoAlbument.
        #
        # Pascal VOC uses an RGB image to encode the segmentation mask for that image. RGB values of a pixel
        # encode the pixel's class.
        #
        # AutoAlbument requires a segmentation mask to be a NumPy array with the shape [height, width, num_classes].
        # Each channel in this mask should encode values for a single class. Pixel in a mask channel should have
        # a value of 1.0 if the pixel of the image belongs to this class and 0.0 otherwise.
        height, width = mask.shape[:2]
        segmentation_mask = np.zeros((height, width, len(VOC_COLORMAP)), dtype=np.float32)
        for label_index, label in enumerate(VOC_COLORMAP):
            segmentation_mask[:, :, label_index] = np.all(mask == label, axis=-1).astype(float)
        return segmentation_mask

    def __getitem__(self, index):
        image = cv2.imread(self.images[index])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks[index])
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        mask = self._convert_to_segmentation_mask(mask)
        if self.transform is not None:
            image = self.transform(image)
            mask = self.transform(mask)
        return image, mask

transform = transforms.Compose([
    transforms.ToTensor(),   # Convert to tensor
    transforms.Resize(size=(64,64), antialias=True)
])

train_set = PascalVOCSearchDataset(image_set="train", transform=transform)
val_set = PascalVOCSearchDataset(image_set="val", transform=transform)

Using downloaded and verified file: C:\Users\XPS/data/pascal_voc\VOCtrainval_11-May-2012.tar
Extracting C:\Users\XPS/data/pascal_voc\VOCtrainval_11-May-2012.tar to C:\Users\XPS/data/pascal_voc
Using downloaded and verified file: C:\Users\XPS/data/pascal_voc\VOCtrainval_11-May-2012.tar
Extracting C:\Users\XPS/data/pascal_voc\VOCtrainval_11-May-2012.tar to C:\Users\XPS/data/pascal_voc


In [4]:
dataloaders = {
    'train': DataLoader(train_set, batch_size=32, shuffle=True),
    'val': DataLoader(val_set, batch_size=32, shuffle=True)
}

In [5]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [6]:
print(torch.cuda.is_available())

True


In [7]:
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

LEARNING_RATE = 1e-4
NUM_EPOCHS = 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNET(in_channels=3, out_channels=21).to(device)

initialize_weights(model)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

train_writer = SummaryWriter('alaruns/train')
val_writer = SummaryWriter('alaruns/val')

In [8]:
# Train the model
train_batch_idx = 0
val_batch_idx = 0
for epoch in range(NUM_EPOCHS):
    # Train the model for one epoch
    for i, data in enumerate(dataloaders['train']):
        # Get inputs and labels
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Reset gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Calculate loss
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

        # Update training loss
        train_writer.add_scalar('train_loss', loss.item(), train_batch_idx)
        train_batch_idx += 1
        train_writer.close()
    # print('Done with training')

    # Evaluate the model on the validation set
    for i, data in enumerate(dataloaders['val']):
        # Get inputs and labels
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        with torch.no_grad():
            # Forward pass
            outputs = model(inputs)

            # Calculate loss
            loss = criterion(outputs, labels)

            # Update validation loss
            val_writer.add_scalar('val_loss', loss.item(), val_batch_idx)
            val_writer.close()


        val_batch_idx += 1
    
      
    print(f'Done with epoch {epoch}')

Done with epoch 0
Done with epoch 1
Done with epoch 2
Done with epoch 3
Done with epoch 4


In [11]:
# Saving model
torch.save(model, 'models/almodel_two.pth')

In [12]:
# Saving model weights
torch.save(model.state_dict(), 'models/almodel_weights_two.pth')