In [12]:
import os
import glob
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
from unet import UNet

In [3]:
class BagDataset(Dataset):

    def __init__(self, filename, images_dir, masks_dir, transform=None):
        assert os.path.exists(filename)
        with open(filename, "r") as f:
            self.files = f.readlines()
        
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):

        img_name = os.path.join(self.images_dir, self.files[idx].strip()+".jpg")
        mask_name = os.path.join(self.masks_dir, self.files[idx].strip()+".png")
        
        image = Image.open(img_name)
        mask = np.array(Image.open(mask_name)) 
        if self.transform:
            image = self.transform(image) 
            mask = self.transform(Image.fromarray(mask))  
        mask[mask > 0] = 1 
        return image, mask

In [4]:
transform = transforms.Compose([
        transforms.Resize((224,224)), 
        transforms.ToTensor()])
    
    
dataset = BagDataset("./bags_data/imagesets/train.txt", "./bags_data/JPEGimages", "./bags_data/segmentation_mask", transform=transform)
dataloader = DataLoader(dataset, batch_size=1,
                        shuffle=True, num_workers=0)


   
        

In [7]:
class Trainer:

    def __init__(self, model, optimizer, cuda=False, experiment_name="", val_step=50):
        self.model = model
        self.optimizer = optimizer
        self.cuda = cuda
        self.best_val_accuracy = 0
        self.best_epoch = -1
        self.val_step = val_step
        self.best_model = None
        self.criterion = nn.BCELoss()

    def predict(self, images, labels):
        if self.cuda:
            images = images.cuda().float()
            labels = labels.cuda().float()
        else:
            images = images.float()
            labels = labels.float()
        self.optimizer.zero_grad()
        output = self.model(images)
       
        loss = self.criterion(output, labels)

        return output, loss, labels

    def train(self, epoch, data_iterator, val_iterator, test=False):
        self.model.train()
        train_loss, train_accuracy = list(), list()
        best_model = None
        best_val_accuracy = 0
        best_epoch = 0
        for i, (images, labels)in enumerate(tqdm(data_iterator)):
            output, loss, labels = self.predict(
                images, labels)

            train_loss.append(loss.cpu().item())
            if i % 5 == 0:
                avg_loss = np.mean(train_loss[-10:])

                print(f'Epoch" {epoch}, Iter: {i},TRAINING__   loss :{loss}, , smooth_loss: {avg_loss}')
                self.best_model = self.model

            loss.backward()
            self.optimizer.step()

    def test(self,  val_iterator, save_images=False):
        print ("\nTesting ...")
        self.model.eval()
        summ = 0
        total_images = 0
        for i, (images, labels, im)in enumerate(tqdm(val_iterator)):
            total_images += 1
            with torch.no_grad():
                output, loss, labels = self.predict(
                    images, labels)
                output = (output > 0.5).float()
                score = dice.dice_coeff(output, labels)
                summ += score
                output = output.numpy().squeeze()

                im = cv2.cvtColor(im.numpy().squeeze(), cv2.COLOR_RGB2GRAY)
                output[output > 0.5] = 255
                output[output <= 0.5] = 0

                if save_images:
                    output = np.hstack(
                        (np.hstack((im, output)), labels.numpy().squeeze()*255)).astype("uint8")
                    cv2.imwrite("results/"+str(i) +
                                str(score)+".jpg", output)
        print ("Test DICE Coefficeint = ", float(summ)/total_images)
 

In [14]:
model = UNet(n_channels=3, n_classes=2)

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, moment