In [340]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [341]:
import torch
from torchvision import datasets, utils, transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler

In [342]:
import pathlib as pl
import glob
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from skimage.transform import resize
from sklearn.externals._pilutil import bytescale
from skimage.color import rgb2gray
from PIL import Image
import time

In [343]:
class BacteriaDataset(Dataset):
    """Bacteria dataset."""

    def __init__(self, base_path, transform=None):
        self.images_list = sorted(glob.glob("/".join([base_path, "images"]) + "/*.png"))
        self.masks_list = sorted(glob.glob("/".join([base_path, "masks"]) + "/*.png"))
        self.transform = transform

        self.class_to_id = {"background": 0, "erythrocytes": 1, "spirochaete": 2}
        self.class_to_color = {"background": [0, 0, 0], "erythrocytes": [255, 0, 0], "spirochaete": [255, 255, 0]}
        self.id_to_class = {v: k for k, v in class_to_id.items()}

        self.num_classes = len(class_to_id)

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, idx):
        image = io.imread(self.images_list[idx])
        mask = io.imread(self.masks_list[idx])

        mask_cp = mask.copy()
        for label, color in self.class_to_color.items():
            mask_cp[np.any(mask == self.class_to_id[label], axis=-1)] = color

        if self.transform is not None:
            image = Image.fromarray(np.uint8(image)).convert('RGB')
            mask_cp = Image.fromarray(np.uint8(mask_cp)).convert('RGB')
            image = self.transform(image)
            mask_cp = self.transform(mask_cp)
        
        sample = {'image': image, 'mask': mask_cp}
        return sample

In [344]:
working_path = "./drive/My Drive/Colab Notebooks/UFRGS/CV/TF"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 8
num_workers = 0

print(f"We are training on {device}")

We are training on cpu


In [345]:
print('==> Preparing data..')
data_transform = {
    'train': transforms.Compose(
        [transforms.Resize((256, 256)),
         transforms.ToTensor(),
         ]),
    'test': transforms.Compose(
        [transforms.Resize((256, 256)),
         transforms.ToTensor(),
         ])
    }

==> Preparing data..


In [346]:
train_dataset = BacteriaDataset(base_path="/".join([working_path, "data/train"]), transform=data_transform['train'])
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=num_workers)

val_dataset = BacteriaDataset(base_path="/".join([working_path, "data/val"]), transform=data_transform['test'])
val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

test_dataset = BacteriaDataset(base_path="/".join([working_path, "data/test"]), transform=None)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=num_workers)

In [347]:
def plot_data(batch_dict):
    images_batch = batch_dict['image']
    masks_batch = batch_dict['mask']
    
    for i in range(images_batch.shape[0]):
        curr_img = images_batch[i, ...].numpy().transpose(1, 2, 0)
        curr_msk = masks_batch[i, ...].numpy().transpose(1, 2, 0)

        fig = plt.figure(figsize=(20, 20))
        ax1 = fig.add_subplot(1, 2, 1)
        ax1.imshow(curr_img)
        ax2 = fig.add_subplot(1, 2, 2)
        ax2.imshow(curr_msk)

sample_batch = next(iter(train_loader))
plot_data(sample_batch)

Output hidden; open in https://colab.research.google.com to view.

In [348]:
def cross_entropy2d(input, target):
    # input: (n, c, h, w), target: (n, h, w)
    print(input.shape)  # torch.Size([8, 1000])
    print(target.shape) # torch.Size([8, 3, 256, 256])

    n, c, h, w = input.size()

    # input: (n*h*w, c)
    input = input.transpose(1, 2).transpose(2, 3).contiguous()
    input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
    input = input.view(-1, c)

    # target: (n*h*w,)
    mask = target >= 0.0
    target = target[mask]

    func_loss = torch.nn.CrossEntropyLoss()
    loss = func_loss(input, target)

    return loss

