<a href="https://colab.research.google.com/github/Motahareh-Mostafavi/Unet/blob/main/Unet_tensorboard(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

In [None]:
if not torch.cuda.is_available():
  raise Exception("GPU not availalbe. CPU training will be too slow.")

print("device name", torch.cuda.get_device_name(0))

device name GeForce RTX 2080 Ti


In [None]:
# Writer will output to ./runs/ directory by default
writer = SummaryWriter()
%load_ext tensorboard

In [None]:

class CloudDataset(Dataset):
    def __init__(self, r_dir, g_dir, b_dir, nir_dir, gt_dir, pytorch=True):
        super().__init__()

        # Loop through the files in red folder and combine, into a dictionary, the other bands
        self.files = [self.combine_files(f, g_dir, b_dir, nir_dir, gt_dir) for f in r_dir.iterdir() if not f.is_dir()]
        self.pytorch = pytorch
        self.augJitter = torchvision.transforms.ColorJitter()


    def combine_files(self, r_file: Path, g_dir, b_dir,nir_dir, gt_dir):

        files = {'red': r_file,
                 'green':g_dir/r_file.name.replace('red', 'green'),
                 'blue': b_dir/r_file.name.replace('red', 'blue'),
                 'nir': nir_dir/r_file.name.replace('red', 'nir'),
                 'gt': gt_dir/r_file.name.replace('red', 'gt')}

        return files

    def __len__(self):

        return len(self.files)

    def open_as_array(self, idx, invert=False, include_nir=False):

        raw_rgb = np.stack([np.array(Image.open(self.files[idx]['red'])),
                            np.array(Image.open(self.files[idx]['green'])),
                            np.array(Image.open(self.files[idx]['blue'])),
                           ], axis=2)

        raw_rgb  = self.augJitter(raw_rgb)

        if include_nir:
            nir = np.expand_dims(np.array(Image.open(self.files[idx]['nir'])), 2)
            raw_rgb = np.concatenate([raw_rgb, nir], axis=2)

        if invert:
            raw_rgb = raw_rgb.transpose((2,0,1))

        # normalize
        return (raw_rgb / np.iinfo(raw_rgb.dtype).max)


    def open_mask(self, idx, add_dims=False):

        raw_mask = np.array(Image.open(self.files[idx]['gt']))
        raw_mask = np.where(raw_mask==255, 1, 0)

        return np.expand_dims(raw_mask, 0) if add_dims else raw_mask

    def __getitem__(self, idx):

        x = torch.tensor(self.open_as_array(idx, invert=self.pytorch, include_nir=True), dtype=torch.float32)
        y = torch.tensor(self.open_mask(idx, add_dims=False), dtype=torch.torch.int64)

        return x, y



# create torch dataset like defined in CloudDataset class
base_path = Path("/home/terra/Downloads/archive/38-Cloud_training")
data = CloudDataset(base_path/'train_red',
                    base_path/'train_green',
                    base_path/'train_blue',
                    base_path/'train_nir',
                    base_path/'train_gt')
# get len of data
len(data)
# returns features x and target feature y
x, y = data[1000]
x.shape, y.shape

# visualize raw image and ground truth
image_index = 500
fig, ax = plt.subplots(1,2, figsize=(10,9))
ax[0].imshow(data.open_as_array(image_index))
ax[1].imshow(data.open_mask(image_index))

NameError: name 'Path' is not defined

In [None]:
train_dataset, valid_dataset = torch.utils.data.random_split(data, (5000, 3400))
# Create dataloads sample from dataset

train_dataload = DataLoader(train_dataset, batch_size=12, shuffle=True)
valid_dataload = DataLoader(valid_dataset, batch_size=12, shuffle=True)

In [None]:
# test dataload
xb, yb = next(iter(train_dataload))
xb.shape, yb.shape


(torch.Size([12, 4, 384, 384]), torch.Size([12, 384, 384]))

In [None]:


