In [27]:
import os
import glob
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
from unet import UNet
from tqdm import tqdm

In [96]:
class BagDataset(Dataset):

    def __init__(self, filename, images_dir, masks_dir, transform=None):
        assert os.path.exists(filename)
        with open(filename, "r") as f:
            self.files = f.readlines()
        
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):

        img_name = os.path.join(self.images_dir, self.files[idx].strip()+".jpg")
        mask_name = os.path.join(self.masks_dir, self.files[idx].strip()+".png")
        
        image = Image.open(img_name)
        mask = np.array(Image.open(mask_name))
        mask[mask>0] = 1
        one_hot = np.zeros((mask.shape[0], mask.shape[1], 2)).astype("uint8")
        one_hot[:, :, 0][mask == 0] = 1
        one_hot[:, :, 1][mask == 0] =1
        if self.transform:
            image = self.transform(image) 
            mask = self.transform(Image.fromarray(one_hot)) 
        mask[mask>0] = 1
        
        return image, mask

In [97]:
transform = transforms.Compose([
        transforms.Resize((224,224)), 
        transforms.ToTensor()])
    
    
dataset = BagDataset("./bags_data/imagesets/train.txt", "./bags_data/JPEGimages", "./bags_data/segmentation_mask", transform=transform)
dataloader = DataLoader(dataset, batch_size=1,
                        shuffle=True, num_workers=0)


   
        

In [98]:
class Trainer:

    def __init__(self, model, optimizer, cuda=False, experiment_name="", val_step=50):
        self.model = model
        self.optimizer = optimizer
        self.cuda = cuda
        self.best_val_accuracy = 0
        self.best_epoch = -1
        self.val_step = val_step
        self.best_model = None
        self.criterion = nn.BCEWithLogitsLoss()

    def predict(self, images, labels):
        if self.cuda:
            images = images.cuda().float()
            labels = labels.cuda().float()
        else: 
            images = images.float()
            labels = labels.float()
        self.optimizer.zero_grad()
        output = self.model(images) 
        loss = self.criterion(output, labels)

        return output, loss, labels

    def train(self, epoch, data_iterator):
        self.model.train()
        train_loss, train_accuracy = list(), list()
        best_model = None
        best_val_accuracy = 0
        best_epoch = 0
        for i, (images, labels)in enumerate(tqdm(data_iterator)):
            output, loss, labels = self.predict(
                images, labels)

            train_loss.append(loss.cpu().item())
            if i % 5 == 0:
                avg_loss = np.mean(train_loss[-10:])

                print(f'Epoch" {epoch}, Iter: {i},TRAINING__   loss :{loss}, , smooth_loss: {avg_loss}')
                self.best_model = self.model

            loss.backward()
            self.optimizer.step()

    def test(self,  val_iterator, save_images=False):
        print ("\nTesting ...")
        self.model.eval()
        summ = 0
        total_images = 0
        for i, (images, labels, im)in enumerate(tqdm(val_iterator)):
            total_images += 1
            with torch.no_grad():
                output, loss, labels = self.predict(
                    images, labels)
                output = (output > 0.5).float()
                score = dice.dice_coeff(output, labels)
                summ += score
                output = output.numpy().squeeze()

                im = cv2.cvtColor(im.numpy().squeeze(), cv2.COLOR_RGB2GRAY)
                output[output > 0.5] = 255
                output[output <= 0.5] = 0

                if save_images:
                    output = np.hstack(
                        (np.hstack((im, output)), labels.numpy().squeeze()*255)).astype("uint8")
                    cv2.imwrite("results/"+str(i) +
                                str(score)+".jpg", output)
        print ("Test DICE Coefficeint = ", float(summ)/total_images)
 

In [99]:
model = UNet(n_channels=3, n_classes=2)

