<a href="https://colab.research.google.com/github/VanessaABC123/SeqNet/blob/master/%E3%80%8CRITE_UNET_using_PyTorch%E3%80%8D%E7%9A%84%E5%89%AF%E6%9C%AC2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:

# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES
# TO THE CORRECT LOCATION (/kaggle/input) IN YOUR NOTEBOOK,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

import os
import sys
from tempfile import NamedTemporaryFile
from urllib.request import urlopen
from urllib.parse import unquote, urlparse
from urllib.error import HTTPError
from zipfile import ZipFile
import tarfile
import shutil

CHUNK_SIZE = 40960
DATA_SOURCE_MAPPING = 'ritedataset:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-data-sets%2F2137503%2F3556298%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240911%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240911T081823Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3D7079c24e1c58f9c944f1095623cd9d98a7d197dfb9ce364158aeca3247fec57d34cecd03d61af4494f4395500c249f0edb089399f49488817024ee92b56b6181effe09d223495e5379185ab0ea1b49311e63e403dbc8096c044e41f680537034878cad26b33f248ff2e417b7af4a1a1b89b8e97806836ed1880ca60cf5110aa724762838a344a6a2ea20cab85225f8a61e2610f56a4995a6154177404541eaeeead6e31afd9c9c63147a0f4e63c36752b801e0b0369b70f8559edeb4ecbcdb202f7d4ab37cf2dde58e88252686086816e15fc63fb2fab3ff998a69e8e29ac2588e8abab29b2607e373954af182cdf7cfc959046b0bdb1fa5561e14787312fa68'

KAGGLE_INPUT_PATH='/kaggle/input'
KAGGLE_WORKING_PATH='/kaggle/working'
KAGGLE_SYMLINK='kaggle'

!umount /kaggle/input/ 2> /dev/null
shutil.rmtree('/kaggle/input', ignore_errors=True)
os.makedirs(KAGGLE_INPUT_PATH, 0o777, exist_ok=True)
os.makedirs(KAGGLE_WORKING_PATH, 0o777, exist_ok=True)

try:
  os.symlink(KAGGLE_INPUT_PATH, os.path.join("..", 'input'), target_is_directory=True)
except FileExistsError:
  pass
try:
  os.symlink(KAGGLE_WORKING_PATH, os.path.join("..", 'working'), target_is_directory=True)
except FileExistsError:
  pass

for data_source_mapping in DATA_SOURCE_MAPPING.split(','):
    directory, download_url_encoded = data_source_mapping.split(':')
    download_url = unquote(download_url_encoded)
    filename = urlparse(download_url).path
    destination_path = os.path.join(KAGGLE_INPUT_PATH, directory)
    try:
        with urlopen(download_url) as fileres, NamedTemporaryFile() as tfile:
            total_length = fileres.headers['content-length']
            print(f'Downloading {directory}, {total_length} bytes compressed')
            dl = 0
            data = fileres.read(CHUNK_SIZE)
            while len(data) > 0:
                dl += len(data)
                tfile.write(data)
                done = int(50 * dl / int(total_length))
                sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {dl} bytes downloaded")
                sys.stdout.flush()
                data = fileres.read(CHUNK_SIZE)
            if filename.endswith('.zip'):
              with ZipFile(tfile) as zfile:
                zfile.extractall(destination_path)
            else:
              with tarfile.open(tfile.name) as tarfile:
                tarfile.extractall(destination_path)
            print(f'\nDownloaded and uncompressed: {directory}')
    except HTTPError as e:
        print(f'Failed to load (likely expired) {download_url} to path {destination_path}')
        continue
    except OSError as e:
        print(f'Failed to load {download_url} to path {destination_path}')
        continue

print('Data source import complete.')


Downloading ritedataset, 35212627 bytes compressed
Downloaded and uncompressed: ritedataset
Data source import complete.


