<a href="https://colab.research.google.com/github/AlenaAntipina/PytorchLearning/blob/main/segmenattion_learn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as plt
from pathlib import Path
import cv2

import torch 
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim


In [None]:
pip install torchmetrics

In [6]:
import torchmetrics

In [None]:
pip install pytorch-lightning lightning-bolts


In [8]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


In [9]:
import albumentations as A

In [None]:
IMG_HEIGHT = 128
IMG_WIDTH = 128
BATCH_SIZE = 100
EPOCHS = 25
SAMPLES = 7000

In [None]:
root = Path("../input/aisegmentcom-matting-human-datasets/")

In [None]:
cutout_paths= sorted(list(root.glob("matting/*/*/*")))
image_paths = sorted(list(root.glob("clip_img/*/*/*")))

In [None]:
im_pths = image_paths[:2500]
cut_pths = cutout_paths[:2500]

In [None]:
f = np.frompyfunc(lambda x, y: x.stem != y.stem, 2, 1)
print(f"Total # of mismatches: {f(im_pths, cut_pths).sum()}")

In [None]:
fig, axs = plt.subplots(1, 2)
im = cv2.imread(im_pths[0].as_posix(), cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
cut = cv2.imread(cut_pths[0].as_posix(), cv2.IMREAD_UNCHANGED)
cutg = cv2.cvtColor(cut, cv2.COLOR_BGR2RGB)
axs[0].imshow(img)
axs[0].set_title("Image")
axs[0].axis("off")
axs[1].imshow(cutg)
axs[1].set_title("Cutout")
axs[1].axis("off")
plt.show()


In [None]:
df = pd.DataFrame({
    "image_paths": im_pths,
    "cutout_paths": cut_pths
})
df.head()

In [None]:
class SegDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.image_paths = df.image_paths
        self.cutout_paths = df.cutout_paths
        self.transfroms = transforms
        
    def __getitem__(self, idx):
        im = cv2.imread(self.image_paths[idx].as_posix(), cv2.IMREAD_UNCHANGED)
        img = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        cut = cv2.imread(cut_pths[idx].as_posix(), cv2.IMREAD_UNCHANGED)
        mask = np.where(cut[:, :, 3] > 0, 1, 0)
        
        if self.transfroms is not None:
            transformed = transform(image=im, mask=mask)
            transformed_image = transformed['image']
            transformed_mask = transformed['mask']

            # make channels first
            transformed_image = np.transpose(transformed_image, (2, 1, 0))
            transformed_mask = np.expand_dims(transformed_mask, 0)
            
            return {
                "image": torch.tensor(transformed_image, dtype=torch.float32), 
                "mask": torch.tensor(transformed_mask, dtype=torch.float32) 
            }
        
        else:
            # make channels first
            img = np.transpose(img, (2, 1, 0))
            mask = np.expand_dims(mask, 0)
        
            return {
                "image": torch.tensor(img, dtype=torch.float32),
                "mask": torch.tensor(mask, dtype=torch.float32)
            }

    def __len__(self):
        return len(self.image_paths)

    def diplay_sample(self, idx):
        im = cv2.imread(self.image_paths[idx].as_posix(), cv2.IMREAD_UNCHANGED)
        img = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        cut = cv2.imread(cut_pths[idx].as_posix(), cv2.IMREAD_UNCHANGED)
        cutg = cv2.cvtColor(cut, cv2.COLOR_BGR2RGB)
        mask = cut[:, :, 3]
        
        fig, axs = plt.subplots(1, 3)
        
        axs[0].imshow(img)
        axs[0].set_title("Image")
        axs[0].axis("off")
        axs[1].imshow(cutg)
        axs[1].set_title("Cutout")
        axs[1].axis("off")
        axs[2].imshow(mask)
        axs[2].set_title("Mask")
        axs[2].axis("off")
        plt.show()

In [None]:
transform = A.Compose([
    A.Resize(width=IMG_WIDTH, height=IMG_HEIGHT),
    A.HorizontalFlip(p=0.5),
    A.Normalize(),
    A.RandomBrightnessContrast(p=0.2),
])

In [None]:
ds = SegDataset(df=df, transforms=transform)

In [None]:
ds[0]["image"].shape, ds[0]["mask"].shape

In [None]:
type(ds[0]["mask"])

In [None]:
ds.diplay_sample(0)

In [None]:
class DoubleConvSame(nn.Module):
    def __init__(self, c_in, c_out):
        super(DoubleConvSame, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=c_out, out_channels=c_out, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)
    
class Encoder(nn.Module):
    def __init__(self, in_channels):
        super(Encoder, self).__init__()

        self.conv = DoubleConvSame(c_in=in_channels, c_out=in_channels * 2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        c = self.conv(x)
        p = self.pool(c)

        return c, p

In [None]:
class AttentionDecoder(nn.Module):
    def __init__(self, in_channels):
        super(AttentionDecoder, self).__init__()

        self.up_conv = DoubleConvSame(c_in=in_channels, c_out=in_channels // 2)
        self.up = nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=in_channels // 2,
            kernel_size=2,
            stride=2,
        )

    def forward(self, conv1, conv2, attn):
        up = self.up(conv1)
        mult = torch.multiply(attn, up)
        cat = torch.cat([mult, conv2], dim=1)
        uc = self.up_conv(cat)

        return uc


class AttentionBlock(nn.Module):
    def __init__(self, g_chl, x_chl):
        super(AttentionBlock, self).__init__()

        inter_shape = x_chl // 4

        # Conv 1x1 with stride 2 for `x`
        self.conv_x = nn.Conv2d(
            in_channels=x_chl,
            out_channels=inter_shape,
            kernel_size=1,
            stride=2,
        )

        # Conv 1x1 with stride 1 for `g` (gating signal)
        self.conv_g = nn.Conv2d(
            in_channels=g_chl,
            out_channels=inter_shape,
            kernel_size=1,
            stride=1,
        )

        # Conv 1x1 for `psi` the output after `g` + `x`
        self.psi = nn.Conv2d(
            in_channels=inter_shape,
            out_channels=1,
            kernel_size=1,
            stride=1,
        )

        # For upsampling the attention output to size of `x`
        self.upsample = nn.Upsample(scale_factor=2)

    def forward(self, g, x):

        # perform the convs on `x` and `g`
        theta_x = self.conv_x(x)
        gate = self.conv_g(g)

        # `theta_x` + `gate`
        add = theta_x + gate

        # ReLU on the add operation
        relu = torch.relu(add)

        # the 1x1 Conv
        psi = self.psi(relu)

        # Sigmoid to squash the outputs/attention weights
        sig = torch.sigmoid(psi)

        # Upsample to original size of `x` to perform multiplication
        upsample = self.upsample(sig)

        # return the attention weights!
        return upsample


class AttentionUNet(nn.Module):
    def __init__(self, c_in, c_out):
        super(AttentionUNet, self).__init__()

        self.conv1 = DoubleConvSame(c_in=c_in, c_out=64)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.enc1 = Encoder(64)
        self.enc2 = Encoder(128)
        self.enc3 = Encoder(256)
        self.enc4 = Encoder(512)

        self.conv5 = DoubleConvSame(c_in=512, c_out=1024)

        self.attn1 = AttentionBlock(1024, 512)
        self.attn2 = AttentionBlock(512, 256)
        self.attn3 = AttentionBlock(256, 128)
        self.attn4 = AttentionBlock(128, 64)

        self.attndeco1 = AttentionDecoder(1024)
        self.attndeco2 = AttentionDecoder(512)
        self.attndeco3 = AttentionDecoder(256)
        self.attndeco4 = AttentionDecoder(128)

        self.conv_1x1 = nn.Conv2d(in_channels=64, out_channels=c_out, kernel_size=1)
    
    def forward(self, x):
        """ENCODER"""

        c1 = self.conv1(x)
        p1 = self.pool(c1)

        c2, p2 = self.enc1(p1)
        c3, p3 = self.enc2(p2)
        c4, p4 = self.enc3(p3)

        """BOTTLE-NECK"""

        c5 = self.conv5(p4)

        """DECODER - WITH ATTENTION"""

        att1 = self.attn1(c5, c4)
        uc1 = self.attndeco1(c5, c4, att1)

        att2 = self.attn2(uc1, c3)
        uc2 = self.attndeco2(c4, c3, att2)

        att3 = self.attn3(uc2, c2)
        uc3 = self.attndeco3(c3, c2, att3)

        att4 = self.attn4(uc3, c1)
        uc4 = self.attndeco4(c2, c1, att4)

        outputs = self.conv_1x1(uc4)

        return outputs

In [None]:
attn_unet = AttentionUNet(3, 1)

In [None]:
attn_unet

In [None]:
class SegDataModule(pl.LightningDataModule):
    def __init__(self, df):
        super().__init__()
        self.dataset = SegDataset(df, transforms=transform)

    def setup(self, stage) -> None:
        if stage == "fit" or stage is None:
            lengths = [int(len(self.dataset) * 0.8), int(len(self.dataset) * 0.2)]
            self.train_data, self.val_data = random_split(self.dataset, lengths)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=BATCH_SIZE, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=BATCH_SIZE, num_workers=2)

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self):
        super(LitModel, self).__init__()
        self.model = AttentionUNet(3, 1)
        self.loss = nn.BCEWithLogitsLoss()
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters())

    def forward(self, images):
        return self.model(images)

    def training_step(self, batch, batch_idx):
        images = batch["image"]
        masks = batch["mask"]
        
        preds = self.forward(images)
        loss = self.loss(input=preds, target=masks)
        acc = self.train_acc(preds, masks.int())

        self.log("train_loss", loss)
        self.log("train_acc", acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        images = batch["image"]
        masks = batch["mask"]
        
        preds = self.forward(images)
        loss = self.loss(input=preds, target=masks)
        acc = self.val_acc(preds, masks.int())

        self.log("val_loss", loss)
        self.log("val_acc", acc, prog_bar=True)

        return loss

In [None]:
model = LitModel()
dm = SegDataModule(df=df)

# for checkpointing our model
checkpoint_callback = ModelCheckpoint(
    dirpath="../working/models", 
    monitor="val_acc", 
    mode="max", 
    verbose=True,
    save_top_k=3,
    filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}'
)