class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = self.contract_block(in_channels, 32, 7, 3)
        self.conv2 = self.contract_block(32, 64, 3, 1)
        self.conv3 = self.contract_block(64, 128, 3, 1)

        self.upconv3 = self.expand_block(128, 64, 3, 1)
        self.upconv2 = self.expand_block(64*2, 32, 3, 1)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, 1)

    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        upconv3 = self.upconv3(conv3)

        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))

        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
                            )
        return expand

In [None]:
unet = UNET(4,2)
# testing one pass
xb, yb = next(iter(train_dataload))
xb.shape, yb.shape

(torch.Size([12, 4, 384, 384]), torch.Size([12, 384, 384]))

In [None]:
pred = unet(xb)
pred.shape

torch.Size([12, 2, 384, 384])

In [None]:
grid = torchvision.utils.make_grid(xb)
writer.add_image('images', grid, 0)
writer.add_graph(unet, xb)
writer.close()

In [None]:
import time
from IPython.display import clear_output

def train(model, train_dl, valid_dl, loss_fn, optimizer, acc_fn, epochs=1):
    model.cuda()

    train_loss, valid_loss = [], []

    best_acc = 0.0

    phase = 'train'
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('**************************************')

        #  setting validation occupance steps
        if epoch % 10 == 0:
            phase = 'valid'
        else:
            phase = 'train'

        if phase == 'train':
            model.train(True)  # Set trainind mode = true
            dataloader = train_dataload
        else:
            model.train(False)  # Set model to evaluate mode
            dataloader = valid_dataload

        running_loss = 0.0
        running_acc = 0.0

        step = 0

        # iterate over data
        for x, y in dataloader:
            x = x.cuda()
            y = y.cuda()
            step += 1

            # forward pass
            if phase == 'train':
                # zero the gradients
                optimizer.zero_grad()
                outputs = model(x)
                loss = loss_fn(outputs, y)
                loss.backward()
                optimizer.step()

            else:
                with torch.no_grad():
                    outputs = model(x)
                    loss = loss_fn(outputs, y.long())

            # stats - whatever is the phase
            acc = acc_fn(outputs, y)

            running_acc  += acc*dataloader.batch_size
            running_loss += loss*dataloader.batch_size

            if step % 100 == 0:
                print(phase, ': Current batch: {}  Loss: {}  Acc: {} '.format(step, loss, acc))

            #  setting validation occupance steps
            if epoch % 10 == 0:
                phase = 'valid'
            else:
                phase = 'train'

        epoch_loss = running_loss / len(dataloader.dataset)
        epoch_acc = running_acc / len(dataloader.dataset)
        if phase == 'train':
            writer.add_scalar("Loss/train", epoch_loss, epoch)
            writer.add_scalar("acc/train", epoch_acc, epoch)
        else:
            writer.add_scalar("Loss/valid", epoch_loss, epoch)
            writer.add_scalar("acc/valid", epoch_acc, epoch)


        clear_output(wait=True)
        print('Epoch: ', epoch, "/", epochs)
        print('**************************************')
        print(phase, ' Loss: ',epoch_loss, 'Acc: ',epoch_acc)
        print('**************************************')


        train_loss.append(epoch_loss) if phase=='train' else valid_loss.append(epoch_loss)

    return train_loss, valid_loss

def acc_metric(predb, yb):
    return (predb.argmax(dim=1) == yb.cuda()).float().mean()

In [None]:
loss_fn = nn.CrossEntropyLoss() # choose loss function
opt = torch.optim.Adam(unet.parameters(), lr=0.01) # choose optimizer function

# start training
train_loss, valid_loss = train(unet, train_dataload, valid_dataload, loss_fn, opt, acc_metric, epochs=50)

writer.flush()
writer.close()
torch.save(unet.state_dict(), 'model_weights.pth')

Epoch:  2 / 50
**************************************
train  Loss:  tensor(0.5563, device='cuda:0', grad_fn=<DivBackward0>) Acc:  tensor(0.7035, device='cuda:0')
**************************************
Epoch 3/49
**************************************
train : Current batch: 100  Loss: 0.38144680857658386  Acc: 0.8889702558517456 
train : Current batch: 200  Loss: 0.38086315989494324  Acc: 0.8455641269683838 