In [4]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
import glob
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from albumentations import HorizontalFlip, VerticalFlip, Rotate
import tqdm
import torch.nn.functional as F
import matplotlib.image as mpimg

In [5]:
torch.manual_seed(42)
np.random.seed(42)
torch.cuda.manual_seed_all(42)

In [6]:
class conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU()

    def forward(self, images):
        x = self.conv1(images)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x


In [7]:
class encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv = conv(in_channels, out_channels)
        self.pool = nn.MaxPool2d((2,2))

    def forward(self, images):
        x = self.conv(images)
        p = self.pool(x)

        return x, p

In [8]:
class decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.conv = conv(out_channels * 2, out_channels)

    def forward(self, images, prev):
        x = self.upconv(images)
        x = torch.cat([x, prev], axis=1)
        x = self.conv(x)

        return x


In [9]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.e1 = encoder(3, 64)
        self.e2 = encoder(64, 128)
        self.e3 = encoder(128, 256)
        self.e4 = encoder(256, 512)

        self.b = conv(512, 1024)

        self.d1 = decoder(1024, 512)
        self.d2 = decoder(512, 256)
        self.d3 = decoder(256, 128)
        self.d4 = decoder(128, 64)

        self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, images):
        x1, p1 = self.e1(images)
        x2, p2 = self.e2(p1)
        x3, p3 = self.e3(p2)
        x4, p4 = self.e4(p3)

        b = self.b(p4)

        d1 = self.d1(b, x4)
        d2 = self.d2(d1, x3)
        d3 = self.d3(d2, x2)
        d4 = self.d4(d3, x1)

        output_mask = torch.sigmoid(self.output(d4))

        return output_mask

In [10]:
class LoadData(Dataset):
    def __init__(self, images_path, masks_path):
        super().__init__()

        self.images_path = images_path
        self.masks_path = masks_path
        self.len = len(images_path)

    def __getitem__(self, idx):
        img = Image.open(self.images_path[idx])
        img = np.transpose(img, (2, 0, 1))
        img = img/255.0
        img = torch.tensor(img)

        mask = Image.open(self.masks_path[idx]).convert('L')
        mask = np.expand_dims(mask, axis=0)
        mask = mask/255.0
        mask = torch.tensor(mask)

        return img, mask

    def __len__(self):
        return self.len

In [11]:
train_X = sorted(glob.glob('../input/ritedataset/train/images/*'))
train_y = sorted(glob.glob('../input/ritedataset/train/masks/*'))

test_X = sorted(glob.glob('../input/ritedataset/test/images/*'))
test_y = sorted(glob.glob('../input/ritedataset/test/masks/*'))

valid_X = sorted(glob.glob('../input/ritedataset/validation/images/*'))
valid_y = sorted(glob.glob('../input/ritedataset/validation/masks/*'))

In [12]:
len(test_X)

10

In [13]:
H = 512
W = 512
size = (H, W)
batch_size = 2
num_epochs = 50
lr = 1e-4
checkpoint_path = "./checkpoint.pth"

train_dataset = LoadData(train_X, train_y)
valid_dataset = LoadData(valid_X, valid_y)


In [14]:
train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
)

valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
)

In [15]:

#device = torch.device('cuda')
#model = UNet()
#model = model.to(device)
device = torch.device('cpu')
model = UNet()
model = model.to(device)

In [16]:
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        loss = torch.nn.BCELoss()
        BCE = loss(inputs, targets)
        Dice_BCE = BCE + dice_loss

        return Dice_BCE

In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = DiceBCELoss()



In [18]:
def train_model(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0

    model.train()
    for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)
#         x,y=x.type(torch.FloatTensor),y.type(torch.FloatTensor)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss = epoch_loss/len(loader)
    return epoch_loss

In [19]:
def evaluate(model, loader, loss_fn, device):
    epoch_loss = 0.0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)