early_stop_callback = EarlyStopping(
    monitor="val_acc", 
    min_delta=0.00, 
    patience=3, 
    verbose=True, 
    mode="max"
)

trainer = pl.Trainer(
    logger=True,
    max_epochs=EPOCHS,
    accelerator="gpu", 
    callbacks=[checkpoint_callback, early_stop_callback],
)

trainer.fit(model, datamodule=dm)

In [None]:
def predict(image_path):
    # Read the image, copy, resize the copy
    im = cv2.imread(image_path.as_posix(), cv2.IMREAD_UNCHANGED)
    test_image = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    test_image_copy = cv2.resize(test_image, (IMG_HEIGHT, IMG_WIDTH))
    
    # resize, transpose and create batch dimension
    test_image = cv2.resize(test_image, (IMG_HEIGHT, IMG_WIDTH))
    test_image = np.transpose(test_image, (2, 1, 0))
    test_image = torch.unsqueeze(torch.tensor(test_image), 0)
    
    # Load the model
    best_model = LitModel.load_from_checkpoint(checkpoint_callback.best_model_path)
    
    # Make the prediction
    pred = best_model(test_image.float())
    pred = pred.detach().numpy()[0]
    pred = np.transpose(pred, (2, 1, 0))
    preds_test_thresh = (pred >= 0.5).astype(np.uint8)
    alpha_preds = preds_test_thresh * 255
    predicted_mask = np.concatenate((test_image_copy, alpha_preds), axis=-1)

    fig, axs = plt.subplots(1, 2)

    axs[0].imshow(test_image_copy)
    axs[0].set_title("Image")
    axs[0].axis("off")
    axs[1].imshow(predicted_mask)
    axs[1].set_title("Prediction")
    axs[1].axis("off")
    plt.show()

In [None]:
for _ in range(10):
    predict(image_paths[np.random.randint(5001, len(image_paths))])