In [100]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
trainer = Trainer(model, optimizer)
trainer.train(1, dataloader)





  0%|          | 0/427 [00:00<?, ?it/s][A[A[A[A

Epoch" 1, Iter: 0,TRAINING__   loss :0.6519176363945007, , smooth_loss: 0.6519176363945007






  0%|          | 1/427 [00:02<19:06,  2.69s/it][A[A[A[A



  0%|          | 2/427 [00:05<19:10,  2.71s/it][A[A[A[A



  1%|          | 3/427 [00:08<19:09,  2.71s/it][A[A[A[A



  1%|          | 4/427 [00:10<19:00,  2.70s/it][A[A[A[A



  1%|          | 5/427 [00:13<19:00,  2.70s/it][A[A[A[A

Epoch" 1, Iter: 5,TRAINING__   loss :0.6260465383529663, , smooth_loss: 0.618398129940033






  1%|▏         | 6/427 [00:16<19:14,  2.74s/it][A[A[A[A



  2%|▏         | 7/427 [00:19<19:06,  2.73s/it][A[A[A[A



  2%|▏         | 8/427 [00:21<18:58,  2.72s/it][A[A[A[A



  2%|▏         | 9/427 [00:24<18:52,  2.71s/it][A[A[A[A



  2%|▏         | 10/427 [00:27<18:51,  2.71s/it][A[A[A[A

Epoch" 1, Iter: 10,TRAINING__   loss :0.5027704238891602, , smooth_loss: 0.5842915654182435






  3%|▎         | 11/427 [00:29<18:46,  2.71s/it][A[A[A[A



  3%|▎         | 12/427 [00:32<18:40,  2.70s/it][A[A[A[A



  3%|▎         | 13/427 [00:35<18:37,  2.70s/it][A[A[A[A



  3%|▎         | 14/427 [00:37<18:35,  2.70s/it][A[A[A[A



  4%|▎         | 15/427 [00:40<18:32,  2.70s/it][A[A[A[A

Epoch" 1, Iter: 15,TRAINING__   loss :0.43836545944213867, , smooth_loss: 0.5222687125205994






  4%|▎         | 16/427 [00:43<18:30,  2.70s/it][A[A[A[A



  4%|▍         | 17/427 [00:45<18:28,  2.70s/it][A[A[A[A



  4%|▍         | 18/427 [00:48<18:26,  2.71s/it][A[A[A[A



  4%|▍         | 19/427 [00:51<18:24,  2.71s/it][A[A[A[A



  5%|▍         | 20/427 [00:54<18:25,  2.72s/it][A[A[A[A

Epoch" 1, Iter: 20,TRAINING__   loss :0.43471208214759827, , smooth_loss: 0.4922600984573364






  5%|▍         | 21/427 [00:57<18:24,  2.72s/it][A[A[A[A



  5%|▌         | 22/427 [00:59<18:24,  2.73s/it][A[A[A[A



  5%|▌         | 23/427 [01:02<18:22,  2.73s/it][A[A[A[A



  6%|▌         | 24/427 [01:05<18:17,  2.72s/it][A[A[A[A



  6%|▌         | 25/427 [01:08<18:13,  2.72s/it][A[A[A[A

Epoch" 1, Iter: 25,TRAINING__   loss :0.6649870872497559, , smooth_loss: 0.5511973381042481






  6%|▌         | 26/427 [01:10<18:10,  2.72s/it][A[A[A[A



  6%|▋         | 27/427 [01:13<18:07,  2.72s/it][A[A[A[A



  7%|▋         | 28/427 [01:16<18:04,  2.72s/it][A[A[A[A



  7%|▋         | 29/427 [01:18<18:00,  2.72s/it][A[A[A[A



  7%|▋         | 30/427 [01:21<17:57,  2.71s/it][A[A[A[A

Epoch" 1, Iter: 30,TRAINING__   loss :0.5364115834236145, , smooth_loss: 0.5484588265419006






  7%|▋         | 31/427 [01:24<17:55,  2.72s/it][A[A[A[A



  7%|▋         | 32/427 [01:27<17:54,  2.72s/it][A[A[A[A



  8%|▊         | 33/427 [01:29<17:51,  2.72s/it][A[A[A[A



  8%|▊         | 34/427 [01:32<17:49,  2.72s/it][A[A[A[A



  8%|▊         | 35/427 [01:35<17:47,  2.72s/it][A[A[A[A

Epoch" 1, Iter: 35,TRAINING__   loss :0.6179039478302002, , smooth_loss: 0.5054433315992355






  8%|▊         | 36/427 [01:38<17:45,  2.72s/it][A[A[A[A



  9%|▊         | 37/427 [01:40<17:43,  2.73s/it][A[A[A[A



  9%|▉         | 38/427 [01:43<17:41,  2.73s/it][A[A[A[A



  9%|▉         | 39/427 [01:46<17:39,  2.73s/it][A[A[A[A



  9%|▉         | 40/427 [01:49<17:35,  2.73s/it][A[A[A[A

Epoch" 1, Iter: 40,TRAINING__   loss :0.4699844419956207, , smooth_loss: 0.5056480377912521






 10%|▉         | 41/427 [01:51<17:32,  2.73s/it][A[A[A[A



 10%|▉         | 42/427 [01:54<17:29,  2.73s/it][A[A[A[A



 10%|█         | 43/427 [01:57<17:27,  2.73s/it][A[A[A[A



 10%|█         | 44/427 [02:00<17:27,  2.74s/it][A[A[A[A



 11%|█         | 45/427 [02:03<17:26,  2.74s/it][A[A[A[A

Epoch" 1, Iter: 45,TRAINING__   loss :0.564396321773529, , smooth_loss: 0.48683176934719086






 11%|█         | 46/427 [02:06<17:28,  2.75s/it][A[A[A[A



 11%|█         | 47/427 [02:11<17:39,  2.79s/it][A[A[A[A



[A[A[A[A

KeyboardInterrupt: 