#             x,y=x.type(torch.DoubleTensor),y.type(torch.DoubleTensor)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

        epoch_loss = epoch_loss/len(loader)
    return epoch_loss

In [None]:
train = []
valid = []

best_valid_loss = float("inf")

for epoch in range(num_epochs):
        train_loss = train_model(model, train_loader, optimizer, loss_fn, device)
        valid_loss = evaluate(model, valid_loader, loss_fn, device)

        train.append(train_loss)
        valid.append(valid_loss)

        if valid_loss < best_valid_loss:
            data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint: {checkpoint_path}"
            print(data_str)

            best_valid_loss = valid_loss
            torch.save(model, checkpoint_path)

        data_str = f'Epoch: {epoch+1:02}\n'
        data_str += f'\tTrain Loss: {train_loss:.3f}\n'
        data_str += f'\t Val. Loss: {valid_loss:.3f}\n'
        print(data_str)

In [None]:
test_dataset = LoadData(test_X, test_y)
test_loader = DataLoader(dataset=test_dataset, batch_size=2)

In [None]:
transform = transforms.ToPILImage()
for x, y in test_loader:
    image0 = transform(x[0])
    image1 = transform(x[1])

    x = x.to(device, dtype=torch.float32)
    y = y.to(device, dtype=torch.float32)

    y_pred = model(x)
    img = y_pred.cpu().detach().numpy()
    plt.figure(figsize=(30,8))

    #subplot(r,c) provide the no. of rows and columns
    f, axarr = plt.subplots(2,3)

    axarr[0,0].imshow(image0)
    axarr[0,1].imshow(np.squeeze(y.cpu().detach().numpy())[0], cmap='gray')
    axarr[0,2].imshow(np.squeeze(img)[0], cmap='gray')

    axarr[1,0].imshow(image1)
    axarr[1,1].imshow(np.squeeze(y.cpu().detach().numpy())[1], cmap='gray')
    axarr[1,2].imshow(np.squeeze(img)[1], cmap='gray')
    break

In [None]:
test_loss = evaluate(model, test_loader, loss_fn, device)

In [None]:
test_loss

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_predictions(model, loader, device, num_samples=5):
    model.eval()  # Set model to evaluation mode
    transform = transforms.ToPILImage()

    with torch.no_grad():  # No need to compute gradients
        for i, (x, y) in enumerate(loader):
            if i >= num_samples:
                break

            # Move inputs to the appropriate device (CPU or GPU)
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            # Get model predictions
            y_pred = model(x)
            y_pred = torch.sigmoid(y_pred)  # Apply sigmoid for binary classification
            y_pred = (y_pred > 0.5).float()  # Threshold to obtain binary mask

            # Convert tensors to numpy arrays for visualization
            images = x.cpu().detach().numpy()
            true_masks = y.cpu().detach().numpy()
            pred_masks = y_pred.cpu().detach().numpy()

            # Plot the images, true masks, and predicted masks
            for j in range(x.size(0)):  # Iterate over the batch
                image = np.transpose(images[j], (1, 2, 0))  # Change from CHW to HWC format
                true_mask = np.squeeze(true_masks[j])  # Squeeze to remove extra channel dimension
                pred_mask = np.squeeze(pred_masks[j])  # Squeeze to remove extra channel dimension

                plt.figure(figsize=(12, 4))

                # Plot original image
                plt.subplot(1, 3, 1)
                plt.imshow(image)
                plt.title('Original Image')
                plt.axis('off')

                # Plot ground truth mask
                plt.subplot(1, 3, 2)
                plt.imshow(true_mask, cmap='gray')
                plt.title('True Mask')
                plt.axis('off')

                # Plot predicted mask
                plt.subplot(1, 3, 3)
                plt.imshow(pred_mask, cmap='gray')
                plt.title('Predicted Mask')
                plt.axis('off')

                plt.show()

# 將模型應用於部分測試集並可視化結果
visualize_predictions(model, test_loader, device, num_samples=5)
