In [23]:
import torch
import random
from monai.networks.nets import UNet
from monai.losses import DiceCELoss, DiceLoss
from monai.data import Dataset
from monai.transforms import Compose,ToTensord, LoadImaged, EnsureChannelFirstD, Rotate90d, Flipd, ToDeviced

In [24]:
class FloodAreaSegmentation(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.lenght = len(data)
        self.transform = transform

    def __len__(self):
        return self.lenght

    def __getitem__(self, index):       
        return {'image': self.data[index]["image"] , 'label': self.data[index]["label"]}

In [25]:
class DataLoader() :
    def __init__(self, dataset, batch_size=5, shuffle=True, transform = None, sides_size = 512, device = torch.device("cpu")) :
        self.dataset = dataset
        self.batch_size = batch_size
        self.transform = transform
        self.size = sides_size
        self.shuffle = True
        self.device = device
    
    def __iter__(self):
        self.aviable_indexes = list(range(self.dataset.lenght))
        return self
        

    def __next__(self):
        actual_batch_size = self.batch_size if len(self.aviable_indexes) >= self.batch_size else len(self.aviable_indexes)
        if actual_batch_size <= 0 :
            raise StopIteration() 
        if self.shuffle :
            sampled_elements = random.sample(self.aviable_indexes, actual_batch_size)
            batch_images = torch.zeros((self.batch_size,3,self.size,self.size)).to(self.device)
            batch_labels = torch.zeros((self.batch_size,1,self.size,self.size)).to(self.device)
            batch = {}
            batch["image"] = batch_images
            batch["label"] = batch_labels
            for i,element in enumerate(sampled_elements):                
                to_select = self.aviable_indexes.index(element)
                transformed =  self.transform({"image" : self.dataset[to_select]["image"],"label" :self.dataset[to_select]["label"]})
                batch["image"][i] = transformed["image"]     
                batch["label"][i] = transformed["label"]
                batch["label"][i]  = torch.where(batch["label"][i] == 255, 1,0)
                self.aviable_indexes.remove(element)
            return batch        

In [26]:
images_path, labels_path = tuple([f"C:/Users/Admin/Desktop/Flood Area Segmentation/Dataset/Train/{x}" for x in ["Images","Labels"]])

train_set = {num : {
    "image" : f"{images_path}/{num}.png",
    "label" : f"{labels_path}/{num}.png"   
} for num in range(0,699)}

images_path, labels_path = tuple([f"C:/Users/Admin/Desktop/Flood Area Segmentation/Dataset/Test/{x}" for x in ["Images","Labels"]])

test_set = {num - 700 : {
    "image" : f"{images_path}/{num}.png",
    "label" : f"{labels_path}/{num}.png"   
} for num in range(700,1045)}


In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"using {device}")

using cuda


In [28]:
model = UNet(
spatial_dims=2, 
in_channels=3, 
out_channels=1,
channels=(16,32,64,128,256,512),
strides=(2,2,2,2,2),
    num_res_units=4).to(device)


In [29]:
train_transforms = Compose([
    LoadImaged(keys=['image','label']),
    ToTensord(keys=['image','label']),
    ToDeviced(keys=['image','label'],device=device),
    EnsureChannelFirstD(keys=['image','label']),
    Rotate90d(keys=['image','label'], k=3),
    Flipd(keys=['image','label'], spatial_axis=1),
])


In [30]:
dataset_train = FloodAreaSegmentation(train_set, transform=train_transforms)
dataloader_train = DataLoader(dataset_train, batch_size=75, shuffle=True, transform=train_transforms)

In [31]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
def trainloop(model,criterion,dataloader,quiet = False, flush_memory=False) :
    model.train()    
    epoch_loss = 0
    step = 0
    for batch in dataloader:
        step += 1

        inputs, targets = batch['image'].to(device), batch['label'].to(device)        
        outputs = model(inputs)
        
        loss = criterion(outputs, targets)
        epoch_loss += loss.item()        
        if(not quiet) : print("CE:{}".format(loss.item()))  

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()        
    if flush_memory : torch.cuda.empty_cache()

    return epoch_loss/step

In [32]:
num_epochs = 1
criterion = DiceCELoss(sigmoid=True).to(device)
losses = []

     
for epoch in range(num_epochs):   
    loss = trainloop(model,criterion,dataloader_train)
    losses.append(loss)    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {losses[-1]}')

torch.save(model.state_dict(), 'unet_model_monai_dice.pth')

CE:1.4308297634124756
CE:1.3843330144882202
CE:1.4438517093658447
CE:1.4201898574829102
CE:1.4085817337036133
CE:1.3915789127349854
CE:1.3384172916412354
CE:1.3107273578643799
CE:1.387979507446289
CE:1.7345930337905884
Epoch 1/1, Loss: 1.425108218193054


In [35]:
from matplotlib import pyplot as plt
def display_segmentation(model,dataset : FloodAreaSegmentation,id) :
    input_dict = dataset.transform(dataset[id])
    model.eval()
    output = model(input_dict)
    out_treshold = torch.where(output>0,1,0)
    # out_treshold = torch.sigmoid(output)
    
    img = (input_dict["image"].cpu())
    img_r = img[0,0,:,:]
    img_g = img[0,1,:,:]
    img_b = img[0,2,:,:]
    img_rgb = torch.stack([img_r,img_g,img_b],dim=2).to(torch.int32)    

    lab = (input_dict["label"][0,0,:,:].cpu()).to(torch.int32)


    plt.subplot(1,3,1)
    plt.imshow(lab)
    plt.subplot(1,3,2)
    plt.imshow(img_rgb.detach())
    plt.subplot(1,3,3)   
    plt.imshow(out_treshold.detach().cpu()[0,0,:,:])

In [None]:
dataset_test = FloodAreaSegmentation(test_set,transform=train_transforms)
dataloader_test = DataLoader(dataset_test,25,transform=train_transforms,size=512)

In [None]:
from matplotlib import pyplot as plt
model.eval()
epoch_loss = 0
step = 0
for batch in dataloader_test :    
    step +=1
    inputs, targets = batch['image'].to(device), batch['label'].to(device)
    # Forward pass
    outputs = model(inputs)

    # Compute the loss
    loss = criterion(outputs, targets)
    print(f"{loss.item()}")

    # Backward pass and optimization    
    epoch_loss += loss.item()

print(f"Final-{epoch_loss/step}")


1.255903959274292


KeyboardInterrupt: 