In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import tqdm.notebook
from torch.utils.tensorboard import SummaryWriter
import datetime
from models import ResNet50, FCN32s

In [None]:
def send_image_mask_to_tensorboard(image, mask, writer, epoch):
    image = image[0,:].cpu().detach().numpy()
    image[0,:] = image[0,:] * 0.229 + 0.485
    image[1,:] = image[1,:] * 0.224 + 0.456
    image[2,:] = image[2,:] * 0.225 + 0.406    
    mask = mask[0,:].cpu().detach().numpy().squeeze()
    mask_binary = mask.copy()
    cm = plt.get_cmap('viridis')
    mask = cm(mask)
    mask = mask[:,:,:3]
    mask = mask.transpose((2,0,1))
    obraz_segmented = image * mask_binary
    obraz = np.concatenate((image, obraz_segmented, mask),2)
    writer.add_image('Epoch {}'.format(epoch),obraz)
    
def select_class(dataset, class_name):
    class_idx = dataset.class_to_idx[class_name]
    dataset = [element for element in dataset if element[1]==class_idx]
    return dataset

def get_random_images_and_masks(generator):
    generator.cpu()
    for image in data_loader:
        break
    with torch.no_grad():
        masks = generator(image[0])
    return image[0].numpy(), masks.numpy()

def save_and_show_segmented_images(images, masks, folder):
    for i in range(images.shape[0]):
        images[i, 0,:] = images[i, 0,:] * 0.229 + 0.485
        images[i, 1,:] = images[i, 1,:] * 0.224 + 0.456
        images[i, 2,:] = images[i, 2,:] * 0.225 + 0.406  
        plt.figure(figsize = (50,50))
        plt.subplot(131)
        plt.imshow(images[i,:].transpose((1,2,0)))
        plt.subplot(132)
        plt.imshow(masks[i,:].squeeze())
        segmented = images[i,:] * masks[i,:]
        plt.subplot(133)
        plt.imshow(segmented.transpose((1,2,0)))
        plt.savefig('{}/example{}.png'.format(folder, i))

In [None]:
device = torch.device('cuda')
class_from_dataset = 'Horse100' # available options: Horse100, Airplane100, Car100
dataset_path = './data'
nb_epochs = 7 # more epochs leads to overfiting

In [None]:
transform = torchvision.transforms.Compose([
                            torchvision.transforms.Resize((384,384)),
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

dataset = torchvision.datasets.ImageFolder(dataset_path, transform=transform)
dataset = select_class(dataset, class_from_dataset)
data_loader = torch.utils.data.DataLoader(dataset, shuffle = True, batch_size= 5)

In [None]:
feature_extractor =  ResNet50().eval()
generator = FCN32s().train()

In [None]:
optimizer = torch.optim.SGD(generator.parameters(), lr = 1e-2)
criterion = torch.nn.MSELoss()
feature_extractor.to(device)
generator.to(device);

In [None]:
now = datetime.datetime.now()
writer = SummaryWriter('./logs/{}:{}:{}'.format(now.hour, now.minute, now.second))
for epoch in range(nb_epochs):
    running_loss = 0
    pbar = tqdm.notebook.tqdm(total=len(data_loader))    
    for i, images in enumerate(data_loader):
        optimizer.zero_grad()                
        images = images[0].to(device)        
        masks = generator(images)
        segmented_objects = masks * images
        segmented_backgrounds = (1 - masks) * images
        object_features = feature_extractor(segmented_objects).squeeze()
        background_features = feature_extractor(segmented_backgrounds).squeeze()
        loss = 0
        if epoch < 2:
            loss = masks.mean() * 2
        else:
            for q in range(images.shape[0]):
                for j in range(images.shape[0]):
                    if q == j:
                        continue
                    dij_plus = criterion(object_features[q,:], object_features[j,:]) ** 2
                    dij_minus = criterion(object_features[q,:], background_features[q,:]) ** 2
                    dij_minus = dij_minus + criterion(object_features[j,:], background_features[j,:]) **2
                    dij_minus = dij_minus/2
                    loss = loss + torch.log(torch.exp(-dij_plus)/(torch.exp(-dij_plus) + torch.exp(-dij_minus))) * (-1)
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            running_loss += loss.item()
        pbar.update(1)
    send_image_mask_to_tensorboard(images, masks, writer,epoch )
    if epoch >=2:
        writer.add_scalar('training_loss',running_loss/len(dataset), epoch)
    print('Epoch {} loss {}'.format(epoch, running_loss/len(dataset)))    

In [None]:
images, masks = get_random_images_and_masks(generator)    
save_and_show_segmented_images(images.copy(), masks.copy(), 'images')