In [349]:
def train(epoch, model, dataloader, batch_size):
    model.train()

    time_train = []
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, sample_batch in enumerate(dataloader):
        start_time = time.time()
        
        inputs, targets = sample_batch['image'], sample_batch['mask']
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = cross_entropy2d(outputs, targets)

        running_loss += loss.item()

        loss.backward()
        optimizer.step()

        _, predicted = outputs.max(1)

        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        time_train.append(time.time() - start_time)

    train_acc = correct / total
    average_epoch_loss_train = running_loss / len(dataloader)

    # print statistics
    print(f"Epoch: {epoch} => "
          f"Train Acc: {train_acc:.2f} | "
          f"Train Loss: {average_epoch_loss_train:.2f} | "
          f"Avg time/img: {(sum(time_train) / len(time_train)) / batch_size:.2f} s")
    
    return train_acc, average_epoch_loss_train

In [350]:
def validation(epoch, model, dataloader, batch_size, lr_scheduler=None):
    model.eval()

    time_val = []
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, sample_batch in enumerate(dataloader):
        start_time = time.time()

        inputs, targets = sample_batch['image'], sample_batch['mask']
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        with torch.no_grad():
            outputs = model(inputs)
            loss = cross_entropy2d(outputs, targets)
            # loss = criterion(outputs, targets)

        running_loss += loss.item()
        _, predicted = outputs.max(1)

        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        time_val.append(time.time() - start_time)

    val_acc = correct / total
    average_epoch_loss_val = running_loss / len(dataloader)

    if lr_scheduler is not None:
        lr_scheduler.step(val_acc)

    # print statistics
    print(f"Epoch: {epoch} => "
          f"Val   Acc: {val_acc:.2f} | "
          f"Val   Loss: {average_epoch_loss_val:.2f} | "
          f"Avg time/img: {(sum(time_val) / len(time_val)) / batch_size:.2f} s")
    
    return val_acc, average_epoch_loss_val

In [351]:
def test(epoch, model, dataloader, batch_size):
    model.eval()

    time_test = []
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, sample_batch in enumerate(dataloader):
        start_time = time.time()

        inputs, targets = sample_batch['image'], sample_batch['mask']
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        with torch.no_grad():
            outputs = model(inputs)
            loss = cross_entropy2d(outputs, targets)
            # loss = criterion(outputs, targets)

        running_loss += loss.item()
        _, predicted = outputs.max(1)

        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        time_test.append(time.time() - start_time)

    test_acc = correct / total
    average_epoch_loss_test = running_loss / len(dataloader)

    # print statistics
    print(f"Epoch: {epoch} => "
          f"Test  Acc: {test_acc:.2f} | "
          f"Test  Loss: {average_epoch_loss_test:.2f} | "
          f"Avg time/img: {(sum(time_test) / len(time_test)) / batch_size:.2f} s")
    
    return test_acc, average_epoch_loss_test

In [352]:
# Model
print('==> Building model..')
num_classes = train_dataset.num_classes

model = models.vgg16(pretrained=True)
model.classifier[6].out_features = num_classes

print(model.classifier)

# freeze convolution weights
for param in model.features.parameters():
    param.requires_grad = False

model = model.to(device)

print('==> Done!')

==> Building model..
Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=3, bias=True)
)
==> Done!


In [353]:
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                    mode='max', 
                                                    patience=20, 
                                                    verbose=True)

In [354]:
info = None
best_val_acc = -1
start_epoch = 0
max_epochs = 10
best_epoch = 0
hist_train_acc = []
hist_train_loss = []
hist_val_acc = []
hist_val_loss = []
hist_test_acc = []
hist_test_loss = []

In [355]:
string = "# ================================================================== # \n" \
         "#                         Starting Training!                         # \n" \
         "# ================================================================== #"
print(string)

for epoch in range(start_epoch, max_epochs):
    train_acc, train_loss = train(epoch=epoch,
                            model=model,
                            dataloader=train_loader,
                            batch_size=batch_size)
    
    val_acc, val_loss = validation(epoch=epoch,
                                   model=model,
                                   dataloader=val_loader,
                                   batch_size=batch_size,
                                   lr_scheduler=lr_scheduler)
    
    test_acc, test_loss = test(epoch=epoch,
                               model=model,
                               dataloader=test_loader,
                               batch_size=batch_size)
    
    hist_train_acc.append(train_acc)
    hist_train_loss.append(train_loss)
    hist_val_acc.append(val_acc)
    hist_val_loss.append(val_loss)
    hist_test_acc.append(test_acc)
    hist_test_loss.append(test_loss)

#                         Starting Training!                         # 
torch.Size([8, 1000])
torch.Size([8, 3, 256, 256])


ValueError: ignored