### UNET Code That will be used in LabMEMS

#### contents:

1. import data
2. create dataset
3. create UNET with SelfAttention
4. Create Train Loop

In [7]:
import os
from math import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

In [3]:
# dataset

class LabMemsDataset(Dataset):
    def __init__(self, img_dir, mask_dir, img_transform=None, mask_transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_transform = img_transform
        self.mask_transform = mask_transform

        self.images = os.listdir(img_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name) # corresponding images and masks have same name

        image = Image.open(img_path).convert("L") # "L" is for grayscale images (1 channel)
        mask = Image.open(mask_path).convert("L")

        if self.img_transform:
            image = self.img_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        
        return image, mask

In [4]:
# initialize transforms

image_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.442, std=0.225)
])

mask_transform = transforms.Compose([
    transforms.ToTensor()
])

In [5]:
train_image_dir = r'C:\Users\USER\Desktop\jupytervscode\LabMems\o-net\lab-data\final-data\input'
train_mask_dir = r'C:\Users\USER\Desktop\jupytervscode\LabMems\o-net\lab-data\final-data\label'

test_image_dir = r'C:\Users\USER\Desktop\jupytervscode\LabMems\o-net\lab-data\final-data\test-data\input'
test_mask_dir = r'C:\Users\USER\Desktop\jupytervscode\LabMems\o-net\lab-data\final-data\test-data\label'

train_dataset = LabMemsDataset(img_dir=train_image_dir, mask_dir=train_mask_dir, img_transform=image_transform, mask_transform=mask_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = LabMemsDataset(img_dir=test_image_dir, mask_dir=test_mask_dir, img_transform=image_transform, mask_transform=mask_transform)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

In [6]:
print(f"train data: {len(train_dataset)}")
print(f"test data: {len(test_dataset)}")

train data: 4199
test data: 221


In [25]:
# operations

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        
        # define double convolution operation
        self.operation = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.operation(x)
    
class SelfAttention(nn.Module):
    def __init__(self, semantic):
        super(SelfAttention, self).__init__()


        # overlaping embedding
        self.query = nn.Conv2d(in_channels=semantic, out_channels=semantic, kernel_size=3, stride=1, padding=1)
        self.key = nn.Conv2d(in_channels=semantic, out_channels=semantic, kernel_size=3, stride=1, padding=1)
        self.value = nn.Conv2d(in_channels=semantic, out_channels=semantic, kernel_size=3, stride=1, padding=1)

        self.normalizer = sqrt(semantic*4)

        self.flatten = nn.Flatten(2, 3)

    def forward(self, x):
        b, c, h, w = x.size()
        q = self.flatten(self.query(x))
        k = self.flatten(self.key(x))
        v = self.flatten(self.value(x))

        scaled = torch.bmm(q, k.permute(0, 2, 1)) / self.normalizer

        return torch.bmm(F.softmax(scaled, dim=-1), v).reshape(b, c, h , w)


In [26]:
a = torch.randn([32, 64, 120, 120])

att = SelfAttention(64)

att(a).shape

torch.Size([32, 64, 120, 120])

In [32]:
# create model

class UNET(nn.Module):
    def __init__(self):
        super(UNET, self).__init__()

        # input [1, 176, 128]

        # encoder
        self.encoder1 = DoubleConv(in_channels=1, out_channels=8)
        self.selfatt1 = SelfAttention(8)
        self.encoder2 = DoubleConv(in_channels=8, out_channels=128)
        self.selfatt2 = SelfAttention(128)
        self.encoder3 = DoubleConv(in_channels=128, out_channels=1024)
        self.selfatt3 = SelfAttention(1024)
        self.encoder4 = DoubleConv(in_channels=1024, out_channels=2048)
        self.selfatt4 = SelfAttention(2048)

        # bottleneck
        self.bottom_conv = nn.Conv2d(in_channels=2048, out_channels=2048, kernel_size=3, stride=1, padding=1)
        self.bottom_normalizer = nn.BatchNorm2d(num_features=2048)
        self.unity_conv = nn.Conv2d(in_channels=2048, out_channels=2048, kernel_size=1, stride=1, padding=0)
        
        # decoder
        self.transpose3 = nn.ConvTranspose2d(in_channels=2048, out_channels=1024, kernel_size=2, stride=2, padding=0, output_padding=0)
        self.decoder3 = DoubleConv(in_channels=2048, out_channels=1024)

        self.transpose2 = nn.ConvTranspose2d(in_channels=1024, out_channels=128, kernel_size=2, stride=2, padding=0, output_padding=0)
        self.decoder2 = DoubleConv(in_channels=256, out_channels=128)

        self.transpose1 = nn.ConvTranspose2d(in_channels=128, out_channels=8, kernel_size=2, stride=2, padding=0, output_padding=0)
        self.decoder1 = DoubleConv(in_channels=16, out_channels=8)

        self.final_conv = nn.Conv2d(in_channels=8, out_channels=1, kernel_size=1)


        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        
        # encoder
        enc1 = self.selfatt1(self.encoder1(x)) # [8, 176, 128]
        enc2 = self.selfatt2(self.encoder2(self.pool(enc1))) # [128, 88, 64]
        enc3 = self.selfatt3(self.encoder3(self.pool(enc2))) # [1024, 44, 32]
        enc4 = self.selfatt4(self.encoder4(self.pool(enc3))) # [2048, 22, 16]


        # bottleneck
        bottom1 = self.relu(self.bottom_normalizer(self.bottom_conv(enc4)))
        bottom2 = self.relu(self.bottom_normalizer(self.unity_conv(bottom1)))

        # decoder
        dec3 = self.decoder3(torch.cat([self.transpose3(bottom2), enc3], dim=1))
        dec2 = self.decoder2(torch.cat([self.transpose2(dec3), enc2], dim=1))
        dec1 = self.decoder1(torch.cat([self.transpose1(dec2), enc1], dim=1))

        final = self.final_conv(dec1)

        return final

In [33]:
a = torch.randn(size=[32, 1, 176, 128])

model = UNET()

model(a).shape

torch.Size([32, 1, 176, 128])

In [34]:
model = UNET()
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Print the total number of parameters
print(f"Total number of parameters: {count_parameters(model)}")

Total number of parameters: 289059985


In [35]:
# training loop

EPOCHS = 10

print("training phase")
print("-------------------------------------------------")
for epoch in range(EPOCHS):
    print(f"epoch: {epoch + 1}/{EPOCHS}")
    model.train()
    running_loss = 0.0
    # running_accuracy = 0.0

    for images, masks in train_loader:
        
        # forward pass
        outputs = model(images)
        loss = loss_fn(outputs, masks)

        # backprop and optimizantion
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update running loss
        running_loss += loss.item() * images.size(0)

        # calculate accuracy
        # accuracy = pixel_accuracy(outputs, masks)
        # running_accuracy = accuracy * images.size(0)

    epoch_loss = running_loss / len(train_dataset)
    # epoch_accuracy = running_accuracy / len(train_dataset)
    print(f"train loss: {epoch_loss:.4f}")
    # print(f"accuracy: {epoch_accuracy:.4f}")
    print("")

training phase
-------------------------------------------------
epoch: 1/10


KeyboardInterrupt: 