<a href="https://colab.research.google.com/github/SolutionLr/DL/blob/main/Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# import necessary modules 
from google.colab import drive
drive.mount('/content/drive')
from torchvision import datasets
import torchvision.transforms as transforms
from torch import nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from unet_dataset import Data
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torchvision
import segmentation_models_pytorch as smp
import albumentations as A
import pandas as pd
import cv2
torch.manual_seed(0)


# load color map
mask_data = pd.read_csv('/content/drive/MyDrive/DL_1/U-Net/unet_dataset/class_dict.csv')
mask_data = dict(
        [(i, [x, y, z]) for i, x, y, z in zip(mask_data['name'], mask_data['r'], mask_data['g'], mask_data['b'])])
mask_colors = np.array(list(mask_data.values()))


#Create model
class UNET(nn.Module):
    def __init__(self):
        super(UNET, self).__init__()
        self.unet = smp.Unet(encoder_name="efficientnet-b0", encoder_weights="imagenet", in_channels=3, classes=32, activation=None)
        for name, param in self.named_parameters():
            if name == 'unet.encoder._blocks.15._project_conv.weight':
                break
            param.requires_grad = False

    def forward(self, x):
        x = self.unet(x)
        return x


# Calculate mean and std
def std_mean(loader):
    sum, sum_sq = 0, 0
    size = len(loader.dataset)
    for img, l in loader:
        sum += torch.mean(img, dim=[0, 2, 3])
        sum_sq += (torch.mean(img**2, dim=[0, 2, 3]))**0.5
    mean = sum/size
    std = sum_sq/size
    return mean, std


# show some predicted values
def show_predict(loader, model, transform):
    model.eval()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch = next(iter(loader))
    images, labels = batch
    images = images.to(device)
    model.to(device)
    pred = model(images)
    pred = pred.argmax(1)
    pred = pred.cpu()
    pred = pred.numpy()
    fig = plt.figure(figsize=(100, 100))
    image = mask_colors[pred]
    plt.imshow(image[0])
    plt.show()


# Create train, test, val loops
def train_loop(model, train_loader, loss, optimizer):
    model.train()
    for i, data in enumerate(train_loader):
        X = data[0].to(device)
        y = data[1].to(device)
        pred = model(X)
        y = y.squeeze(1)
        L = loss(pred, y)
        optimizer.zero_grad()
        L.backward()
        optimizer.step()
        if i % 10 == 0:
            tp, fp, fn, tn = smp.metrics.get_stats(pred.argmax(1), y, mode='multiclass', num_classes=32)
            iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
            f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
            f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
            accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
            recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
            print(f"train_loss: {L.item()}, iou_score: {iou_score}, f1_score: {f1_score}, f2_score: {f2_score},"
              f"accuracy: {accuracy}, recall:{recall}")


def val_loop(model, val_loader, loss):
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X = data[0].to(device)
            y = data[1].to(device)
            y = y.squeeze(1)
            pred = model(X)
            L = loss(pred, y)
            if i % 10 == 0:
                tp, fp, fn, tn = smp.metrics.get_stats(pred.argmax(1), y, mode='multiclass', num_classes=32)
                iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
                f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
                f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
                accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
                recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
                print(f"test_loss: {L.item()}, iou_score: {iou_score}, f1_score: {f1_score}, f2_score: {f2_score},"
                  f"accuracy: {accuracy}, recall:{recall}")


def test_loop(model, test_loader, loss):
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            X = data[0].to(device)
            y = data[1].to(device)
            pred = model(X)
            y = y.squeeze(1)
            L = loss(pred, y)
            if i % 10 == 0:
               tp, fp, fn, tn = smp.metrics.get_stats(pred.argmax(1), y, mode='multiclass', num_classes=32)
               iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
               f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
               f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
               accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
               recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
               print(f"test_loss: {L.item()}, iou_score: {iou_score}, f1_score: {f1_score}, f2_score: {f2_score},"
                  f"accuracy: {accuracy}, recall:{recall}")


# h. parameters
loss_fn = nn.CrossEntropyLoss()
epoch = 10
lr = 0.03
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = UNET()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
transforms = {'train': A.Compose([A.Resize(320, 320), A.Normalize((0.2855, 0.2922, 0.2972), (0.3529, 0.3608, 0.3633)), ToTensorV2(), ]),
              'val': A.Compose([A.Resize(320, 320),])}

#Load data
train_loader = Data("/content/drive/MyDrive/DL_1/U-Net/unet_dataset/train", "/content/drive/MyDrive/DL_1/U-Net/unet_dataset/train_labels", transform=transforms).load_data(32)
test_loader = Data("/content/drive/MyDrive/DL_1/U-Net/unet_dataset/test", "/content/drive/MyDrive/DL_1/U-Net/unet_dataset/test_labels", transform=transforms).load_data(32)
val_loader = Data("/content/drive/MyDrive/DL_1/U-Net/unet_dataset/val", "/content/drive/MyDrive/DL_1/U-Net/unet_dataset/val_labels", transform=transforms).load_data(32)

# train and test
for i in range(epoch):
    train_loop(model, train_loader, loss_fn, optimizer)
    val_loop(model, val_loader, loss_fn)
test_loop(model, test_loader, loss_fn)
torch.save(model.state_dict(), '/content/drive/MyDrive/DL_1/U-Net/Unet.pth')
show_predict(test_loader, model, transform=transforms)
