In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.models import resnet50

import numpy as np
np.random.seed(42)

import pickle
import pandas as pd
import os
from skimage.io import imread

import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
!pip3 install -U sentence-transformers

In [2]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from sentence_transformers import SentenceTransformer

In [3]:
image_dim = 224

def show_sample(sample):
    """
    Displays a sample as they come out of the trainloader.
    """

    fig, (ax1, ax2) = plt.subplots(1, 2)
    fig.suptitle(sample['caption'], size=20)
    ax1.imshow(sample['full_image'].permute(1,2,0))
    ax2.imshow(sample['masked_image'].permute(1,2,0))
    plt.show()

class COCODataset(Dataset):
    def __init__(self, annotations, datadir, transform=None):
        """
        Dataset of obfuscated coco images, with captions.
        
        annotations: load from pickle, akshay's processed annotations
        datadir: Preprocessed data. Contains /originals and /masked
        tranforms: function to be run on each sample
        """
        
        self.datadir = datadir
        self.transform = transform
        self.annotations = annotations
        self.filenames = os.listdir(datadir)
        
        # Since every 5 samples is the same image, we have a one image cache.
        # TODO this may get fucky with shuffle? we can find out later.
        self.last_image = None
        self.last_index = None
        
    def __len__(self):
        return len(self.filenames) * 5
    
    def __getitem__(self, idx):
        """
        Gets images from the dataset.
        
        Each image has 5 replicas, with different captions and sections
        
        Returns: dictionary with blanked out ['image'] and ['caption']
            image: FloatTensor
            caption: string (may later be a list)
        """

        # Load image or retrieve from cache
        
        image_filename = self.filenames[idx // 5]
        image_id = int(image_filename.split(".")[0])
        
        
        if self.last_index is not None and idx // 5 == self.last_index // 5:
            full_image = self.last_image
        else:
            image_filepath = os.path.join(self.datadir, image_filename)
            full_image = Image.open(image_filepath)
            self.last_image = full_image
        
        self.last_index = idx
        full_image = full_image.convert("RGB") # The occasional 1 channel grayscale image is in there.
        full_image = full_image.resize((image_dim, image_dim))

        # Fetch annotation, mask out area
        anno = self.annotations[image_id][idx % 5]
        
        masked_image = full_image.copy()
        
        draw = ImageDraw.Draw(masked_image)
        draw.rectangle([(anno['coord_start'][0], anno['coord_start'][1]), (anno['coord_end'][0], anno['coord_end'][1])], fill="black")

        sample = {'masked_image': masked_image, 'caption': anno['caption'], 'full_image': full_image, 'image_id':image_id}

        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [4]:
%%time
annos = pd.read_pickle("../annotations_train2017.pickle")

# Recommended resnet transforms.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
# TODO change masking logic to accomodate this
#resnet_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), normalize, transforms.ToTensor()])
#resnet_transform = transforms.Compose([transforms.Resize((image_dim,image_dim)), transforms.ToTensor(), normalize])
resnet_transform = transforms.Compose([transforms.ToTensor(), normalize])

def basic_transform_sample(sample):
    """
    A "default" transformer. Applies recommended resnet transforms.
    """
    sample['masked_image'] = resnet_transform(sample['masked_image'])
    sample['full_image'] = resnet_transform(sample['full_image'])
    return sample

dataset_train = COCODataset(annos, "../data/train2017", transform=basic_transform_sample)

CPU times: user 1.84 s, sys: 188 ms, total: 2.03 s
Wall time: 2.03 s


In [5]:
class rotfoNETv1(nn.Module):
    def __init__(self):
        super(rotfoNETv1, self).__init__()
        self.caption_encoder = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')
        for p in self.caption_encoder.parameters():
            p.requires_grad = False
            pass
        
        self.image_encoder = resnet50(pretrained=True)
        for p in self.image_encoder.parameters():
            p.requires_grad = False
            pass
        self.image_encoder.fc = nn.Linear(2048, 768)
        self.merge_fc = nn.Linear(768*2, 256)
        self.dropout = nn.Dropout(.35)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1, 8, 2, stride=2),
            nn.LeakyReLU(True),
            nn.ConvTranspose2d(8, 16, 5, stride=2),
            nn.LeakyReLU(True),
            nn.ConvTranspose2d(16, 32, 5, stride=3),
            nn.LeakyReLU(True),
            nn.ConvTranspose2d(32, 16, 10, stride=1),
            nn.LeakyReLU(True),
            nn.ConvTranspose2d(16, 8, 7, stride=1),
            nn.LeakyReLU(True),
            nn.ConvTranspose2d(8, 3, 7, stride=1),
            nn.ReLU(True),
            nn.Tanh()
        )
        
    def forward(self, img, caption):
        encoded_caption = torch.Tensor(self.caption_encoder.encode(caption)).to(device)
        encoded_img = self.image_encoder(img)
        
        x = torch.cat((encoded_caption, encoded_img), 1)
        x = self.merge_fc(x)
        x = self.dropout(x)
        
        x = x.view(-1, 1, 16, 16)
        x = self.decoder(x)
        
        return x


In [6]:
n_epoch = 10
batch_size = 32
lr = 1e-4

trainloader = DataLoader(dataset_train, batch_size=32, shuffle=False, num_workers=4) # VERY important to make sure num_workers > 0.

model = rotfoNETv1().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
for epoch in tqdm(range(n_epoch)):
    for i, batch in tqdm(enumerate(trainloader), total=round(len(dataset_train)/batch_size)):
        model.train()
        masked_image = batch['masked_image'].to(device)
        captions = batch['caption']
        
        optimizer.zero_grad()
        inpainted_image = model( batch['masked_image'].to(device), captions)
        loss = criterion(inpainted_image, batch['full_image'].to(device))
        loss.backward()
        optimizer.step()
    
    with torch.no_grad():
        model.eval()
        sample_index = np.random.randrange(len(dataset_train))
        sample_input = dataset_train[sample_index]
        
        torch.save(model.state_dict(), 'ckpt_rotfoNETv1.pth')
        
        save_image(sample_input['full_image'], './samples/original_{}_{}.png'.format(sample_input['image_id'], epoch))
        save_image(sample_input['masked_image'], './samples/original_masked_{}_{}.png'.format(sample_input['image_id'], epoch))
        
        inpainted_sample = model(sample_input['masked_image'].to(device), sample_input['caption'])
        save_image(inpainted_sample, './samples/inpainted_masked_{}_{}.png'.format(sample_input['image_id'], epoch))
        

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=18482.0), HTML(value='')))